Spring实现的基本思路
通过DispatcherServlet一个类演示Spring核心原理
public class DispatcherServlet extends HttpServlet {
// 保存用户配置的配置文件
private Properties contextConfiguration = new Properties();
// 保存扫描包下的所有类的全名
private List<String> classNames = new ArrayList<String>();
// ioc容器保存所有扫描到的类的实例对象
private Map<String, Object> ioc = new HashMap<String, Object>();
// 保存URL和method的对应关系
private Map<String, Method> handlerMapping = new HashMap<String, Method>();
@Override
public void init(ServletConfig config) throws ServletException {
// 1、加载配置文件
doLoadConfig(config.getInitParameter("contextConfigLocation"));
// 2、扫描相关的类
doScanner(contextConfiguration.getProperty("scanPackage"));
// 3、初始化IoC容器,将扫描到的类实例化,缓存到IoC容器
doInstance();
// 4、完成依赖注入
doAutowired();
// 5、初始化HandlerMapping
initHandlerMapping();
System.out.println("Spring framework is init.");
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
this.doPost(req, resp);
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
// 6、根据URL委派给具体的调用方法
try {
doDispatch(req, resp);
} catch (Exception e) {
e.printStackTrace();
}
}
private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
String url = req.getRequestURI();
String contextPath = req.getContextPath();
url = url.replaceAll(contextPath, ""); //获取相对路径
if (!this.handlerMapping.containsKey(url)) {
resp.getWriter().write("404 Not Found");
return;
}
Method method = this.handlerMapping.get(url);
//1、先把参数的位置参数名字建立映射关系,并且缓存下来
Map<String, Integer> paramIndexMapping = new HashMap<String, Integer>();
Annotation[][] parameterAnnotations = method.getParameterAnnotations();
for (int i = 0; i < parameterAnnotations.length; i ++) {
for (Annotation anno : parameterAnnotations[i]) {
if (anno instanceof MyRequestParam) {
String paraName = ((MyRequestParam) anno).value();
if (!"".equals(paraName.trim())) {
paramIndexMapping.put(paraName, i);
}
}
}
}
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i ++) {
Class<?> parameterType = parameterTypes[i];
if (parameterType == HttpServletRequest.class || parameterType == HttpServletResponse.class) {
paramIndexMapping.put(parameterType.getName(), i);
}
}
//2、根据参数位置匹配参数名字,从url中获取到参数名字对应的值
Object[] paramValues = new Object[parameterTypes.length];
Map<String, String[]> parameterMap = req.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String key = entry.getKey();
String[] values = entry.getValue();
String value = Arrays.toString(values).replaceAll("\\[|\\]", "").replaceAll("\\s", "");
if (!parameterMap.containsKey(key)) { continue; }
int index = paramIndexMapping.get(key);
// TODO 涉及到类型转换
paramValues[index] = value;
}
if (paramIndexMapping.containsKey(HttpServletRequest.class.getName())) {
paramValues[paramIndexMapping.get(HttpServletRequest.class.getName())] = req;
}
if (paramIndexMapping.containsKey(HttpServletResponse.class.getName())) {
paramValues[paramIndexMapping.get(HttpServletResponse.class.getName())] = resp;
}
//3、组成动态参数列表,传给方法反射调用
String beanName = toLowerCaseFirstChar(method.getDeclaringClass().getSimpleName());
method.invoke(ioc.get(beanName), paramValues);
}
private void initHandlerMapping() {
if (ioc.isEmpty()) { return; }
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Class<?> clazz = entry.getValue().getClass();
if (!clazz.isAnnotationPresent(MyController.class)) { continue; }
String controllerUrl = "";
if (clazz.isAnnotationPresent(MyRequestMapping.class)) {
controllerUrl = clazz.getAnnotation(MyRequestMapping.class).value();
}
// 只迭代controller中的public方法
for (Method method : clazz.getMethods()) {
if (!method.isAnnotationPresent(MyRequestMapping.class)) { continue; }
MyRequestMapping myRequestMapping = method.getAnnotation(MyRequestMapping.class);
String url = "/" + controllerUrl + "/" + myRequestMapping.value();
// 不确定用户声明的url路径前是否添加了/,所以暂时当做没有加/,最后通过正则将多个连续/替换成一个/
url = url.replaceAll("/+", "/");
handlerMapping.put(url, method);
System.out.println("Mapped" + url + " --> " + method);
}
}
}
// 自动注入
private void doAutowired() {
if (ioc.isEmpty()) { return; }
for (Map.Entry<String, Object> entry : ioc.entrySet()) {
Field[] declaredFields = entry.getValue().getClass().getDeclaredFields();
for (Field field : declaredFields) {
if (!field.isAnnotationPresent(MyAutowired.class)) { continue; }
MyAutowired myAutowired = field.getAnnotation(MyAutowired.class);
String beanName = myAutowired.value().trim();//获取注解的别名
if ("".equals(beanName)) {//如果注解上没有起别名,那么就获取对象名
beanName = field.getType().getName();
}
// 强制访问赋值
field.setAccessible(true);
try {
// 相当于userController.userService=ioc.get("com.lucifer.project.service.UserService");
field.set(entry.getValue(), ioc.get(beanName));
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
// 实例化对象,并缓存到IoC容器中
private void doInstance() {
if (classNames.isEmpty()){ return; }
try {
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
Object instance = clazz.newInstance();
if (clazz.isAnnotationPresent(MyController.class)) {
String beanName = toLowerCaseFirstChar(clazz.getSimpleName());
ioc.put(beanName, instance);
} else if (clazz.isAnnotationPresent(MyService.class)) {
// 1、默认类名首字母小写,取默认名
String beanName = toLowerCaseFirstChar(clazz.getSimpleName());
// 2、如果多个包下出现相同的类名,优先使用别名
MyService serviceAnno = clazz.getAnnotation(MyService.class);
if (!"".equals(serviceAnno.value())) {
beanName = serviceAnno.value();
}
// 3、如果是接口,只能吃实话它的实现类
for (Class<?> inte : clazz.getInterfaces()) {
if (ioc.containsKey(inte.getName())) {
throw new Exception("The" + inte.getName() + " is exists, please use alias.");
}
beanName = inte.getName();
}
ioc.put(beanName, instance);
} else {
continue;
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
// 将类名首字母转小写
private String toLowerCaseFirstChar(String simpleName) {
char[] chars = simpleName.toCharArray();
chars[0] += 32; //利用ASCII码,大写字母和小写字母相差32
return String.valueOf(chars);
}
// 扫描指定包下所有的Class文件
private void doScanner(String scanPackage) {
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File classpath = new File(url.getFile());
for (File file : classpath.listFiles()) {
if (file.isDirectory()) {
doScanner(scanPackage + "." + file.getName() + ".");
} else {
if (!file.getName().endsWith(".class")) { continue; }
// 包名.类名 比如com.lucifer.project.UserController
String className = scanPackage + file.getName().replace(".class", "");
// 实例化时需要通过Class.forName(className);
classNames.add(className);
}
}
}
// 根据contextConfigLocation路径去Classpath下找到对应配置文件
private void doLoadConfig(String contextConfigLocation) {
InputStream in = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
try {
contextConfiguration.load(in);
} catch (IOException e) {
e.printStackTrace();
} finally {
if (null != in) {
try {
in.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
}
自定义Spring注解类
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyAutowired {
String value() default "";
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyController {
String value() default "";
}
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyService {
String value() default "";
}
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRespository {
String value() default "";
}
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestMapping {
String value() default "";
}
@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MyRequestParam {
String value() default "";
}
测试用户请求访问类
@MyController
@MyRequestMapping("/user")
public class UserController {
@MyAutowired
UserService userService;
@MyRequestMapping("/query")
public void query(HttpServletRequest req, HttpServletResponse resp, @MyRequestParam("name") String name) {
String result = "Hello, I am " + name;
try {
resp.getWriter().write(result);
} catch (IOException e) {
e.printStackTrace();
}
}
}