自定义Mybatis的拦截器

/**
 * 用于辅助审计日志的生成
 * <p>
 * 拦截 StatementHandler.prepare 方法,针对 insert, update, delete 操作对SQL语句进行修改,并注入操作人以及链路跟踪信息
 * <p>
 * 接入时请设置需要审计的业务表名称列表(属性:auditTables),格式:schema.table
 *
 * @date 2019-07-05
 */
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare",
        args = {Connection.class, Integer.class})})
public class AuditLogInterceptor implements Interceptor {

    private static Logger logger = LoggerFactory.getLogger(AuditLogInterceptor.class);

    /**
     * 包含操作人信息的注释模板
     */
    private static final String SQL_TEMPLATE = "/*@%s,%s@*/ %s";

    /**
     * 未知操作人
     */
    private static final String UNKNOWN_OPERATOR = "unknown";

    /**
     * 开关控制是否进行拦截 默认开启
     */
    private boolean enable = true;

    /**
     * 审计的表名称
     */
    private List<String> auditTables;

    /**
     * 拦截的数据库操作类型
     */
    private List<String> commandTypes = Arrays.asList(SqlCommandType.INSERT.toString(),
            SqlCommandType.UPDATE.toString(), SqlCommandType.DELETE.toString());

    /**
     * SQL 类型
     */
    private String sqlType = JdbcConstants.MYSQL;


    public AuditLogInterceptor() {
        String interceptorEnabled = System.getProperty("auditlog.interceptor.enable");
        if (interceptorEnabled != null && !Boolean.valueOf(interceptorEnabled)) {
            this.enable = false;
        }
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (!enable) {
            return invocation.proceed();
        }
        // 判断是不是需要审计的命令类型
        if (!isAuditCommandType(invocation)) {
            return invocation.proceed();
        }

        // 判断是不是需要审计的表名
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        if (!isAuditTable(invocation, boundSql.getSql())) {
            return invocation.proceed();
        }
        // 修改sql
        modifySql(boundSql);
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String auditTables = (String) properties.get("auditTables");
        if (StringUtils.isBlank(auditTables)) {
            throw new IllegalArgumentException("Property auditTables is required");
        }

        this.auditTables = Arrays.stream(auditTables.toUpperCase().split(","))
                .map(String::trim)
                .collect(Collectors.toList());

        String sqlType = (String) properties.get("sqlType");
        if (StringUtils.isNotBlank(sqlType)) {
            this.sqlType = sqlType.toUpperCase();
        }

        String commandTypes = (String) properties.get("commandTypes");
        if (StringUtils.isNotBlank(commandTypes)) {
            this.commandTypes = Arrays.asList(commandTypes.toUpperCase().split(","));
        }
    }


    /**
     * 过滤不拦截的表
     * 
     * @param invocation
     * @param sql
     * @return
     */
    private boolean isAuditTable(Invocation invocation, String sql) {
        try {
            if (CollectionUtils.isEmpty(auditTables)) {
                return false;
            }

            String schema = getSchema(invocation);
            List<String> tables = getTables(schema, sql);
            for (String table : tables) {
                table = table.toUpperCase();
                for (String auditTable : auditTables) {
                    if (table.equals(auditTable) || table.matches(auditTable)) {
                        return true;
                    }
                }
            }
            logger.debug("filtered tables: {}", tables);
        } catch (Exception e) {
            logger.error("Failed check audit table", e);
            return false;
        }
        return false;
    }



    /**
     * 获取SQL命令类型
     *
     * @param invocation
     * @return
     */
    private boolean isAuditCommandType(Invocation invocation) {
        try {
            StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
            MetaObject metaObject =
                    MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
                            SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
            MappedStatement mappedStatement =
                    (MappedStatement) metaObject.getValue("delegate.mappedStatement");
            SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
            return sqlCommandType == null ? false : commandTypes.contains(sqlCommandType.toString());
        } catch (Exception e) {
            logger.error("Failed check sql command type", e);
            return false;
        }
    }


    /**
     * 在sql前面prepend注释信息,其包含操作人和traceId
     *
     * @param boundSql
     */
    private void modifySql(BoundSql boundSql) {
        try {
            // 修改sql
            Long customerId = CustomerIdGetterFactory.getCustomerId();
            String operatorId = customerId != null ? String.valueOf(customerId) : UNKNOWN_OPERATOR;
            String updatedSql = String.format(SQL_TEMPLATE, operatorId, TraceId.get(), boundSql.getSql());
            logger.debug("updatedSql: {}", updatedSql);

            // 回写sql
            Field field = boundSql.getClass().getDeclaredField("sql");
            field.setAccessible(true);
            field.set(boundSql, updatedSql);
        } catch (Exception e) {
            logger.error("Failed modify sql", e);
        }
    }

    /**
     * 解析数据库名称
     * 
     * @param invocation
     * @return
     */
    private String getSchema(Invocation invocation) {
        Connection conn = (Connection) invocation.getArgs()[0];
        try {
            return conn.getMetaData().getConnection().getCatalog();
        } catch (SQLException e) {
            logger.error("Failed get scheme", e);
            return null;
        }
    }

    /**
     * 解析表名称
     * 
     * @param schema
     * @param sql
     * @return
     */
    private List<String> getTables(String schema, String sql) {
        List<SQLStatement> statements = SQLUtils.parseStatements(sql, sqlType);
        return statements.stream().map(statement -> {
            MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
            statement.accept(visitor);
            String table = visitor.getCurrentTable();
            if (schema == null) {
                return table;
            }
            return table.contains(".") ? table : schema + "." + table;
        }).collect(Collectors.toList());
    }
}

 

 

 

@Configuration
public class TestLogConfig {

    /**
     *
     * @return
     */
    @Bean
    public AuditLogInterceptor auditLogInterceptor() {
        AuditLogInterceptor sqlStatsInterceptor = new AuditLogInterceptor();
        Properties properties = new Properties();
        properties.setProperty("auditTables", "test.*");
        sqlStatsInterceptor.setProperties(properties);
        return sqlStatsInterceptor;
    }
}
上一篇:dubbo源码分析- 集群容错之Cluster(一)


下一篇:为何在参数列表求值后执行空检查?