'    :用于包裹搜索条件,需转为 \'

% :用于代替任意数目的任意字符,需转换为 \%

_   :用于代替一个任意字符,需转换为 \_

\    :转义符号,需转换为 \\\\

\t   :  需转换为 \\t

\n :  需转换为 \\n

* : 需转换为 \*

str.replaceAll("\\\\", "\\\\\\\\")
   .replace("\'", "\\'")
   .replace("%", "\\%")
   .replace("_", "\\_")
   .replace("\t", "\\\\t")
   .replace("\n", "\\\\n")
   .replace("*", "\\*");

可以使用切面 或者mybatis/mybatisplus的处理器,进行替换

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface CharTrans {

    //请求方式
    RequestMethod reqType();

    //需要转义的对象
    String target() default "";

    //需要转义的字段
    String field() default "";

}
@Configuration
@Aspect
@Slf4j
public class SqlCharTransAspect {
    //设置切入点
    @Pointcut("@annotation(com.***.**.**.CharTrans)")
    public void logPointCut() {
    }

    //创建增强
    @Around("logPointCut()")
    public Object transBefor(ProceedingJoinPoint joinPoint) throws Throwable {
        //获取代理方法的参数
        MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature();
        // 获取参数和参数值
        Method method = methodSignature.getMethod();
        String[] names = methodSignature.getParameterNames();
        Object[] args = joinPoint.getArgs();
        //获取注解参数
        SqlCharTrans sqlCharTrans = method.getAnnotation(SqlCharTrans.class);
        final RequestMethod requestMethod = sqlCharTrans.reqType();
        switch (requestMethod) {
            case GET:
                dealGetParam(names, args, sqlCharTrans);
                break;
            case POST:
                dealPostData(names, args, sqlCharTrans);
                break;
            default:
                break;
        }
        Object returnValue = joinPoint.proceed(args);
        return returnValue;
    }

    private void dealPostData(String[] names, Object[] args, SqlCharTrans sqlCharTrans) throws IllegalAccessException {
        String field = sqlCharTrans.field();
        String target = sqlCharTrans.target();
        if (StrUtil.isNotBlank(target) && StrUtil.isNotBlank(field) && names.length > 0) {
            for (int i = 0; i < names.length; i++) {
                if (target.equals(names[i])) {
                    Class<?> resultClz = args[i].getClass();
                    //获取class里的所有字段  父类字段获取不到
                    Field[] fieldInfo = resultClz.getDeclaredFields();
                    for (Field f : fieldInfo) {
                        if (field.equals(f.getName())) {
                            //成员变量为private,故必须进行此操
                            f.setAccessible(true);
                            String fieldValue = (String) f.get(args[i]);
                            //判断字符中是否包含特殊字符,如果包含就替换掉
                            String newFieldValue = judgeChar(fieldValue);
                            if (StrUtil.isNotBlank(newFieldValue)) {
                                f.set(args[i], newFieldValue);
                            }
                            break;
                        }
                    }
                    break;
                }
            }
        }
    }

    private void dealGetParam(String[] names, Object[] args, SqlCharTrans sqlCharTrans) {
        String field = sqlCharTrans.field();
        if (StrUtil.isNotBlank(field) && names.length > 0) {
            for (int i = 0; i < names.length; i++) {
                if (field.equals(names[i])) {
                    String newFieldValue = judgeChar((String) args[i]);
                    if (StrUtil.isNotBlank(newFieldValue)) {
                        args[i] = newFieldValue;
                    }
                    break;
                }
            }
        }
    }

    private String judgeChar(String fieldValue) {
        String newFieldVale = "";
        if (fieldValue != null) {
            newFieldVale = fieldValue
                    .replaceAll("\\\\", "\\\\\\\\")
                    .replace("\'", "\\'")
                    .replace("%", "\\%")
                    .replace("_", "\\_")
                    .replace("\t", "\\\\t")
                    .replace("\n", "\\\\n")
                    .replace("*", "\\*");
        }
        return newFieldVale;
    }
}
    @PostMapping("/getList")
    @SqlCharTrans(target = "Book",field = "bookName",reqType = RequestMethod.POST)
    public ResponseObject getList(@RequestBody @Validated({Book.ListJobView.class}) Book book){
    pass
}


    @GetMapping("/getById")
    @SqlCharTrans(field = "bookName",reqType = RequestMethod.GET)
    public ResponseObject getById(@RequestParam(value = "bookName", required = false) String bookName){
    pass
}

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