spring 6.0 RestTemplate 配置

RestTemplateConfig 

package cn.sitc.pc.config;

import cn.sitc.pc.interceptor.LoggingInterceptor;
import cn.sitc.pc.wrapper.APIResponseWrapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.hc.client5.http.config.RequestConfig;
import org.apache.hc.client5.http.impl.classic.CloseableHttpClient;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManagerBuilder;
import org.apache.hc.client5.http.ssl.*;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.web.client.RestTemplate;

import javax.net.ssl.TrustManager;
import javax.net.ssl.SSLContext;
import javax.net.ssl.X509TrustManager;
import java.security.cert.X509Certificate;
import java.time.LocalDateTime;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

/**
 * @date 2024/10/24
 **/
@Slf4j
@Configuration
public class RestTemplateConfig {

    @Bean
    public RestTemplate apiRestTemplate() throws Exception {
        RestTemplate restTemplate = new RestTemplate(factory());
        restTemplate.setInterceptors(Collections.singletonList(bpmInterceptor()));
//        MappingJackson2HttpMessageConverter messageConverter = restTemplate.getMessageConverters().stream().filter(MappingJackson2HttpMessageConverter.class::isInstance)
//                .map(MappingJackson2HttpMessageConverter.class::cast).findFirst().orElseThrow(() -> new RuntimeException("MappingJackson2HttpMessageConverter not found"));
//        messageConverter.setObjectMapper(new ObjectMapper());
//        //防止响应中文乱码
//        restTemplate.getMessageConverters().stream().filter(StringHttpMessageConverter.class::isInstance).map(StringHttpMessageConverter.class::cast).forEach(a -> {
//            a.setWriteAcceptCharset(false);
//            a.setDefaultCharset(StandardCharsets.UTF_8);
//        });
        return restTemplate;
    }

    public ClientHttpRequestFactory factory() throws Exception {
        // 创建一个HttpComponentsClientHttpRequestFactory,并设置HttpClient
        HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory();
        requestFactory.setHttpClient(createHttpClient());
        return requestFactory;
    }

    public CloseableHttpClient createHttpClient() throws Exception {
        TlsSocketStrategy tlsStrategy = new DefaultClientTlsStrategy(createSslContext(), NoopHostnameVerifier.INSTANCE);
        // 使用自定义的 SSLConnectionSocketFactory 构建连接管理器
        PoolingHttpClientConnectionManager connManager = PoolingHttpClientConnectionManagerBuilder.create()
                .setTlsSocketStrategy(tlsStrategy)
                .build();
        RequestConfig requestConfig = RequestConfig.custom()
                .setConnectionRequestTimeout(5L, TimeUnit.SECONDS)
                .setResponseTimeout(5L, TimeUnit.SECONDS).build();

        // 创建 HttpClient
        return HttpClients.custom()
                .setConnectionManager(connManager)
                .setDefaultRequestConfig(requestConfig)
                .build();
    }

    public SSLContext createSslContext() throws Exception {
        // 创建一个信任所有证书的信任管理器
        TrustManager[] trustAllCerts = new TrustManager[]{
                new X509TrustManager() {
                    @Override
                    public X509Certificate[] getAcceptedIssuers() {
                        return null;
                    }
                    @Override
                    public void checkClientTrusted(X509Certificate[] certs, String authType) {
                    }
                    @Override
                    public void checkServerTrusted(X509Certificate[] certs, String authType) {
                    }
                }
        };
        // 使用我们的信任管理器初始化SSL上下文
        SSLContext sslContext = SSLContext.getInstance("TLS");
        sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
        return sslContext;
    }

    @Bean
    public LoggingInterceptor loggingInterceptor() {
        return new LoggingInterceptor();
    }

    @Bean
    public ClientHttpRequestInterceptor bpmInterceptor() {
        return (request, body, execution) -> {
            LocalDateTime startTime = LocalDateTime.now();
            try {
                ClientHttpResponse response = execution.execute(request, body);
                return new APIResponseWrapper(response, "BPM", request, body, startTime);
            } catch (Exception e) {
                log.error(e.getMessage(), e);
                return new APIResponseWrapper(e, "BPM", request, body, startTime);
            }
        };
    }
}

ClientHttpResponse

package cn.sitc.pc.wrapper;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.hibernate.service.spi.ServiceException;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.lang.NonNull;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;

