过滤器验证Token

之前的token验证一直觉得很烂,今天做下优化,老项目就不贴出来了。

第一步

首先将过滤器注册到Bean容器,拦截相关请求,本来是想通过实现ApplicationContextAware接口的setApplicationContext去获取spring的上下文,但是容器启动时报错,才发现WebMvcConfigurationSupport中已经实现了,所以这里直接取用,new Fitler时要将applicationContext通过构造方法传入以便处理

package com.mbuyy.config;

import com.mbuyy.common.CommonInterceptor;
import com.mbuyy.filter.LoginFilter;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurationSupport;

import java.nio.charset.Charset;

@Configuration
public class WebConfig extends WebMvcConfigurationSupport {

    @Bean
    public HttpMessageConverter<String> responseBodyConverter() {
        StringHttpMessageConverter converter = new StringHttpMessageConverter(
                Charset.forName("UTF-8"));
        return converter;
    }


    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new CommonInterceptor()).addPathPatterns("/**");
        super.addInterceptors(registry);
    }

    /**
     * 1、注册过滤器
     *
     * @return
     */
    @Bean
    public FilterRegistrationBean filterRegist() {
        FilterRegistrationBean frBean = new FilterRegistrationBean();
        frBean.setFilter(new LoginFilter(this.getApplicationContext()));
        frBean.addUrlPatterns("/payment/*");
        frBean.addUrlPatterns("/merchant/*");
        frBean.addUrlPatterns("/mobile/*");
        System.out.println("filter");
        return frBean;
    }
}

第二步

书写构造器

package com.mbuyy.filter;

import com.mbuyy.annotation.TokenIgnoreUtils;
import com.mbuyy.annotation.TokenVerify;
import com.mbuyy.util.JWTUtils;
import com.mbuyy.util.RedisUtils;
import com.mbuyy.util.ValidateUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationContext;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.Enumeration;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Pattern;

import static com.mbuyy.constants.Constants.*;

public class LoginFilter implements Filter {
    private static final Logger logger = LoggerFactory.getLogger(LoginFilter.class);

    private TokenIgnoreUtils tokenIgnoreUtils;
    private ApplicationContext applicationContext;

    Pattern pattern = Pattern.compile("^[-\\+]?[\\d]*$");

    public LoginFilter() {
    }

    /**
     * 构造器注入spring的ApplicationContext上下文对象
     * @param applicationContext
     */
    public LoginFilter(ApplicationContext applicationContext) {
        this.applicationContext = applicationContext;
    }

    @Override
    public void destroy() {

    }

    @Override
    public void init(FilterConfig arg0) {
        tokenIgnoreUtils = new TokenIgnoreUtils();

        // 获取所有添加了TokenVerifyAnnotion注解的类,注册并添加
        Map<String, Object> beansWithAnnotationMap = this.applicationContext.getBeansWithAnnotation(TokenVerify.class);

        for (Object tokenVerifyBean : beansWithAnnotationMap.keySet()){
            tokenIgnoreUtils.registerController(tokenVerifyBean.getClass());
        }
    }


    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        //拿到对象
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse rep = (HttpServletResponse) response;

        setCrossDomian(req, rep);

        // 获取请求路径
        String requestURI = req.getRequestURI();

        // 0客户端token,1商户端token,2代驾端token
        int tokenType = getTokenType(req);
        String token = getToken(req);

        req.setAttribute(TOKEN_TYPE, tokenType);

        System.out.println("token:" + token);
        System.out.println("tokenType:" + tokenType);
        // TODO: 2019/1/10 这里有万能token,本地测试使用,别忘了去掉!!!!
//        if (ValidateUtils.isNotEmpty(token) && pattern.matcher(token).matches()) {
//            req.setAttribute(TOKEN_ID, Integer.parseInt(token));
//            chain.doFilter(request, response);
//            return;
//        }

        try {
            verify(req, token, tokenType,requestURI);
        } catch (Exception e) {
            e.printStackTrace();
            failed(rep, e.getMessage());
            return;
        }

