由于Request和Response是用流的方式传递数据,所以只能读取一次。tomcat中已有SavedRequest类,没有SavedResponse类,我们创建两个容器类来装载Request/Response->写一个过滤器Filter拦截请求将Info装载入容器中。
RequestWrapper:
import com.baomidou.mybatisplus.core.toolkit.ObjectUtils; import com.longshine.luxicrmboot.commons.utils.ApplicationUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.springframework.util.StreamUtils; import org.springframework.web.util.HtmlUtils; import javax.servlet.ReadListener; import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import java.io.*; import java.nio.charset.StandardCharsets; import java.util.Objects; /** * Request包装类 * <p> * 1.预防xss攻击 * 2.拓展requestbody无限获取(HttpServletRequestWrapper只能获取一次) * </p> * * @author Caratacus */ @Slf4j public class RequestWrapper extends HttpServletRequestWrapper { /** * 存储requestBody byte[] */ private final byte[] body; public RequestWrapper(HttpServletRequest request) { super(request); byte[] body = new byte[0]; try { body = StreamUtils.copyToByteArray(request.getInputStream()); } catch (IOException e) { log.error("Error: Get RequestBody byte[] fail," + e); } this.body = body; } @Override public BufferedReader getReader() { ServletInputStream inputStream = getInputStream(); return Objects.isNull(inputStream) ? null : new BufferedReader(new InputStreamReader(inputStream)); } @Override public ServletInputStream getInputStream() { if (ObjectUtils.isEmpty(body)) { return null; } final ByteArrayInputStream bais = new ByteArrayInputStream(body); return new ServletInputStream() { @Override public boolean isFinished() { return false; } @Override public boolean isReady() { return false; } @Override @SuppressWarnings("EmptyMethod") public void setReadListener(ReadListener readListener) { } @Override public int read() { return bais.read(); } }; } @Override public String[] getParameterValues(String name) { String[] values = super.getParameterValues(name); if (values == null) { return null; } int count = values.length; String[] encodedValues = new String[count]; for (int i = 0; i < count; i++) { encodedValues[i] = htmlEscape(values[i]); } return encodedValues; } @Override public String getParameter(String name) { String value = super.getParameter(name); if (value == null) { return null; } return htmlEscape(value); } @Override public Object getAttribute(String name) { Object value = super.getAttribute(name); if (value instanceof String) { htmlEscape((String) value); } return value; } @Override public String getHeader(String name) { String value = super.getHeader(name); if (value == null) { return null; } return htmlEscape(value); } @Override public String getQueryString() { String value = super.getQueryString(); if (value == null) { return null; } return htmlEscape(value); } /** * 使用spring HtmlUtils 转义html标签达到预防xss攻击效果 * * @param str * @see org.springframework.web.util.HtmlUtils#htmlEscape */ protected String htmlEscape(String str) { return HtmlUtils.htmlEscape(str); } }
ResponseWrapper:
import com.alibaba.fastjson.JSON; import com.google.common.base.Throwables; import com.longshine.luxicrmboot.commons.msg.AjaxResult; import com.longshine.luxicrmboot.commons.msg.ErrorCode; import io.swagger.annotations.ApiResponses; import lombok.extern.slf4j.Slf4j; import org.springframework.util.MimeTypeUtils; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponseWrapper; import java.io.IOException; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.util.Objects; /** * response包装类 * * @author Caratacus */ @Slf4j public class ResponseWrapper extends HttpServletResponseWrapper { private ErrorCode errorcode; public ResponseWrapper(HttpServletResponse response) { super(response); } public ResponseWrapper(HttpServletResponse response, ErrorCode errorcode) { super(response); setErrorCode(errorcode); } /** * 获取ErrorCode * * @return */ public ErrorCode getErrorCode() { return errorcode; } /** * 设置ErrorCode * * @param errorCode */ public void setErrorCode(ErrorCode errorCode) { if (Objects.nonNull(errorCode)) { this.errorcode = errorCode; super.setStatus(this.errorcode.getHttpCode()); } } /** * 向外输出错误信息 * * @param e * @throws IOException */ public void writerErrorMsg(Exception e) { if (Objects.isNull(errorcode)) { log.warn("Warn: ErrorCodeEnum cannot be null, Skip the implementation of the method."); return; } printWriterApiResponses(AjaxResult.failure(this.getErrorCode(), e)); } /** * 设置ApiErrorMsg */ public void writerErrorMsg() { writerErrorMsg(null); } /** * 向外输出AjaxResult * * @param ajaxResult */ public void printWriterApiResponses(AjaxResult ajaxResult) { writeValueAsJson(ajaxResult); } /** * 向外输出json对象 * * @param obj */ public void writeValueAsJson(Object obj) { if (super.isCommitted()) { log.warn("Warn: Response isCommitted, Skip the implementation of the method."); return; } super.setContentType(MimeTypeUtils.APPLICATION_JSON_VALUE); super.setCharacterEncoding(StandardCharsets.UTF_8.name()); try (PrintWriter writer = super.getWriter()) { writer.print(JSON.toJSONString(obj)); writer.flush(); } catch (IOException e) { log.warn("Error: Response printJson faild, stackTrace: {}", Throwables.getStackTraceAsString(e)); } } }
过滤器:
import com.longshine.luxicrmboot.commons.wrapper.RequestWrapper; import org.springframework.stereotype.Component; import javax.servlet.*; import javax.servlet.annotation.WebFilter; import javax.servlet.http.HttpServletRequest; import java.io.IOException; /** * 记住Request/Response 过滤器 * 解决Request/Response不能重复使用问题 * * @author Caratacus */ @Component @WebFilter(filterName = "crownFilter", urlPatterns = "/*") public class MemoryReqResFilter implements Filter { @Override @SuppressWarnings("EmptyMethod") public void destroy() { } @Override public void doFilter(ServletRequest request, ServletResponse res, FilterChain chain) throws ServletException, IOException { HttpServletRequest req = (HttpServletRequest) request; chain.doFilter(new RequestWrapper(req), res); } @Override @SuppressWarnings("EmptyMethod") public void init(FilterConfig config) { } }