Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions aikido_zen/middleware/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ async def __call__(self, scope, receive, send):
message = "You are rate limited by Zen."
if result["trigger"] == "ip" and result["ip"]:
message += " (Your IP: " + result["ip"] + ")"
return await send_status_code_and_text(send, (message, 429))
extra_headers = [
(b"retry-after", str(result["retry_after_seconds"]).encode())
]
return await send_status_code_and_text(send, (message, 429), extra_headers)

if result["type"] == "blocked":
return await send_status_code_and_text(
Expand All @@ -30,13 +33,16 @@ async def __call__(self, scope, receive, send):
return await self.app(scope, receive, send)


async def send_status_code_and_text(send, pre_response):
async def send_status_code_and_text(send, pre_response, extra_headers=None):
"""Sends a status code and text"""
headers = [(b"content-type", b"text/plain")]
if extra_headers:
headers = headers + extra_headers
await send(
{
"type": "http.response.start",
"status": pre_response[1],
"headers": [(b"content-type", b"text/plain")],
"headers": headers,
}
)
await send(
Expand Down
4 changes: 3 additions & 1 deletion aikido_zen/middleware/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __call__(self, request):
message = "You are rate limited by Zen."
if result["trigger"] == "ip" and result["ip"]:
message += " (Your IP: " + result["ip"] + ")"
return self.HttpResponse(message, content_type="text/plain", status=429)
response = self.HttpResponse(message, content_type="text/plain", status=429)
response["Retry-After"] = str(result["retry_after_seconds"])
return response

if result["type"] == "blocked":
return self.HttpResponse(
Expand Down
1 change: 1 addition & 0 deletions aikido_zen/middleware/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __call__(self, environ, start_response):
if result["trigger"] == "ip" and result["ip"]:
message += " (Your IP: " + result["ip"] + ")"
res = self.Response(message, mimetype="text/plain", status=429)
res.headers["Retry-After"] = str(result["retry_after_seconds"])
return res(environ, start_response)

if result["type"] == "blocked":
Expand Down
3 changes: 2 additions & 1 deletion aikido_zen/middleware/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ def test_cache_comms_with_endpoints():

mock_comms.send_data_to_bg_process.return_value = {
"success": True,
"data": {"block": True, "trigger": "my_trigger"},
"data": {"block": True, "trigger": "my_trigger", "retry_after_seconds": 10},
}
assert thread_cache.stats.rate_limited_hits == 0
assert should_block_request() == {
"block": True,
"ip": "1.1.1.1",
"type": "ratelimited",
"trigger": "my_trigger",
"retry_after_seconds": 10,
}
assert thread_cache.stats.rate_limited_hits == 1
1 change: 1 addition & 0 deletions aikido_zen/middleware/should_block_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def should_block_request():
"type": "ratelimited",
"trigger": ratelimit_res["data"]["trigger"],
"ip": context.remote_address,
"retry_after_seconds": ratelimit_res["data"]["retry_after_seconds"],
}
except Exception as e:
logger.debug("Exception occurred in should_block_request: %s", e)
Expand Down
30 changes: 21 additions & 9 deletions aikido_zen/ratelimiting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,46 @@ def should_ratelimit_request(
windows_size_in_ms = int(endpoint["rateLimiting"]["windowSizeInMS"])

if group:
allowed = connection_manager.rate_limiter.is_allowed(
result = connection_manager.rate_limiter.is_allowed(
get_key_for_group(endpoint, group),
windows_size_in_ms,
max_requests,
)
if not allowed:
return {"block": True, "trigger": "group"}
if not result["allowed"]:
return {
"block": True,
"trigger": "group",
"retry_after_seconds": result["retry_after_seconds"],
}

# Do not check IP or user rate limit if group is set
return {"block": False}
if user:
allowed = connection_manager.rate_limiter.is_allowed(
result = connection_manager.rate_limiter.is_allowed(
get_key_for_user(endpoint, user),
windows_size_in_ms,
max_requests,
)
if not allowed:
return {"block": True, "trigger": "user"}
if not result["allowed"]:
return {
"block": True,
"trigger": "user",
"retry_after_seconds": result["retry_after_seconds"],
}
# Do not check IP rate limit if user is set
return {"block": False}
if remote_address:
allowed = connection_manager.rate_limiter.is_allowed(
result = connection_manager.rate_limiter.is_allowed(
get_key_for_ip(endpoint, remote_address),
windows_size_in_ms,
max_requests,
)
if not allowed:
return {"block": True, "trigger": "ip"}
if not result["allowed"]:
return {
"block": True,
"trigger": "ip",
"retry_after_seconds": result["retry_after_seconds"],
}

return {"block": False}

Expand Down
103 changes: 53 additions & 50 deletions aikido_zen/ratelimiting/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def test_rate_limits_by_ip():
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
"block": False
}
assert should_ratelimit_request(route_metadata, "1.2.3.4", None, cm) == {
"block": True,
"trigger": "ip",
}
result = should_ratelimit_request(route_metadata, "1.2.3.4", None, cm)
assert result["block"] is True
assert result["trigger"] == "ip"
assert result["retry_after_seconds"] >= 0


def test_rate_limiting_ip_allowed():
Expand Down Expand Up @@ -126,10 +126,10 @@ def test_rate_limiting_by_user(user):
assert should_ratelimit_request(route_metadata, "1.2.3.6", user, cm) == {
"block": False
}
assert should_ratelimit_request(route_metadata, "1.2.3.7", user, cm) == {
"block": True,
"trigger": "user",
}
result = should_ratelimit_request(route_metadata, "1.2.3.7", user, cm)
assert result["block"] is True
assert result["trigger"] == "user"
assert result["retry_after_seconds"] >= 0


def test_rate_limiting_with_wildcard():
Expand Down Expand Up @@ -160,9 +160,12 @@ def test_rate_limiting_with_wildcard():
) == {"block": False}

# This request should trigger the rate limit
assert should_ratelimit_request(
result = should_ratelimit_request(
create_route_metadata(route="/api/login"), "1.2.3.4", None, cm
) == {"block": True, "trigger": "ip"}
)
assert result["block"] is True
assert result["trigger"] == "ip"
assert result["retry_after_seconds"] >= 0


def test_rate_limiting_with_wildcard2():
Expand Down Expand Up @@ -193,10 +196,10 @@ def test_rate_limiting_with_wildcard2():

# This request should trigger the rate limit
metadata = create_route_metadata(route="/api/login", method="GET")
assert should_ratelimit_request(metadata, "1.2.3.4", None, cm) == {
"block": True,
"trigger": "ip",
}
result = should_ratelimit_request(metadata, "1.2.3.4", None, cm)
assert result["block"] is True
assert result["trigger"] == "ip"
assert result["retry_after_seconds"] >= 0


def test_rate_limiting_by_user_with_same_ip():
Expand Down Expand Up @@ -228,10 +231,10 @@ def test_rate_limiting_by_user_with_same_ip():
}

# This request should trigger the rate limit
assert should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm) == {
"block": True,
"trigger": "user",
}
result = should_ratelimit_request(metadata, "1.2.3.4", {"id": "123"}, cm)
assert result["block"] is True
assert result["trigger"] == "user"
assert result["retry_after_seconds"] >= 0


def test_rate_limiting_by_user_with_different_ips():
Expand Down Expand Up @@ -267,10 +270,10 @@ def test_rate_limiting_by_user_with_different_ips():
}

# This request from second IP should trigger the rate limit
assert should_ratelimit_request(metadata, "4.3.2.1", {"id": "123"}, cm) == {
"block": True,
"trigger": "user",
}
result = should_ratelimit_request(metadata, "4.3.2.1", {"id": "123"}, cm)
assert result["block"] is True
assert result["trigger"] == "user"
assert result["retry_after_seconds"] >= 0


def test_rate_limiting_same_ip_different_users():
Expand Down Expand Up @@ -385,10 +388,10 @@ def test_rate_limits_by_user_with_different_ips():
"block": False
}
# This request should trigger the rate limit by group
assert should_ratelimit_request(route_metadata, "4.3.2.1", user, cm, "group1") == {
"block": True,
"trigger": "group",
}
result = should_ratelimit_request(route_metadata, "4.3.2.1", user, cm, "group1")
assert result["block"] is True
assert result["trigger"] == "group"
assert result["retry_after_seconds"] >= 0