        chain.doFilter(request, response);
    }

    private String getToken(HttpServletRequest req) {
        String token = "";
        if (ValidateUtils.isNotEmpty(req.getHeader("token_merchant"))){
            token = req.getHeader("token_merchant");
        } else {
            token = req.getHeader("token");
        }
        return token;
    }

    /**
     * 判断是来自哪个端的请求,哪个token有取哪一个,不同的token之间互斥
     * @param req
     * @return
     */
    private int getTokenType(HttpServletRequest req) {
        int tokenType = 0;
        Enumeration<String> headerNames = req.getHeaderNames();
        String headName = headerNames.nextElement();
        while (ValidateUtils.isNotEmpty(headName)){
            if (headName.equals("token_merchant")){
                tokenType = 1;
                break;
            } else if (headName.equals("token")){
                tokenType = 0;
                break;
            }
            headName = headerNames.nextElement();
        }
        return tokenType;
    }


    /**
     * 验证token
     * @param req
     * @param token
     * @param tokenType
     * @param requestURI
     */
    private void verify(HttpServletRequest req, String token, int tokenType, String requestURI) throws Exception {
        //如果是登录就不需要验证
        boolean isNeedToken = tokenIgnoreUtils.startCheck(requestURI);
        if (!isNeedToken) {
            //如果不需要Token
            if (ValidateUtils.isNotEmpty(token)) {
                // 解析token
                parseToken(req, token, tokenType);
            }
        }else{
            if (ValidateUtils.isNotEmpty(token)) {
                // 解析token
                parseToken(req, token, tokenType);
            }else{
                //token没传 直接失败
                throw new RuntimeException("请输入验证Token");
            }

        }
    }

    /**
     * 设置跨域问题
     * @param req
     * @param rep
     * @throws UnsupportedEncodingException
     */
    private void setCrossDomian(HttpServletRequest req, HttpServletResponse rep) throws UnsupportedEncodingException {
        req.setCharacterEncoding("utf-8");

        // 设置允许跨域访问的域,*表示支持所有的来源
        rep.setHeader("Access-Control-Allow-Origin", "*"); //Access-Control-Allow-Origin
        System.out.println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>");
        // 设置允许跨域访问的方法
        rep.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
        rep.setHeader("Access-Control-Max-Age", "3600");
        rep.setHeader("Access-Control-Allow-Headers", "*");
    }

    /**
     * 解析token得到tokenId
     *
     * @param req
     * @param token
     * @param tokenType
     * @return
     * @throws Exception
     */
    private void parseToken(HttpServletRequest req, String token, int tokenType) throws Exception {
        //存在token使用JWT鉴权,鉴权失败以无token处理
        Map<String, String> tokenMap;
        try {
            tokenMap = JWTUtils.getInstance().verifyToken(token);
        } catch (JWTUtils.TokenException e) {
            throw new RuntimeException("鉴权异常");
        }

        // 根据token获取token_id
        Integer tokenId = Integer.parseInt(tokenMap.get(TOKEN_ID));
        if (tokenId == null){
            throw new RuntimeException("TokenId获取失败");
        }

        switch (tokenType){
            case 1 :
                parseMerchantToken(tokenId, token);
                break;
            default:
                parseCustomerToken(tokenId, token);
                break;
        }


        //放行
        //添加token里的user_id
        req.setAttribute(TOKEN_ID, tokenId);

        logger.info("token_id = " + tokenId);

    }

    private void parseCustomerToken(Integer tokenId, String token) throws Exception {
        logger.info("解析客户端token:" + token);

        // redis缓存中是否存在token
        if(!RedisUtils.exists(TOKEN + tokenId)){
            throw new RuntimeException("redis match error");
        }

        String redisToken = RedisUtils.get(TOKEN + tokenId);
        if (!Objects.equals(token, redisToken) && !("-1".equals(redisToken))) {
            throw new RuntimeException("Token Error");
        }


//        int result = userPunishService.verifyUser(tokenId);
//        switch (result){
//            case 1: throw new RuntimeException("Account has been deleted");
//            case 2: throw new RuntimeException("The account has been disabled");
////            case 3: throw new RuntimeException("Account has been deleted"); break;
//        }
    }


    private void parseMerchantToken(Integer tokenId, String token) throws Exception {
        logger.info("解析商户端token:" + token);

        // redis缓存中是否存在token
        if(!RedisUtils.exists(TOKEN_SHOP + tokenId)){
            throw new RuntimeException("匹配异常");
        }

        String redisToken = RedisUtils.get(TOKEN_SHOP + tokenId);
        if (!Objects.equals(token, redisToken) && !("-1".equals(redisToken))) {
            throw new RuntimeException("Token错误");
        }

//        int result = userPunishService.verifyMerchant(tokenId);
//        switch (result){
//            case 1: throw new RuntimeException("Account has been deleted");
//            case 2: throw new RuntimeException("The account has been disabled");
////            case 3: throw new RuntimeException("Account has been deleted"); break;
//        }
    }

    private void failed(HttpServletResponse rep, String msg) throws IOException {
        PrintWriter w = rep.getWriter();
        w.write("{\"status\": 401,\"msg\": \"" + msg + "\"}");
        w.flush();
        w.close();
    }
}

第三步

获取当前登录用户信息

Integer currentUserId = (Integer) request.getAttribute(TOKEN_ID);

笔者还很菜,要走的路还很长,如有不妥之处,还请联系小编不吝赐教

 

上一篇:手写Pascal解释器(一)


下一篇:render与vue组件和注册