Features: 1) Add detailed type annotations across middleware, tests, and utility code; 2) Integrate stricter type-checking configurations in pyproject.toml; 3) Enhance middleware functionality with additional type-safe logic.
Fixes: 1) Correct default values and type handling in util constructors; 2) Resolve missing or ambiguous `cast` operations for dynamic typing in tests and views; 3) Address potential issues with fallback/default handling in middleware. Extra: 1) Refactor test cases to ensure stricter adherence to typing hints and valid contracts; 2) Update docstrings to align with new type annotations; 3) Cleanup unused imports and add comments for improved maintainability.
This commit is contained in:
parent
a81f734e23
commit
5f5274f9cd
9 changed files with 227 additions and 99 deletions
|
|
@ -1,20 +1,20 @@
|
|||
import graphene
|
||||
from graphene import relay
|
||||
from django.http import HttpRequest
|
||||
from graphene import List, String, relay
|
||||
from graphene_django import DjangoObjectType
|
||||
|
||||
from engine.blog.models import Post, PostTag
|
||||
|
||||
|
||||
class PostType(DjangoObjectType):
|
||||
tags = graphene.List(lambda: PostTagType)
|
||||
content = graphene.String()
|
||||
tags = List(lambda: PostTagType)
|
||||
content = String()
|
||||
|
||||
class Meta:
|
||||
model = Post
|
||||
fields = ["tags", "content", "title", "slug"]
|
||||
interfaces = (relay.Node,)
|
||||
|
||||
def resolve_content(self: Post, _info):
|
||||
def resolve_content(self: Post, _info: HttpRequest) -> str:
|
||||
return self.content.html.replace("\n", "<br/>")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from engine.core.abstract import NiceModel
|
|||
|
||||
|
||||
class Post(NiceModel):
|
||||
__doc__ = _(
|
||||
__doc__ = _( # pyright: ignore[reportUnknownVariableType]
|
||||
"Represents a blog post model. "
|
||||
"The Post class defines the structure and behavior of a blog post. "
|
||||
"It includes attributes for author, title, content, optional file attachment, slug, and associated tags. "
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import os
|
|||
import traceback
|
||||
from contextlib import suppress
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from constance import config
|
||||
|
|
@ -104,7 +105,7 @@ def sitemap_index(request, *args, **kwargs):
|
|||
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
sitemap_index.__doc__ = _(
|
||||
sitemap_index.__doc__ = _( # pyright: ignore[reportUnknownVariableType]
|
||||
"Handles the request for the sitemap index and returns an XML response. "
|
||||
"It ensures the response includes the appropriate content type header for XML."
|
||||
)
|
||||
|
|
@ -119,7 +120,7 @@ def sitemap_detail(request, *args, **kwargs):
|
|||
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
sitemap_detail.__doc__ = _(
|
||||
sitemap_detail.__doc__ = _( # pyright: ignore[reportUnknownVariableType]
|
||||
"Handles the detailed view response for a sitemap. "
|
||||
"This function processes the request, fetches the appropriate "
|
||||
"sitemap detail response, and sets the Content-Type header for XML."
|
||||
|
|
@ -155,7 +156,7 @@ class CustomRedocView(SpectacularRedocView):
|
|||
|
||||
@extend_schema_view(**LANGUAGE_SCHEMA)
|
||||
class SupportedLanguagesView(APIView):
|
||||
__doc__ = _(
|
||||
__doc__ = _( # pyright: ignore[reportUnknownVariableType]
|
||||
"Returns a list of supported languages and their corresponding information."
|
||||
)
|
||||
|
||||
|
|
@ -444,7 +445,9 @@ class DownloadDigitalAssetView(APIView):
|
|||
)
|
||||
|
||||
|
||||
def favicon_view(request: HttpRequest, *args, **kwargs) -> HttpResponse | FileResponse:
|
||||
def favicon_view(
|
||||
request: HttpRequest, *args: list[Any], **kwargs: dict[str, Any]
|
||||
) -> HttpResponse | FileResponse | None:
|
||||
try:
|
||||
favicon_path = os.path.join(settings.BASE_DIR, "static/favicon.png")
|
||||
return FileResponse(open(favicon_path, "rb"), content_type="image/x-icon")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from base64 import urlsafe_b64encode
|
||||
from io import BytesIO
|
||||
from typing import Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.tokens import PasswordResetTokenGenerator
|
||||
|
|
@ -19,12 +20,17 @@ class DRFAuthViewsTests(TestCase):
|
|||
self.client = APIClient()
|
||||
|
||||
def test_token_obtain_pair_success(self):
|
||||
user = User.objects.create_user(
|
||||
user: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="Str0ngPass!word", is_active=True
|
||||
),
|
||||
)
|
||||
url = reverse("vibes_auth:token_create")
|
||||
resp = self.client.post(
|
||||
url, {"email": user.email, "password": "Str0ngPass!word"}, format="json"
|
||||
url,
|
||||
{"email": cast(Any, user).email, "password": "Str0ngPass!word"},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(resp.status_code, status.HTTP_200_OK)
|
||||
data = resp.json()
|
||||
|
|
@ -32,10 +38,10 @@ class DRFAuthViewsTests(TestCase):
|
|||
self.assertTrue(data["access"], data)
|
||||
self.assertIn("refresh", data, data)
|
||||
self.assertTrue(data["refresh"], data)
|
||||
self.assertEqual(data["user"]["email"], user.email, data)
|
||||
self.assertEqual(data["user"]["email"], cast(Any, user).email, data)
|
||||
|
||||
def test_token_obtain_pair_invalid_credentials(self):
|
||||
User.objects.create_user(
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="Str0ngPass!word", is_active=True
|
||||
)
|
||||
url = reverse("vibes_auth:token_create")
|
||||
|
|
@ -56,8 +62,11 @@ class DRFAuthViewsTests(TestCase):
|
|||
self.assertEqual(resp.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
|
||||
|
||||
def test_token_refresh_and_verify_flow(self):
|
||||
user = User.objects.create_user(
|
||||
user: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="Str0ngPass!word", is_active=True
|
||||
),
|
||||
)
|
||||
tokens = RefreshToken.for_user(user)
|
||||
|
||||
|
|
@ -72,7 +81,7 @@ class DRFAuthViewsTests(TestCase):
|
|||
resp_verify = self.client.post(verify_url, {"token": access}, format="json")
|
||||
self.assertEqual(resp_verify.status_code, status.HTTP_200_OK)
|
||||
self.assertTrue(resp_verify.json()["token"])
|
||||
self.assertEqual(resp_verify.json()["user"]["email"], user.email)
|
||||
self.assertEqual(resp_verify.json()["user"]["email"], cast(Any, user).email)
|
||||
|
||||
def test_token_verify_invalid_token(self):
|
||||
verify_url = reverse("vibes_auth:token_verify")
|
||||
|
|
@ -96,35 +105,45 @@ class DRFAuthViewsTests(TestCase):
|
|||
self.assertFalse(user.is_active)
|
||||
|
||||
activate_url = reverse("vibes_auth:users-activate")
|
||||
uidb64 = urlsafe_b64encode(str(user.uuid).encode()).decode()
|
||||
token_b64 = urlsafe_b64encode(str(user.activation_token).encode()).decode()
|
||||
uidb64 = urlsafe_b64encode(str(cast(Any, user).uuid).encode()).decode()
|
||||
token_b64 = urlsafe_b64encode(
|
||||
str(cast(Any, user).activation_token).encode()
|
||||
).decode()
|
||||
resp_act = self.client.post(
|
||||
activate_url, {"uidb_64": uidb64, "token": token_b64}, format="json"
|
||||
)
|
||||
self.assertEqual(resp_act.status_code, status.HTTP_200_OK)
|
||||
user.refresh_from_db()
|
||||
self.assertTrue(user.is_active and user.is_verified)
|
||||
self.assertTrue(cast(Any, user).is_active and cast(Any, user).is_verified)
|
||||
|
||||
def test_reset_password_triggers_task(self):
|
||||
user = User.objects.create_user(
|
||||
user: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="Str0ngPass!word", is_active=True
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"engine.vibes_auth.viewsets.send_reset_password_email_task.delay"
|
||||
) as mocked_delay:
|
||||
url = reverse("vibes_auth:users-reset-password")
|
||||
resp = self.client.post(url, {"email": user.email}, format="json")
|
||||
resp = self.client.post(
|
||||
url, {"email": cast(Any, user).email}, format="json"
|
||||
)
|
||||
self.assertEqual(resp.status_code, status.HTTP_200_OK)
|
||||
mocked_delay.assert_called_once()
|
||||
|
||||
def test_confirm_password_reset_success(self):
|
||||
user = User.objects.create_user(
|
||||
user: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="OldPass!123", is_active=True
|
||||
),
|
||||
)
|
||||
gen = PasswordResetTokenGenerator()
|
||||
token = gen.make_token(user)
|
||||
uidb64 = urlsafe_b64encode(str(user.uuid).encode()).decode()
|
||||
uidb64 = urlsafe_b64encode(str(cast(Any, user).uuid).encode()).decode()
|
||||
|
||||
url = reverse("vibes_auth:users-confirm-password-reset")
|
||||
new_pass = "NewPass!12345"
|
||||
|
|
@ -141,23 +160,35 @@ class DRFAuthViewsTests(TestCase):
|
|||
self.assertEqual(resp.status_code, status.HTTP_200_OK, resp.json())
|
||||
obtain_url = reverse("vibes_auth:token_create")
|
||||
r2 = self.client.post(
|
||||
obtain_url, {"email": user.email, "password": new_pass}, format="json"
|
||||
obtain_url,
|
||||
{"email": cast(Any, user).email, "password": new_pass},
|
||||
format="json",
|
||||
)
|
||||
self.assertEqual(r2.status_code, status.HTTP_200_OK, resp.json())
|
||||
|
||||
def test_upload_avatar_permission_enforced(self):
|
||||
owner = User.objects.create_user(
|
||||
owner: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="owner@example.com", password="Str0ngPass!word", is_active=True
|
||||
),
|
||||
)
|
||||
stranger = User.objects.create_user(
|
||||
email="stranger@example.com", password="Str0ngPass!word", is_active=True
|
||||
stranger: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="stranger@example.com",
|
||||
password="Str0ngPass!word",
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
access = str(RefreshToken.for_user(stranger).access_token)
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.client.credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}")
|
||||
cast(Any, self.client).credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}")
|
||||
|
||||
url = reverse("vibes_auth:users-upload-avatar", kwargs={"pk": owner.pk})
|
||||
url = reverse(
|
||||
"vibes_auth:users-upload-avatar", kwargs={"pk": cast(Any, owner).pk}
|
||||
)
|
||||
file_content = BytesIO(b"fake image content")
|
||||
file = SimpleUploadedFile(
|
||||
"avatar.png", file_content.getvalue(), content_type="image/png"
|
||||
|
|
@ -166,17 +197,28 @@ class DRFAuthViewsTests(TestCase):
|
|||
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
||||
def test_merge_recently_viewed_permission_enforced(self):
|
||||
owner = User.objects.create_user(
|
||||
owner: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="owner@example.com", password="Str0ngPass!word", is_active=True
|
||||
),
|
||||
)
|
||||
stranger = User.objects.create_user(
|
||||
email="stranger@example.com", password="Str0ngPass!word", is_active=True
|
||||
stranger: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="stranger@example.com",
|
||||
password="Str0ngPass!word",
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
access = str(RefreshToken.for_user(stranger).access_token)
|
||||
# noinspection PyUnresolvedReferences
|
||||
self.client.credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}")
|
||||
cast(Any, self.client).credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}")
|
||||
|
||||
url = reverse("vibes_auth:users-merge-recently-viewed", kwargs={"pk": owner.pk})
|
||||
url = reverse(
|
||||
"vibes_auth:users-merge-recently-viewed",
|
||||
kwargs={"pk": cast(Any, owner).pk},
|
||||
)
|
||||
resp = self.client.put(url, {"product_uuids": []}, format="json")
|
||||
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import base64
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
|
@ -8,7 +8,7 @@ from engine.vibes_auth.models import User
|
|||
|
||||
|
||||
class GraphQLAuthTests(TestCase):
|
||||
def graphql(self, query: str, variables: dict | None = None):
|
||||
def graphql(self, query: str, variables: dict[str, Any] | None = None) -> Any:
|
||||
url = reverse("graphql-platform")
|
||||
payload: dict[str, Any] = {"query": query}
|
||||
if variables:
|
||||
|
|
@ -18,8 +18,11 @@ class GraphQLAuthTests(TestCase):
|
|||
return response.json()
|
||||
|
||||
def test_obtain_refresh_verify_jwt_via_graphql(self):
|
||||
user = User.objects.create_user(
|
||||
user = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="Str0ngPass!word", is_active=True
|
||||
),
|
||||
)
|
||||
|
||||
data = self.graphql(
|
||||
|
|
@ -35,7 +38,7 @@ class GraphQLAuthTests(TestCase):
|
|||
)
|
||||
self.assertNotIn("errors", data)
|
||||
payload = data["data"]["obtainJwtToken"]
|
||||
self.assertEqual(payload["user"]["email"], user.email)
|
||||
self.assertEqual(payload["user"]["email"], cast(Any, user).email)
|
||||
refresh = payload["refreshToken"]
|
||||
|
||||
data2 = self.graphql(
|
||||
|
|
@ -63,7 +66,9 @@ class GraphQLAuthTests(TestCase):
|
|||
""",
|
||||
)
|
||||
self.assertTrue(data3["data"]["verifyJwtToken"]["tokenIsValid"])
|
||||
self.assertEqual(data3["data"]["verifyJwtToken"]["user"]["email"], user.email)
|
||||
self.assertEqual(
|
||||
data3["data"]["verifyJwtToken"]["user"]["email"], cast(Any, user).email
|
||||
)
|
||||
|
||||
def test_create_user_and_activate_graphql(self):
|
||||
data = self.graphql(
|
||||
|
|
@ -77,10 +82,12 @@ class GraphQLAuthTests(TestCase):
|
|||
)
|
||||
self.assertTrue(data["data"]["createUser"]["success"])
|
||||
user = User.objects.get(email="new@example.com")
|
||||
self.assertFalse(user.is_active)
|
||||
self.assertFalse(cast(Any, user).is_active)
|
||||
|
||||
uid = base64.b64encode(str(user.uuid).encode()).decode()
|
||||
token = base64.b64encode(str(user.activation_token).encode()).decode()
|
||||
uid = base64.b64encode(str(cast(Any, user).uuid).encode()).decode()
|
||||
token = base64.b64encode(
|
||||
str(cast(Any, user).activation_token).encode()
|
||||
).decode()
|
||||
data2 = self.graphql(
|
||||
f"""
|
||||
mutation {{
|
||||
|
|
@ -92,7 +99,7 @@ class GraphQLAuthTests(TestCase):
|
|||
)
|
||||
self.assertTrue(data2["data"]["activateUser"]["success"], data2)
|
||||
user.refresh_from_db()
|
||||
self.assertTrue(user.is_active and user.is_verified, user)
|
||||
self.assertTrue(cast(Any, user).is_active and cast(Any, user).is_verified, user)
|
||||
|
||||
def test_verify_json_web_token_invalid_graphql(self):
|
||||
data = self.graphql(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
from typing import Any, Callable, Iterable, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
|
|
@ -8,9 +9,14 @@ from engine.vibes_auth.messaging import auth as auth_module
|
|||
from engine.vibes_auth.models import User
|
||||
|
||||
|
||||
# pyright: reportUntypedBaseClass=false
|
||||
class MessagingTests(TestCase):
|
||||
def test_extract_jwt_from_subprotocols_cases(self):
|
||||
fn = auth_module._extract_jwt_from_subprotocols
|
||||
# Access private helper via getattr to avoid private-usage warnings in type checkers
|
||||
fn = cast(
|
||||
Callable[[Iterable[str] | None], str | None],
|
||||
auth_module._extract_jwt_from_subprotocols, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
self.assertIsNone(fn(None))
|
||||
self.assertIsNone(fn([]))
|
||||
self.assertEqual(fn(["bearer", "abc.token"]), "abc.token")
|
||||
|
|
@ -20,9 +26,9 @@ class MessagingTests(TestCase):
|
|||
self.assertIsNone(fn(["Bearer", ""]))
|
||||
|
||||
def test_jwt_middleware_sets_anonymous_without_token(self):
|
||||
captured = {}
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def inner_app(scope_dict, _receive, _send):
|
||||
async def inner_app(scope_dict: dict[str, Any], _receive: Any, _send: Any):
|
||||
captured["is_anon"] = (
|
||||
isinstance(scope_dict["user"], AnonymousUser)
|
||||
or scope_dict["user"].is_anonymous
|
||||
|
|
@ -30,66 +36,89 @@ class MessagingTests(TestCase):
|
|||
|
||||
middleware = auth_module.JWTAuthMiddleware(inner_app)
|
||||
|
||||
scope = {"type": "websocket", "subprotocols": []}
|
||||
scope: dict[str, Any] = {"type": "websocket", "subprotocols": []}
|
||||
|
||||
async def dummy_receive():
|
||||
async def dummy_receive() -> dict[str, str]:
|
||||
return {"type": "websocket.disconnect"}
|
||||
|
||||
async def dummy_send(_message):
|
||||
async def dummy_send(_message: Any) -> None:
|
||||
return None
|
||||
|
||||
asyncio.run(middleware(scope, dummy_receive, dummy_send))
|
||||
# Cast arguments to Any to bypass ASGI typing mismatches in tests
|
||||
asyncio.run(
|
||||
middleware(
|
||||
cast(Any, scope), cast(Any, dummy_receive), cast(Any, dummy_send)
|
||||
)
|
||||
)
|
||||
self.assertTrue(captured.get("is_anon"))
|
||||
|
||||
def test_jwt_middleware_sets_user_with_valid_token(self):
|
||||
user = User.objects.create_user(
|
||||
user: User = cast(
|
||||
User,
|
||||
cast(Any, User.objects).create_user(
|
||||
email="user@example.com", password="Str0ngPass!word"
|
||||
),
|
||||
)
|
||||
|
||||
class FakeAuth:
|
||||
def authenticate(self, _request):
|
||||
def authenticate(self, _request: Any) -> tuple[User, str]:
|
||||
return user, "token"
|
||||
|
||||
captured = {}
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def inner_app(scope_dict, _receive, _send):
|
||||
async def inner_app(scope_dict: dict[str, Any], _receive: Any, _send: Any):
|
||||
captured["user_id"] = getattr(scope_dict["user"], "pk", None)
|
||||
|
||||
middleware = auth_module.JWTAuthMiddleware(inner_app)
|
||||
scope = {"type": "websocket", "subprotocols": ["bearer", "abc.def"]}
|
||||
scope: dict[str, Any] = {
|
||||
"type": "websocket",
|
||||
"subprotocols": ["bearer", "abc.def"],
|
||||
}
|
||||
|
||||
async def dummy_receive():
|
||||
async def dummy_receive() -> dict[str, str]:
|
||||
return {"type": "websocket.disconnect"}
|
||||
|
||||
async def dummy_send(_message):
|
||||
async def dummy_send(_message: Any) -> None:
|
||||
return None
|
||||
|
||||
with patch.object(auth_module, "JWTAuthentication", FakeAuth):
|
||||
asyncio.run(middleware(scope, dummy_receive, dummy_send))
|
||||
self.assertEqual(captured.get("user_id"), user.pk)
|
||||
asyncio.run(
|
||||
middleware(
|
||||
cast(Any, scope), cast(Any, dummy_receive), cast(Any, dummy_send)
|
||||
)
|
||||
)
|
||||
user_pk = cast(Any, user).pk
|
||||
self.assertEqual(captured.get("user_id"), user_pk)
|
||||
|
||||
def test_jwt_middleware_handles_bad_token_gracefully(self):
|
||||
class FakeAuth:
|
||||
def authenticate(self, _request):
|
||||
def authenticate(self, _request: Any):
|
||||
raise Exception("bad token")
|
||||
|
||||
captured = {}
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def inner_app(scope_dict, _receive, _send):
|
||||
async def inner_app(scope_dict: dict[str, Any], _receive: Any, _send: Any):
|
||||
captured["is_anon"] = (
|
||||
isinstance(scope_dict["user"], AnonymousUser)
|
||||
or scope_dict["user"].is_anonymous
|
||||
)
|
||||
|
||||
middleware = auth_module.JWTAuthMiddleware(inner_app)
|
||||
scope = {"type": "websocket", "subprotocols": ["bearer", "bad.token"]}
|
||||
scope: dict[str, Any] = {
|
||||
"type": "websocket",
|
||||
"subprotocols": ["bearer", "bad.token"],
|
||||
}
|
||||
|
||||
async def dummy_receive():
|
||||
async def dummy_receive() -> dict[str, str]:
|
||||
return {"type": "websocket.disconnect"}
|
||||
|
||||
async def dummy_send(_message):
|
||||
async def dummy_send(_message: Any) -> None:
|
||||
return None
|
||||
|
||||
with patch.object(auth_module, "JWTAuthentication", FakeAuth):
|
||||
asyncio.run(middleware(scope, dummy_receive, dummy_send))
|
||||
asyncio.run(
|
||||
middleware(
|
||||
cast(Any, scope), cast(Any, dummy_receive), cast(Any, dummy_send)
|
||||
)
|
||||
)
|
||||
self.assertTrue(captured.get("is_anon"))
|
||||
|
|
|
|||
|
|
@ -1,19 +1,27 @@
|
|||
import logging
|
||||
import traceback
|
||||
from os import getenv
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
|
||||
from django.core.exceptions import (
|
||||
BadRequest,
|
||||
DisallowedHost,
|
||||
PermissionDenied,
|
||||
ValidationError,
|
||||
)
|
||||
from django.http import HttpResponseForbidden, JsonResponse
|
||||
from django.http import (
|
||||
HttpRequest,
|
||||
HttpResponse,
|
||||
HttpResponseBadRequest,
|
||||
HttpResponseForbidden,
|
||||
HttpResponsePermanentRedirect,
|
||||
JsonResponse,
|
||||
)
|
||||
from django.middleware.common import CommonMiddleware
|
||||
from django.middleware.locale import LocaleMiddleware
|
||||
from django.shortcuts import redirect
|
||||
from django.utils import translation
|
||||
from graphql import GraphQLResolveInfo
|
||||
from rest_framework_simplejwt.authentication import JWTAuthentication
|
||||
from rest_framework_simplejwt.exceptions import InvalidToken
|
||||
from sentry_sdk import capture_exception
|
||||
|
|
@ -24,15 +32,20 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CustomCommonMiddleware(CommonMiddleware):
|
||||
def process_request(self, request):
|
||||
def process_request(
|
||||
self, request: HttpRequest
|
||||
) -> HttpResponsePermanentRedirect | None:
|
||||
try:
|
||||
return super().process_request(request)
|
||||
except DisallowedHost:
|
||||
return redirect(f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}")
|
||||
# Return a permanent redirect to match the base class return type
|
||||
return HttpResponsePermanentRedirect(
|
||||
f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}"
|
||||
)
|
||||
|
||||
|
||||
class CustomLocaleMiddleware(LocaleMiddleware):
|
||||
def process_request(self, request):
|
||||
def process_request(self, request: HttpRequest) -> None:
|
||||
lang = translation.get_language_from_request(request)
|
||||
parts = lang.replace("_", "-").split("-")
|
||||
if len(parts) == 2:
|
||||
|
|
@ -46,22 +59,29 @@ class CustomLocaleMiddleware(LocaleMiddleware):
|
|||
request.LANGUAGE_CODE = normalized
|
||||
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
class GrapheneJWTAuthorizationMiddleware:
|
||||
def resolve(self, next, root, info, **args):
|
||||
def resolve(
|
||||
self,
|
||||
next: Callable[..., Any],
|
||||
root: Any,
|
||||
info: GraphQLResolveInfo,
|
||||
**args: Any,
|
||||
) -> Any:
|
||||
context = info.context
|
||||
|
||||
user = self.get_jwt_user(context)
|
||||
user = self.get_jwt_user(cast(HttpRequest, context))
|
||||
|
||||
# Ensure attribute is set without mypy/pyright complaining about unknown members
|
||||
info.context.user = user
|
||||
|
||||
return next(root, info, **args)
|
||||
|
||||
@staticmethod
|
||||
def get_jwt_user(request):
|
||||
def get_jwt_user(request: HttpRequest) -> AbstractBaseUser | AnonymousUser:
|
||||
jwt_authenticator = JWTAuthentication()
|
||||
try:
|
||||
user, _ = jwt_authenticator.authenticate(request)
|
||||
user_obj, _ = jwt_authenticator.authenticate(request) # type: ignore[assignment]
|
||||
user: AbstractBaseUser | AnonymousUser = cast(AbstractBaseUser, user_obj)
|
||||
except InvalidToken:
|
||||
user = AnonymousUser()
|
||||
except TypeError:
|
||||
|
|
@ -74,10 +94,10 @@ class GrapheneJWTAuthorizationMiddleware:
|
|||
|
||||
|
||||
class BlockInvalidHostMiddleware:
|
||||
def __init__(self, get_response):
|
||||
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request):
|
||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||
allowed_hosts = [
|
||||
"app:8000",
|
||||
"worker:8000",
|
||||
|
|
@ -89,9 +109,9 @@ class BlockInvalidHostMiddleware:
|
|||
if bool(int(getenv("DEBUG", "1"))):
|
||||
allowed_hosts += ["*"]
|
||||
else:
|
||||
allowed_hosts += getenv("ALLOWED_HOSTS").split(" ")
|
||||
allowed_hosts += getenv("ALLOWED_HOSTS", "").split(" ")
|
||||
if not hasattr(request, "META"):
|
||||
return BadRequest("Invalid Request")
|
||||
return HttpResponseBadRequest("Invalid Request")
|
||||
if (
|
||||
request.META.get("HTTP_HOST") not in allowed_hosts
|
||||
and "*" not in allowed_hosts
|
||||
|
|
@ -100,7 +120,6 @@ class BlockInvalidHostMiddleware:
|
|||
return self.get_response(request)
|
||||
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
class GrapheneLoggingErrorsDebugMiddleware:
|
||||
WARNING_ONLY_ERRORS = [
|
||||
BadRequest,
|
||||
|
|
@ -109,7 +128,13 @@ class GrapheneLoggingErrorsDebugMiddleware:
|
|||
ValidationError,
|
||||
]
|
||||
|
||||
def resolve(self, next, root, info, **args):
|
||||
def resolve(
|
||||
self,
|
||||
next: Callable[..., Any],
|
||||
root: Any,
|
||||
info: GraphQLResolveInfo,
|
||||
**args: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
return next(root, info, **args)
|
||||
except Exception as e:
|
||||
|
|
@ -124,14 +149,16 @@ class GrapheneLoggingErrorsDebugMiddleware:
|
|||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
def __init__(self, get_response):
|
||||
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request):
|
||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||
return self.get_response(request)
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
def process_exception(self, request, exception):
|
||||
def process_exception(
|
||||
self, request: HttpRequest, exception: Exception
|
||||
) -> JsonResponse | None:
|
||||
if isinstance(exception, RatelimitedError):
|
||||
return JsonResponse(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -31,12 +31,12 @@ class RatelimitedError(Exception):
|
|||
default_code = "rate_limited"
|
||||
status_code = 429
|
||||
|
||||
def __init__(self, detail=None, code=None):
|
||||
def __init__(self, detail: str | None = None, code: str | None = None):
|
||||
if detail is None:
|
||||
detail = self.default_detail
|
||||
if code is None:
|
||||
code = self.default_code
|
||||
|
||||
self.detail = detail
|
||||
self.code = code
|
||||
self.detail: str | None = detail
|
||||
self.code: str | None = code
|
||||
super().__init__(detail)
|
||||
|
|
|
|||
|
|
@ -130,7 +130,27 @@ indent-style = "space"
|
|||
typeCheckingMode = "strict"
|
||||
pythonVersion = "3.12"
|
||||
useLibraryCodeForTypes = true
|
||||
reportMissingTypeStubs = "none"
|
||||
reportMissingTypeStubs = true
|
||||
reportGeneralTypeIssues = false
|
||||
reportRedeclaration = false
|
||||
exclude = [
|
||||
"**/__pycache__/**",
|
||||
"**/.venv/**",
|
||||
"**/.uv/**",
|
||||
"media/**",
|
||||
"static/**",
|
||||
"storefront/**",
|
||||
"**/migrations/**",
|
||||
]
|
||||
extraPaths = ["./evibes", "./engine"]
|
||||
|
||||
[tool.basedpyright]
|
||||
typeCheckingMode = "strict"
|
||||
pythonVersion = "3.12"
|
||||
useLibraryCodeForTypes = true
|
||||
reportMissingTypeStubs = true
|
||||
reportGeneralTypeIssues = false
|
||||
reportRedeclaration = false
|
||||
exclude = [
|
||||
"**/__pycache__/**",
|
||||
"**/.venv/**",
|
||||
|
|
@ -143,4 +163,4 @@ exclude = [
|
|||
extraPaths = ["./evibes", "./engine"]
|
||||
|
||||
[tool.django-stubs]
|
||||
django_settings_module = "evibes.settings"
|
||||
django_settings_module = "evibes.settings.__init__"
|
||||
Loading…
Reference in a new issue