def test_rate_limits_different_users_in_same_group():
Expand Down Expand Up @@ -420,12 +423,12 @@ def test_rate_limits_different_users_in_same_group():
route_metadata, "1.2.3.4", {"id": "789"}, cm, "group1"
) == {"block": False}
# This request should trigger the rate limit by group
assert should_ratelimit_request(
result = should_ratelimit_request(
route_metadata, "4.3.2.1", {"id": "101112"}, cm, "group1"
) == {
"block": True,
"trigger": "group",
}
)
assert result["block"] is True
assert result["trigger"] == "group"
assert result["retry_after_seconds"] >= 0


def test_works_with_multiple_rate_limit_groups_and_different_users():
Expand Down Expand Up @@ -457,30 +460,30 @@ def test_works_with_multiple_rate_limit_groups_and_different_users():
route_metadata, "4.3.2.1", {"id": "101112"}, cm, "group2"
) == {"block": False}
# This request should trigger the rate limit for group1
assert should_ratelimit_request(
result = should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "789"}, cm, "group1"
) == {
"block": True,
"trigger": "group",
}
)
assert result["block"] is True
assert result["trigger"] == "group"
assert result["retry_after_seconds"] >= 0
# This request should also trigger the rate limit for group1
assert should_ratelimit_request(
result = should_ratelimit_request(
route_metadata, "1.2.3.4", {"id": "4321"}, cm, "group1"
) == {
"block": True,
"trigger": "group",
}
)
assert result["block"] is True
assert result["trigger"] == "group"
assert result["retry_after_seconds"] >= 0
# First request from user 953, group2
assert should_ratelimit_request(
route_metadata, "4.3.2.1", {"id": "953"}, cm, "group2"
) == {"block": False}
# This request should trigger the rate limit for group2
assert should_ratelimit_request(
result = should_ratelimit_request(
route_metadata, "4.3.2.1", {"id": "1563"}, cm, "group2"
) == {
"block": True,
"trigger": "group",
}
)
assert result["block"] is True
assert result["trigger"] == "group"
assert result["retry_after_seconds"] >= 0


