Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1342,17 +1342,17 @@ async def your_background_function(

### 5.11 Rate Limiting

To limit how many times a user can make a request in a certain interval of time (very useful to create subscription plans or just to protect your API against DDOS), you may just use the `rate_limiter` dependency:
To limit how many times a user can make a request in a certain interval of time (very useful to create subscription plans or just to protect your API against DDOS), you may just use the `rate_limiter_dependency` dependency:

```python
from fastapi import Depends

from app.api.dependencies import rate_limiter
from app.api.dependencies import rate_limiter_dependency
from app.core.utils import queue
from app.schemas.job import Job


@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)])
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter_dependency)])
async def create_task(message: str):
job = await queue.pool.enqueue_job("sample_background_task", message)
return {"id": job.job_id}
Expand Down Expand Up @@ -1446,7 +1446,7 @@ curl -X POST 'http://127.0.0.1:8000/api/v1/tasks/task?message=test' \
```

> \[!TIP\]
> Since the `rate_limiter` dependency uses the `get_optional_user` dependency instead of `get_current_user`, it will not require authentication to be used, but will behave accordingly if the user is authenticated (and token is passed in header). If you want to ensure authentication, also use `get_current_user` if you need.
> Since the `rate_limiter_dependency` dependency uses the `get_optional_user` dependency instead of `get_current_user`, it will not require authentication to be used, but will behave accordingly if the user is authenticated (and token is passed in header). If you want to ensure authentication, also use `get_current_user` if you need.

To change a user's tier, you may just use the `PATCH api/v1/user/{username}/tier` endpoint.
Note that for flexibility (since this is a boilerplate), it's not necessary to previously inform a tier_id to create a user, but you probably should set every user to a certain tier (let's say `free`) once they are created.
Expand Down
9 changes: 6 additions & 3 deletions src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..core.exceptions.http_exceptions import ForbiddenException, RateLimitException, UnauthorizedException
from ..core.logger import logging
from ..core.security import oauth2_scheme, verify_token
from ..core.utils.rate_limit import is_rate_limited
from ..core.utils.rate_limit import rate_limiter
from ..crud.crud_rate_limit import crud_rate_limits
from ..crud.crud_tier import crud_tiers
from ..crud.crud_users import crud_users
Expand Down Expand Up @@ -72,9 +72,12 @@ async def get_current_superuser(current_user: Annotated[dict, Depends(get_curren
return current_user


async def rate_limiter(
async def rate_limiter_dependency(
request: Request, db: Annotated[AsyncSession, Depends(async_get_db)], user: User | None = Depends(get_optional_user)
) -> None:
if hasattr(request.app.state, "initialization_complete"):
await request.app.state.initialization_complete.wait()

path = sanitize_path(request.url.path)
if user:
user_id = user["id"]
Expand All @@ -96,6 +99,6 @@ async def rate_limiter(
user_id = request.client.host
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD

is_limited = await is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period)
is_limited = await rate_limiter.is_rate_limited(db=db, user_id=user_id, path=path, limit=limit, period=period)
if is_limited:
raise RateLimitException("Rate limit exceeded.")
2 changes: 1 addition & 1 deletion src/app/api/v1/rate_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ...api.dependencies import get_current_superuser
from ...core.db.database import async_get_db
from ...core.exceptions.http_exceptions import DuplicateValueException, NotFoundException, RateLimitException
from ...core.exceptions.http_exceptions import DuplicateValueException, NotFoundException
from ...crud.crud_rate_limit import crud_rate_limits
from ...crud.crud_tier import crud_tiers
from ...schemas.rate_limit import RateLimitCreate, RateLimitCreateInternal, RateLimitRead, RateLimitUpdate
Expand Down
4 changes: 2 additions & 2 deletions src/app/api/v1/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from arq.jobs import Job as ArqJob
from fastapi import APIRouter, Depends

from ...api.dependencies import rate_limiter
from ...api.dependencies import rate_limiter_dependency
from ...core.utils import queue
from ...schemas.job import Job

router = APIRouter(prefix="/tasks", tags=["tasks"])


@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)])
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter_dependency)])
async def create_task(message: str) -> dict[str, str]:
"""Create a new background task.

Expand Down
46 changes: 27 additions & 19 deletions src/app/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from fastapi.openapi.utils import get_openapi

from ..api.dependencies import get_current_superuser
from ..core.utils.rate_limit import rate_limiter
from ..middleware.client_cache_middleware import ClientCacheMiddleware
from ..models import *
from .config import (
AppSettings,
ClientSideCacheSettings,
Expand All @@ -24,9 +26,10 @@
RedisRateLimiterSettings,
settings,
)
from .db.database import Base, async_engine as engine
from .db.database import Base
from .db.database import async_engine as engine
from .utils import cache, queue, rate_limit
from ..models import *


# -------------- database --------------
async def create_tables() -> None:
Expand Down Expand Up @@ -55,8 +58,7 @@ async def close_redis_queue_pool() -> None:

# -------------- rate limit --------------
async def create_redis_rate_limit_pool() -> None:
rate_limit.pool = redis.ConnectionPool.from_url(settings.REDIS_RATE_LIMIT_URL)
rate_limit.client = redis.Redis.from_pool(rate_limit.pool) # type: ignore
rate_limiter.initialize(settings.REDIS_RATE_LIMIT_URL) # type: ignore


async def close_redis_rate_limit_pool() -> None:
Expand Down Expand Up @@ -85,30 +87,36 @@ def lifespan_factory(

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
from asyncio import Event

initialization_complete = Event()
app.state.initialization_complete = initialization_complete

await set_threadpool_tokens()

if isinstance(settings, DatabaseSettings) and create_tables_on_start:
await create_tables()
try:
if isinstance(settings, RedisCacheSettings):
await create_redis_cache_pool()

if isinstance(settings, RedisCacheSettings):
await create_redis_cache_pool()
if isinstance(settings, RedisQueueSettings):
await create_redis_queue_pool()

if isinstance(settings, RedisQueueSettings):
await create_redis_queue_pool()
if isinstance(settings, RedisRateLimiterSettings):
await create_redis_rate_limit_pool()

if isinstance(settings, RedisRateLimiterSettings):
await create_redis_rate_limit_pool()
initialization_complete.set()

yield
yield

if isinstance(settings, RedisCacheSettings):
await close_redis_cache_pool()
finally:
if isinstance(settings, RedisCacheSettings):
await close_redis_cache_pool()

if isinstance(settings, RedisQueueSettings):
await close_redis_queue_pool()
if isinstance(settings, RedisQueueSettings):
await close_redis_queue_pool()

if isinstance(settings, RedisRateLimiterSettings):
await close_redis_rate_limit_pool()
if isinstance(settings, RedisRateLimiterSettings):
await close_redis_rate_limit_pool()

return lifespan

Expand Down
63 changes: 43 additions & 20 deletions src/app/core/utils/rate_limit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import UTC, datetime
from typing import Optional

from redis.asyncio import ConnectionPool, Redis
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -8,31 +9,53 @@

logger = logging.getLogger(__name__)

pool: ConnectionPool | None = None
client: Redis | None = None

class RateLimiter:
_instance: Optional["RateLimiter"] = None
pool: Optional[ConnectionPool] = None
client: Optional[Redis] = None

async def is_rate_limited(db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool:
if client is None:
logger.error("Redis client is not initialized.")
raise Exception("Redis client is not initialized.")
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

current_timestamp = int(datetime.now(UTC).timestamp())
window_start = current_timestamp - (current_timestamp % period)
@classmethod
def initialize(cls, redis_url: str) -> None:
instance = cls()
if instance.pool is None:
instance.pool = ConnectionPool.from_url(redis_url)
instance.client = Redis(connection_pool=instance.pool)

sanitized_path = sanitize_path(path)
key = f"ratelimit:{user_id}:{sanitized_path}:{window_start}"
@classmethod
def get_client(cls) -> Redis:
instance = cls()
if instance.client is None:
logger.error("Redis client is not initialized.")
raise Exception("Redis client is not initialized.")
return instance.client

try:
current_count = await client.incr(key)
if current_count == 1:
await client.expire(key, period)
async def is_rate_limited(self, db: AsyncSession, user_id: int, path: str, limit: int, period: int) -> bool:
client = self.get_client()
current_timestamp = int(datetime.now(UTC).timestamp())
window_start = current_timestamp - (current_timestamp % period)

if current_count > limit:
return True
sanitized_path = sanitize_path(path)
key = f"ratelimit:{user_id}:{sanitized_path}:{window_start}"

except Exception as e:
logger.exception(f"Error checking rate limit for user {user_id} on path {path}: {e}")
raise e
try:
current_count = await client.incr(key)
if current_count == 1:
await client.expire(key, period)

return False
if current_count > limit:
return True

except Exception as e:
logger.exception(f"Error checking rate limit for user {user_id} on path {path}: {e}")
raise e

return False


rate_limiter = RateLimiter()