/**
 * @date 2024/10/24
 **/

@Slf4j
public class APIResponseWrapper implements ClientHttpResponse {

    private ClientHttpResponse originalResponse;
    private final byte[] responseBodyByte;
    private ByteArrayInputStream responseInputStream;

    public APIResponseWrapper(ClientHttpResponse originalResponse) {
        if (originalResponse == null) {
            // 调用OMS Token接口异常
            this.responseBodyByte = "ERROR".getBytes();
        } else {
            this.originalResponse = originalResponse;
            try {
                responseBodyByte = IOUtils.toString(originalResponse.getBody(), StandardCharsets.UTF_8).getBytes();
                // JSONObject json = JSON.parseObject(new String(responseBodyByte));
            } catch (Exception e1) {
                throw new ServiceException("There was a problem reading/decoding the response coming from the service ", e1);
            }
        }
    }

    public APIResponseWrapper(ClientHttpResponse response, String system, HttpRequest request, byte[] body, LocalDateTime startTime) throws IOException {
        this(response);
        logIntegration(request, body, response, this.responseBodyByte, system, startTime);
    }

    public APIResponseWrapper(Exception exception, String system, HttpRequest request, byte[] body, LocalDateTime startTime) {
        this(null);
        logRequest(request, body);
        callTime(system, startTime);
    }

    private void logIntegration(HttpRequest request, byte[] requestBody, ClientHttpResponse response, byte[] responseBody, String system, LocalDateTime startTime) throws IOException {
        logRequest(request, requestBody);
        logResponse(response, responseBody);
        callTime(system, startTime);
    }

    private void logRequest(HttpRequest request, byte[] requestBody) {
        log.info("=========================== request begin ===========================");
        log.info("Request method: {}", request.getMethod());
        log.info("Request URI: {}", request.getURI());
        log.info("Request headers: {}", request.getHeaders());
        if (requestBody.length > 0) {
            log.info("Request body: {}", URLDecoder.decode(new String(requestBody, StandardCharsets.UTF_8), StandardCharsets.UTF_8));
        }
        log.info("=========================== request end ===========================");
    }

    private void logResponse(ClientHttpResponse response, byte[] responseBody) throws IOException {
        log.info("=========================== response begin ===========================");
        log.info("Response status code: {}", response.getStatusCode());
        log.info("Response headers: {}", response.getHeaders());
        String responseString = new String(responseBody, StandardCharsets.UTF_8);
        // 截取打印 response 的日志长度
        String truncatedResponse = responseString.substring(0, Math.min(10000, responseString.length()));
        log.info("Response body: {}", truncatedResponse);
        log.info("=========================== response end ===========================");
    }

    private void callTime(String system, LocalDateTime startTime) {
        log.info("===========================invoke {} call time:{}ms", system, ChronoUnit.MILLIS.between(startTime, LocalDateTime.now()));
    }

    @NonNull
    @Override
    public HttpStatusCode getStatusCode() throws IOException {
        return originalResponse != null ? originalResponse.getStatusCode() : HttpStatus.INTERNAL_SERVER_ERROR;
    }

    @NonNull
    @Override
    public String getStatusText() throws IOException {
        return originalResponse != null ? originalResponse.getStatusText() : HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase();
    }

    @Override
    public void close() {
        originalResponse.close();
    }

    @NonNull
    @Override
    public InputStream getBody() throws IOException {
        if (responseInputStream == null) {
            try {
                JSONObject json = JSON.parseObject(new String(responseBodyByte));
                // 根据返回的code 做后续逻辑处理
//                if (json.getInteger("code") != 200) {
//                    // 存在错误
//                    log.error("ERROR: {}", json.toJSONString());
//                    // 修改返回
//                    json.put("message", "" + json.getString("message"));
//                }
                responseInputStream = new ByteArrayInputStream(json.toJSONString().getBytes());
            } catch (Exception e) {
                log.error("系统异常", e);
            }
            if (responseInputStream == null) {
                responseInputStream = new ByteArrayInputStream(responseBodyByte);
            }
        }
        return responseInputStream;
    }

    @NonNull
    @Override
    public HttpHeaders getHeaders() {
        return originalResponse != null ? originalResponse.getHeaders() : new HttpHeaders();
    }

}

上一篇:Linux的基本指令(一)