在项目中,我们需要对接外部厂商或者提供接口给外部厂商。这时候就需要保证接口的安全性,不被攻击。我们需要设置接口调用的频次。

1、自定义一个注解

其中count 为调用频次
time为默认时间

import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;

import java.lang.annotation.*;

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE,ElementType.METHOD})
@Inherited
@Documented
@Order(Ordered.HIGHEST_PRECEDENCE) //最高优先级
public @interface RequestLimit {
    /**
     *
     * 允许访问的次数,默认值MAX_VALUE
     */
    int count() default Integer.MAX_VALUE;

    /**
     *
     * 时间段,单位为毫秒,默认值一分钟
     */
    long time() default 60000;
}

2、使用注解

	@RequestLimit(count = 10)
	@RequestMapping("/sendMessageResultSearch")
	public ResponseModel sendMessageResultSearch(@RequestBody MessageRecord messageRecord,HttpServletRequest request){
        // 获取远程调用ip
        String remoteAddr = RequestUtil.getRemoteAddr(request);
        if (!authIp.contains(remoteAddr)) {
            return RespModelFactory.fail("您的ip暂未授权,请联系管理员");
        }
        if (null == messageRecord || StringUtils.isEmpty(messageRecord.getSignName()) || null == messageRecord.getCreateTime()){
            return RespModelFactory.fail("查询参数不能为空");
        }
        try {
            // 查询指定时间后的所有短信发送情况
            List<MessageRecord> list = this.service.sendMessageResultSearch(messageRecord);
            if (CollectionUtils.isNotEmpty(list)) {
                for (MessageRecord record : list) {
                    String createTimeStr = DateUtil.formatDate(record.getCreateTime(), "yyyy-MM-dd HH:mm:ss");
                    record.setCreateTimeStr(createTimeStr);
                }
            }
            return RespModelFactory.success(list);
        } catch (Exception e) {
            logger.info("查询短信报错" + e);
            return RespModelFactory.fail("查询短信报错");
        }
    }

3、重写拦截器实现接口拦截

这里需要用到缓存,这里根据项目实际情况自己选择。一般可以使用redis。我这里是自定义的一个缓存工具类。

import com.alibaba.fastjson.JSON;
import com.google.common.collect.Maps;
import com.paas.common.annotion.RequestLimit;
import com.paas.common.util.ResponseCodeEnum;
import com.paas.common.web.util.ConcurrentHashMapCacheUtils;
import com.paas.common.web.util.RequestUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;


public class RequestLimitInterceptor extends HandlerInterceptorAdapter {

    private static final Logger logger = LoggerFactory.getLogger("RequestLimitInterceptor");

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        //方法注解
        RequestLimit methodAnnotation = ((HandlerMethod) handler).getMethodAnnotation(RequestLimit.class);
        //类注解
        RequestLimit classAnnotation = ((HandlerMethod) handler).getBean().getClass().getAnnotation(RequestLimit.class);
        boolean vcode = true;
        if (methodAnnotation != null) {
            vcode = validateCode(request, methodAnnotation.count(), methodAnnotation.time());
        } else if (classAnnotation != null) {
            vcode = validateCode(request, classAnnotation.count(), classAnnotation.time());
        }
        if (vcode) {
            return true;
        } else {
            Map<String, Object> resultMap = Maps.newHashMap();
            resultMap.put("retCode", ResponseCodeEnum.REQUESTFULL.getRetCode());
            resultMap.put("retDesc", ResponseCodeEnum.REQUESTFULL.getRetDesc());
            try {
                response.setCharacterEncoding("UTF-8");
                response.setContentType("application/json;charset=UTF-8");
                PrintWriter pw = response.getWriter();
                pw.write(JSON.toJSONString(resultMap));
                pw.flush();
                pw.close();
            } catch (IOException e) {
                logger.error("返回页面数据出错!" + e.getMessage(), e);
                throw e;
            }
            return false;
        }
    }

    /**
     * 接口的访问频次限制
     *
     * @param request
     * @return
     */
    private boolean validateCode(HttpServletRequest request, int maxSize, long timeOut) {
        boolean resultCode = true;
        try {
            String ip = RequestUtil.getRemoteAddr(request);
            String url = request.getRequestURL().toString();
            String key = "req_limit_".concat(url).concat(ip);

            // 将ip设置到缓存里
            Integer count = 0;
            ConcurrentHashMapCacheUtils hashMapCache = ConcurrentHashMapCacheUtils.getHashMapCache();
            hashMapCache.deleteTimeOut(); // 删除过期缓存
            Integer cacheCount = (Integer) hashMapCache.get(key);
            if (null == cacheCount) {
                cacheCount = 0;
            }
            count = cacheCount + 1;
            if (count == 1){
                hashMapCache.set(key,count,timeOut);
            }else {
                hashMapCache.set(key,count,hashMapCache.getExpire(key));
            }
            if (count > maxSize) {
                logger.info("用户IP[" + ip + "]访问地址[" + url + "]超过了限定的次数[" + maxSize + "]");
                resultCode = false;
            }
        } catch (Exception e) {
            logger.error("发生异常: ", e);
        }
        return resultCode;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }
}

