一个类手写Spring核心原理

Spring实现的基本思路

一个类手写Spring核心原理

通过DispatcherServlet一个类演示Spring核心原理

public class DispatcherServlet extends HttpServlet {
    // 保存用户配置的配置文件
    private Properties contextConfiguration = new Properties();
    // 保存扫描包下的所有类的全名
    private List<String> classNames = new ArrayList<String>();
    // ioc容器保存所有扫描到的类的实例对象
    private Map<String, Object> ioc = new HashMap<String, Object>();
    // 保存URL和method的对应关系
    private Map<String, Method> handlerMapping = new HashMap<String, Method>();

    @Override
    public void init(ServletConfig config) throws ServletException {
        // 1、加载配置文件
        doLoadConfig(config.getInitParameter("contextConfigLocation"));
        // 2、扫描相关的类
        doScanner(contextConfiguration.getProperty("scanPackage"));
        // 3、初始化IoC容器,将扫描到的类实例化,缓存到IoC容器
        doInstance();
        // 4、完成依赖注入
        doAutowired();
        // 5、初始化HandlerMapping
        initHandlerMapping();
        System.out.println("Spring framework is init.");
    }

    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        // 6、根据URL委派给具体的调用方法
        try {
            doDispatch(req, resp);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
        String url = req.getRequestURI();
        String contextPath = req.getContextPath();
        url = url.replaceAll(contextPath, ""); //获取相对路径
        if (!this.handlerMapping.containsKey(url)) {
            resp.getWriter().write("404 Not Found");
            return;
        }
        Method method = this.handlerMapping.get(url);

        //1、先把参数的位置参数名字建立映射关系,并且缓存下来
        Map<String, Integer> paramIndexMapping = new HashMap<String, Integer>();
        Annotation[][] parameterAnnotations = method.getParameterAnnotations();
        for (int i = 0; i < parameterAnnotations.length; i ++) {
            for (Annotation anno : parameterAnnotations[i]) {
                if (anno instanceof MyRequestParam) {
                    String paraName = ((MyRequestParam) anno).value();
                    if (!"".equals(paraName.trim())) {
                        paramIndexMapping.put(paraName, i);
                    }
                }
            }
        }
        Class<?>[] parameterTypes = method.getParameterTypes();
        for (int i = 0; i < parameterTypes.length; i ++) {
            Class<?> parameterType = parameterTypes[i];
            if (parameterType == HttpServletRequest.class || parameterType == HttpServletResponse.class) {
                paramIndexMapping.put(parameterType.getName(), i);
            }
        }
        //2、根据参数位置匹配参数名字,从url中获取到参数名字对应的值
        Object[] paramValues = new Object[parameterTypes.length];
        Map<String, String[]> parameterMap = req.getParameterMap();
        for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
            String key = entry.getKey();
            String[] values = entry.getValue();
            String value = Arrays.toString(values).replaceAll("\\[|\\]", "").replaceAll("\\s", "");
            if (!parameterMap.containsKey(key)) { continue; }
            int index = paramIndexMapping.get(key);
            // TODO 涉及到类型转换
            paramValues[index] = value;
        }
        if (paramIndexMapping.containsKey(HttpServletRequest.class.getName())) {
            paramValues[paramIndexMapping.get(HttpServletRequest.class.getName())] = req;
        }
        if (paramIndexMapping.containsKey(HttpServletResponse.class.getName())) {
            paramValues[paramIndexMapping.get(HttpServletResponse.class.getName())] = resp;
        }
        //3、组成动态参数列表,传给方法反射调用
        String beanName = toLowerCaseFirstChar(method.getDeclaringClass().getSimpleName());
        method.invoke(ioc.get(beanName), paramValues);
    }

    private void initHandlerMapping() {
        if (ioc.isEmpty()) { return; }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Class<?> clazz = entry.getValue().getClass();
            if (!clazz.isAnnotationPresent(MyController.class)) { continue; }
            String controllerUrl = "";
            if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
                controllerUrl = clazz.getAnnotation(MyRequestMapping.class).value();
            }
            // 只迭代controller中的public方法
            for (Method method : clazz.getMethods()) {
                if (!method.isAnnotationPresent(MyRequestMapping.class)) { continue; }
                MyRequestMapping myRequestMapping = method.getAnnotation(MyRequestMapping.class);
                String url = "/" + controllerUrl + "/" + myRequestMapping.value();
                // 不确定用户声明的url路径前是否添加了/,所以暂时当做没有加/,最后通过正则将多个连续/替换成一个/
                url = url.replaceAll("/+", "/");
                handlerMapping.put(url, method);
                System.out.println("Mapped" + url + " --> " + method);
            }
        }
    }

    // 自动注入
    private void doAutowired() {
        if (ioc.isEmpty()) { return; }
        for (Map.Entry<String, Object> entry : ioc.entrySet()) {
            Field[] declaredFields = entry.getValue().getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                if (!field.isAnnotationPresent(MyAutowired.class)) { continue; }
                MyAutowired myAutowired = field.getAnnotation(MyAutowired.class);
                String beanName = myAutowired.value().trim();//获取注解的别名
                if ("".equals(beanName)) {//如果注解上没有起别名,那么就获取对象名
                    beanName = field.getType().getName();
                }
                // 强制访问赋值
                field.setAccessible(true);
                try {
                    // 相当于userController.userService=ioc.get("com.lucifer.project.service.UserService");
                    field.set(entry.getValue(), ioc.get(beanName));
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    // 实例化对象,并缓存到IoC容器中
    private void doInstance() {
        if (classNames.isEmpty()){ return; }
        try {
            for (String className : classNames) {
                Class<?> clazz = Class.forName(className);
                Object instance = clazz.newInstance();
                if (clazz.isAnnotationPresent(MyController.class)) {
                    String beanName = toLowerCaseFirstChar(clazz.getSimpleName());
                    ioc.put(beanName, instance);
                } else if (clazz.isAnnotationPresent(MyService.class)) {
                    // 1、默认类名首字母小写,取默认名
                    String beanName = toLowerCaseFirstChar(clazz.getSimpleName());
                    // 2、如果多个包下出现相同的类名,优先使用别名
                    MyService serviceAnno = clazz.getAnnotation(MyService.class);
                    if (!"".equals(serviceAnno.value())) {
                        beanName = serviceAnno.value();
                    }
                    // 3、如果是接口,只能吃实话它的实现类
                    for (Class<?> inte : clazz.getInterfaces()) {
                        if (ioc.containsKey(inte.getName())) {
                            throw new Exception("The" + inte.getName() + " is exists, please use alias.");
                        }
                        beanName = inte.getName();
                    }
                    ioc.put(beanName, instance);
                } else {
                    continue;
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // 将类名首字母转小写
    private String toLowerCaseFirstChar(String simpleName) {
        char[] chars = simpleName.toCharArray();
        chars[0] += 32; //利用ASCII码,大写字母和小写字母相差32
        return String.valueOf(chars);
    }

    // 扫描指定包下所有的Class文件
    private void doScanner(String scanPackage) {
        URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
        File classpath = new File(url.getFile());
        for (File file : classpath.listFiles()) {
            if (file.isDirectory()) {
                doScanner(scanPackage + "." + file.getName() + ".");
            } else {
                if (!file.getName().endsWith(".class")) { continue; }
                // 包名.类名 比如com.lucifer.project.UserController
                String className = scanPackage + file.getName().replace(".class", "");
                // 实例化时需要通过Class.forName(className);
                classNames.add(className);
            }
        }
    }

    // 根据contextConfigLocation路径去Classpath下找到对应配置文件
    private void doLoadConfig(String contextConfigLocation) {
        InputStream in = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
        try {
            contextConfiguration.load(in);
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (null != in) {
                try {
                    in.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

自定义Spring注解类

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutowired {
    String value() default "";
}

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyController {
    String value() default "";
}

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {
    String value() default "";
}

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRespository {
    String value() default "";
}

@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestMapping {
    String value() default "";
}

@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestParam {
    String value() default "";
}

测试用户请求访问类

@MyController
@MyRequestMapping("/user")
public class UserController {
    @MyAutowired
    UserService userService;

    @MyRequestMapping("/query")
    public void query(HttpServletRequest req, HttpServletResponse resp, @MyRequestParam("name") String name) {
        String result = "Hello, I am " + name;
        try {
            resp.getWriter().write(result);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

一个类手写Spring核心原理

上一篇:new.target的用处


下一篇:Java 反射机制