schon/evibes/middleware.py
Egor fureunoir Gorbunov 29fb56be89 Features: 1) Add async and sync capabilities to CamelCaseMiddleWare; 2) Include OpenAPI support for Enum name overrides in DRF settings; 3) Integrate OpenAPI types in DRF views for improved schema accuracy.
Fixes: 1) Correct `lookup_field` to `uuid` in various viewsets; 2) Replace `type=str` with `OpenApiTypes.STR` in path parameters of multiple DRF endpoints; 3) Add missing import `iscoroutinefunction` and `markcoroutinefunction`.

Extra: 1) Refactor `__call__` method in `CamelCaseMiddleWare` to separate sync and async logic; 2) Enhance documentation schema responses with precise types in multiple DRF views.
2025-12-19 17:27:36 +03:00

214 lines
6.5 KiB
Python

import logging
import traceback
from os import getenv
from typing import Any, Callable, cast
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import (
BadRequest,
DisallowedHost,
PermissionDenied,
ValidationError,
)
from django.http import (
HttpRequest,
HttpResponse,
HttpResponseBadRequest,
HttpResponseForbidden,
HttpResponsePermanentRedirect,
JsonResponse,
QueryDict,
)
from django.middleware.common import CommonMiddleware
from django.middleware.locale import LocaleMiddleware
from django.utils import translation
from graphql import GraphQLResolveInfo
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import InvalidToken
from sentry_sdk import capture_exception
from engine.vibes_auth.models import User
from evibes.settings.drf import JSON_UNDERSCOREIZE
from evibes.utils.misc import RatelimitedError
from evibes.utils.parsers import underscoreize
logger = logging.getLogger(__name__)
class CustomCommonMiddleware(CommonMiddleware):
def process_request(
self, request: HttpRequest
) -> HttpResponsePermanentRedirect | None:
try:
return super().process_request(request)
except DisallowedHost:
# Return a permanent redirect to match the base class return type
return HttpResponsePermanentRedirect(
f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}"
)
class CustomLocaleMiddleware(LocaleMiddleware):
def process_request(self, request: HttpRequest) -> None:
lang = translation.get_language_from_request(request)
parts = lang.replace("_", "-").split("-")
if len(parts) == 2:
lang_code = parts[0].lower()
region = parts[1].lower()
normalized = f"{lang_code}-{region}"
else:
normalized = lang.lower()
translation.activate(normalized)
request.LANGUAGE_CODE = normalized
class GrapheneJWTAuthorizationMiddleware:
def resolve(
self,
next: Callable[..., Any],
root: Any,
info: GraphQLResolveInfo,
**args: Any,
) -> Any:
context = info.context
user = self.get_jwt_user(cast(HttpRequest, context))
# Ensure attribute is set without mypy/pyright complaining about unknown members
info.context.user = user
return next(root, info, **args)
@staticmethod
def get_jwt_user(request: HttpRequest) -> User | AnonymousUser:
jwt_authenticator = JWTAuthentication()
try:
user_obj, _ = jwt_authenticator.authenticate(request) # type: ignore[assignment]
user: User | AnonymousUser = cast(User, user_obj)
except InvalidToken:
user = AnonymousUser()
except TypeError:
user = AnonymousUser()
except Exception as e:
logger.warning("Could not authenticate user: %s", str(e))
logger.debug(traceback.format_exc())
user = AnonymousUser()
return user
class BlockInvalidHostMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
allowed_hosts = [
"app:8000",
"worker:8000",
"beat:8000",
"localhost:8000",
"127.0.0.1:8000",
]
if bool(int(getenv("DEBUG", "1"))):
allowed_hosts += ["*"]
else:
allowed_hosts += getenv("ALLOWED_HOSTS", "").split(" ")
if not hasattr(request, "META"):
return HttpResponseBadRequest("Invalid Request")
if (
request.META.get("HTTP_HOST") not in allowed_hosts
and "*" not in allowed_hosts
):
return HttpResponseForbidden("Invalid Host Header")
return self.get_response(request)
class GrapheneLoggingErrorsDebugMiddleware:
WARNING_ONLY_ERRORS = [
BadRequest,
PermissionDenied,
DisallowedHost,
ValidationError,
]
def resolve(
self,
next: Callable[..., Any],
root: Any,
info: GraphQLResolveInfo,
**args: Any,
) -> Any:
try:
return next(root, info, **args)
except Exception as e:
if any(
isinstance(e, error_type) for error_type in self.WARNING_ONLY_ERRORS
):
logger.warning(str(e))
else:
logger.error(str(e), exc_info=True)
capture_exception(e)
raise e
class RateLimitMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
return self.get_response(request)
# noinspection PyUnusedLocal
def process_exception(
self, request: HttpRequest, exception: Exception
) -> JsonResponse | None:
if isinstance(exception, RatelimitedError):
return JsonResponse(
{
"error": str(exception),
"code": getattr(exception, "code", "rate_limited"),
},
status=429,
)
return None
class CamelCaseMiddleWare:
async_capable = True
sync_capable = True
def __init__(self, get_response):
self.get_response = get_response
if iscoroutinefunction(get_response):
markcoroutinefunction(self) # ty:ignore[invalid-argument-type]
async def __call__(self, request):
if iscoroutinefunction(self.get_response):
self._underscoreize_request(request)
response = await self.get_response(request)
return response
return self._sync_call(request)
def _sync_call(self, request):
self._underscoreize_request(request)
response = self.get_response(request)
return response
def _underscoreize_request(self, request):
underscoreized_get = underscoreize(
{k: v for k, v in request.GET.lists()},
**JSON_UNDERSCOREIZE,
)
new_get = QueryDict(mutable=True)
for key, value in underscoreized_get.items():
if isinstance(value, list):
for val in value:
new_get.appendlist(key, val)
else:
new_get[key] = value
new_get._mutable = False
request.GET = new_get