Features: 1) Add retrieve_user flag to serializers for optional user data retrieval; 2) Allow flexibility in handling user data as UUID or serialized object;
Fixes: 1) Adjust `mutate` methods in GraphQL mutations to support `retrieve_user` flag; 2) Ensure proper fallback to UUID handling when `retrieve_user` is set to `False`; Extra: 1) Refactor serializers and mutations for improved customization and reduced overhead; 2) Add comments and clean up redundant code.
This commit is contained in:
parent
c5fa95d77f
commit
ecd6e5ad9f
2 changed files with 30 additions and 14 deletions
|
|
@ -201,11 +201,11 @@ class ObtainJSONWebToken(BaseMutation):
|
|||
access_token = String(required=True)
|
||||
|
||||
def mutate(self, info, email, password):
|
||||
serializer = TokenObtainPairSerializer(data={"email": email, "password": password})
|
||||
serializer = TokenObtainPairSerializer(data={"email": email, "password": password}, retrieve_user=False)
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
return ObtainJSONWebToken(
|
||||
user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]),
|
||||
user=User.objects.get(uuid=serializer.validated_data["user"]),
|
||||
refresh_token=serializer.validated_data["refresh"],
|
||||
access_token=serializer.validated_data["access"],
|
||||
)
|
||||
|
|
@ -222,11 +222,11 @@ class RefreshJSONWebToken(BaseMutation):
|
|||
refresh_token = String()
|
||||
|
||||
def mutate(self, info, refresh_token):
|
||||
serializer = TokenRefreshSerializer(data={"refresh": refresh_token})
|
||||
serializer = TokenRefreshSerializer(data={"refresh": refresh_token}, retrieve_user=False)
|
||||
try:
|
||||
serializer.is_valid(raise_exception=True)
|
||||
return RefreshJSONWebToken(
|
||||
user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"]),
|
||||
user=User.objects.get(uuid=serializer.validated_data["user"]),
|
||||
access_token=serializer.validated_data["access"],
|
||||
refresh_token=serializer.validated_data["refresh"],
|
||||
)
|
||||
|
|
@ -243,12 +243,12 @@ class VerifyJSONWebToken(BaseMutation):
|
|||
detail = String()
|
||||
|
||||
def mutate(self, info, token):
|
||||
serializer = TokenVerifySerializer(data={"token": token})
|
||||
serializer = TokenVerifySerializer(data={"token": token}, retrieve_user=False)
|
||||
with suppress(Exception):
|
||||
serializer.is_valid(raise_exception=True)
|
||||
# noinspection PyTypeChecker
|
||||
return VerifyJSONWebToken(
|
||||
token_is_valid=True, user=User.objects.get(uuid=serializer.validated_data["user"]["uuid"])
|
||||
token_is_valid=True, user=User.objects.get(uuid=serializer.validated_data["user"])
|
||||
)
|
||||
detail = traceback.format_exc() if settings.DEBUG else ""
|
||||
# noinspection PyTypeChecker
|
||||
|
|
|
|||
|
|
@ -137,6 +137,8 @@ class TokenObtainSerializer(Serializer):
|
|||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.user: User | None = None
|
||||
self.user_uuid: str | None = None
|
||||
self.retrieve_user: bool = kwargs.get("retrieve_user", True)
|
||||
self.fields[self.username_field] = CharField(write_only=True)
|
||||
self.fields["password"] = PasswordField()
|
||||
|
||||
|
|
@ -179,7 +181,7 @@ class TokenObtainPairSerializer(TokenObtainSerializer):
|
|||
data["refresh"] = str(refresh)
|
||||
# noinspection PyUnresolvedReferences
|
||||
data["access"] = str(refresh.access_token) # type: ignore [attr-defined]
|
||||
data["user"] = UserSerializer(self.user).data
|
||||
data["user"] = UserSerializer(self.user).data if self.retrieve_user else self.user.pk
|
||||
|
||||
if api_settings.UPDATE_LAST_LOGIN:
|
||||
if not self.user:
|
||||
|
|
@ -195,6 +197,11 @@ class TokenRefreshSerializer(Serializer):
|
|||
access = CharField(read_only=True)
|
||||
token_class = RefreshToken
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.retrieve_user: bool = kwargs.get("retrieve_user", True)
|
||||
|
||||
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
|
||||
refresh = self.token_class(attrs["refresh"])
|
||||
|
||||
|
|
@ -210,9 +217,12 @@ class TokenRefreshSerializer(Serializer):
|
|||
refresh.set_iat()
|
||||
|
||||
data["refresh"] = str(refresh)
|
||||
user = User.objects.get(uuid=refresh.payload["user_uuid"])
|
||||
# noinspection PyTypeChecker
|
||||
data["user"] = UserSerializer(user).data # type: ignore [assignment]
|
||||
data["user"] = (
|
||||
UserSerializer(User.objects.get(uuid=refresh.payload["user_uuid"])).data
|
||||
if self.retrieve_user
|
||||
else refresh.payload["user_uuid"]
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
|
@ -220,6 +230,11 @@ class TokenRefreshSerializer(Serializer):
|
|||
class TokenVerifySerializer(Serializer):
|
||||
token = CharField(write_only=True)
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.retrieve_user: bool = kwargs.get("retrieve_user", True)
|
||||
|
||||
def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
|
||||
token = UntypedToken(attrs["token"])
|
||||
|
||||
|
|
@ -237,15 +252,16 @@ class TokenVerifySerializer(Serializer):
|
|||
raise ValidationError(_("invalid token")) from te
|
||||
|
||||
try:
|
||||
user_uuid = payload["user_uuid"]
|
||||
user = User.objects.get(uuid=user_uuid)
|
||||
# noinspection PyTypeChecker
|
||||
attrs["user"] = (
|
||||
UserSerializer(User.objects.get(uuid=payload["user_uuid"])).data
|
||||
if self.retrieve_user
|
||||
else payload["user_uuid"]
|
||||
)
|
||||
except KeyError as ke:
|
||||
raise ValidationError(_("no user uuid claim present in token")) from ke
|
||||
except User.DoesNotExist as dne:
|
||||
raise ValidationError(_("user does not exist")) from dne
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
attrs["user"] = UserSerializer(user).data # type: ignore [assignment]
|
||||
return attrs
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue