Features: 1) Add retrieve_user flag to serializers for optional user data retrieval; 2) Allow flexibility in handling user data as UUID or serialized object;

Fixes: 1) Adjust `mutate` methods in GraphQL mutations to support `retrieve_user` flag; 2) Ensure proper fallback to UUID handling when `retrieve_user` is set to `False`;

Extra: 1) Refactor serializers and mutations for improved customization and reduced overhead; 2) Add comments and clean up redundant code.
This commit is contained in:
Egor Pavlovich Gorbunov 2025-11-24 16:13:42 +03:00
parent c5fa95d77f
commit ecd6e5ad9f
2 changed files with 30 additions and 14 deletions

View file

@ -201,11 +201,11 @@ class ObtainJSONWebToken(BaseMutation):
access_token = String(required=True) access_token = String(required=True)
def mutate(self, info, email, password): def mutate(self, info, email, password):
serializer = TokenObtainPairSerializer(data={"email": email, "password": password}) serializer = TokenObtainPairSerializer(data={"email": email, "password": password}, retrieve_user=False)
try: try:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
return ObtainJSONWebToken( return ObtainJSONWebToken(
user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]), user=User.objects.get(uuid=serializer.validated_data["user"]),
refresh_token=serializer.validated_data["refresh"], refresh_token=serializer.validated_data["refresh"],
access_token=serializer.validated_data["access"], access_token=serializer.validated_data["access"],
) )
@ -222,11 +222,11 @@ class RefreshJSONWebToken(BaseMutation):
refresh_token = String() refresh_token = String()
def mutate(self, info, refresh_token): def mutate(self, info, refresh_token):
serializer = TokenRefreshSerializer(data={"refresh": refresh_token}) serializer = TokenRefreshSerializer(data={"refresh": refresh_token}, retrieve_user=False)
try: try:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
return RefreshJSONWebToken( return RefreshJSONWebToken(
user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]), user=User.objects.get(uuid=serializer.validated_data["user"]),
access_token=serializer.validated_data["access"], access_token=serializer.validated_data["access"],
refresh_token=serializer.validated_data["refresh"], refresh_token=serializer.validated_data["refresh"],
) )
@ -243,12 +243,12 @@ class VerifyJSONWebToken(BaseMutation):
detail = String() detail = String()
def mutate(self, info, token): def mutate(self, info, token):
serializer = TokenVerifySerializer(data={"token": token}) serializer = TokenVerifySerializer(data={"token": token}, retrieve_user=False)
with suppress(Exception): with suppress(Exception):
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
# noinspection PyTypeChecker # noinspection PyTypeChecker
return VerifyJSONWebToken( return VerifyJSONWebToken(
token_is_valid=True, user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]) token_is_valid=True, user=User.objects.get(uuid=serializer.validated_data["user"])
) )
detail = traceback.format_exc() if settings.DEBUG else "" detail = traceback.format_exc() if settings.DEBUG else ""
# noinspection PyTypeChecker # noinspection PyTypeChecker

View file

@ -137,6 +137,8 @@ class TokenObtainSerializer(Serializer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.user: User | None = None self.user: User | None = None
self.user_uuid: str | None = None
self.retrieve_user: bool = kwargs.get("retrieve_user", True)
self.fields[self.username_field] = CharField(write_only=True) self.fields[self.username_field] = CharField(write_only=True)
self.fields["password"] = PasswordField() self.fields["password"] = PasswordField()
@ -179,7 +181,7 @@ class TokenObtainPairSerializer(TokenObtainSerializer):
data["refresh"] = str(refresh) data["refresh"] = str(refresh)
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
data["access"] = str(refresh.access_token) # type: ignore [attr-defined] data["access"] = str(refresh.access_token) # type: ignore [attr-defined]
data["user"] = UserSerializer(self.user).data data["user"] = UserSerializer(self.user).data if self.retrieve_user else self.user.pk
if api_settings.UPDATE_LAST_LOGIN: if api_settings.UPDATE_LAST_LOGIN:
if not self.user: if not self.user:
@ -195,6 +197,11 @@ class TokenRefreshSerializer(Serializer):
access = CharField(read_only=True) access = CharField(read_only=True)
token_class = RefreshToken token_class = RefreshToken
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.retrieve_user: bool = kwargs.get("retrieve_user", True)
def validate(self, attrs: dict[str, Any]) -> dict[str, str]: def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
refresh = self.token_class(attrs["refresh"]) refresh = self.token_class(attrs["refresh"])
@ -210,9 +217,12 @@ class TokenRefreshSerializer(Serializer):
refresh.set_iat() refresh.set_iat()
data["refresh"] = str(refresh) data["refresh"] = str(refresh)
user = User.objects.get(uuid=refresh.payload["user_uuid"])
# noinspection PyTypeChecker # noinspection PyTypeChecker
data["user"] = UserSerializer(user).data # type: ignore [assignment] data["user"] = (
UserSerializer(User.objects.get(uuid=refresh.payload["user_uuid"])).data
if self.retrieve_user
else refresh.payload["user_uuid"]
)
return data return data
@ -220,6 +230,11 @@ class TokenRefreshSerializer(Serializer):
class TokenVerifySerializer(Serializer): class TokenVerifySerializer(Serializer):
token = CharField(write_only=True) token = CharField(write_only=True)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.retrieve_user: bool = kwargs.get("retrieve_user", True)
def validate(self, attrs: dict[str, None]) -> dict[Any, Any]: def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
token = UntypedToken(attrs["token"]) token = UntypedToken(attrs["token"])
@ -237,15 +252,16 @@ class TokenVerifySerializer(Serializer):
raise ValidationError(_("invalid token")) from te raise ValidationError(_("invalid token")) from te
try: try:
user_uuid = payload["user_uuid"] # noinspection PyTypeChecker
user = User.objects.get(uuid=user_uuid) attrs["user"] = (
UserSerializer(User.objects.get(uuid=payload["user_uuid"])).data
if self.retrieve_user
else payload["user_uuid"]
)
except KeyError as ke: except KeyError as ke:
raise ValidationError(_("no user uuid claim present in token")) from ke raise ValidationError(_("no user uuid claim present in token")) from ke
except User.DoesNotExist as dne: except User.DoesNotExist as dne:
raise ValidationError(_("user does not exist")) from dne raise ValidationError(_("user does not exist")) from dne
# noinspection PyTypeChecker
attrs["user"] = UserSerializer(user).data # type: ignore [assignment]
return attrs return attrs