4、IP限制

在上述代码中我们可以看到RequestUtil 中的 getRemoteAddr 这个方法。通过这个方法我们可以获取到远程服务调用的IP地址,从而做出限制。

	/**
	 * 获得用户远程地址
	 */
	public static String getRemoteAddr(HttpServletRequest request) {
		// 获取请求主机IP地址,如果通过代理进来,则透过防火墙获取真实IP地址
		String ip = request.getHeader("X-Forwarded-For");
		if (org.apache.commons.lang3.StringUtils.isBlank(ip)) {
			if (org.apache.commons.lang3.StringUtils.isBlank(ip)) {
				ip = request.getHeader("Proxy-Client-IP");
			}
			if (org.apache.commons.lang3.StringUtils.isBlank(ip)) {
				ip = request.getHeader("WL-Proxy-Client-IP");
			}
			if (org.apache.commons.lang3.StringUtils.isBlank(ip)) {
				ip = request.getHeader("HTTP_CLIENT_IP");
			}
			if (org.apache.commons.lang3.StringUtils.isBlank(ip)) {
				ip = request.getHeader("HTTP_X_FORWARDED_FOR");
			}
			if (org.apache.commons.lang3.StringUtils.isBlank(ip)) {
				ip = request.getRemoteAddr();
			}
		} else if (org.apache.commons.lang3.StringUtils.isNotBlank(ip) && ip.length() > 15) {
			String[] ips = ip.split(",");
			for (int index = 0; index < ips.length; index++) {
				String strIp = (String) ips[index];
				if (!("unknown".equalsIgnoreCase(strIp))) {
					ip = strIp;
					break;
				}
			}
		}
		return ip;
	}

5、自定义缓存工具类

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;


public class ConcurrentHashMapCacheUtils {

    // 用来存放数据
    private static final Map<String, ExpireData> CACHE_OBJECT_MAP = new ConcurrentHashMap<>();

