根据请求参数中的关键字段实现对单位时间内某特定用户的单接口访问频率限制。
如果是请求头中携带 Token 稍微修改一下即可。
频率注解:
package com.seliote.fr.annotation;
import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;
/**
* API 调用频率限制注解,配合切面使用
* 要求注解的方法有且只能有一个参数
*
* @author seliote
*/
@Documented
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ApiFrequency {
// 判断频率使用的请求参数值,多个参数使用 && 连接
String key();
// API 最大频率
int frequency();
// 时间
int time();
// 时间类型
TimeUnit timeUnit();
}
主要切面代码:
package com.seliote.fr.config.api;
import com.seliote.fr.annotation.stereotype.ApiComponent;
import com.seliote.fr.exception.FrequencyException;
import com.seliote.fr.service.RedisService;
import com.seliote.fr.util.CommonUtils;
import lombok.extern.log4j.Log4j2;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.annotation.Order;
import org.springframework.lang.Nullable;
import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.lang.reflect.InvocationTargetException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import static com.seliote.fr.util.ReflectUtils.getClassName;
/**
* API 调用频率限制 AOP
*
* @author seliote
*/
@Log4j2
@Order(1)
@ApiComponent
@Aspect
public class ApiFrequency {
public static final String redisNameSpace = "frequency";
private final RedisService redisService;
@Autowired
public ApiFrequency(RedisService redisService) {
this.redisService = redisService;
log.debug("Construct {}", getClassName(ApiFrequency.class));
}
/**
* API 调用频率限制
*
* @param proceedingJoinPoint AOP ProceedingJoinPoint 对象
* @return 返回对象
* @throws Throwable Controller 处理异常时抛出
*/
@Around("execution(public * com.seliote.fr.controller..*.*(..)) && @annotation(com.seliote.fr.annotation.ApiFrequency)")
public Object apiFrequency(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
Optional<String> uri = CommonUtils.getUri();
var methodSignature = (MethodSignature) proceedingJoinPoint.getSignature();
var method = methodSignature.getMethod();
var annotation = method.getAnnotation(com.seliote.fr.annotation.ApiFrequency.class);
var key = annotation.key();
var frequency = annotation.frequency();
var time = annotation.time();
var timeUnit = annotation.timeUnit();
var args = proceedingJoinPoint.getArgs();
var current = getFrequency(uri.orElse(null), key, time, timeUnit, args);
if (current <= frequency) {
log.info("Pass frequency check for: {}, current: {}, args: {}",
uri.orElse(null), current, Arrays.toString(args));
return proceedingJoinPoint.proceed();
} else {
var msg = String.format("Frequency too high, reject request: %s, args: %s, current: %s",
uri.orElse(null), Arrays.toString(args), current);
log.warn(msg);
throw new FrequencyException(msg);
}
}
/**
* 获取访问频率(包含当前访问并自动加 1)
*
* @param uri 访问的 URI
* @param key 判断频率使用的请求参数值,多个参数使用 && 连接
* @param time 时间
* @param timeUnit 时间类型
* @param args 请求参数
* @return 访问频率
*/
private long getFrequency(@Nullable String uri, @Nullable String key, @Nullable int time,
@Nullable TimeUnit timeUnit, @Nullable Object... args) {
if (uri == null || key == null || key.equals("") || time == 0 || timeUnit == null
|| args == null || args.length != 1) {
log.warn("Get frequency args incorrect, uri: {}, key: {}, time: {}, timeUnit: {}, args: {}",
uri, key, time, timeUnit, Arrays.toString(args));
return 0;
}
var arg = args[0];
var keys = key.split("&&");
// [0] 是 redis 命名空间,[-1] 是用户当前频率起始时间
var identifiers = new String[keys.length + 2];
identifiers[0] = redisNameSpace;
for (var i = 0; i < keys.length; ++i) {
try {
final var pd = new PropertyDescriptor(keys[i], arg.getClass());
identifiers[i + 1] = pd.getReadMethod().invoke(arg).toString();
} catch (IntrospectionException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException exception) {
log.warn("Get frequency occur: {}, message: {}, uri: {}, key: {}, " +
"time: {}, timeUnit: {}, args: {}, exception occur at {}",
getClassName(exception), exception.getMessage(), uri, key, time,
timeUnit, Arrays.toString(args), keys[i]);
return 0;
}
}
var now = Instant.now().getEpochSecond();
var seconds = CommonUtils.time2Seconds(time, timeUnit);
identifiers[identifiers.length - 1] = (now - (now % seconds)) + "";
var redisKey = redisService.formatKey(identifiers);
if (!redisService.exists(redisKey)) {
redisService.setex(redisKey, (int) seconds, "0");
}
return redisService.incr(redisKey);
}
}
用到的一个工具方法:
package com.seliote.fr.util;
import com.seliote.fr.exception.UtilException;
import lombok.extern.log4j.Log4j2;
import org.springframework.lang.NonNull;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SecureRandom;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import static com.seliote.fr.util.ReflectUtils.getClassName;
/**
* 通用工具
*
* @author seliote
*/
@Log4j2
public class CommonUtils {
...
/**
* 将对应时间转换为秒
*
* @param time 时间
* @param timeUnit 时间单位
* @return 秒
*/
public static long time2Seconds(int time, TimeUnit timeUnit) {
var second = TimeUnit.SECONDS.convert(time, timeUnit);
log.trace("CommonUtils.time2Seconds(int, TimeUnit) for {} {}, result {}", time, timeUnit, second);
return second;
}
...
实际使用:
...
/**
* 获取短信的图形验证码
*
* @param ci CI
* @return CO
*/
@ApiFrequency(key = "countryCode&&telNo", frequency = 1, time = 1, timeUnit = TimeUnit.MINUTES)
@RequestMapping("sms")
@ResponseBody
public Co<SmsCo> sms(@Valid @RequestBody SmsCi ci) {
var captchaBytes = captchaService.sms(BeanUtils.copy(ci, SmsSi.class));
if (captchaBytes.isEmpty()) {
log.error("Get captcha for sms, service return empty");
throw new ServiceException("service return empty");
}
var smsCo = new SmsCo(TextUtils.base64Encode(captchaBytes.get()));
return Co.cco(smsCo);
}
...