From ecd6e5ad9fa80dbfc92c05e72c76a8fe3c65ba90 Mon Sep 17 00:00:00 2001 From: Egor fureunoir Gorbunov Date: Mon, 24 Nov 2025 16:13:42 +0300 Subject: [PATCH] Features: 1) Add `retrieve_user` flag to serializers for optional user data retrieval; 2) Allow flexibility in handling user data as UUID or serialized object; Fixes: 1) Adjust `mutate` methods in GraphQL mutations to support `retrieve_user` flag; 2) Ensure proper fallback to UUID handling when `retrieve_user` is set to `False`; Extra: 1) Refactor serializers and mutations for improved customization and reduced overhead; 2) Add comments and clean up redundant code. --- engine/vibes_auth/graphene/mutations.py | 12 +++++----- engine/vibes_auth/serializers.py | 32 ++++++++++++++++++------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/engine/vibes_auth/graphene/mutations.py b/engine/vibes_auth/graphene/mutations.py index 4172362f..11274918 100644 --- a/engine/vibes_auth/graphene/mutations.py +++ b/engine/vibes_auth/graphene/mutations.py @@ -201,11 +201,11 @@ class ObtainJSONWebToken(BaseMutation): access_token = String(required=True) def mutate(self, info, email, password): - serializer = TokenObtainPairSerializer(data={"email": email, "password": password}) + serializer = TokenObtainPairSerializer(data={"email": email, "password": password}, retrieve_user=False) try: serializer.is_valid(raise_exception=True) return ObtainJSONWebToken( - user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]), + user=User.objects.get(uuid=serializer.validated_data["user"]), refresh_token=serializer.validated_data["refresh"], access_token=serializer.validated_data["access"], ) @@ -222,11 +222,11 @@ class RefreshJSONWebToken(BaseMutation): refresh_token = String() def mutate(self, info, refresh_token): - serializer = TokenRefreshSerializer(data={"refresh": refresh_token}) + serializer = TokenRefreshSerializer(data={"refresh": refresh_token}, retrieve_user=False) try: serializer.is_valid(raise_exception=True) return RefreshJSONWebToken( - user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]), + user=User.objects.get(uuid=serializer.validated_data["user"]), access_token=serializer.validated_data["access"], refresh_token=serializer.validated_data["refresh"], ) @@ -243,12 +243,12 @@ class VerifyJSONWebToken(BaseMutation): detail = String() def mutate(self, info, token): - serializer = TokenVerifySerializer(data={"token": token}) + serializer = TokenVerifySerializer(data={"token": token}, retrieve_user=False) with suppress(Exception): serializer.is_valid(raise_exception=True) # noinspection PyTypeChecker return VerifyJSONWebToken( - token_is_valid=True, user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]) + token_is_valid=True, user=User.objects.get(uuid=serializer.validated_data["user"]) ) detail = traceback.format_exc() if settings.DEBUG else "" # noinspection PyTypeChecker diff --git a/engine/vibes_auth/serializers.py b/engine/vibes_auth/serializers.py index cee31913..f795ef2c 100644 --- a/engine/vibes_auth/serializers.py +++ b/engine/vibes_auth/serializers.py @@ -137,6 +137,8 @@ class TokenObtainSerializer(Serializer): super().__init__(*args, **kwargs) self.user: User | None = None + self.user_uuid: str | None = None + self.retrieve_user: bool = kwargs.get("retrieve_user", True) self.fields[self.username_field] = CharField(write_only=True) self.fields["password"] = PasswordField() @@ -179,7 +181,7 @@ class TokenObtainPairSerializer(TokenObtainSerializer): data["refresh"] = str(refresh) # noinspection PyUnresolvedReferences data["access"] = str(refresh.access_token) # type: ignore [attr-defined] - data["user"] = UserSerializer(self.user).data + data["user"] = UserSerializer(self.user).data if self.retrieve_user else self.user.pk if api_settings.UPDATE_LAST_LOGIN: if not self.user: @@ -195,6 +197,11 @@ class TokenRefreshSerializer(Serializer): access = CharField(read_only=True) token_class = RefreshToken + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.retrieve_user: bool = kwargs.get("retrieve_user", True) + def validate(self, attrs: dict[str, Any]) -> dict[str, str]: refresh = self.token_class(attrs["refresh"]) @@ -210,9 +217,12 @@ class TokenRefreshSerializer(Serializer): refresh.set_iat() data["refresh"] = str(refresh) - user = User.objects.get(uuid=refresh.payload["user_uuid"]) # noinspection PyTypeChecker - data["user"] = UserSerializer(user).data # type: ignore [assignment] + data["user"] = ( + UserSerializer(User.objects.get(uuid=refresh.payload["user_uuid"])).data + if self.retrieve_user + else refresh.payload["user_uuid"] + ) return data @@ -220,6 +230,11 @@ class TokenRefreshSerializer(Serializer): class TokenVerifySerializer(Serializer): token = CharField(write_only=True) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.retrieve_user: bool = kwargs.get("retrieve_user", True) + def validate(self, attrs: dict[str, None]) -> dict[Any, Any]: token = UntypedToken(attrs["token"]) @@ -237,15 +252,16 @@ class TokenVerifySerializer(Serializer): raise ValidationError(_("invalid token")) from te try: - user_uuid = payload["user_uuid"] - user = User.objects.get(uuid=user_uuid) + # noinspection PyTypeChecker + attrs["user"] = ( + UserSerializer(User.objects.get(uuid=payload["user_uuid"])).data + if self.retrieve_user + else payload["user_uuid"] + ) except KeyError as ke: raise ValidationError(_("no user uuid claim present in token")) from ke except User.DoesNotExist as dne: raise ValidationError(_("user does not exist")) from dne - - # noinspection PyTypeChecker - attrs["user"] = UserSerializer(user).data # type: ignore [assignment] return attrs