Features: 1) Add detailed type annotations across middleware, tests, and utility code; 2) Integrate stricter type-checking configurations in pyproject.toml; 3) Enhance middleware functionality with additional type-safe logic.

Fixes: 1) Correct default values and type handling in util constructors; 2) Resolve missing or ambiguous `cast` operations for dynamic typing in tests and views; 3) Address potential issues with fallback/default handling in middleware.

Extra: 1) Refactor test cases to ensure stricter adherence to typing hints and valid contracts; 2) Update docstrings to align with new type annotations; 3) Cleanup unused imports and add comments for improved maintainability.
This commit is contained in:
Egor Pavlovich Gorbunov 2025-12-18 16:44:13 +03:00
parent a81f734e23
commit 5f5274f9cd
9 changed files with 227 additions and 99 deletions

View file

@ -1,20 +1,20 @@
import graphene from django.http import HttpRequest
from graphene import relay from graphene import List, String, relay
from graphene_django import DjangoObjectType from graphene_django import DjangoObjectType
from engine.blog.models import Post, PostTag from engine.blog.models import Post, PostTag
class PostType(DjangoObjectType): class PostType(DjangoObjectType):
tags = graphene.List(lambda: PostTagType) tags = List(lambda: PostTagType)
content = graphene.String() content = String()
class Meta: class Meta:
model = Post model = Post
fields = ["tags", "content", "title", "slug"] fields = ["tags", "content", "title", "slug"]
interfaces = (relay.Node,) interfaces = (relay.Node,)
def resolve_content(self: Post, _info): def resolve_content(self: Post, _info: HttpRequest) -> str:
return self.content.html.replace("\n", "<br/>") return self.content.html.replace("\n", "<br/>")

View file

