import logging import traceback from os import getenv from typing import Any, Callable, cast from django.contrib.auth.models import AbstractBaseUser, AnonymousUser from django.core.exceptions import ( BadRequest, DisallowedHost, PermissionDenied, ValidationError, ) from django.http import ( HttpRequest, HttpResponse, HttpResponseBadRequest, HttpResponseForbidden, HttpResponsePermanentRedirect, JsonResponse, QueryDict, ) from django.middleware.common import CommonMiddleware from django.middleware.locale import LocaleMiddleware from django.utils import translation from graphql import GraphQLResolveInfo from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.exceptions import InvalidToken from sentry_sdk import capture_exception from evibes.settings.drf import JSON_UNDERSCOREIZE from evibes.utils.misc import RatelimitedError from evibes.utils.parsers import underscoreize logger = logging.getLogger(__name__) class CustomCommonMiddleware(CommonMiddleware): def process_request( self, request: HttpRequest ) -> HttpResponsePermanentRedirect | None: try: return super().process_request(request) except DisallowedHost: # Return a permanent redirect to match the base class return type return HttpResponsePermanentRedirect( f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}" ) class CustomLocaleMiddleware(LocaleMiddleware): def process_request(self, request: HttpRequest) -> None: lang = translation.get_language_from_request(request) parts = lang.replace("_", "-").split("-") if len(parts) == 2: lang_code = parts[0].lower() region = parts[1].lower() normalized = f"{lang_code}-{region}" else: normalized = lang.lower() translation.activate(normalized) request.LANGUAGE_CODE = normalized class GrapheneJWTAuthorizationMiddleware: def resolve( self, next: Callable[..., Any], root: Any, info: GraphQLResolveInfo, **args: Any, ) -> Any: context = info.context user = self.get_jwt_user(cast(HttpRequest, context)) # Ensure attribute is set without mypy/pyright complaining about unknown members info.context.user = user return next(root, info, **args) @staticmethod def get_jwt_user(request: HttpRequest) -> AbstractBaseUser | AnonymousUser: jwt_authenticator = JWTAuthentication() try: user_obj, _ = jwt_authenticator.authenticate(request) # type: ignore[assignment] user: AbstractBaseUser | AnonymousUser = cast(AbstractBaseUser, user_obj) except InvalidToken: user = AnonymousUser() except TypeError: user = AnonymousUser() except Exception as e: logger.warning("Could not authenticate user: %s", str(e)) logger.debug(traceback.format_exc()) user = AnonymousUser() return user class BlockInvalidHostMiddleware: def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None: self.get_response = get_response def __call__(self, request: HttpRequest) -> HttpResponse: allowed_hosts = [ "app:8000", "worker:8000", "beat:8000", "localhost:8000", "127.0.0.1:8000", ] if bool(int(getenv("DEBUG", "1"))): allowed_hosts += ["*"] else: allowed_hosts += getenv("ALLOWED_HOSTS", "").split(" ") if not hasattr(request, "META"): return HttpResponseBadRequest("Invalid Request") if ( request.META.get("HTTP_HOST") not in allowed_hosts and "*" not in allowed_hosts ): return HttpResponseForbidden("Invalid Host Header") return self.get_response(request) class GrapheneLoggingErrorsDebugMiddleware: WARNING_ONLY_ERRORS = [ BadRequest, PermissionDenied, DisallowedHost, ValidationError, ] def resolve( self, next: Callable[..., Any], root: Any, info: GraphQLResolveInfo, **args: Any, ) -> Any: try: return next(root, info, **args) except Exception as e: if any( isinstance(e, error_type) for error_type in self.WARNING_ONLY_ERRORS ): logger.warning(str(e)) else: logger.error(str(e), exc_info=True) capture_exception(e) raise e class RateLimitMiddleware: def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None: self.get_response = get_response def __call__(self, request: HttpRequest) -> HttpResponse: return self.get_response(request) # noinspection PyUnusedLocal def process_exception( self, request: HttpRequest, exception: Exception ) -> JsonResponse | None: if isinstance(exception, RatelimitedError): return JsonResponse( { "error": str(exception), "code": getattr(exception, "code", "rate_limited"), }, status=429, ) return None class CamelCaseMiddleWare: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): underscoreized_get = underscoreize( {k: v for k, v in request.GET.lists()}, **JSON_UNDERSCOREIZE, ) new_get = QueryDict(mutable=True) for key, value in underscoreized_get.items(): if isinstance(value, list): for val in value: new_get.appendlist(key, val) else: new_get[key] = value new_get._mutable = False request.GET = new_get response = self.get_response(request) return response