通过Filter,对request 和 response 进行处理

1、首先Filter的实现基础

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
                     FilterChain filterChain) throws IOException, ServletException {
    filterChain.doFilter(servletRequest, servletResponse);
}
用户请求到达,经过Filter到后台,后台处理完成,到Filter,返回给用户
doFilter(servletRequest, servletResponse)方法一直传递servletRequest,servletResponse

Controller是怎么获取参数和返回参数的呢?

public interface ServletRequest {

    String getParameter(String var1);

    String[] getParameterValues(String var1);

    Map<String, String[]> getParameterMap();
}

Controller主要通过这三个方法获取参数

public interface ServletResponse {
    ServletOutputStream getOutputStream() throws IOException;

    PrintWriter getWriter() throws IOException;
}
Controller主要通过这两个流输出结果到前端
因此可以重写ServletRequest的方法让Controller在去参数时得到的是我们修改过的参数

重写ServletResponse的方法让Controller在往前端写结果时写到我们的重写类里面,然后处理这些数据,再重新写到前端

ServletRequest重写


package com.spring.demo.filter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.regex.Pattern;

public class RequestWrapper extends HttpServletRequestWrapper{
    public RequestWrapper(HttpServletRequest request) {
        super(request);
    }
    @Override
    public String getParameter(String name) {
        String value = super.getParameter(replaceXSS(name));
        if (value != null) {
            value = replaceXSS(value);
        }
        return value;
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(replaceXSS(name));
        if(values != null && values.length > 0){
            for(int i =0; i< values.length ;i++){
                values[i] = replaceXSS(values[i]);
            }
        }
        return values;
    }

