【基于spring-cloud-gateway实现自己的网关过滤器】

基于spring-cloud-gateway实现自己的网关过滤器

spring cloud gateway custom starter

自定义非阻塞式反应网关服务,集成鉴权、限流、响应的增强处理等等
  • 环境要求
    <properties>
          <java.version>17</java.version>
          <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
          <maven.compiler.target>17</maven.compiler.target>
          <maven.compiler.source>17</maven.compiler.source>
          <java.source.version>17</java.source.version>
          <java.target.version>17</java.target.version>
          <spring-boot.version>3.1.12</spring-boot.version>
          <spring-cloud.version>2022.0.5</spring-cloud.version>
          <commons.pool2.version>2.12.0</commons.pool2.version>
          <redisson.version>3.34.1</redisson.version>
          <com.fastjson.jackson.version>2.17.2</com.fastjson.jackson.version>
          <commons.lang3.version>3.16.0</commons.lang3.version>
          <lombok.version>1.18.32</lombok.version>
      </properties>
    
  • GatewayFilter 路由过滤器
    • TokenFilterGatewayFilterFactory,鉴权处理
    • 配置示例:
    spring:
      cloud:
        gateway:
          routes:
            - id: testRoute
              uri: http://127.0.0.1:8080
              predicates:
                - Path=/test/**
              filters:
                - name: TokenFilter
                  args:
                    requestHeaderKey: auth
    
    • RlimterGatewayFilterFactory,自定义key,比如:我们想根据不同用户去做自定义限流,那么我们可以在自己的网关过滤工厂里面将limterKey设置为根据请求头自定义的用户标识,来进行自定义的配置。相比重写spring-cloud-gateway里面自定义的keyResolver和RedisLimiter相对容易一些。
      • 配置示例:
      spring:
          cloud:
              gateway:
                routes:
                  - id: route_test
                    uri: http://192.168.1.1:8091
                    predicates:
                      - Path=/from/requestIds/to/appNames
                    filters:
                      - name: RLimter
                        args:
                          limterKey: v1_10 #限流所需要的key
                          rate: 5  #每秒允许的请求数
                          crust: 0 #每秒令牌桶的填充数
      
      • 更多路由过滤器扩展中。。。
    自定义网关过滤工厂实现,TokenFilterGatewayFilterFactory和RlimterGatewayFilterFactory,以自定义限流器RlimterGatewayFilterFactory为例
  • 首先定义网关过滤工厂功能接口,限流的key,速率和桶的大小,我们是按照spring-cloud-gateway内部限流的实现改编而来的,通过调用lua脚本,采用redis的令牌桶算法做限流
public interface RLimter {

    Mono<Response> isAllowed(String limitKey, String rate, String crust);

    @Setter
    @Getter
    @NoArgsConstructor
    @AllArgsConstructor
    class Response {
        private boolean allowed;
    }
}
  • 限流功能组件的注入和声明,注入我们需要的Bean,注入RedisScript,调用自定义的lua脚本,以及StringRedisTemplate,因为下面这段代码我是将整个网关作为的启动器,所以限流的实现类也一并交给spring管理了,RRedisRateLimiter
@Configuration
public class RLimterAutoConfiguration {

    @Bean(name = "rredisRequestRateLimiterScript")
    public RedisScript<?> redisRequestRateLimiterScript() {
        DefaultRedisScript redisScript = new DefaultRedisScript();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/request_rate_limiter.lua")));
        redisScript.setResultType(List.class);
        return redisScript;
    }

    @Bean(name = "rredisRateLimiter")
    public RRedisRateLimiter redisRateLimiter(@Qualifier("rredisRequestRateLimiterScript") RedisScript<List<Long>> redisRequestRateLimiterScript, StringRedisTemplate redisTemplate) {
        return new RRedisRateLimiter(redisTemplate, redisRequestRateLimiterScript);
    }

    @Bean
    public RLimterGatewayFilterFactory rLimterGatewayFilterFactory(@Qualifier("rredisRateLimiter") RRedisRateLimiter redisRateLimiter) {
        return new RLimterGatewayFilterFactory(redisRateLimiter);
    }
}
  • 上面代码中提到的new ClassPathResource(“script/request_rate_limiter.lua”),工程resources目录下即可,也是摘抄自spring-cloud-gateway内部的限流脚本,原封不动
redis.replicate_commands()

local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
-- redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])

local fill_time = capacity / rate
local ttl = math.floor(fill_time * 2)

-- for testing, it should use redis system time in production
if now == nil then
  now = redis.call('TIME')[1]
end

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. now)
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)

local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)

if ttl > 0 then
  redis.call("setex", tokens_key, ttl, new_tokens)
  redis.call("setex", timestamp_key, ttl, now)
end

-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }

  • 然后是限流实现类,通过传入上面lua脚本需要的四个参数即可,lua脚本中requested为每次从桶里面取出的令牌数,这个默认为1,此处不关注这个,默认值即可。
public class RRedisRateLimiter implements RLimter {

    private static final Logger logger = LoggerFactory.getLogger(RedisRateLimiter.class);
    private final StringRedisTemplate redisTemplate;
    private final RedisScript<List<Long>> script;

    public RRedisRateLimiter(StringRedisTemplate redisTemplate, RedisScript<List<Long>> script) {
        this.redisTemplate = redisTemplate;
        this.script = script;
    }

    static List<String> getKeys(String id) {
        String prefix = "request_rate_limiter.{" + id;
        String tokenKey = prefix + "}.tokens";
        String timestampKey = prefix + "}.timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }

    @Override
    public Mono<Response> isAllowed(String key, String rate, String crust) {
        List<String> keys = getKeys(key);
        List<Long> exec;
        try {
            exec = this.redisTemplate.execute(this.script, keys, rate, crust, "", "1");
        } catch (Throwable throwable) {
            logger.error("Error calling rate limiter lua", throwable);
            exec = Arrays.asList(1L, -1L);
        }
        assert exec != null;
        boolean allowed = exec.get(0) == 1L;
        return Mono.just(new Response(allowed));
    }
}
  • 最后就是定义我们自己的自定义限流网关工厂了,通过继承spring-cloud-gateway的一个父类,帮助我们加载自定义的网关过滤工厂,AbstractGatewayFilterFactory,父类支持传入我们的自定义配置参数,Config,通过泛型参数定义自己的配置类,并在构造中传入。
public class RLimterGatewayFilterFactory extends AbstractGatewayFilterFactory<RLimterGatewayFilterFactory.Config2> {

    private final RRedisRateLimiter redisRateLimiter;

    public RLimterGatewayFilterFactory(RRedisRateLimiter redisRateLimiter) {
        super(Config2.class);
        this.redisRateLimiter = redisRateLimiter;
    }

    @Override
    public GatewayFilter apply(Config2 config) {
        return (exchange, chain) -> redisRateLimiter.isAllowed(config.limterKey, config.rate, config.crust).flatMap(response -> {
            if (response.isAllowed()) {
                return chain.filter(exchange);
            } else {
                ServerWebExchangeUtils.setResponseStatus(exchange, config.getStatusCode());
                return exchange.getResponse().setComplete();
            }
        });
    }

    @Setter
    @Getter
    public static class Config2 implements HasRouteId {
        private String routeId;
        private String limterKey = "default";
        private String rate = "1";
        private String crust = "1";
        private HttpStatus statusCode = HttpStatus.TOO_MANY_REQUESTS;

        @Override
        public String getRouteId() {
            return routeId;
        }

        @Override
        public void setRouteId(String routeId) {
            this.routeId = routeId;
        }
    }
}
  • 这里我只是写了个简单的实例,如果实现自定义的限流,还需要在网关工厂里面,通过exchange拿到请求体来进行相应的key的解析,结合业务实现自定义的限流。配置可以参考一开始的yaml配置示例。可以测试当把令牌桶设置为0时,会给出TOO MANY REQUEST的429状态码。
全局过滤器案例
  • 首先定义自己的上下文对象
@Setter
@Getter
public class RequestContext {
    private String requestId;
    private long requestStartTime;
    private String requestIp;
}

  • 全局过滤器实现
@Component
public class RequestContextFilter implements WebFilter, Ordered {

    private static final Logger logger = LoggerFactory.getLogger(RequestContextFilter.class);

    @Override
    public int getOrder() {
        return OrderConstant.REQUEST_CONTEXT_ORDER;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        long requestStartTime = SystemClock.now();
        String requestId = generateRequestId();
        ServerHttpRequest request = exchange.getRequest();
        String uri = request.getURI().getRawPath();
        String requestIp = IpUtils.tryGetRealIp(request);
        exchange.getResponse().getHeaders().add("requestId", requestId);
        logger.info("request start,requestId:{}, requestUri:{},ip:{},requestStartTime:{}", requestId, uri, requestIp, requestStartTime);
        RequestContext context = new RequestContext();
        context.setRequestId(requestId);
        context.setRequestStartTime(requestStartTime);
        context.setRequestIp(requestIp);
        return chain.filter(exchange).contextWrite(Context.of(RequestContext.class, context))
                .doOnEach(signal -> {
                    long requestEndTime = SystemClock.now();
                    if (signal.isOnComplete()) {
                        logger.info("request end,requestId:{},response:{},requestEndTime:{},耗时ms:{}", requestId, exchange.getResponse().getStatusCode(), requestEndTime, (requestEndTime - requestStartTime));
                    }
                    if (signal.isOnError()) {
                        logger.info("request end,requestId:{},error:{},requestEndTime:{},耗时ms:{}", requestId, signal.getThrowable(), requestEndTime, (requestEndTime - requestStartTime));
                    }
                });
    }

    private String generateRequestId() {
        return UUID.randomUUID().toString().replaceAll("-", "");
    }
}
上一篇:【Unity踩坑】Unity更新Google Play结算库


下一篇:记一次vue-cli老项目的打包时长优化