def test_rate_limits_by_group_if_user_is_not_set():
Expand Down Expand Up @@ -512,10 +515,10 @@ def test_rate_limits_by_group_if_user_is_not_set():
"block": False
}
# This request should trigger the rate limit by group
assert should_ratelimit_request(route_metadata, "4.3.2.1", None, cm, "group1") == {
"block": True,
"trigger": "group",
}
result = should_ratelimit_request(route_metadata, "4.3.2.1", None, cm, "group1")
assert result["block"] is True
assert result["trigger"] == "group"
assert result["retry_after_seconds"] >= 0


def test_does_not_rate_limit_excluded_users():
Expand Down
16 changes: 13 additions & 3 deletions aikido_zen/ratelimiting/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Mostly exports the class RateLimiter
"""

import math
from aikido_zen.helpers.get_current_unixtime_ms import get_unixtime_ms
from .lru_cache import LRUCache

Expand All @@ -18,7 +19,8 @@ def __init__(self, max_items, time_to_live_in_ms):

def is_allowed(self, key, window_size_in_ms, max_requests):
"""
Checks if the request is allowed given the history
Checks if the request is allowed given the history.
Returns {"allowed": True} or {"allowed": False, "retry_after_seconds": int}.
"""
current_time = get_unixtime_ms()
request_timestamps = self.rate_limited_items.get(key) or []
Expand All @@ -39,5 +41,13 @@ def is_allowed(self, key, window_size_in_ms, max_requests):
request_timestamps.append(current_time)
self.rate_limited_items.set(key, request_timestamps)

# if the total amount of requests in the current window exceeds max requests, we rate-limit
return len(request_timestamps) <= max_requests
if len(request_timestamps) <= max_requests:
return {"allowed": True}

retry_after_ms = max(
0, request_timestamps[0] + window_size_in_ms - current_time
)
return {
"allowed": False,
"retry_after_seconds": math.ceil(retry_after_ms / 1000),
}
Loading
Loading