    /**
     * 普通缓存放入
     *
     * @param key   键
     * @param value 值
     * @return true成功 false失败
     */
    public boolean set(String key, Object value) {
        try {
            ExpireData expireData = new ExpireData(key, value);
            CACHE_OBJECT_MAP.put(key, expireData);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 普通缓存放入并设置时间
     *
     * @param key               键
     * @param value             值
     * @param expireMillisecond 过期时长(毫秒),要大于0,如果小于等于0,将设置无限期
     * @return
     */
    public boolean set(String key, Object value, long expireMillisecond) {
        try {
            if (expireMillisecond > 0) {
                ExpireData expireData = new ExpireData(key, value, expireMillisecond);
                CACHE_OBJECT_MAP.put(key, expireData);
            } else {
                set(key, value);
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 根据key获取value
     *
     * @param key 键
     * @return 值
     */
    public Object get(String key) {
        ExpireData expireData = CACHE_OBJECT_MAP.get(key);
        if (expireData == null) {
            return null;
        }
        if (expireData.getExpireMillisecond() == 0) {
            return expireData.getValue();
        }
        long nowTime = System.currentTimeMillis();
        if (nowTime < expireData.getEndTime()) {
            return expireData.getValue();
        } else {
            return null;
        }
    }

    /**
     * 根据key删除
     *
     * @param key 键
     * @return true 成功,false 不成功,key不存在也是true 成功
     */
    public boolean delete(String key) {
        try {
            CACHE_OBJECT_MAP.remove(key);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 删除所有
     *
     * @return true 成功,false 不成功
     */
    public boolean flush() {
        try {
            CACHE_OBJECT_MAP.clear();
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 判断key是否存在
     *
     * @param key 键
     * @return true 存在,false不存在
     */
    public boolean hasKey(String key) {
        try {
            return get(key) == null ? false : true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }


    /**
     * 根据key获取还有多长时间过期
     *
     * @param key 键
     * @return 还有多长时间过期(毫秒)
     */
    public long getExpire(String key) {
        ExpireData expireData = CACHE_OBJECT_MAP.get(key);
        if (null == expireData) {
            return 0;
        }
        // 这个代表没有过期时长
        if (expireData.getExpireMillisecond() == 0) {
            return 999999999999999L;
        }
        long expire = expireData.getEndTime() - System.currentTimeMillis();
        if (expire < 0) {
            return 0;
        } else {
            return expire;
        }
    }

    /**
     * 指定缓存失效时间
     *
     * @param key               键
     * @param expireMillisecond 过期时长(毫秒),要大于0,如果小于等于0,将设置无限期
     * @return true 成功,false 不成功
     */
    public boolean expire(String key, long expireMillisecond) {
        try {
            ExpireData expireData = CACHE_OBJECT_MAP.get(key);
            if (null == expireData) {
                return false;
            }
            if (expireMillisecond > 0) {
                expireData = new ExpireData(key, expireData.getValue(), expireMillisecond);
                CACHE_OBJECT_MAP.put(key, expireData);
            } else {
                set(key, expireData.getValue());
            }
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 获取所有的键
     *
     * @return
     */
    public List<String> getKeys() {
        List<String> list = new ArrayList<String>();
        for (String key : CACHE_OBJECT_MAP.keySet()) {
            if (hasKey(key)) {
                list.add(key);
            }
        }
        return list;
    }

    /**
     * 删除过期的缓存(该方法需要根据实际业务定时调用,不然可能导致内存泄露)
     */
    public void deleteTimeOut() {
        System.out.println(CACHE_OBJECT_MAP.keySet());
        for (String key : CACHE_OBJECT_MAP.keySet()) {
            if (!hasKey(key)) {
                delete(key);
            }
        }
        System.out.println(CACHE_OBJECT_MAP.keySet());
    }

    /**
     * 保存数据实体类
     */
    private static class ExpireData {
        private String key; // 键
        private Object value; // 值
        private long expireMillisecond;//过期时长
        private long startTime;//保存时间
        private long endTime;//过期时间

        public ExpireData(String key, Object value) {
            this.key = key;
            this.value = value;
        }

        public ExpireData(String key, Object value, Long expireMillisecond) {
            this.key = key;
            this.value = value;
            this.expireMillisecond = expireMillisecond;
            this.startTime = System.currentTimeMillis();
            this.endTime = startTime + expireMillisecond;
        }

        public String getKey() {
            return key;
        }

        public void setKey(String key) {
            this.key = key;
        }

        public Object getValue() {
            return value;
        }

        public void setValue(Object value) {
            this.value = value;
        }

        public long getExpireMillisecond() {
            return expireMillisecond;
        }

        public void setExpireMillisecond(long expireMillisecond) {
            this.expireMillisecond = expireMillisecond;
        }

        public long getStartTime() {
            return startTime;
        }

        public void setStartTime(long startTime) {
            this.startTime = startTime;
        }

        public long getEndTime() {
            return endTime;
        }

        public void setEndTime(long endTime) {
            this.endTime = endTime;
        }
    }

    /**
     * 私有构造方法,为了防止被别的类new出来
     */
    private ConcurrentHashMapCacheUtils() {
        // 反射破解单例模式需要添加的代码
        if (SingletonHolder.hashMapCache != null) {
            throw new RuntimeException();
        }
    }

    /**
     * 在成员位置创建该类的对象
     */
    private static class SingletonHolder {
        private static final ConcurrentHashMapCacheUtils hashMapCache = new ConcurrentHashMapCacheUtils();
    }

    /**
     * 对外提供静态方法获取该对象
     *
     * @return
     */
    public static ConcurrentHashMapCacheUtils getHashMapCache() {
        return SingletonHolder.hashMapCache;
    }

    /**
     * 下面是为了解决序列化反序列化破解单例模式
     *
     * @return
     */
    private Object readResolve() {
        return SingletonHolder.hashMapCache;
    }

    public static void main(String[] args) throws InterruptedException {
        ConcurrentHashMapCacheUtils hashMapCache = ConcurrentHashMapCacheUtils.getHashMapCache();

        for (int i = 0; i < 4; i++) {
            Integer count = 0;
            Integer cacheCount = (Integer) hashMapCache.get("1");
            if (null == cacheCount) {
                cacheCount = 0;
            }
            count = cacheCount + 1;
            if (count == 1){
                hashMapCache.set("1",count,6000);
            }else {
                hashMapCache.set("1",count,hashMapCache.getExpire("1"));
            }

            Thread.sleep(1000);
            System.out.println(hashMapCache.getExpire("1"));

            System.out.println(count);
        }
    }

}

6、在mvc拦截器配置文件中添加

        <!-- 拦截提供外部接口访问次数限制-->
        <mvc:interceptor>
            <mvc:mapping path="/pub/handleMessage/**" />
            <bean class="com.paas.common.web.filter.RequestLimitInterceptor" />
        </mvc:interceptor>

坚持、细心!!!

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