手写一个简单版的SpringMVC

一 写在前面

这是自己实现一个简单的具有SpringMVC功能的小Demo,主要实现效果是;

自己定义的实现效果是通过浏览器地址传一个name参数,打印“my name is”+name参数。不使用SpringMVC,自己定义部分注解,实现DispatcherServlet核心功能,通过这个demo可以加深自己对源码的理解。

先看一下实现效果:

(传入了参数时)

手写一个简单版的SpringMVC

(没有传入参数时)

手写一个简单版的SpringMVC

二  DispatcherServlet流程

  1. 加载配置文件
  2. 扫描所有相关类
  3. 初始化所有相关的类
  4. 自动注入
  5. 初始化HandlerMapping
  6. 等待请求

三 代码回顾

1.首先来看一下Pom文件的依赖:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
<dependencies>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>servlet-api</artifactId>
<version>2.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.10</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.12</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>1.2.3</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.3</version>
</dependency>
</dependencies>

依赖比较少,没有spring的依赖,主要就是一个servlet的。

2. 配置文件:

2.1. application.properties文件:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
scanPackage=com.qunar.framework.demo

这是说明要扫描的位置。

2.2. web.xml文件:

<!DOCTYPE web-app PUBLIC
"-//Sun Microsystems, Inc.//DTD Web Application 2.3//EN"
"http://java.sun.com/dtd/web-app_2_3.dtd" > <web-app>
<display-name>MySpringMVC</display-name>
<servlet>
<servlet-name>mvc</servlet-name>
<servlet-class>com.qunar.framework.webmvc.DispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>/application.properties</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
</servlet>
<servlet-mapping>
<servlet-name>mvc</servlet-name>
<url-pattern>/*</url-pattern>
</servlet-mapping>
</web-app>

3. 下面是整个工程的目录结构:

手写一个简单版的SpringMVC

4. 自定义注解:

@Controller:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Controller {
String value() default "";
}

@Service:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {
String value() default "";
}

@AutoWired:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
String value() default "";
}

@RequestMapping:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
String value() default "";
}

@RequestParam:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RequestParam {
String value() default "";
}

5.自己封装的Handler:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
public class Handler {
protected Object controller;
protected Method method;
protected Pattern pattern;
protected Map<String,Integer> paramIndexMap; public Handler(Object controller, Method method, Pattern pattern) {
this.controller = controller;
this.method = method;
this.pattern = pattern;
this.paramIndexMap = new HashMap<>();
putParamIndexMapping(method);
} private void putParamIndexMapping(Method method) {
//获取方法中加了注解的参数
Annotation[][] annotations = method.getParameterAnnotations();
for (int i =0; i < annotations.length;i++){
for (Annotation annotation : annotations[i]){
if (annotation instanceof RequestParam){
String paramName = ((RequestParam) annotation).value();
if (!StringUtils.isBlank(paramName)){
paramIndexMap.put(paramName,i);
}
}
}
}
//获取方法中的我request和response的参数
Class<?>[] paramTypes = method.getParameterTypes();
for (int i = 0; i < paramTypes.length; i++){
Class<?> paramType = paramTypes[i];
if (paramType == HttpServletRequest.class || paramType == HttpServletResponse.class){
paramIndexMap.put(paramType.getName(),i);
}
}
}
}

6. 自己封装的DispatcherServlet:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Slf4j
public class DispatcherServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private Properties contextConfig = new Properties();
private List<String> classNames = new ArrayList<>();
private Map<String, Object> iocMap = new HashMap<>();
private List<Handler> handlerMapping = new ArrayList<>(); @Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
this.doPost(req, resp);
} @Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
//等待请求
try {
doDispatch(req, resp);
} catch (Exception exception) {
resp.getWriter().write("500 Exception");
log.error("500 Exception. Cause: {}", exception.getMessage());
exception.printStackTrace();
}
} private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
Handler handler = getHandler(req);
if (handler == null) {
//没有匹配上,404
log.info("404 Not Found");
resp.getWriter().write("404 Not Found");
return;
}
//获取参数列表
Class<?>[] parameterTypes = handler.method.getParameterTypes();
//保存所有需要自动赋值的参数值
Object[] parameterValues = new Object[parameterTypes.length]; Map<String, String[]> parameterMap = req.getParameterMap();
for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
String value = Arrays.toString(entry.getValue()).replaceAll("\\[|\\]", "").replaceAll("/+", "/");
log.info(value);
//如果找到了匹配的值,就填充
if (!handler.paramIndexMap.containsKey(entry.getKey())) {
continue;
}
Integer index = handler.paramIndexMap.get(entry.getKey());
parameterValues[index] = convert(parameterTypes[index], value);
}
//设置方法中的request对象和response对象
Integer reqIndex = handler.paramIndexMap.get(HttpServletRequest.class.getName());
Integer respIndex = handler.paramIndexMap.get(HttpServletResponse.class.getName());
parameterValues[reqIndex] = req;
parameterValues[respIndex] = resp;
handler.method.invoke(handler.controller, parameterValues);
} private Object convert(Class<?> parameterType, String value) {
if (parameterType == Integer.class) {
return Integer.valueOf(value);
}
return value;
} private Handler getHandler(HttpServletRequest req) {
if (handlerMapping.isEmpty()) {
return null;
}
String requestURI = req.getRequestURI();
String contextPath = req.getContextPath();
requestURI = requestURI.replace(contextPath, "").replaceAll("/+", "/");
for (Handler handler : handlerMapping) {
Matcher matcher = handler.pattern.matcher(requestURI);
if (!matcher.matches()) {
continue;
}
return handler;
}
return null;
} @Override
public void init(ServletConfig config) {
//从这里开始启动:
//加载配置文件
loadConfig(config.getInitParameter("contextConfigLocation"));
//扫描相关类
doScanner(contextConfig.getProperty("scanPackage"));
//初始化相关类
try {
doInstance();
} catch (Exception exception) {
log.error("Execute doInstance method fail.");
exception.printStackTrace();
}
//自动注入
doAutowired();
//初始化HandlerMapping
initHandlerMapping();
} private void initHandlerMapping() {
if (iocMap.isEmpty()) {
return;
}
for (Map.Entry<String, Object> entry : iocMap.entrySet()) {
Class<?> clazz = entry.getValue().getClass();
if (!clazz.isAnnotationPresent(Controller.class)) {
continue;
}
String baseUrl = "";
if (clazz.isAnnotationPresent(RequestMapping.class)) {
RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class);
baseUrl = requestMapping.value();
}
//扫描所有的公共方法
for (Method method : clazz.getMethods()) {
if (!method.isAnnotationPresent(RequestMapping.class)) {
continue;
}
RequestMapping requestMapping = method.getAnnotation(RequestMapping.class);
String regex = ("/" + baseUrl + requestMapping.value()).replaceAll("/+", "/");
Pattern pattern = Pattern.compile(regex);
handlerMapping.add(new Handler(entry.getValue(), method, pattern));
log.info("Mapping: {}.{}", regex, method);
}
}
} private void doAutowired() {
if (iocMap.isEmpty()) {
return;
}
//循环所有的类,对需要自动赋值的属性进行赋值
for (Map.Entry<String, Object> entry : iocMap.entrySet()) {
Field[] fields = entry.getValue().getClass().getDeclaredFields();
for (Field field : fields) {
if (!field.isAnnotationPresent(Autowired.class)) {
continue;
}
Autowired autowired = field.getAnnotation(Autowired.class);
String beanName = autowired.value();
if (beanName != null) {
beanName = beanName.trim();
}
if (StringUtils.isBlank(beanName)) {
beanName = field.getType().getName();
}
field.setAccessible(true);
try {
field.set(entry.getValue(), iocMap.get(beanName));
} catch (IllegalAccessException e) {
log.error("AutoWired fail,beanName: {}", beanName);
e.printStackTrace();
continue;
}
}
}
} private void doInstance() throws Exception {
if (classNames.isEmpty()) {
return;
}
for (String className : classNames) {
Class<?> clazz = Class.forName(className);
//如果自定义了名字,就优先使用自己的名字,否则默认是小写(这里就不默认首字母为小写了
if (clazz.isAnnotationPresent(Controller.class)) {
Controller controller = clazz.getAnnotation(Controller.class);
String beanName = controller.value();
if (StringUtils.isBlank(beanName)) {
beanName = clazz.getName().toLowerCase();
}
Object instance = clazz.newInstance();
iocMap.put(beanName, instance);
} else if (clazz.isAnnotationPresent(Service.class)) {
Service service = clazz.getAnnotation(Service.class);
String beanName = service.value();
if (StringUtils.isBlank(beanName)) {
beanName = clazz.getName().toLowerCase();
}
Object instance = clazz.newInstance();
iocMap.put(beanName, instance);
//根据接口类型来赋值
for (Class<?> clazzInterface : clazz.getInterfaces()) {
iocMap.put(clazzInterface.getName(), instance);
}
} else {
continue;
}
}
} private void doScanner(String scanPackage) {
URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.", "/"));
File classDir = new File(url.getFile());
for (File file : classDir.listFiles()) {
if (file.isDirectory()) {
doScanner(scanPackage + "." + file.getName());
} else {
String className = scanPackage + "." + file.getName().replace(".class", "");
classNames.add(className);
}
}
} private void loadConfig(String location) {
InputStream inputStream = this.getClass().getResourceAsStream(location);
try {
contextConfig.load(inputStream);
} catch (IOException e) {
log.error("Load fail, location: {}", location);
e.printStackTrace();
} finally {
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
log.error("Close fail, inputStream: {}", inputStream);
e.printStackTrace();
}
}
}
}
}

这个类就是最核心的类,它做了SpringMVC的事情。

7.下面是验证自己SpringMVC是否可用的时候了,自己写了service和controller:

7.1 service:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
public class DemoServiceImpl implements IDemoService {
@Override
public String get(String name) {
return "my name is " + name;
}
}

7.2 controller:

手写一个简单版的SpringMVC手写一个简单版的SpringMVC
@Controller
@RequestMapping("/demo")
@Slf4j
public class DemoController {
@Autowired
IDemoService service; @RequestMapping("/get")
public void get(HttpServletRequest req, HttpServletResponse resp, @RequestParam("name") String name) {
String res = service.get(name);
try {
resp.setContentType("text/html;charset=UTF-8");
resp.getWriter().println(res);
} catch (IOException e) {
log.info(e.getMessage());
e.printStackTrace();
}
}
}

再结合开头贴出来的图片,验证了自己的这个SpringMVC是可以使用的。

四 最后

这里只要实现了SpringMVC最简单的功能而已。这只是一个加深自己对SpringMVC的mapping映射流程的理解而已,真正的SpringMVC当然远不止如此简单。

Demo的github地址:https://github.com/Happy-Ape/Spring

上一篇:让Python带你看一场唯美的横飘雪!


下一篇:测试驱动开发(TDD)及测试框架Mocha.js入门学习