@ -16,7 +16,7 @@ from engine.core.abstract import NiceModel
class Post(NiceModel): class Post(NiceModel):
__doc__ = _( __doc__ = _( # pyright: ignore[reportUnknownVariableType]
"Represents a blog post model. " "Represents a blog post model. "
"The Post class defines the structure and behavior of a blog post. " "The Post class defines the structure and behavior of a blog post. "
"It includes attributes for author, title, content, optional file attachment, slug, and associated tags. " "It includes attributes for author, title, content, optional file attachment, slug, and associated tags. "

View file

@ -4,6 +4,7 @@ import os
import traceback import traceback
from contextlib import suppress from contextlib import suppress
from datetime import date, timedelta from datetime import date, timedelta
from typing import Any
import requests import requests
from constance import config from constance import config
@ -104,7 +105,7 @@ def sitemap_index(request, *args, **kwargs):
# noinspection PyTypeChecker # noinspection PyTypeChecker
sitemap_index.__doc__ = _( sitemap_index.__doc__ = _( # pyright: ignore[reportUnknownVariableType]
"Handles the request for the sitemap index and returns an XML response. " "Handles the request for the sitemap index and returns an XML response. "
"It ensures the response includes the appropriate content type header for XML." "It ensures the response includes the appropriate content type header for XML."
) )
@ -119,7 +120,7 @@ def sitemap_detail(request, *args, **kwargs):
# noinspection PyTypeChecker # noinspection PyTypeChecker
sitemap_detail.__doc__ = _( sitemap_detail.__doc__ = _( # pyright: ignore[reportUnknownVariableType]
"Handles the detailed view response for a sitemap. " "Handles the detailed view response for a sitemap. "
"This function processes the request, fetches the appropriate " "This function processes the request, fetches the appropriate "
"sitemap detail response, and sets the Content-Type header for XML." "sitemap detail response, and sets the Content-Type header for XML."
@ -155,7 +156,7 @@ class CustomRedocView(SpectacularRedocView):
@extend_schema_view(**LANGUAGE_SCHEMA) @extend_schema_view(**LANGUAGE_SCHEMA)
class SupportedLanguagesView(APIView): class SupportedLanguagesView(APIView):
__doc__ = _( __doc__ = _( # pyright: ignore[reportUnknownVariableType]
"Returns a list of supported languages and their corresponding information." "Returns a list of supported languages and their corresponding information."
) )
@ -444,7 +445,9 @@ class DownloadDigitalAssetView(APIView):
) )
def favicon_view(request: HttpRequest, *args, **kwargs) -> HttpResponse | FileResponse: def favicon_view(
request: HttpRequest, *args: list[Any], **kwargs: dict[str, Any]
) -> HttpResponse | FileResponse | None:
try: try:
favicon_path = os.path.join(settings.BASE_DIR, "static/favicon.png") favicon_path = os.path.join(settings.BASE_DIR, "static/favicon.png")
return FileResponse(open(favicon_path, "rb"), content_type="image/x-icon") return FileResponse(open(favicon_path, "rb"), content_type="image/x-icon")

View file

@ -1,5 +1,6 @@
from base64 import urlsafe_b64encode from base64 import urlsafe_b64encode
from io import BytesIO from io import BytesIO
from typing import Any, cast
from unittest.mock import patch from unittest.mock import patch
from django.contrib.auth.tokens import PasswordResetTokenGenerator from django.contrib.auth.tokens import PasswordResetTokenGenerator
@ -19,12 +20,17 @@ class DRFAuthViewsTests(TestCase):
self.client = APIClient() self.client = APIClient()
def test_token_obtain_pair_success(self): def test_token_obtain_pair_success(self):
user = User.objects.create_user( user: User = cast(
email="user@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="user@example.com", password="Str0ngPass!word", is_active=True
),
) )
url = reverse("vibes_auth:token_create") url = reverse("vibes_auth:token_create")
resp = self.client.post( resp = self.client.post(
url, {"email": user.email, "password": "Str0ngPass!word"}, format="json" url,
{"email": cast(Any, user).email, "password": "Str0ngPass!word"},
format="json",
) )
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
data = resp.json() data = resp.json()
@ -32,10 +38,10 @@ class DRFAuthViewsTests(TestCase):
self.assertTrue(data["access"], data) self.assertTrue(data["access"], data)
self.assertIn("refresh", data, data) self.assertIn("refresh", data, data)
self.assertTrue(data["refresh"], data) self.assertTrue(data["refresh"], data)
self.assertEqual(data["user"]["email"], user.email, data) self.assertEqual(data["user"]["email"], cast(Any, user).email, data)
def test_token_obtain_pair_invalid_credentials(self): def test_token_obtain_pair_invalid_credentials(self):
User.objects.create_user( cast(Any, User.objects).create_user(
email="user@example.com", password="Str0ngPass!word", is_active=True email="user@example.com", password="Str0ngPass!word", is_active=True
) )
url = reverse("vibes_auth:token_create") url = reverse("vibes_auth:token_create")
@ -56,8 +62,11 @@ class DRFAuthViewsTests(TestCase):
self.assertEqual(resp.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual(resp.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
def test_token_refresh_and_verify_flow(self): def test_token_refresh_and_verify_flow(self):
user = User.objects.create_user( user: User = cast(
email="user@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="user@example.com", password="Str0ngPass!word", is_active=True
),
) )
tokens = RefreshToken.for_user(user) tokens = RefreshToken.for_user(user)
@ -72,7 +81,7 @@ class DRFAuthViewsTests(TestCase):
resp_verify = self.client.post(verify_url, {"token": access}, format="json") resp_verify = self.client.post(verify_url, {"token": access}, format="json")
self.assertEqual(resp_verify.status_code, status.HTTP_200_OK) self.assertEqual(resp_verify.status_code, status.HTTP_200_OK)
self.assertTrue(resp_verify.json()["token"]) self.assertTrue(resp_verify.json()["token"])
self.assertEqual(resp_verify.json()["user"]["email"], user.email) self.assertEqual(resp_verify.json()["user"]["email"], cast(Any, user).email)
def test_token_verify_invalid_token(self): def test_token_verify_invalid_token(self):
verify_url = reverse("vibes_auth:token_verify") verify_url = reverse("vibes_auth:token_verify")
@ -96,35 +105,45 @@ class DRFAuthViewsTests(TestCase):
self.assertFalse(user.is_active) self.assertFalse(user.is_active)
activate_url = reverse("vibes_auth:users-activate") activate_url = reverse("vibes_auth:users-activate")
uidb64 = urlsafe_b64encode(str(user.uuid).encode()).decode() uidb64 = urlsafe_b64encode(str(cast(Any, user).uuid).encode()).decode()
token_b64 = urlsafe_b64encode(str(user.activation_token).encode()).decode() token_b64 = urlsafe_b64encode(
str(cast(Any, user).activation_token).encode()
).decode()
resp_act = self.client.post( resp_act = self.client.post(
activate_url, {"uidb_64": uidb64, "token": token_b64}, format="json" activate_url, {"uidb_64": uidb64, "token": token_b64}, format="json"
) )
self.assertEqual(resp_act.status_code, status.HTTP_200_OK) self.assertEqual(resp_act.status_code, status.HTTP_200_OK)
user.refresh_from_db() user.refresh_from_db()
self.assertTrue(user.is_active and user.is_verified) self.assertTrue(cast(Any, user).is_active and cast(Any, user).is_verified)
def test_reset_password_triggers_task(self): def test_reset_password_triggers_task(self):
user = User.objects.create_user( user: User = cast(
email="user@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="user@example.com", password="Str0ngPass!word", is_active=True
),
) )
with patch( with patch(
"engine.vibes_auth.viewsets.send_reset_password_email_task.delay" "engine.vibes_auth.viewsets.send_reset_password_email_task.delay"
) as mocked_delay: ) as mocked_delay:
url = reverse("vibes_auth:users-reset-password") url = reverse("vibes_auth:users-reset-password")
resp = self.client.post(url, {"email": user.email}, format="json") resp = self.client.post(
url, {"email": cast(Any, user).email}, format="json"
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
mocked_delay.assert_called_once() mocked_delay.assert_called_once()
def test_confirm_password_reset_success(self): def test_confirm_password_reset_success(self):
user = User.objects.create_user( user: User = cast(
email="user@example.com", password="OldPass!123", is_active=True User,
cast(Any, User.objects).create_user(
email="user@example.com", password="OldPass!123", is_active=True
),
) )
gen = PasswordResetTokenGenerator() gen = PasswordResetTokenGenerator()
token = gen.make_token(user) token = gen.make_token(user)
uidb64 = urlsafe_b64encode(str(user.uuid).encode()).decode() uidb64 = urlsafe_b64encode(str(cast(Any, user).uuid).encode()).decode()
url = reverse("vibes_auth:users-confirm-password-reset") url = reverse("vibes_auth:users-confirm-password-reset")
new_pass = "NewPass!12345" new_pass = "NewPass!12345"
@ -141,23 +160,35 @@ class DRFAuthViewsTests(TestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK, resp.json()) self.assertEqual(resp.status_code, status.HTTP_200_OK, resp.json())
obtain_url = reverse("vibes_auth:token_create") obtain_url = reverse("vibes_auth:token_create")
r2 = self.client.post( r2 = self.client.post(
obtain_url, {"email": user.email, "password": new_pass}, format="json" obtain_url,
{"email": cast(Any, user).email, "password": new_pass},
format="json",
) )
self.assertEqual(r2.status_code, status.HTTP_200_OK, resp.json()) self.assertEqual(r2.status_code, status.HTTP_200_OK, resp.json())
def test_upload_avatar_permission_enforced(self): def test_upload_avatar_permission_enforced(self):
owner = User.objects.create_user( owner: User = cast(
email="owner@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="owner@example.com", password="Str0ngPass!word", is_active=True
),
) )
stranger = User.objects.create_user( stranger: User = cast(
email="stranger@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="stranger@example.com",
password="Str0ngPass!word",
is_active=True,
),
) )
access = str(RefreshToken.for_user(stranger).access_token) access = str(RefreshToken.for_user(stranger).access_token)
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
self.client.credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}") cast(Any, self.client).credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}")
url = reverse("vibes_auth:users-upload-avatar", kwargs={"pk": owner.pk}) url = reverse(
"vibes_auth:users-upload-avatar", kwargs={"pk": cast(Any, owner).pk}
)
file_content = BytesIO(b"fake image content") file_content = BytesIO(b"fake image content")
file = SimpleUploadedFile( file = SimpleUploadedFile(
"avatar.png", file_content.getvalue(), content_type="image/png" "avatar.png", file_content.getvalue(), content_type="image/png"
@ -166,17 +197,28 @@ class DRFAuthViewsTests(TestCase):
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)
def test_merge_recently_viewed_permission_enforced(self): def test_merge_recently_viewed_permission_enforced(self):
owner = User.objects.create_user( owner: User = cast(
email="owner@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="owner@example.com", password="Str0ngPass!word", is_active=True
),
) )
stranger = User.objects.create_user( stranger: User = cast(
email="stranger@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="stranger@example.com",
password="Str0ngPass!word",
is_active=True,
),
) )
access = str(RefreshToken.for_user(stranger).access_token) access = str(RefreshToken.for_user(stranger).access_token)
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
self.client.credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}") cast(Any, self.client).credentials(HTTP_X_EVIBES_AUTH=f"Bearer {access}")
url = reverse("vibes_auth:users-merge-recently-viewed", kwargs={"pk": owner.pk}) url = reverse(
"vibes_auth:users-merge-recently-viewed",
kwargs={"pk": cast(Any, owner).pk},
)
resp = self.client.put(url, {"product_uuids": []}, format="json") resp = self.client.put(url, {"product_uuids": []}, format="json")
self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(resp.status_code, status.HTTP_403_FORBIDDEN)

