schon/engine/vibes_auth/serializers.py

277 lines
10 KiB
Python

import logging
from collections.abc import Collection
from contextlib import suppress
from hmac import compare_digest
from typing import Any
from constance import config
from django.contrib.auth import authenticate
from django.contrib.auth.models import update_last_login
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema_field
from rest_framework.exceptions import AuthenticationFailed, ValidationError
from rest_framework.fields import (
BooleanField,
CharField,
EmailField,
ListField,
SerializerMethodField,
JSONField,
)
from rest_framework.serializers import ModelSerializer, Serializer
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.serializers import AuthUser, PasswordField
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken
from rest_framework_simplejwt.tokens import RefreshToken, Token, UntypedToken
from engine.core.models import Product
from engine.core.serializers import ProductSimpleSerializer
from engine.core.utils.security import is_safe_key
from django.conf import settings
from engine.vibes_auth.models import User
from engine.vibes_auth.validators import validate_phone_number
logger = logging.getLogger(__name__)
class UserSerializer(ModelSerializer):
avatar_url = SerializerMethodField(required=False, read_only=True)
password = CharField(write_only=True, required=False)
confirm_password = CharField(write_only=True, required=False)
is_staff = BooleanField(read_only=True)
recently_viewed = SerializerMethodField(required=False, read_only=True)
attributes = JSONField(required=False)
@staticmethod
def get_avatar_url(obj) -> str:
if obj.avatar:
return f"https://api.{config.BASE_DOMAIN}/media/{obj.avatar!s}"
return f"https://api.{config.BASE_DOMAIN}/static/person.png"
class Meta:
model = User
fields = [
"uuid",
"email",
"avatar_url",
"is_staff",
"recently_viewed",
"attributes",
"first_name",
"last_name",
"password",
"confirm_password",
"phone_number",
"is_subscribed",
"modified",
"created",
]
def create(self, validated_data: dict[str, Any]):
user = User.objects.create(
email=validated_data.pop("email"),
first_name=validated_data.pop("first_name", ""),
last_name=validated_data.pop("last_name", ""),
)
user.set_password(validated_data.pop("password"))
for attr, value in validated_data.items():
if is_safe_key(attr):
setattr(user, attr, value)
user.save()
return user
def update(self, instance: User, validated_data: dict[str, Any]):
for attr, value in validated_data.items():
if is_safe_key(attr):
setattr(instance, attr, value)
if attr == "password":
instance.set_password(value)
instance.save()
return instance
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
if "attributes" in attrs:
if not isinstance(attrs["attributes"], dict):
raise ValidationError(_("attributes must be a dictionary"))
if attrs["attributes"].get("is_business") and not attrs["attributes"].get("business_identificator"):
raise ValidationError(_("business identificator is required when registering as a business"))
if "password" in attrs:
validate_password(attrs["password"])
if not compare_digest(attrs["password"], attrs["confirm_password"]):
raise ValidationError(_("passwords do not match"))
if "confirm_password" in attrs:
validate_password(attrs["confirm_password"])
if not compare_digest(attrs["password"], attrs["confirm_password"]):
raise ValidationError(_("passwords do not match"))
if "phone_number" in attrs:
validate_phone_number(attrs["phone_number"])
if type(self.instance) is User:
if User.objects.filter(phone_number=attrs["phone_number"]).exclude(uuid=self.instance.uuid).exists():
phone_number = attrs["phone_number"]
raise ValidationError(_(f"malformed phone number: {phone_number}"))
if "email" in attrs:
validate_email(attrs["email"])
if type(self.instance) is User:
if User.objects.filter(email=attrs["email"]).exclude(uuid=self.instance.uuid).exists():
email = attrs["email"]
raise ValidationError(_(f"malformed email: {email}"))
return attrs
@extend_schema_field(ProductSimpleSerializer(many=True))
def get_recently_viewed(self, obj: User) -> Collection[Any]:
"""
Returns a list of serialized ProductSimpleSerializer representations
for the UUIDs in obj.recently_viewed.
"""
# noinspection PyTypeChecker
return ProductSimpleSerializer(
instance=Product.objects.filter(uuid__in=obj.recently_viewed, is_active=True),
many=True,
).data
class TokenObtainSerializer(Serializer):
username_field = User.USERNAME_FIELD
token_class: type[Token] | None = None
default_error_messages = {"no_active_account": _("no active account")}
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.user: User | None = None
self.fields[self.username_field] = CharField(write_only=True)
self.fields["password"] = PasswordField()
def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]:
authenticate_kwargs = {
self.username_field: attrs[self.username_field],
"password": attrs["password"],
}
with suppress(KeyError):
authenticate_kwargs["request"] = self.context["request"]
self.user: User | None = authenticate(**authenticate_kwargs)
if not api_settings.USER_AUTHENTICATION_RULE(self.user):
raise AuthenticationFailed(
self.error_messages["no_active_account"],
str(_("no active account")),
)
return {}
@classmethod
def get_token(cls, user: AuthUser) -> Token:
if cls.token_class is not None:
return cls.token_class.for_user(user)
else:
raise RuntimeError(_("must set token_class attribute on class."))
class TokenObtainPairSerializer(TokenObtainSerializer):
token_class = RefreshToken
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
data = super().validate(attrs)
if self.user is None:
raise ValidationError(_("no active account"))
refresh = self.get_token(self.user)
data["refresh"] = str(refresh)
# noinspection PyUnresolvedReferences
data["access"] = str(refresh.access_token) # type: ignore [attr-defined]
data["user"] = UserSerializer(self.user).data
if api_settings.UPDATE_LAST_LOGIN:
if not self.user:
raise ValidationError(_("no active account"))
# noinspection PyTypeChecker
update_last_login(User, self.user)
return data
class TokenRefreshSerializer(Serializer):
refresh = CharField()
access = CharField(read_only=True)
token_class = RefreshToken
def validate(self, attrs: dict[str, Any]) -> dict[str, str]:
refresh = self.token_class(attrs["refresh"])
data = {"access": str(refresh.access_token)}
if api_settings.ROTATE_REFRESH_TOKENS:
if api_settings.BLACKLIST_AFTER_ROTATION:
with suppress(AttributeError):
refresh.blacklist()
refresh.set_jti()
refresh.set_exp()
refresh.set_iat()
data["refresh"] = str(refresh)
user = User.objects.get(uuid=refresh.payload["user_uuid"])
# noinspection PyTypeChecker
data["user"] = UserSerializer(user).data # type: ignore [assignment]
return data
class TokenVerifySerializer(Serializer):
token = CharField(write_only=True)
def validate(self, attrs: dict[str, None]) -> dict[Any, Any]:
token = UntypedToken(attrs["token"])
if (
api_settings.BLACKLIST_AFTER_ROTATION
and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS
):
jti = token.get(api_settings.JTI_CLAIM)
if BlacklistedToken.objects.filter(token__jti=jti).exists():
raise ValidationError(_("token_blacklisted"))
try:
payload = UntypedToken(attrs["token"]).payload
except TokenError as te:
raise ValidationError(_("invalid token")) from te
try:
user_uuid = payload["user_uuid"]
user = User.objects.get(uuid=user_uuid)
except KeyError as ke:
raise ValidationError(_("no user uuid claim present in token")) from ke
except User.DoesNotExist as dne:
raise ValidationError(_("user does not exist")) from dne
# noinspection PyTypeChecker
attrs["user"] = UserSerializer(user).data # type: ignore [assignment]
return attrs
class ConfirmPasswordResetSerializer(Serializer):
uidb64 = CharField(write_only=True, required=True)
token = CharField(write_only=True, required=True)
password = CharField(write_only=True, required=True)
confirm_password = CharField(write_only=True, required=True)
class ResetPasswordSerializer(Serializer):
email = EmailField(write_only=True, required=True)
class ActivateEmailSerializer(Serializer):
uidb64 = CharField(required=True)
token = CharField(required=True)
class MergeRecentlyViewedSerializer(Serializer):
product_uuids = ListField(required=True, child=CharField(required=True))