186 lines
5.6 KiB
Python
186 lines
5.6 KiB
Python
import logging
|
|
import traceback
|
|
from os import getenv
|
|
from typing import Any, Callable, cast
|
|
|
|
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
|
|
from django.core.exceptions import (
|
|
BadRequest,
|
|
DisallowedHost,
|
|
PermissionDenied,
|
|
ValidationError,
|
|
)
|
|
from django.http import (
|
|
HttpRequest,
|
|
HttpResponse,
|
|
HttpResponseBadRequest,
|
|
HttpResponseForbidden,
|
|
HttpResponsePermanentRedirect,
|
|
JsonResponse,
|
|
)
|
|
from django.middleware.common import CommonMiddleware
|
|
from django.middleware.locale import LocaleMiddleware
|
|
from django.utils import translation
|
|
from graphql import GraphQLResolveInfo
|
|
from rest_framework_simplejwt.authentication import JWTAuthentication
|
|
from rest_framework_simplejwt.exceptions import InvalidToken
|
|
from sentry_sdk import capture_exception
|
|
|
|
from evibes.settings.drf import JSON_UNDERSCOREIZE
|
|
from evibes.utils.misc import RatelimitedError
|
|
from evibes.utils.parsers import underscoreize
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CustomCommonMiddleware(CommonMiddleware):
|
|
def process_request(
|
|
self, request: HttpRequest
|
|
) -> HttpResponsePermanentRedirect | None:
|
|
try:
|
|
return super().process_request(request)
|
|
except DisallowedHost:
|
|
# Return a permanent redirect to match the base class return type
|
|
return HttpResponsePermanentRedirect(
|
|
f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}"
|
|
)
|
|
|
|
|
|
class CustomLocaleMiddleware(LocaleMiddleware):
|
|
def process_request(self, request: HttpRequest) -> None:
|
|
lang = translation.get_language_from_request(request)
|
|
parts = lang.replace("_", "-").split("-")
|
|
if len(parts) == 2:
|
|
lang_code = parts[0].lower()
|
|
region = parts[1].lower()
|
|
normalized = f"{lang_code}-{region}"
|
|
else:
|
|
normalized = lang.lower()
|
|
|
|
translation.activate(normalized)
|
|
request.LANGUAGE_CODE = normalized
|
|
|
|
|
|
class GrapheneJWTAuthorizationMiddleware:
|
|
def resolve(
|
|
self,
|
|
next: Callable[..., Any],
|
|
root: Any,
|
|
info: GraphQLResolveInfo,
|
|
**args: Any,
|
|
) -> Any:
|
|
context = info.context
|
|
|
|
user = self.get_jwt_user(cast(HttpRequest, context))
|
|
|
|
# Ensure attribute is set without mypy/pyright complaining about unknown members
|
|
info.context.user = user
|
|
|
|
return next(root, info, **args)
|
|
|
|
@staticmethod
|
|
def get_jwt_user(request: HttpRequest) -> AbstractBaseUser | AnonymousUser:
|
|
jwt_authenticator = JWTAuthentication()
|
|
try:
|
|
user_obj, _ = jwt_authenticator.authenticate(request) # type: ignore[assignment]
|
|
user: AbstractBaseUser | AnonymousUser = cast(AbstractBaseUser, user_obj)
|
|
except InvalidToken:
|
|
user = AnonymousUser()
|
|
except TypeError:
|
|
user = AnonymousUser()
|
|
except Exception as e:
|
|
logger.warning("Could not authenticate user: %s", str(e))
|
|
logger.debug(traceback.format_exc())
|
|
user = AnonymousUser()
|
|
return user
|
|
|
|
|
|
class BlockInvalidHostMiddleware:
|
|
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
|
|
self.get_response = get_response
|
|
|
|
def __call__(self, request: HttpRequest) -> HttpResponse:
|
|
allowed_hosts = [
|
|
"app:8000",
|
|
"worker:8000",
|
|
"beat:8000",
|
|
"localhost:8000",
|
|
"127.0.0.1:8000",
|
|
]
|
|
|
|
if bool(int(getenv("DEBUG", "1"))):
|
|
allowed_hosts += ["*"]
|
|
else:
|
|
allowed_hosts += getenv("ALLOWED_HOSTS", "").split(" ")
|
|
if not hasattr(request, "META"):
|
|
return HttpResponseBadRequest("Invalid Request")
|
|
if (
|
|
request.META.get("HTTP_HOST") not in allowed_hosts
|
|
and "*" not in allowed_hosts
|
|
):
|
|
return HttpResponseForbidden("Invalid Host Header")
|
|
return self.get_response(request)
|
|
|
|
|
|
class GrapheneLoggingErrorsDebugMiddleware:
|
|
WARNING_ONLY_ERRORS = [
|
|
BadRequest,
|
|
PermissionDenied,
|
|
DisallowedHost,
|
|
ValidationError,
|
|
]
|
|
|
|
def resolve(
|
|
self,
|
|
next: Callable[..., Any],
|
|
root: Any,
|
|
info: GraphQLResolveInfo,
|
|
**args: Any,
|
|
) -> Any:
|
|
try:
|
|
return next(root, info, **args)
|
|
except Exception as e:
|
|
if any(
|
|
isinstance(e, error_type) for error_type in self.WARNING_ONLY_ERRORS
|
|
):
|
|
logger.warning(str(e))
|
|
else:
|
|
logger.error(str(e), exc_info=True)
|
|
capture_exception(e)
|
|
raise e
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
|
|
self.get_response = get_response
|
|
|
|
def __call__(self, request: HttpRequest) -> HttpResponse:
|
|
return self.get_response(request)
|
|
|
|
# noinspection PyUnusedLocal
|
|
def process_exception(
|
|
self, request: HttpRequest, exception: Exception
|
|
) -> JsonResponse | None:
|
|
if isinstance(exception, RatelimitedError):
|
|
return JsonResponse(
|
|
{
|
|
"error": str(exception),
|
|
"code": getattr(exception, "code", "rate_limited"),
|
|
},
|
|
status=429,
|
|
)
|
|
return None
|
|
|
|
|
|
class CamelCaseMiddleWare:
|
|
def __init__(self, get_response):
|
|
self.get_response = get_response
|
|
|
|
def __call__(self, request):
|
|
request.GET = underscoreize(
|
|
request.GET,
|
|
**JSON_UNDERSCOREIZE,
|
|
)
|
|
|
|
response = self.get_response(request)
|
|
return response
|