from __future__ import annotations from contextlib import suppress from typing import Iterable, Optional from channels.middleware import BaseMiddleware from django.contrib.auth.models import AnonymousUser from django.utils.functional import LazyObject from rest_framework_simplejwt.authentication import JWTAuthentication class _LazyUser(LazyObject): def _setup(self): self._wrapped = AnonymousUser() def _extract_jwt_from_subprotocols(subprotocols: Optional[Iterable[str]]) -> Optional[str]: if not subprotocols: return None items = list(subprotocols) if len(items) >= 2 and items[0].lower() == "bearer" and items[1]: return items[1] if len(items) == 1 and items[0]: return items[0] return None class JWTAuthMiddleware(BaseMiddleware): async def __call__(self, scope, receive, send): scope["user"] = _LazyUser() token = _extract_jwt_from_subprotocols(scope.get("subprotocols")) if token: jwt_auth = JWTAuthentication() with suppress(Exception): class _Req: def __init__(self, token_str: str): self.META = {"HTTP_AUTHORIZATION": f"Bearer {token_str}"} user, _ = jwt_auth.authenticate(_Req(token)) # type: ignore[arg-type] scope["user"] = user return await super().__call__(scope, receive, send) def JWTAuthMiddlewareStack(inner): return JWTAuthMiddleware(inner)