schon/evibes/middleware.py
Egor fureunoir Gorbunov dc7f8be926 Features: 1) None;
Fixes: 1) Add `# ty: ignore` comments to suppress type errors in multiple files; 2) Correct method argument annotations and definitions to align with type hints; 3) Fix cases of invalid or missing imports and unresolved attributes;

Extra: Refactor method definitions to use tuple-based method declarations; replace custom type aliases with `Any`; improve caching utility and error handling logic in utility scripts.
2025-12-19 16:43:39 +03:00

199 lines
6 KiB
Python

import logging
import traceback
from os import getenv
from typing import Any, Callable, cast
from django.contrib.auth.models import AnonymousUser
from django.core.exceptions import (
BadRequest,
DisallowedHost,
PermissionDenied,
ValidationError,
)
from django.http import (
HttpRequest,
HttpResponse,
HttpResponseBadRequest,
HttpResponseForbidden,
HttpResponsePermanentRedirect,
JsonResponse,
QueryDict,
)
from django.middleware.common import CommonMiddleware
from django.middleware.locale import LocaleMiddleware
from django.utils import translation
from graphql import GraphQLResolveInfo
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import InvalidToken
from sentry_sdk import capture_exception
from engine.vibes_auth.models import User
from evibes.settings.drf import JSON_UNDERSCOREIZE
from evibes.utils.misc import RatelimitedError
from evibes.utils.parsers import underscoreize
logger = logging.getLogger(__name__)
class CustomCommonMiddleware(CommonMiddleware):
def process_request(
self, request: HttpRequest
) -> HttpResponsePermanentRedirect | None:
try:
return super().process_request(request)
except DisallowedHost:
# Return a permanent redirect to match the base class return type
return HttpResponsePermanentRedirect(
f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}"
)
class CustomLocaleMiddleware(LocaleMiddleware):
def process_request(self, request: HttpRequest) -> None:
lang = translation.get_language_from_request(request)
parts = lang.replace("_", "-").split("-")
if len(parts) == 2:
lang_code = parts[0].lower()
region = parts[1].lower()
normalized = f"{lang_code}-{region}"
else:
normalized = lang.lower()
translation.activate(normalized)
request.LANGUAGE_CODE = normalized
class GrapheneJWTAuthorizationMiddleware:
def resolve(
self,
next: Callable[..., Any],
root: Any,
info: GraphQLResolveInfo,
**args: Any,
) -> Any:
context = info.context
user = self.get_jwt_user(cast(HttpRequest, context))
# Ensure attribute is set without mypy/pyright complaining about unknown members
info.context.user = user
return next(root, info, **args)
@staticmethod
def get_jwt_user(request: HttpRequest) -> "User" | AnonymousUser:
jwt_authenticator = JWTAuthentication()
try:
user_obj, _ = jwt_authenticator.authenticate(request) # type: ignore[assignment]
user: "User" | AnonymousUser = cast(User, user_obj)
except InvalidToken:
user = AnonymousUser()
except TypeError:
user = AnonymousUser()
except Exception as e:
logger.warning("Could not authenticate user: %s", str(e))
logger.debug(traceback.format_exc())
user = AnonymousUser()
return user
class BlockInvalidHostMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
allowed_hosts = [
"app:8000",
"worker:8000",
"beat:8000",
"localhost:8000",
"127.0.0.1:8000",
]
if bool(int(getenv("DEBUG", "1"))):
allowed_hosts += ["*"]
else:
allowed_hosts += getenv("ALLOWED_HOSTS", "").split(" ")
if not hasattr(request, "META"):
return HttpResponseBadRequest("Invalid Request")
if (
request.META.get("HTTP_HOST") not in allowed_hosts
and "*" not in allowed_hosts
):
return HttpResponseForbidden("Invalid Host Header")
return self.get_response(request)
class GrapheneLoggingErrorsDebugMiddleware:
WARNING_ONLY_ERRORS = [
BadRequest,
PermissionDenied,
DisallowedHost,
ValidationError,
]
def resolve(
self,
next: Callable[..., Any],
root: Any,
info: GraphQLResolveInfo,
**args: Any,
) -> Any:
try:
return next(root, info, **args)
except Exception as e:
if any(
isinstance(e, error_type) for error_type in self.WARNING_ONLY_ERRORS
):
logger.warning(str(e))
else:
logger.error(str(e), exc_info=True)
capture_exception(e)
raise e
class RateLimitMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
return self.get_response(request)
# noinspection PyUnusedLocal
def process_exception(
self, request: HttpRequest, exception: Exception
) -> JsonResponse | None:
if isinstance(exception, RatelimitedError):
return JsonResponse(
{
"error": str(exception),
"code": getattr(exception, "code", "rate_limited"),
},
status=429,
)
return None
class CamelCaseMiddleWare:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
underscoreized_get = underscoreize(
{k: v for k, v in request.GET.lists()},
**JSON_UNDERSCOREIZE,
)
new_get = QueryDict(mutable=True)
for key, value in underscoreized_get.items():
if isinstance(value, list):
for val in value:
new_get.appendlist(key, val)
else:
new_get[key] = value
new_get._mutable = False
request.GET = new_get
response = self.get_response(request)
return response