    @Override
    public String getHeader(String name) {

        String value = super.getHeader(replaceXSS(name));
        if (value != null) {
            value = replaceXSS(value);
        }
        return value;
    }
    /**
     * 去除待带script、src的语句,转义替换后的value值
     */
    public static String replaceXSS(String value) {
        if (value != null) {
            try{
                value = value.replace("+","%2B");   //'+' replace to '%2B'
                value = URLDecoder.decode(value, "utf-8");
            }catch(UnsupportedEncodingException e){
            }catch(IllegalArgumentException e){
            }

            // Avoid null characters
            value = value.replaceAll("\0", "");

            // Avoid anything between script tags
            Pattern scriptPattern = Pattern.compile("<script>(.*?)</script>", Pattern.CASE_INSENSITIVE);
            value = scriptPattern.matcher(value).replaceAll("");

            // Avoid anything in a src='...' type of e­xpression
            scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\'(.*?)\\\'", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            value = scriptPattern.matcher(value).replaceAll("");

            scriptPattern = Pattern.compile("src[\r\n]*=[\r\n]*\\\"(.*?)\\\"", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            value = scriptPattern.matcher(value).replaceAll("");

            // Remove any lonesome </script> tag
            scriptPattern = Pattern.compile("</script>", Pattern.CASE_INSENSITIVE);
            value = scriptPattern.matcher(value).replaceAll("");

            // Remove any lonesome <script ...> tag
            scriptPattern = Pattern.compile("<script(.*?)>", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            value = scriptPattern.matcher(value).replaceAll("");

            // Avoid eval(...) e­xpressions
            scriptPattern = Pattern.compile("eval\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            value = scriptPattern.matcher(value).replaceAll("");

            // Avoid e­xpression(...) e­xpressions
            scriptPattern = Pattern.compile("e­xpression\\((.*?)\\)", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            value = scriptPattern.matcher(value).replaceAll("");

            // Avoid javascript:... e­xpressions
            scriptPattern = Pattern.compile("javascript:", Pattern.CASE_INSENSITIVE);
            value = scriptPattern.matcher(value).replaceAll("");
            // Avoid alert:... e­xpressions
            scriptPattern = Pattern.compile("alert", Pattern.CASE_INSENSITIVE);
            value = scriptPattern.matcher(value).replaceAll("");
            // Avoid onload= e­xpressions
            scriptPattern = Pattern.compile("onload(.*?)=", Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
            value = scriptPattern.matcher(value).replaceAll("");
            scriptPattern = Pattern.compile("vbscript[\r\n| | ]*:[\r\n| | ]*", Pattern.CASE_INSENSITIVE);
            value = scriptPattern.matcher(value).replaceAll("");
        }
        return filter(value);
    }

    /**
     * 过滤特殊字符
     */
    public static String filter(String value) {
        if (value == null) {
            return null;
        }
        StringBuffer result = new StringBuffer(value.length());
        for (int i=0; i<value.length(); ++i) {
            switch (value.charAt(i)) {
                case '<':
                    result.append("<");
                    break;
                case '>':
                    result.append(">");
                    break;
                case '"':
                    result.append("\"");
                    break;
                case '\'':
                    result.append("'");
                    break;
                case '%':
                    result.append("%");
                    break;
                case ';':
                    result.append(";");
                    break;
                case '(':
                    result.append("(");
                    break;
                case ')':
                    result.append(")");
                    break;
                case '&':
                    result.append("&");
                    break;
                case '+':
                    result.append("+");
                    break;
                default:
                    result.append(value.charAt(i));
                    break;
            }
        }
        return result.toString();
    }
}
ServletResponse重写

//重定向输出流写到DataOutputStream

package com.spring.demo.filter;

import javax.servlet.ServletOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
//重定向输出流写到DataOutputStream 
public class FilterServletOutputStream extends ServletOutputStream {
    DataOutputStream output;
    public FilterServletOutputStream(OutputStream output) {
        this.output = new DataOutputStream(output);
    }

    @Override
    public void write(int arg0) throws IOException {
        output.write(arg0);
    }

    @Override
    public void write(byte[] arg0, int arg1, int arg2) throws IOException {
        output.write(arg0, arg1, arg2);
    }

    @Override
    public void write(byte[] arg0) throws IOException {
        output.write(arg0);
    }
}

//重定向输出流写到ByteArrayOutputStream 

//ByteArrayOutputStream 接受Controller写入的数据,并以byte[]形式返回给Filter

package com.spring.demo.filter;

import io.netty.handler.codec.http.HttpResponseStatus;

import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

public class ResponseWrapper extends HttpServletResponseWrapper {

    ByteArrayOutputStream output;
    FilterServletOutputStream filterOutput;
    HttpResponseStatus status = HttpResponseStatus.OK;

    public ResponseWrapper(HttpServletResponse response) {
        super(response);
        output = new ByteArrayOutputStream();
    }

    @Override
    public ServletOutputStream getOutputStream() throws IOException {
        if (filterOutput == null) {
            filterOutput = new FilterServletOutputStream(output);
        }
        return filterOutput;
    }

    public byte[] getDataStream() {
        return output.toByteArray();
    }
}




package com.spring.demo.filter;

import java.io.Serializable;

public class RestResponse implements Serializable {
    private int status;

    private String message;

    private Object data;

    public RestResponse(int status, String message, Object data) {
        this.status = status;
        this.message = message;
        this.data = data;
    }

    public int getStatus() {
        return status;
    }

    public void setStatus(int status) {
        this.status = status;
    }

    public String getMessage() {
        return message;
    }

    public void setMessage(String message) {
        this.message = message;
    }

    public Object getData() {
        return data;
    }

    public void setData(Object data) {
        this.data = data;
    }


}


//Filter重写request和response并以chain的形式传递给Controller,Controller获取或输出数据都讲调用RequestWrapper,ResponseWrapper


package com.spring.demo.filter;


import com.alibaba.fastjson.JSONObject;
import org.apache.htrace.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

public class SessionFilter  implements Filter {
    protected final Logger logger = LoggerFactory.getLogger(SessionFilter.class);
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        logger.info("SessionFilter init" );
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse,
                         FilterChain filterChain) throws IOException, ServletException {
        logger.info("doFilter start" );
        // TODO Auto-generated method stub
        RequestWrapper requestWrapper = new RequestWrapper((HttpServletRequest) servletRequest);
        ResponseWrapper responseWrapper = new ResponseWrapper((HttpServletResponse) servletResponse);
        try {
            filterChain.doFilter(requestWrapper, responseWrapper);
        } catch (ServletException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        String responseContent = new String(responseWrapper.getDataStream());
        //此处可以处理responseContent,然后封装成RestResponse 返回给前端
        JSONObject jsonObject = JSONObject.parseObject(responseContent);
        logger.info("responseContent({})",responseContent);
        RestResponse fullResponse = new RestResponse(205, "OK-MESSAGE",jsonObject);

        byte[] responseToSend = restResponseBytes(fullResponse);

        servletResponse.getOutputStream().write(responseToSend);
        logger.info("doFilter end" );
    }
    @Override
    public void destroy() {

    }

    private byte[] restResponseBytes(RestResponse response) throws IOException {
        String serialized = new ObjectMapper().writeValueAsString(response);
        return serialized.getBytes("UTF-8");
}
}








上一篇:Java干货神总结,程序员面试技巧


下一篇:【TcaplusDB知识库】[Generic表]插入数据示例代码