Merge branch 'refs/heads/main' into storefront

This commit is contained in:
Alexandr SaVBaD Waltz 2025-05-29 23:08:19 +03:00
commit 9220de5e63
16 changed files with 191 additions and 61 deletions

View file

@ -20,6 +20,7 @@ from core.viewsets import (
CategoryViewSet, CategoryViewSet,
FeedbackViewSet, FeedbackViewSet,
OrderViewSet, OrderViewSet,
ProductTagViewSet,
ProductViewSet, ProductViewSet,
PromoCodeViewSet, PromoCodeViewSet,
PromotionViewSet, PromotionViewSet,
@ -41,6 +42,7 @@ core_router.register(r"stocks", StockViewSet, basename="stocks")
core_router.register(r"promo_codes", PromoCodeViewSet, basename="promo_codes") core_router.register(r"promo_codes", PromoCodeViewSet, basename="promo_codes")
core_router.register(r"promotions", PromotionViewSet, basename="promotions") core_router.register(r"promotions", PromotionViewSet, basename="promotions")
core_router.register(r"addresses", AddressViewSet, basename="addresses") core_router.register(r"addresses", AddressViewSet, basename="addresses")
core_router.register(r"product_tags", ProductTagViewSet, basename="product_tags")
sitemaps = { sitemaps = {
"products": ProductSitemap, "products": ProductSitemap,

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

View file

@ -1,5 +1,6 @@
import json import json
import logging import logging
import uuid
from django.db.models import Avg, FloatField, OuterRef, Q, Subquery, Value from django.db.models import Avg, FloatField, OuterRef, Q, Subquery, Value
from django.db.models.functions import Coalesce from django.db.models.functions import Coalesce
@ -251,7 +252,7 @@ class WishlistFilter(FilterSet):
class CategoryFilter(FilterSet): class CategoryFilter(FilterSet):
uuid = UUIDFilter(field_name="uuid", lookup_expr="exact") uuid = UUIDFilter(field_name="uuid", lookup_expr="exact")
name = CharFilter(field_name="name", lookup_expr="icontains") name = CharFilter(field_name="name", lookup_expr="icontains")
parent_uuid = UUIDFilter(field_name="parent__uuid", lookup_expr="exact") parent_uuid = CharFilter(method="filter_parent_uuid")
slug = CharFilter(field_name="slug", lookup_expr="exact") slug = CharFilter(field_name="slug", lookup_expr="exact")
order_by = OrderingFilter( order_by = OrderingFilter(
@ -264,7 +265,22 @@ class CategoryFilter(FilterSet):
class Meta: class Meta:
model = Category model = Category
fields = ["uuid", "name"] fields = ["uuid", "name", "parent_uuid", "slug"]
def filter_parent_uuid(self, queryset, name, value):
"""
If ?parent_uuid= or ?parent_uuid=null, return items with parent=None.
Otherwise treat `value` as a real UUID and filter parent__uuid=value.
"""
if value in ("", "null", "None"):
return queryset.filter(parent=None)
try:
uuid_val = uuid.UUID(value)
except (ValueError, TypeError):
return queryset
return queryset.filter(parent__uuid=uuid_val)
class BrandFilter(FilterSet): class BrandFilter(FilterSet):

View file

@ -21,6 +21,7 @@ from core.models import (
OrderProduct, OrderProduct,
Product, Product,
ProductImage, ProductImage,
ProductTag,
PromoCode, PromoCode,
Promotion, Promotion,
Stock, Stock,
@ -140,7 +141,7 @@ class CategoryType(DjangoObjectType):
if depth <= 0: if depth <= 0:
return Category.objects.none() return Category.objects.none()
categories = Category.objects.language(info.context.locale).filter(parent=self) categories = Category.objects.filter(parent=self)
if info.context.user.has_perm("core.view_category"): if info.context.user.has_perm("core.view_category"):
return categories return categories
return categories.filter(is_active=True) return categories.filter(is_active=True)
@ -455,6 +456,17 @@ class WishlistType(DjangoObjectType):
description = _("wishlists") description = _("wishlists")
class ProductTagType(DjangoObjectType):
product_set = DjangoFilterConnectionField(ProductType, description=_("tagged products"))
class Meta:
model = ProductTag
interfaces = (relay.Node,)
fields = ("uuid", "tag_name", "name", "product_set")
filter_fields = ["uuid", "tag_name", "name"]
description = _("product tags")
class ConfigType(ObjectType): class ConfigType(ObjectType):
project_name = String(description=_("project name")) project_name = String(description=_("project name"))
base_domain = String(description=_("company email")) base_domain = String(description=_("company email"))
@ -511,5 +523,5 @@ class SearchResultsType(ObjectType):
class BulkActionOrderProductInput(InputObjectType): class BulkActionOrderProductInput(InputObjectType):
id = UUID(required=True) uuid = UUID(required=True)
attributes = GenericScalar(required=False) attributes = GenericScalar(required=False)

View file

@ -48,6 +48,7 @@ from core.graphene.object_types import (
OrderProductType, OrderProductType,
OrderType, OrderType,
ProductImageType, ProductImageType,
ProductTagType,
ProductType, ProductType,
PromoCodeType, PromoCodeType,
PromotionType, PromotionType,
@ -64,6 +65,7 @@ from core.models import (
OrderProduct, OrderProduct,
Product, Product,
ProductImage, ProductImage,
ProductTag,
PromoCode, PromoCode,
Promotion, Promotion,
Stock, Stock,
@ -108,6 +110,7 @@ class Query(ObjectType):
product_images = DjangoFilterConnectionField(ProductImageType) product_images = DjangoFilterConnectionField(ProductImageType)
stocks = DjangoFilterConnectionField(StockType) stocks = DjangoFilterConnectionField(StockType)
wishlists = DjangoFilterConnectionField(WishlistType, filterset_class=WishlistFilter) wishlists = DjangoFilterConnectionField(WishlistType, filterset_class=WishlistFilter)
product_tags = DjangoFilterConnectionField(ProductTagType)
promotions = DjangoFilterConnectionField(PromotionType) promotions = DjangoFilterConnectionField(PromotionType)
promocodes = DjangoFilterConnectionField(PromoCodeType) promocodes = DjangoFilterConnectionField(PromoCodeType)
brands = DjangoFilterConnectionField(BrandType, filterset_class=BrandFilter) brands = DjangoFilterConnectionField(BrandType, filterset_class=BrandFilter)
@ -184,7 +187,7 @@ class Query(ObjectType):
@staticmethod @staticmethod
def resolve_categories(_parent, info, **kwargs): def resolve_categories(_parent, info, **kwargs):
categories = Category.objects.filter(parent=None) categories = Category.objects.all()
if info.context.user.has_perm("core.view_category"): if info.context.user.has_perm("core.view_category"):
return categories return categories
return categories.filter(is_active=True) return categories.filter(is_active=True)
@ -280,6 +283,12 @@ class Query(ObjectType):
return promocodes.filter(user__uuid=kwargs.get("user_uuid")) or promocodes.all() return promocodes.filter(user__uuid=kwargs.get("user_uuid")) or promocodes.all()
return promocodes.filter(is_active=True, user=info.context.user) return promocodes.filter(is_active=True, user=info.context.user)
@staticmethod
def resolve_product_tags(_parent, info, **kwargs):
if info.context.user.has_perm("core.view_producttag"):
return ProductTag.objects.all()
return ProductTag.objects.filter(is_active=True)
class Mutation(ObjectType): class Mutation(ObjectType):
search = Search.Field() search = Search.Field()

View file

@ -22,7 +22,7 @@ class AddressManager(models.Manager):
"addressdetails": 1, "addressdetails": 1,
"q": raw_data, "q": raw_data,
} }
resp = requests.get(config.NOMINATIM_URL, params=params) resp = requests.get(config.NOMINATIM_URL.rstrip("/") + "/search", params=params)
resp.raise_for_status() resp.raise_for_status()
results = resp.json() results = resp.json()
if not results: if not results:
@ -31,7 +31,7 @@ class AddressManager(models.Manager):
# Parse address components # Parse address components
addr = data.get("address", {}) addr = data.get("address", {})
street = addr.get("road") or addr.get("pedestrian") or "" street = f"{addr.get('road', '') or addr.get('pedestrian', '')}, {addr.get('house_number', '')}"
district = addr.get("city_district") or addr.get("suburb") or "" district = addr.get("city_district") or addr.get("suburb") or ""
city = addr.get("city") or addr.get("town") or addr.get("village") or "" city = addr.get("city") or addr.get("town") or addr.get("village") or ""
region = addr.get("state") or addr.get("region") or "" region = addr.get("state") or addr.get("region") or ""
@ -49,6 +49,7 @@ class AddressManager(models.Manager):
# Create the model instance, storing both the input string and full API response # Create the model instance, storing both the input string and full API response
return super().create( return super().create(
raw_data=raw_data, raw_data=raw_data,
address_line=f"{kwargs.get('address_line_1')}, {kwargs.get('address_line_2')}",
street=street, street=street,
district=district, district=district,
city=city, city=city,

View file

@ -0,0 +1,18 @@
# Generated by Django 5.2 on 2025-05-28 19:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('core', '0022_category_slug'),
]
operations = [
migrations.AddField(
model_name='address',
name='address_line',
field=models.TextField(blank=True, help_text='address line for the customer', null=True,
verbose_name='address line'),
),
]

View file

@ -1259,6 +1259,12 @@ class Documentary(NiceModel):
class Address(NiceModel): class Address(NiceModel):
is_publicly_visible = False is_publicly_visible = False
address_line = TextField( # noqa: DJ001
blank=True,
null=True,
help_text=_("address line for the customer"),
verbose_name=_("address line"),
)
street = CharField(_("street"), max_length=255, null=True) # noqa: DJ001 street = CharField(_("street"), max_length=255, null=True) # noqa: DJ001
district = CharField(_("district"), max_length=255, null=True) # noqa: DJ001 district = CharField(_("district"), max_length=255, null=True) # noqa: DJ001
city = CharField(_("city"), max_length=100, null=True) # noqa: DJ001 city = CharField(_("city"), max_length=100, null=True) # noqa: DJ001

View file

@ -22,6 +22,8 @@ class EvibesPermission(permissions.BasePermission):
- Standard model perms ('add', 'view', 'change', 'delete') are enforced for all other actions, - Standard model perms ('add', 'view', 'change', 'delete') are enforced for all other actions,
including for staff users. including for staff users.
- Publicly visible models allow anonymous list/retrieve. - Publicly visible models allow anonymous list/retrieve.
- If an instance or queryset has a "user" attribute, ensure that the request.user is the same,
unless the user is an admin with the required django permission.
""" """
ACTION_PERM_MAP = { ACTION_PERM_MAP = {
@ -34,6 +36,8 @@ class EvibesPermission(permissions.BasePermission):
} }
USER_SCOPED_ACTIONS = { USER_SCOPED_ACTIONS = {
"list",
"retrieve",
"buy", "buy",
"buy_unregistered", "buy_unregistered",
"current", "current",
@ -64,19 +68,56 @@ class EvibesPermission(permissions.BasePermission):
if request.user.has_perm(f"{app_label}.{codename}"): if request.user.has_perm(f"{app_label}.{codename}"):
return True return True
return bool(action in ("list", "retrieve") and getattr(model, "is_publicly_visible", False)) return bool(
action in ("list", "retrieve")
and getattr(model, "is_publicly_visible", False)
)
def has_object_permission(self, request, view, obj):
if request.method in permissions.SAFE_METHODS:
return True
if hasattr(obj, "user"):
if obj.user == request.user:
return True
# Allow admins who hold the required model permission
app_label = obj._meta.app_label
model_name = obj._meta.model_name
action = getattr(view, "action", None)
perm_prefix = self.ACTION_PERM_MAP.get(action)
return bool(perm_prefix and request.user.has_perm(f"{app_label}.{perm_prefix}_{model_name}"))
model = view.queryset.model
app_label = model._meta.app_label
model_name = model._meta.model_name
action = getattr(view, "action", None)
perm_prefix = self.ACTION_PERM_MAP.get(action)
return bool(perm_prefix and request.user.has_perm(f"{app_label}.{perm_prefix}_{model_name}"))
def has_queryset_permission(self, request, view, queryset): def has_queryset_permission(self, request, view, queryset):
""" """
Filter the base queryset according to the action and user. Filter the base queryset according to the action and user.
Staff users still require view permissions to see records. For models with a "user" field, restrict access to records belonging to the request user
unless the admin holds the needed permissions.
""" """
model = view.queryset.model model = view.queryset.model
app_label = model._meta.app_label app_label = model._meta.app_label
model_name = model._meta.model_name model_name = model._meta.model_name
if view.action in self.USER_SCOPED_ACTIONS: if hasattr(model, "user"):
return queryset.filter(user=request.user) if view.action in self.USER_SCOPED_ACTIONS:
return queryset.filter(user=request.user)
if view.action in ("list", "retrieve"):
if request.user.has_perm(f"{app_label}.view_{model_name}"):
return queryset
return queryset.none()
base = queryset.filter(is_active=True, user=request.user)
if request.user.is_staff and request.user.has_perm(
f"{app_label}.{self.ACTION_PERM_MAP.get(view.action)}_{model_name}"
):
return queryset.filter(is_active=True)
return base
if view.action in ("list", "retrieve"): if view.action in ("list", "retrieve"):
if request.user.has_perm(f"{app_label}.view_{model_name}"): if request.user.has_perm(f"{app_label}.view_{model_name}"):
@ -87,10 +128,7 @@ class EvibesPermission(permissions.BasePermission):
base = queryset.filter(is_active=True) base = queryset.filter(is_active=True)
match view.action: match view.action:
case "update": case "update" | "partial_update":
if request.user.has_perm(f"{app_label}.change_{model_name}"):
return base
case "partial_update":
if request.user.has_perm(f"{app_label}.change_{model_name}"): if request.user.has_perm(f"{app_label}.change_{model_name}"):
return base return base
case "destroy": case "destroy":

View file

@ -158,10 +158,24 @@ class AddressCreateSerializer(ModelSerializer): # noqa: F405
write_only=True, write_only=True,
max_length=512, max_length=512,
) )
address_line_1 = CharField(
write_only=True,
max_length=128,
required=False
)
address_line_2 = CharField(
write_only=True,
max_length=128,
required=False
)
class Meta: class Meta:
model = Address model = Address
fields = ["raw_data"] fields = [
"raw_data",
"address_line_1",
"address_line_2"
]
def create(self, validated_data): def create(self, validated_data):
raw = validated_data.pop("raw_data") raw = validated_data.pop("raw_data")

View file

@ -89,20 +89,19 @@ class CategoryDetailSerializer(ModelSerializer):
.distinct() .distinct()
) )
distinct_vals_list = list(distinct_vals) distinct_vals_list = list(distinct_vals)[0:128] if len(list(distinct_vals)) > 128 else list(distinct_vals)
if len(distinct_vals_list) <= 256: filterable_results.append(
filterable_results.append( {
{ "attribute_name": attr.name,
"attribute_name": attr.name, "possible_values": distinct_vals_list,
"possible_values": distinct_vals_list, "value_type": attr.value_type,
"value_type": attr.value_type, }
} )
)
else: if not user.has_perm("view_attribute"):
continue cache.set(f"{obj.uuid}_filterable_results", filterable_results, 86400)
cache.set(f"{obj.uuid}_filterable_results", filterable_results, 86400)
return filterable_results return filterable_results
def get_children(self, obj) -> list[dict]: def get_children(self, obj) -> list[dict]:
@ -123,7 +122,7 @@ class CategoryDetailSerializer(ModelSerializer):
class BrandDetailSerializer(ModelSerializer): class BrandDetailSerializer(ModelSerializer):
categories = CategoryDetailSerializer(many=True) categories = CategorySimpleSerializer(many=True)
small_logo = SerializerMethodField() small_logo = SerializerMethodField()
big_logo = SerializerMethodField() big_logo = SerializerMethodField()

View file

@ -73,13 +73,18 @@ from core.serializers import (
OrderProductSimpleSerializer, OrderProductSimpleSerializer,
OrderSimpleSerializer, OrderSimpleSerializer,
ProductDetailSerializer, ProductDetailSerializer,
ProductImageDetailSerializer,
ProductImageSimpleSerializer, ProductImageSimpleSerializer,
ProductSimpleSerializer, ProductSimpleSerializer,
ProductTagDetailSerializer,
ProductTagSimpleSerializer, ProductTagSimpleSerializer,
PromoCodeDetailSerializer,
PromoCodeSimpleSerializer, PromoCodeSimpleSerializer,
PromotionDetailSerializer,
PromotionSimpleSerializer, PromotionSimpleSerializer,
RemoveOrderProductSerializer, RemoveOrderProductSerializer,
RemoveWishlistProductSerializer, RemoveWishlistProductSerializer,
StockDetailSerializer,
StockSimpleSerializer, StockSimpleSerializer,
VendorSimpleSerializer, VendorSimpleSerializer,
WishlistDetailSerializer, WishlistDetailSerializer,
@ -364,21 +369,11 @@ class OrderProductViewSet(EvibesViewSet):
} }
class ProductTagViewSet(EvibesViewSet):
queryset = ProductTag.objects.all()
filter_backends = [DjangoFilterBackend]
filterset_fields = ["tag_name", "is_active"]
serializer_class = AttributeGroupDetailSerializer
action_serializer_classes = {
"list": ProductTagSimpleSerializer,
}
class ProductImageViewSet(EvibesViewSet): class ProductImageViewSet(EvibesViewSet):
queryset = ProductImage.objects.all() queryset = ProductImage.objects.all()
filter_backends = [DjangoFilterBackend] filter_backends = [DjangoFilterBackend]
filterset_fields = ["product", "priority", "is_active"] filterset_fields = ["product", "priority", "is_active"]
serializer_class = AttributeGroupDetailSerializer serializer_class = ProductImageDetailSerializer
action_serializer_classes = { action_serializer_classes = {
"list": ProductImageSimpleSerializer, "list": ProductImageSimpleSerializer,
} }
@ -388,7 +383,7 @@ class PromoCodeViewSet(EvibesViewSet):
queryset = PromoCode.objects.all() queryset = PromoCode.objects.all()
filter_backends = [DjangoFilterBackend] filter_backends = [DjangoFilterBackend]
filterset_fields = ["code", "discount_amount", "discount_percent", "start_time", "end_time", "used_on", "is_active"] filterset_fields = ["code", "discount_amount", "discount_percent", "start_time", "end_time", "used_on", "is_active"]
serializer_class = AttributeGroupDetailSerializer serializer_class = PromoCodeDetailSerializer
action_serializer_classes = { action_serializer_classes = {
"list": PromoCodeSimpleSerializer, "list": PromoCodeSimpleSerializer,
} }
@ -398,7 +393,7 @@ class PromotionViewSet(EvibesViewSet):
queryset = Promotion.objects.all() queryset = Promotion.objects.all()
filter_backends = [DjangoFilterBackend] filter_backends = [DjangoFilterBackend]
filterset_fields = ["name", "discount_percent", "is_active"] filterset_fields = ["name", "discount_percent", "is_active"]
serializer_class = AttributeGroupDetailSerializer serializer_class = PromotionDetailSerializer
action_serializer_classes = { action_serializer_classes = {
"list": PromotionSimpleSerializer, "list": PromotionSimpleSerializer,
} }
@ -408,7 +403,7 @@ class StockViewSet(EvibesViewSet):
queryset = Stock.objects.all() queryset = Stock.objects.all()
filter_backends = [DjangoFilterBackend] filter_backends = [DjangoFilterBackend]
filterset_fields = ["vendor", "product", "sku", "is_active"] filterset_fields = ["vendor", "product", "sku", "is_active"]
serializer_class = AttributeGroupDetailSerializer serializer_class = StockDetailSerializer
action_serializer_classes = { action_serializer_classes = {
"list": StockSimpleSerializer, "list": StockSimpleSerializer,
} }
@ -419,7 +414,7 @@ class WishlistViewSet(EvibesViewSet):
queryset = Wishlist.objects.all() queryset = Wishlist.objects.all()
filter_backends = [DjangoFilterBackend] filter_backends = [DjangoFilterBackend]
filterset_fields = ["user", "is_active"] filterset_fields = ["user", "is_active"]
serializer_class = AttributeGroupDetailSerializer serializer_class = WishlistDetailSerializer
action_serializer_classes = { action_serializer_classes = {
"list": WishlistSimpleSerializer, "list": WishlistSimpleSerializer,
} }
@ -535,3 +530,13 @@ class AddressViewSet(EvibesViewSet):
) )
return Response(suggestions, status=status.HTTP_200_OK) return Response(suggestions, status=status.HTTP_200_OK)
class ProductTagViewSet(EvibesViewSet):
queryset = ProductTag.objects.all()
filter_backends = [DjangoFilterBackend]
filterset_fields = ["tag_name", "is_active"]
serializer_class = ProductTagDetailSerializer
action_serializer_classes = {
"list": ProductTagSimpleSerializer,
}

View file

@ -2,7 +2,7 @@ import logging
from os import getenv from os import getenv
from pathlib import Path from pathlib import Path
EVIBES_VERSION = "2.7.0" EVIBES_VERSION = "2.7.1"
BASE_DIR = Path(__file__).resolve().parent.parent.parent BASE_DIR = Path(__file__).resolve().parent.parent.parent

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "eVibes" name = "eVibes"
version = "2.7.0" version = "2.7.1"
description = "eVibes is an open-source eCommerce backend service built with Django. Its designed for flexibility, making it ideal for various use cases and learning Django skills. The project is easy to customize, allowing for straightforward editing and extension." description = "eVibes is an open-source eCommerce backend service built with Django. Its designed for flexibility, making it ideal for various use cases and learning Django skills. The project is easy to customize, allowing for straightforward editing and extension."
authors = ["fureunoir <contact@fureunoir.com>"] authors = ["fureunoir <contact@fureunoir.com>"]
readme = "README.md" readme = "README.md"

View file

@ -3,7 +3,7 @@ from hmac import compare_digest
from django.contrib.auth.password_validation import validate_password from django.contrib.auth.password_validation import validate_password
from django.contrib.auth.tokens import PasswordResetTokenGenerator from django.contrib.auth.tokens import PasswordResetTokenGenerator
from django.core.exceptions import BadRequest, PermissionDenied from django.core.exceptions import BadRequest, PermissionDenied, ValidationError
from django.db import IntegrityError from django.db import IntegrityError
from django.http import Http404 from django.http import Http404
from django.utils.http import urlsafe_base64_decode from django.utils.http import urlsafe_base64_decode
@ -11,7 +11,6 @@ from django.utils.translation import gettext_lazy as _
from graphene import UUID, Boolean, Field, List, String from graphene import UUID, Boolean, Field, List, String
from graphene.types.generic import GenericScalar from graphene.types.generic import GenericScalar
from graphene_file_upload.scalars import Upload from graphene_file_upload.scalars import Upload
from rest_framework.exceptions import ValidationError
from core.graphene import BaseMutation from core.graphene import BaseMutation
from core.utils.messages import permission_denied_message from core.utils.messages import permission_denied_message
@ -123,8 +122,8 @@ class UpdateUser(BaseMutation):
password = kwargs.get("password", "") password = kwargs.get("password", "")
confirm_password = kwargs.get("confirm_password", "") confirm_password = kwargs.get("confirm_password", "")
if compare_digest(password.lower(), email.lower()): if password:
raise BadRequest(_("password too weak")) validate_password(password=password, user=user)
if not compare_digest(password, "") and compare_digest(password, confirm_password): if not compare_digest(password, "") and compare_digest(password, confirm_password):
user.set_password(password) user.set_password(password)
@ -314,13 +313,15 @@ class ConfirmResetPassword(BaseMutation):
if not password_reset_token.check_token(user, token): if not password_reset_token.check_token(user, token):
raise BadRequest(_("token is invalid!")) raise BadRequest(_("token is invalid!"))
validate_password(password=password, user=user)
user.set_password(password) user.set_password(password)
user.save() user.save()
return ConfirmResetPassword(success=True) return ConfirmResetPassword(success=True)
except (TypeError, ValueError, OverflowError, User.DoesNotExist) as e: except (TypeError, ValueError, OverflowError, ValidationError, User.DoesNotExist) as e:
raise BadRequest(_(f"something went wrong: {e!s}")) raise BadRequest(_(f"something went wrong: {e!s}"))

View file

@ -3,7 +3,9 @@ import traceback
from contextlib import suppress from contextlib import suppress
from secrets import compare_digest from secrets import compare_digest
from django.contrib.auth.password_validation import validate_password
from django.contrib.auth.tokens import PasswordResetTokenGenerator from django.contrib.auth.tokens import PasswordResetTokenGenerator
from django.core.exceptions import ValidationError
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.utils.http import urlsafe_base64_decode from django.utils.http import urlsafe_base64_decode
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -20,7 +22,6 @@ from evibes.settings import DEBUG
from vibes_auth.docs.drf.viewsets import USER_SCHEMA from vibes_auth.docs.drf.viewsets import USER_SCHEMA
from vibes_auth.models import User from vibes_auth.models import User
from vibes_auth.serializers import ( from vibes_auth.serializers import (
ConfirmPasswordResetSerializer,
UserSerializer, UserSerializer,
) )
from vibes_auth.utils.emailing import send_reset_password_email_task from vibes_auth.utils.emailing import send_reset_password_email_task
@ -64,29 +65,34 @@ class UserViewSet(
@action(detail=False, methods=["post"]) @action(detail=False, methods=["post"])
@method_decorator(ratelimit(key="ip", rate="2/h" if not DEBUG else "888/h")) @method_decorator(ratelimit(key="ip", rate="2/h" if not DEBUG else "888/h"))
def confirm_password_reset(self): def confirm_password_reset(self, request, *args, **kwargs):
try: try:
data = ConfirmPasswordResetSerializer(self.request.data).data
if not compare_digest(data.get("password"), data.get("confirm_password")): if not compare_digest(request.data.get("password"), request.data.get("confirm_password")):
return Response( return Response(
{"error": _("passwords do not match")}, {"error": _("passwords do not match")},
status=status.HTTP_400_BAD_REQUEST, status=status.HTTP_400_BAD_REQUEST,
) )
uuid = urlsafe_base64_decode(data.get("uidb64")).decode() uuid = urlsafe_base64_decode(request.data.get("uidb64")).decode()
user = User.objects.get(pk=uuid) user = User.objects.get(pk=uuid)
validate_password(password=request.data.get("password"), user=user)
password_reset_token = PasswordResetTokenGenerator() password_reset_token = PasswordResetTokenGenerator()
if not password_reset_token.check_token(user, data.get("token")): if not password_reset_token.check_token(user, request.data.get("token")):
return Response({"error": _("token is invalid!")}, status=status.HTTP_400_BAD_REQUEST) return Response({"error": _("token is invalid!")}, status=status.HTTP_400_BAD_REQUEST)
user.set_password(data.get("password")) user.set_password(request.data.get("password"))
user.save() user.save()
return Response({"message": _("password reset successfully")}, status=status.HTTP_200_OK) return Response({"message": _("password reset successfully")}, status=status.HTTP_200_OK)
except (TypeError, ValueError, OverflowError, User.DoesNotExist) as e: except (TypeError, ValueError, OverflowError, ValidationError, User.DoesNotExist) as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST) data = {"error": str(e)}
if DEBUG:
data["detail"] = str(traceback.format_exc())
data["received"] = str(request.data)
return Response(data, status=status.HTTP_400_BAD_REQUEST)
@method_decorator(ratelimit(key="ip", rate="3/h" if not DEBUG else "888/h")) @method_decorator(ratelimit(key="ip", rate="3/h" if not DEBUG else "888/h"))
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
@ -142,6 +148,9 @@ class UserViewSet(
return Response(serializer.data) return Response(serializer.data)
def update(self, request, pk=None, *args, **kwargs): def update(self, request, pk=None, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
instance = serializer.update(instance=self.get_object(), validated_data=request.data)
return Response( return Response(
self.get_serializer(self.get_object()).update(instance=self.get_object(), validated_data=request.data).data self.get_serializer(instance.data)
) )