import logging import traceback from os import getenv from django.contrib.auth.models import AnonymousUser from django.core.exceptions import ( BadRequest, DisallowedHost, PermissionDenied, ValidationError, ) from django.http import HttpResponseForbidden, JsonResponse from django.middleware.common import CommonMiddleware from django.middleware.locale import LocaleMiddleware from django.shortcuts import redirect from django.utils import translation from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.exceptions import InvalidToken from sentry_sdk import capture_exception from evibes.utils.misc import RatelimitedError logger = logging.getLogger(__name__) class CustomCommonMiddleware(CommonMiddleware): def process_request(self, request): try: return super().process_request(request) except DisallowedHost: return redirect(f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}") class CustomLocaleMiddleware(LocaleMiddleware): def process_request(self, request): lang = translation.get_language_from_request(request) parts = lang.replace("_", "-").split("-") if len(parts) == 2: lang_code = parts[0].lower() region = parts[1].lower() normalized = f"{lang_code}-{region}" else: normalized = lang.lower() translation.activate(normalized) request.LANGUAGE_CODE = normalized # noinspection PyShadowingBuiltins class GrapheneJWTAuthorizationMiddleware: def resolve(self, next, root, info, **args): context = info.context user = self.get_jwt_user(context) info.context.user = user return next(root, info, **args) @staticmethod def get_jwt_user(request): jwt_authenticator = JWTAuthentication() try: user, _ = jwt_authenticator.authenticate(request) except InvalidToken: user = AnonymousUser() except TypeError: user = AnonymousUser() except Exception as e: logger.warning("Could not authenticate user: %s", str(e)) logger.debug(traceback.format_exc()) user = AnonymousUser() return user class BlockInvalidHostMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): allowed_hosts = [ "app:8000", "worker:8000", "beat:8000", "localhost:8000", "127.0.0.1:8000", ] if bool(int(getenv("DEBUG", "1"))): allowed_hosts += ["*"] else: allowed_hosts += getenv("ALLOWED_HOSTS").split(" ") if not hasattr(request, "META"): return BadRequest("Invalid Request") if ( request.META.get("HTTP_HOST") not in allowed_hosts and "*" not in allowed_hosts ): return HttpResponseForbidden("Invalid Host Header") return self.get_response(request) # noinspection PyShadowingBuiltins class GrapheneLoggingErrorsDebugMiddleware: WARNING_ONLY_ERRORS = [ BadRequest, PermissionDenied, DisallowedHost, ValidationError, ] def resolve(self, next, root, info, **args): try: return next(root, info, **args) except Exception as e: if any( isinstance(e, error_type) for error_type in self.WARNING_ONLY_ERRORS ): logger.warning(str(e)) else: logger.error(str(e), exc_info=True) capture_exception(e) raise e class RateLimitMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): return self.get_response(request) # noinspection PyUnusedLocal def process_exception(self, request, exception): if isinstance(exception, RatelimitedError): return JsonResponse( { "error": str(exception), "code": getattr(exception, "code", "rate_limited"), }, status=429, ) return None