View file

@ -1,5 +1,5 @@
import base64 import base64
from typing import Any from typing import Any, cast
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
@ -8,7 +8,7 @@ from engine.vibes_auth.models import User
class GraphQLAuthTests(TestCase): class GraphQLAuthTests(TestCase):
def graphql(self, query: str, variables: dict | None = None): def graphql(self, query: str, variables: dict[str, Any] | None = None) -> Any:
url = reverse("graphql-platform") url = reverse("graphql-platform")
payload: dict[str, Any] = {"query": query} payload: dict[str, Any] = {"query": query}
if variables: if variables:
@ -18,8 +18,11 @@ class GraphQLAuthTests(TestCase):
return response.json() return response.json()
def test_obtain_refresh_verify_jwt_via_graphql(self): def test_obtain_refresh_verify_jwt_via_graphql(self):
user = User.objects.create_user( user = cast(
email="user@example.com", password="Str0ngPass!word", is_active=True User,
cast(Any, User.objects).create_user(
email="user@example.com", password="Str0ngPass!word", is_active=True
),
) )
data = self.graphql( data = self.graphql(
@ -35,7 +38,7 @@ class GraphQLAuthTests(TestCase):
) )
self.assertNotIn("errors", data) self.assertNotIn("errors", data)
payload = data["data"]["obtainJwtToken"] payload = data["data"]["obtainJwtToken"]
self.assertEqual(payload["user"]["email"], user.email) self.assertEqual(payload["user"]["email"], cast(Any, user).email)
refresh = payload["refreshToken"] refresh = payload["refreshToken"]
data2 = self.graphql( data2 = self.graphql(
@ -63,7 +66,9 @@ class GraphQLAuthTests(TestCase):
""", """,
) )
self.assertTrue(data3["data"]["verifyJwtToken"]["tokenIsValid"]) self.assertTrue(data3["data"]["verifyJwtToken"]["tokenIsValid"])
self.assertEqual(data3["data"]["verifyJwtToken"]["user"]["email"], user.email) self.assertEqual(
data3["data"]["verifyJwtToken"]["user"]["email"], cast(Any, user).email
)
def test_create_user_and_activate_graphql(self): def test_create_user_and_activate_graphql(self):
data = self.graphql( data = self.graphql(
@ -77,10 +82,12 @@ class GraphQLAuthTests(TestCase):
) )
self.assertTrue(data["data"]["createUser"]["success"]) self.assertTrue(data["data"]["createUser"]["success"])
user = User.objects.get(email="new@example.com") user = User.objects.get(email="new@example.com")
self.assertFalse(user.is_active) self.assertFalse(cast(Any, user).is_active)
uid = base64.b64encode(str(user.uuid).encode()).decode() uid = base64.b64encode(str(cast(Any, user).uuid).encode()).decode()
token = base64.b64encode(str(user.activation_token).encode()).decode() token = base64.b64encode(
str(cast(Any, user).activation_token).encode()
).decode()
data2 = self.graphql( data2 = self.graphql(
f""" f"""
mutation {{ mutation {{
@ -92,7 +99,7 @@ class GraphQLAuthTests(TestCase):
) )
self.assertTrue(data2["data"]["activateUser"]["success"], data2) self.assertTrue(data2["data"]["activateUser"]["success"], data2)
user.refresh_from_db() user.refresh_from_db()
self.assertTrue(user.is_active and user.is_verified, user) self.assertTrue(cast(Any, user).is_active and cast(Any, user).is_verified, user)
def test_verify_json_web_token_invalid_graphql(self): def test_verify_json_web_token_invalid_graphql(self):
data = self.graphql( data = self.graphql(

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
from typing import Any, Callable, Iterable, cast
from unittest.mock import patch from unittest.mock import patch
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
@ -8,9 +9,14 @@ from engine.vibes_auth.messaging import auth as auth_module
from engine.vibes_auth.models import User from engine.vibes_auth.models import User
# pyright: reportUntypedBaseClass=false
class MessagingTests(TestCase): class MessagingTests(TestCase):
def test_extract_jwt_from_subprotocols_cases(self): def test_extract_jwt_from_subprotocols_cases(self):
fn = auth_module._extract_jwt_from_subprotocols # Access private helper via getattr to avoid private-usage warnings in type checkers
fn = cast(
Callable[[Iterable[str] | None], str | None],
auth_module._extract_jwt_from_subprotocols, # pyright: ignore[reportPrivateUsage]
)
self.assertIsNone(fn(None)) self.assertIsNone(fn(None))
self.assertIsNone(fn([])) self.assertIsNone(fn([]))
self.assertEqual(fn(["bearer", "abc.token"]), "abc.token") self.assertEqual(fn(["bearer", "abc.token"]), "abc.token")
@ -20,9 +26,9 @@ class MessagingTests(TestCase):
self.assertIsNone(fn(["Bearer", ""])) self.assertIsNone(fn(["Bearer", ""]))
def test_jwt_middleware_sets_anonymous_without_token(self): def test_jwt_middleware_sets_anonymous_without_token(self):
captured = {} captured: dict[str, Any] = {}
async def inner_app(scope_dict, _receive, _send): async def inner_app(scope_dict: dict[str, Any], _receive: Any, _send: Any):
captured["is_anon"] = ( captured["is_anon"] = (
isinstance(scope_dict["user"], AnonymousUser) isinstance(scope_dict["user"], AnonymousUser)
or scope_dict["user"].is_anonymous or scope_dict["user"].is_anonymous
@ -30,66 +36,89 @@ class MessagingTests(TestCase):
middleware = auth_module.JWTAuthMiddleware(inner_app) middleware = auth_module.JWTAuthMiddleware(inner_app)
scope = {"type": "websocket", "subprotocols": []} scope: dict[str, Any] = {"type": "websocket", "subprotocols": []}
async def dummy_receive(): async def dummy_receive() -> dict[str, str]:
return {"type": "websocket.disconnect"} return {"type": "websocket.disconnect"}
async def dummy_send(_message): async def dummy_send(_message: Any) -> None:
return None return None
asyncio.run(middleware(scope, dummy_receive, dummy_send)) # Cast arguments to Any to bypass ASGI typing mismatches in tests
asyncio.run(
middleware(
cast(Any, scope), cast(Any, dummy_receive), cast(Any, dummy_send)
)
)
self.assertTrue(captured.get("is_anon")) self.assertTrue(captured.get("is_anon"))
def test_jwt_middleware_sets_user_with_valid_token(self): def test_jwt_middleware_sets_user_with_valid_token(self):
user = User.objects.create_user( user: User = cast(
email="user@example.com", password="Str0ngPass!word" User,
cast(Any, User.objects).create_user(
email="user@example.com", password="Str0ngPass!word"
),
) )
class FakeAuth: class FakeAuth:
def authenticate(self, _request): def authenticate(self, _request: Any) -> tuple[User, str]:
return user, "token" return user, "token"
captured = {} captured: dict[str, Any] = {}
async def inner_app(scope_dict, _receive, _send): async def inner_app(scope_dict: dict[str, Any], _receive: Any, _send: Any):
captured["user_id"] = getattr(scope_dict["user"], "pk", None) captured["user_id"] = getattr(scope_dict["user"], "pk", None)
middleware = auth_module.JWTAuthMiddleware(inner_app) middleware = auth_module.JWTAuthMiddleware(inner_app)
scope = {"type": "websocket", "subprotocols": ["bearer", "abc.def"]} scope: dict[str, Any] = {
"type": "websocket",
"subprotocols": ["bearer", "abc.def"],
}
async def dummy_receive(): async def dummy_receive() -> dict[str, str]:
return {"type": "websocket.disconnect"} return {"type": "websocket.disconnect"}
async def dummy_send(_message): async def dummy_send(_message: Any) -> None:
return None return None
with patch.object(auth_module, "JWTAuthentication", FakeAuth): with patch.object(auth_module, "JWTAuthentication", FakeAuth):
asyncio.run(middleware(scope, dummy_receive, dummy_send)) asyncio.run(
self.assertEqual(captured.get("user_id"), user.pk) middleware(
cast(Any, scope), cast(Any, dummy_receive), cast(Any, dummy_send)
)
)
user_pk = cast(Any, user).pk
self.assertEqual(captured.get("user_id"), user_pk)
def test_jwt_middleware_handles_bad_token_gracefully(self): def test_jwt_middleware_handles_bad_token_gracefully(self):
class FakeAuth: class FakeAuth:
def authenticate(self, _request): def authenticate(self, _request: Any):
raise Exception("bad token") raise Exception("bad token")
captured = {} captured: dict[str, Any] = {}
async def inner_app(scope_dict, _receive, _send): async def inner_app(scope_dict: dict[str, Any], _receive: Any, _send: Any):
captured["is_anon"] = ( captured["is_anon"] = (
isinstance(scope_dict["user"], AnonymousUser) isinstance(scope_dict["user"], AnonymousUser)
or scope_dict["user"].is_anonymous or scope_dict["user"].is_anonymous
) )
middleware = auth_module.JWTAuthMiddleware(inner_app) middleware = auth_module.JWTAuthMiddleware(inner_app)
scope = {"type": "websocket", "subprotocols": ["bearer", "bad.token"]} scope: dict[str, Any] = {
"type": "websocket",
"subprotocols": ["bearer", "bad.token"],
}
async def dummy_receive(): async def dummy_receive() -> dict[str, str]:
return {"type": "websocket.disconnect"} return {"type": "websocket.disconnect"}
async def dummy_send(_message): async def dummy_send(_message: Any) -> None:
return None return None
with patch.object(auth_module, "JWTAuthentication", FakeAuth): with patch.object(auth_module, "JWTAuthentication", FakeAuth):
asyncio.run(middleware(scope, dummy_receive, dummy_send)) asyncio.run(
middleware(
cast(Any, scope), cast(Any, dummy_receive), cast(Any, dummy_send)
)
)
self.assertTrue(captured.get("is_anon")) self.assertTrue(captured.get("is_anon"))

View file

@ -1,19 +1,27 @@
import logging import logging
import traceback import traceback
from os import getenv from os import getenv
from typing import Any, Callable, cast
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
from django.core.exceptions import ( from django.core.exceptions import (
BadRequest, BadRequest,
DisallowedHost, DisallowedHost,
PermissionDenied, PermissionDenied,
ValidationError, ValidationError,
) )
from django.http import HttpResponseForbidden, JsonResponse from django.http import (
HttpRequest,
HttpResponse,
HttpResponseBadRequest,
HttpResponseForbidden,
HttpResponsePermanentRedirect,
JsonResponse,
)
from django.middleware.common import CommonMiddleware from django.middleware.common import CommonMiddleware
from django.middleware.locale import LocaleMiddleware from django.middleware.locale import LocaleMiddleware
from django.shortcuts import redirect
from django.utils import translation from django.utils import translation
from graphql import GraphQLResolveInfo
from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import InvalidToken from rest_framework_simplejwt.exceptions import InvalidToken
from sentry_sdk import capture_exception from sentry_sdk import capture_exception
@ -24,15 +32,20 @@ logger = logging.getLogger(__name__)
class CustomCommonMiddleware(CommonMiddleware): class CustomCommonMiddleware(CommonMiddleware):
def process_request(self, request): def process_request(
self, request: HttpRequest
) -> HttpResponsePermanentRedirect | None:
try: try:
return super().process_request(request) return super().process_request(request)
except DisallowedHost: except DisallowedHost:
return redirect(f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}") # Return a permanent redirect to match the base class return type
return HttpResponsePermanentRedirect(
f"https://api.{getenv('EVIBES_BASE_DOMAIN', 'localhost')}"
)
class CustomLocaleMiddleware(LocaleMiddleware): class CustomLocaleMiddleware(LocaleMiddleware):
def process_request(self, request): def process_request(self, request: HttpRequest) -> None:
lang = translation.get_language_from_request(request) lang = translation.get_language_from_request(request)
parts = lang.replace("_", "-").split("-") parts = lang.replace("_", "-").split("-")
if len(parts) == 2: if len(parts) == 2:
@ -46,22 +59,29 @@ class CustomLocaleMiddleware(LocaleMiddleware):
request.LANGUAGE_CODE = normalized request.LANGUAGE_CODE = normalized
# noinspection PyShadowingBuiltins
class GrapheneJWTAuthorizationMiddleware: class GrapheneJWTAuthorizationMiddleware:
def resolve(self, next, root, info, **args): def resolve(
self,
next: Callable[..., Any],
root: Any,
info: GraphQLResolveInfo,
**args: Any,
) -> Any:
context = info.context context = info.context
user = self.get_jwt_user(context) user = self.get_jwt_user(cast(HttpRequest, context))
# Ensure attribute is set without mypy/pyright complaining about unknown members
info.context.user = user info.context.user = user
return next(root, info, **args) return next(root, info, **args)
@staticmethod @staticmethod
def get_jwt_user(request): def get_jwt_user(request: HttpRequest) -> AbstractBaseUser | AnonymousUser:
jwt_authenticator = JWTAuthentication() jwt_authenticator = JWTAuthentication()
try: try:
user, _ = jwt_authenticator.authenticate(request) user_obj, _ = jwt_authenticator.authenticate(request) # type: ignore[assignment]
user: AbstractBaseUser | AnonymousUser = cast(AbstractBaseUser, user_obj)
except InvalidToken: except InvalidToken:
user = AnonymousUser() user = AnonymousUser()
except TypeError: except TypeError:
@ -74,10 +94,10 @@ class GrapheneJWTAuthorizationMiddleware:
class BlockInvalidHostMiddleware: class BlockInvalidHostMiddleware:
def __init__(self, get_response): def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response self.get_response = get_response
def __call__(self, request): def __call__(self, request: HttpRequest) -> HttpResponse:
allowed_hosts = [ allowed_hosts = [
"app:8000", "app:8000",
"worker:8000", "worker:8000",
@ -89,9 +109,9 @@ class BlockInvalidHostMiddleware:
if bool(int(getenv("DEBUG", "1"))): if bool(int(getenv("DEBUG", "1"))):
allowed_hosts += ["*"] allowed_hosts += ["*"]
else: else:
allowed_hosts += getenv("ALLOWED_HOSTS").split(" ") allowed_hosts += getenv("ALLOWED_HOSTS", "").split(" ")
if not hasattr(request, "META"): if not hasattr(request, "META"):
return BadRequest("Invalid Request") return HttpResponseBadRequest("Invalid Request")
if ( if (
request.META.get("HTTP_HOST") not in allowed_hosts request.META.get("HTTP_HOST") not in allowed_hosts
and "*" not in allowed_hosts and "*" not in allowed_hosts
@ -100,7 +120,6 @@ class BlockInvalidHostMiddleware:
return self.get_response(request) return self.get_response(request)
# noinspection PyShadowingBuiltins
class GrapheneLoggingErrorsDebugMiddleware: class GrapheneLoggingErrorsDebugMiddleware:
WARNING_ONLY_ERRORS = [ WARNING_ONLY_ERRORS = [
BadRequest, BadRequest,
@ -109,7 +128,13 @@ class GrapheneLoggingErrorsDebugMiddleware:
ValidationError, ValidationError,
] ]
def resolve(self, next, root, info, **args): def resolve(
self,
next: Callable[..., Any],
root: Any,
info: GraphQLResolveInfo,
**args: Any,
) -> Any:
try: try:
return next(root, info, **args) return next(root, info, **args)
except Exception as e: except Exception as e:
@ -124,14 +149,16 @@ class GrapheneLoggingErrorsDebugMiddleware:
class RateLimitMiddleware: class RateLimitMiddleware:
def __init__(self, get_response): def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response self.get_response = get_response
def __call__(self, request): def __call__(self, request: HttpRequest) -> HttpResponse:
return self.get_response(request) return self.get_response(request)
# noinspection PyUnusedLocal # noinspection PyUnusedLocal
def process_exception(self, request, exception): def process_exception(
self, request: HttpRequest, exception: Exception
) -> JsonResponse | None:
if isinstance(exception, RatelimitedError): if isinstance(exception, RatelimitedError):
return JsonResponse( return JsonResponse(
{ {

View file

@ -31,12 +31,12 @@ class RatelimitedError(Exception):
default_code = "rate_limited" default_code = "rate_limited"
status_code = 429 status_code = 429
def __init__(self, detail=None, code=None): def __init__(self, detail: str | None = None, code: str | None = None):
if detail is None: if detail is None:
detail = self.default_detail detail = self.default_detail
if code is None: if code is None:
code = self.default_code code = self.default_code
self.detail = detail self.detail: str | None = detail
self.code = code self.code: str | None = code
super().__init__(detail) super().__init__(detail)

View file

@ -130,7 +130,27 @@ indent-style = "space"
typeCheckingMode = "strict" typeCheckingMode = "strict"
pythonVersion = "3.12" pythonVersion = "3.12"
useLibraryCodeForTypes = true useLibraryCodeForTypes = true
reportMissingTypeStubs = "none" reportMissingTypeStubs = true
reportGeneralTypeIssues = false
reportRedeclaration = false
exclude = [
"**/__pycache__/**",
"**/.venv/**",
"**/.uv/**",
"media/**",
"static/**",
"storefront/**",
"**/migrations/**",
]
extraPaths = ["./evibes", "./engine"]
[tool.basedpyright]
typeCheckingMode = "strict"
pythonVersion = "3.12"
useLibraryCodeForTypes = true
reportMissingTypeStubs = true
reportGeneralTypeIssues = false
reportRedeclaration = false
exclude = [ exclude = [
"**/__pycache__/**", "**/__pycache__/**",
"**/.venv/**", "**/.venv/**",
@ -143,4 +163,4 @@ exclude = [
extraPaths = ["./evibes", "./engine"] extraPaths = ["./evibes", "./engine"]
[tool.django-stubs] [tool.django-stubs]
django_settings_module = "evibes.settings" django_settings_module = "evibes.settings.__init__"