import logging from collections.abc import Collection from contextlib import suppress from hmac import compare_digest from typing import Any from constance import config from django.contrib.auth import authenticate from django.contrib.auth.models import update_last_login from django.contrib.auth.password_validation import validate_password from django.core.validators import validate_email from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema_field from rest_framework.exceptions import AuthenticationFailed, ValidationError from rest_framework.fields import ( BooleanField, CharField, EmailField, ListField, SerializerMethodField, JSONField, ) from rest_framework.serializers import ModelSerializer, Serializer from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.serializers import AuthUser, PasswordField from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken from rest_framework_simplejwt.tokens import RefreshToken, Token, UntypedToken from core.models import Product from core.serializers import ProductSimpleSerializer from core.utils.security import is_safe_key from evibes import settings from vibes_auth.models import User from vibes_auth.validators import validate_phone_number logger = logging.getLogger("django") class UserSerializer(ModelSerializer): avatar_url = SerializerMethodField(required=False, read_only=True) password = CharField(write_only=True, required=False) confirm_password = CharField(write_only=True, required=False) is_staff = BooleanField(read_only=True) recently_viewed = SerializerMethodField(required=False, read_only=True) attributes = JSONField(required=False) @staticmethod def get_avatar_url(obj) -> str: if obj.avatar: return f"https://api.{config.BASE_DOMAIN}/media/{obj.avatar!s}" return f"https://api.{config.BASE_DOMAIN}/static/person.png" class Meta: model = User fields = [ "uuid", "email", "avatar_url", "is_staff", "recently_viewed", "attributes", "first_name", "last_name", "password", "confirm_password", "phone_number", "is_subscribed", "modified", "created", ] def create(self, validated_data): user = User.objects.create( email=validated_data.pop("email"), first_name=validated_data.pop("first_name", ""), last_name=validated_data.pop("last_name", ""), ) user.set_password(validated_data.pop("password")) for attr, value in validated_data.items(): if is_safe_key(attr): setattr(user, attr, value) user.save() return user def update(self, instance, validated_data): for attr, value in validated_data.items(): if is_safe_key(attr): setattr(instance, attr, value) if attr == "password": instance.set_password(value) instance.save() return instance def validate(self, attrs): if "attributes" in attrs: if not isinstance(attrs["attributes"], dict): raise ValidationError(_("attributes must be a dictionary")) if attrs["attributes"].get("is_business") and not attrs["attributes"].get("business_identificator"): raise ValidationError(_("business identificator is required when registering as a business")) if "password" in attrs: validate_password(attrs["password"]) if not compare_digest(attrs["password"], attrs["confirm_password"]): raise ValidationError(_("passwords do not match")) if "confirm_password" in attrs: validate_password(attrs["confirm_password"]) if not compare_digest(attrs["password"], attrs["confirm_password"]): raise ValidationError(_("passwords do not match")) if "phone_number" in attrs: validate_phone_number(attrs["phone_number"]) if self.instance: if User.objects.filter(phone_number=attrs["phone_number"]).exclude(uuid=self.instance.uuid).exists(): phone_number = attrs["phone_number"] raise ValidationError(_(f"malformed phone number: {phone_number}")) if "email" in attrs: validate_email(attrs["email"]) if self.instance: if User.objects.filter(email=attrs["email"]).exclude(uuid=self.instance.uuid).exists(): email = attrs["email"] raise ValidationError(_(f"malformed email: {email}")) return attrs @extend_schema_field(ProductSimpleSerializer(many=True)) def get_recently_viewed(self, obj) -> Collection[Any]: """ Returns a list of serialized ProductSimpleSerializer representations for the UUIDs in obj.recently_viewed. """ # noinspection PyTypeChecker return ProductSimpleSerializer( instance=Product.objects.filter(uuid__in=obj.recently_viewed, is_active=True), many=True, ).data class TokenObtainSerializer(Serializer): username_field = User.USERNAME_FIELD token_class: type[Token] | None = None default_error_messages = {"no_active_account": _("no active account")} def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.user: User | None = None self.fields[self.username_field] = CharField(write_only=True) self.fields["password"] = PasswordField() def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]: authenticate_kwargs = { self.username_field: attrs[self.username_field], "password": attrs["password"], } with suppress(KeyError): authenticate_kwargs["request"] = self.context["request"] self.user: User | None = authenticate(**authenticate_kwargs) if not api_settings.USER_AUTHENTICATION_RULE(self.user): raise AuthenticationFailed( self.error_messages["no_active_account"], str(_("no active account")), ) return {} @classmethod def get_token(cls, user: AuthUser) -> Token: if cls.token_class is not None: return cls.token_class.for_user(user) else: raise RuntimeError(_("must set token_class attribute on class.")) class TokenObtainPairSerializer(TokenObtainSerializer): token_class = RefreshToken def validate(self, attrs: dict[str, Any]) -> dict[str, str]: data = super().validate(attrs) logger.debug("Data validated") if self.user is None: raise ValidationError(_("no active account")) refresh = self.get_token(self.user) data["refresh"] = str(refresh) # noinspection PyUnresolvedReferences data["access"] = str(refresh.access_token) # type: ignore [attr-defined] data["user"] = UserSerializer(self.user).data logger.debug("Data formed") if api_settings.UPDATE_LAST_LOGIN: if not self.user: raise ValidationError(_("no active account")) # noinspection PyTypeChecker update_last_login(User, self.user) logger.debug("Updated last login") logger.debug("Returning data") return data class TokenRefreshSerializer(Serializer): refresh = CharField() access = CharField(read_only=True) token_class = RefreshToken def validate(self, attrs: dict[str, Any]) -> dict[str, str]: refresh = self.token_class(attrs["refresh"]) data = {"access": str(refresh.access_token)} if api_settings.ROTATE_REFRESH_TOKENS: if api_settings.BLACKLIST_AFTER_ROTATION: with suppress(AttributeError): refresh.blacklist() refresh.set_jti() refresh.set_exp() refresh.set_iat() data["refresh"] = str(refresh) user = User.objects.get(uuid=refresh.payload["user_uuid"]) # noinspection PyTypeChecker data["user"] = UserSerializer(user).data # type: ignore [assignment] return data class TokenVerifySerializer(Serializer): token = CharField(write_only=True) def validate(self, attrs: dict[str, None]) -> dict[Any, Any]: token = UntypedToken(attrs["token"]) if ( api_settings.BLACKLIST_AFTER_ROTATION and "rest_framework_simplejwt.token_blacklist" in settings.INSTALLED_APPS ): jti = token.get(api_settings.JTI_CLAIM) if BlacklistedToken.objects.filter(token__jti=jti).exists(): raise ValidationError(_("token_blacklisted")) try: payload = UntypedToken(attrs["token"]).payload except TokenError as te: raise ValidationError(_("invalid token")) from te try: user_uuid = payload["user_uuid"] user = User.objects.get(uuid=user_uuid) except KeyError as ke: raise ValidationError(_("no user uuid claim present in token")) from ke except User.DoesNotExist as dne: raise ValidationError(_("user does not exist")) from dne # noinspection PyTypeChecker attrs["user"] = UserSerializer(user).data # type: ignore [assignment] return attrs class ConfirmPasswordResetSerializer(Serializer): uidb64 = CharField(write_only=True, required=True) token = CharField(write_only=True, required=True) password = CharField(write_only=True, required=True) confirm_password = CharField(write_only=True, required=True) class ResetPasswordSerializer(Serializer): email = EmailField(write_only=True, required=True) class ActivateEmailSerializer(Serializer): uidb64 = CharField(required=True) token = CharField(required=True) class MergeRecentlyViewedSerializer(Serializer): product_uuids = ListField(required=True, child=CharField(required=True))