blob: d754b0cc72457f9a4a74c33e2c2ee9897d80c60d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.catalina.filters;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletException;
import org.apache.catalina.Context;
import org.apache.catalina.filters.TestRemoteIpFilter.MockFilterChain;
import org.apache.catalina.filters.TestRemoteIpFilter.MockHttpServletRequest;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.unittest.TesterResponse;
import org.apache.tomcat.util.descriptor.web.FilterDef;
import org.apache.tomcat.util.descriptor.web.FilterMap;
import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
import java.time.Instant;
public class TestRateLimitFilter extends TomcatBaseTest {
@Test
public void testRateLimitWith4Clients() throws Exception {
int bucketRequests = 40;
int bucketDuration = 4;
FilterDef filterDef = new FilterDef();
filterDef.addInitParameter(RateLimitFilter.PARAM_BUCKET_REQUESTS, String.valueOf(bucketRequests));
filterDef.addInitParameter(RateLimitFilter.PARAM_BUCKET_DURATION, String.valueOf(bucketDuration));
Tomcat tomcat = getTomcatInstance();
Context root = tomcat.addContext("", TEMP_DIR);
tomcat.start();
MockFilterChain filterChain = new MockFilterChain();
RateLimitFilter rateLimitFilter = testRateLimitFilter(filterDef, root);
int allowedRequests = (int) Math.round(rateLimitFilter.bucketCounter.getRatio() * bucketRequests);
long sleepTime = rateLimitFilter.bucketCounter.getMillisUntilNextBucket();
System.out.printf("Sleeping %d millis for the next time bucket to start\n", Long.valueOf(sleepTime));
Thread.sleep(sleepTime);
TestClient tc1 = new TestClient(rateLimitFilter, filterChain, "10.20.20.5", 200, 5);
TestClient tc2 = new TestClient(rateLimitFilter, filterChain, "10.20.20.10", 200, 10);
TestClient tc3 = new TestClient(rateLimitFilter, filterChain, "10.20.20.20", 200, 20);
TestClient tc4 = new TestClient(rateLimitFilter, filterChain, "10.20.20.40", 200, 40);
// Sleep for up to 10s for clients to complete
int count = 0;
while (count < 100 && (tc1.results[24] == 0 || tc2.results[49] == 0 || tc3.results[allowedRequests - 1] == 0 ||
tc3.results[allowedRequests] == 0 || tc4.results[allowedRequests - 1] == 0 ||
tc4.results[allowedRequests] == 0)) {
Thread.sleep(100);
count++;
}
Assert.assertEquals(200, tc1.results[24]); // only 25 requests made, all allowed
Assert.assertEquals(200, tc2.results[49]); // only 25 requests made, all allowed
Assert.assertEquals(200, tc3.results[allowedRequests - 1]); // first allowedRequests allowed
Assert.assertEquals(429, tc3.results[allowedRequests]); // subsequent requests dropped
Assert.assertEquals(200, tc4.results[allowedRequests - 1]); // first allowedRequests allowed
Assert.assertEquals(429, tc4.results[allowedRequests]); // subsequent requests dropped
}
private RateLimitFilter testRateLimitFilter(FilterDef filterDef, Context root) throws ServletException {
RateLimitFilter rateLimitFilter = new RateLimitFilter();
filterDef.setFilterClass(RateLimitFilter.class.getName());
filterDef.setFilter(rateLimitFilter);
filterDef.setFilterName(RateLimitFilter.class.getName());
root.addFilterDef(filterDef);
FilterMap filterMap = new FilterMap();
filterMap.setFilterName(RateLimitFilter.class.getName());
filterMap.addURLPatternDecoded("*");
root.addFilterMap(filterMap);
FilterConfig filterConfig = TesterFilterConfigs.generateFilterConfig(filterDef);
rateLimitFilter.init(filterConfig);
return rateLimitFilter;
}
static class TestClient extends Thread {
RateLimitFilter filter;
FilterChain filterChain;
String ip;
int requests;
int sleep;
int[] results;
TestClient(RateLimitFilter filter, FilterChain filterChain, String ip, int requests, int rps) {
this.filter = filter;
this.filterChain = filterChain;
this.ip = ip;
this.requests = requests;
this.sleep = 1000 / rps;
this.results = new int[requests];
super.setDaemon(true);
super.start();
}
@Override
public void run() {
try {
for (int i = 0; i < requests; i++) {
MockHttpServletRequest request = new MockHttpServletRequest();
request.setRemoteAddr(ip);
TesterResponse response = new TesterResponseWithStatus();
response.setRequest(request);
filter.doFilter(request, response, filterChain);
results[i] = response.getStatus();
System.out.printf("%s %s: %s %d\n", ip, Instant.now(), Integer.valueOf(i + 1),
Integer.valueOf(response.getStatus()));
Thread.sleep(sleep);
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
}
static class TesterResponseWithStatus extends TesterResponse {
int status = 200;
String message = "OK";
@Override
public void sendError(int status, String message) throws IOException {
this.status = status;
this.message = message;
}
@Override
public int getStatus() {
return status;
}
}
}