diff --git a/core/admin.py b/core/admin.py index 209ab8e8..928f6943 100644 --- a/core/admin.py +++ b/core/admin.py @@ -145,6 +145,13 @@ class AttributeValueInline(TabularInline): # type: ignore [type-arg] verbose_name_plural = _("attribute values") icon = "fa-solid fa-list-ul" + def get_queryset(self, request): + return ( + super() + .get_queryset(request) + .select_related("attribute", "product") + ) + class ProductImageInline(TabularInline): # type: ignore [type-arg] model = ProductImage @@ -154,6 +161,9 @@ class ProductImageInline(TabularInline): # type: ignore [type-arg] verbose_name_plural = _("images") icon = "fa-regular fa-images" + def get_queryset(self, request): + return super().get_queryset(request).select_related("product") + class StockInline(TabularInline): # type: ignore [type-arg] model = Stock @@ -163,6 +173,13 @@ class StockInline(TabularInline): # type: ignore [type-arg] verbose_name_plural = _("stocks") icon = "fa-solid fa-boxes-stacked" + def get_queryset(self, request): + return ( + super() + .get_queryset(request) + .select_related("vendor", "product") + ) + class OrderProductInline(TabularInline): # type: ignore [type-arg] model = OrderProduct @@ -439,6 +456,20 @@ class ProductAdmin(FieldsetsMixin, ActivationActionsMixin, ModelAdmin): # type: "tags", ] + def get_queryset(self, request): + # Optimize product change page to avoid N+1 queries + return ( + super() + .get_queryset(request) + .select_related("category", "brand") + .prefetch_related( + "tags", + "images", + "stocks__vendor", + "attributes__attribute", + ) + ) + @register(ProductTag) class ProductTagAdmin(FieldsetsMixin, ActivationActionsMixin, ModelAdmin): # type: ignore [misc, type-arg] diff --git a/core/graphene/object_types.py b/core/graphene/object_types.py index 8d97f7b9..4767a948 100644 --- a/core/graphene/object_types.py +++ b/core/graphene/object_types.py @@ -1,10 +1,10 @@ import logging +from contextlib import suppress from typing import Any from constance import config from django.core.cache import cache from django.db.models import Max, Min, QuerySet -from django.db.models.functions import Length from django.utils.translation import gettext_lazy as _ from graphene import ( UUID, @@ -212,7 +212,6 @@ class CategoryType(DjangoObjectType): # type: ignore [misc] fields = ( "uuid", "markup_percent", - "attributes", "children", "name", "slug", @@ -223,11 +222,16 @@ class CategoryType(DjangoObjectType): # type: ignore [misc] filter_fields = ["uuid"] description = _("categories") - def resolve_children(self, info) -> TreeQuerySet: + def resolve_children(self, info) -> TreeQuerySet | list[Category]: categories = Category.objects.filter(parent=self) - if info.context.user.has_perm("core.view_category"): - return categories - return categories.filter(is_active=True) + if not info.context.user.has_perm("core.view_category"): + categories = categories.filter(is_active=True) + result = categories + with suppress(Exception): + items = list(categories) + Category.bulk_prefetch_filterable_attributes(items) + result = items + return result def resolve_image(self: Category, info) -> str: return info.context.build_absolute_uri(self.image.url) if self.image else "" @@ -238,38 +242,7 @@ class CategoryType(DjangoObjectType): # type: ignore [misc] return 0.0 def resolve_filterable_attributes(self: Category, info): - filterable_results = cache.get(f"{self.uuid}_filterable_results", []) - - if len(filterable_results) > 0: - return filterable_results - - for attr in ( - self.attributes.all() - if info.context.user.has_perm("view_attribute") - else self.attributes.filter(is_active=True) - ): - distinct_vals = ( - AttributeValue.objects.annotate(value_length=Length("value")) - .filter(attribute=attr, attribute__is_filterable=True, product__category=self, value_length__lte=30) - .values_list("value", flat=True) - .distinct() - ) - - distinct_vals_list = list(distinct_vals) - - if len(distinct_vals_list) <= 128: - filterable_results.append( - { - "attribute_name": attr.name, - "possible_values": distinct_vals_list, - } - ) - else: - pass - - cache.set(f"{self.uuid}_filterable_results", filterable_results, 86400) - - return filterable_results + return self.filterable_attributes def resolve_min_max_prices(self: Category, _info): min_max_prices = cache.get(key=f"{self.name}_min_max_prices", default={}) diff --git a/core/models.py b/core/models.py index 27b9f0ee..5617bf92 100644 --- a/core/models.py +++ b/core/models.py @@ -3,7 +3,7 @@ import json import logging import traceback from contextlib import suppress -from typing import Any, Optional, Self +from typing import Any, Optional, Self, Iterable from constance import config from django.contrib.gis.db.models import PointField @@ -37,6 +37,7 @@ from django.db.models import ( URLField, ) from django.db.models.indexes import Index +from django.db.models.functions import Length from django.http import Http404 from django.utils import timezone from django.utils.encoding import force_bytes @@ -52,6 +53,7 @@ from core.abstract import NiceModel from core.choices import ORDER_PRODUCT_STATUS_CHOICES, ORDER_STATUS_CHOICES from core.errors import DisabledCommerceError, NotEnoughMoneyError from core.managers import AddressManager, ProductManager +from core.typing import FilterableAttribute from core.utils import ( generate_human_readable_id, generate_human_readable_token, @@ -340,6 +342,81 @@ class Category(ExportModelOperationsMixin("category"), NiceModel, MPTTModel): # return 0 return self.get_descendants().aggregate(max_depth=Max("level"))["max_depth"] - self.get_level() + @classmethod + def bulk_prefetch_filterable_attributes(cls, categories: Iterable["Category"]) -> None: + cat_list = [c for c in categories] + if not cat_list: + return + cat_ids = [c.id for c in cat_list if c.id] + if not cat_ids: + return + + rows = ( + AttributeValue.objects.annotate(value_length=Length("value")) + .filter( + product__category_id__in=cat_ids, + attribute__is_filterable=True, + value_length__lte=30, + ) + .values_list( + "product__category_id", + "attribute_id", + "attribute__name", + "attribute__value_type", + "value", + ) + .distinct() + ) + + per_cat: dict[int, dict[int, dict]] = {} + for cat_id, attr_id, attr_name, value_type, value in rows: + cat_bucket = per_cat.get(cat_id) + if cat_bucket is None: + cat_bucket = {} + per_cat[cat_id] = cat_bucket + bucket = cat_bucket.get(attr_id) + if bucket is None: + bucket = { + "attribute_name": attr_name, + "possible_values": [], + "value_type": value_type, + } + cat_bucket[attr_id] = bucket + if len(bucket["possible_values"]) < 128 and value not in bucket["possible_values"]: + bucket["possible_values"].append(value) + + for c in cat_list: + data = list(per_cat.get(c.id, {}).values()) + c.__dict__["filterable_attributes"] = data + + @cached_property + def filterable_attributes(self) -> list[FilterableAttribute]: + rows = ( + AttributeValue.objects.annotate(value_length=Length("value")) + .filter( + product__category=self, + attribute__is_filterable=True, + value_length__lte=30, + ) + .values_list("attribute_id", "attribute__name", "attribute__value_type", "value") + .distinct() + ) + + by_attr: dict[int, dict] = {} + for attr_id, attr_name, value_type, value in rows: + bucket = by_attr.get(attr_id) + if bucket is None: + bucket = { + "attribute_name": attr_name, + "possible_values": [], + "value_type": value_type, + } + by_attr[attr_id] = bucket + if len(bucket["possible_values"]) < 128 and value not in bucket["possible_values"]: + bucket["possible_values"].append(value) + + return list(by_attr.values()) # type: ignore [arg-type] + class Meta: verbose_name = _("category") verbose_name_plural = _("categories") @@ -627,13 +704,6 @@ class Attribute(ExportModelOperationsMixin("attribute"), NiceModel): # type: ig ) is_publicly_visible = True - categories = ManyToManyField( - "core.Category", - related_name="attributes", - help_text=_("category of this attribute"), - verbose_name=_("categories"), - ) - group = ForeignKey( "core.AttributeGroup", on_delete=CASCADE, diff --git a/core/serializers/detail.py b/core/serializers/detail.py index f9c9ef90..dd0f605a 100644 --- a/core/serializers/detail.py +++ b/core/serializers/detail.py @@ -1,12 +1,9 @@ import logging -from collections import defaultdict from contextlib import suppress from typing import Collection, Any -from django.core.cache import cache -from django.db.models.functions import Length from rest_framework.fields import JSONField, SerializerMethodField -from rest_framework.serializers import ModelSerializer +from rest_framework.serializers import ModelSerializer, ListSerializer from rest_framework_recursive.fields import RecursiveField from core.models import ( @@ -28,6 +25,7 @@ from core.models import ( Wishlist, ) from core.serializers.simple import CategorySimpleSerializer, ProductSimpleSerializer +from core.typing import FilterableAttribute from core.serializers.utility import AddressSerializer logger = logging.getLogger("django") @@ -47,6 +45,15 @@ class AttributeGroupDetailSerializer(ModelSerializer): ] +class CategoryDetailListSerializer(ListSerializer): + + def to_representation(self, data): # type: ignore[override] + items = list(data) + with suppress(Exception): + Category.bulk_prefetch_filterable_attributes(items) + return super().to_representation(items) + + class CategoryDetailSerializer(ModelSerializer): children = SerializerMethodField() image = SerializerMethodField() @@ -54,6 +61,7 @@ class CategoryDetailSerializer(ModelSerializer): class Meta: model = Category + list_serializer_class = CategoryDetailListSerializer fields = [ "uuid", "name", @@ -72,46 +80,8 @@ class CategoryDetailSerializer(ModelSerializer): return obj.image.url return None - def get_filterable_attributes(self, obj: Category) -> list[dict]: - cache_key = f"{obj.uuid}_filterable_results" - filterable_results = cache.get(cache_key) - if filterable_results is not None: - return filterable_results - - attrs_qs = obj.attributes.filter(is_active=True, is_filterable=True) - attributes = list(attrs_qs) - - attr_ids = [a.id for a in attributes] - raw_vals = ( - AttributeValue.objects.annotate(value_length=Length("value")) - .filter( - attribute_id__in=attr_ids, - product__category=obj, - value_length__lte=30, - ) - .values_list("attribute_id", "value") - .distinct() - ) - - grouped = defaultdict(list) - for attr_id, val in raw_vals: - grouped[attr_id].append(val) - - filterable_results = [] - for attr in attributes: - vals = grouped.get(attr.id, []) # type: ignore - slice_vals = vals[:128] if len(vals) > 128 else vals - filterable_results.append( - { - "attribute_name": attr.name, - "possible_values": slice_vals, - "value_type": attr.value_type, - } - ) - - cache.set(cache_key, filterable_results, 3600) - - return filterable_results + def get_filterable_attributes(self, obj: Category) -> list[FilterableAttribute]: + return obj.filterable_attributes def get_children(self, obj) -> Collection[Any]: request = self.context.get("request") diff --git a/core/typing/__init__.py b/core/typing/__init__.py new file mode 100644 index 00000000..a686422e --- /dev/null +++ b/core/typing/__init__.py @@ -0,0 +1,6 @@ +from .models import AttributeValueTypeLiteral, FilterableAttribute + +__all__ = [ + "AttributeValueTypeLiteral", + "FilterableAttribute", +] diff --git a/core/typing/models.py b/core/typing/models.py new file mode 100644 index 00000000..5acc80ed --- /dev/null +++ b/core/typing/models.py @@ -0,0 +1,16 @@ +from typing import Literal, TypedDict + +AttributeValueTypeLiteral = Literal[ + "string", + "integer", + "float", + "boolean", + "array", + "object", +] + + +class FilterableAttribute(TypedDict): + attribute_name: str + possible_values: list[str] + value_type: AttributeValueTypeLiteral diff --git a/vibes_auth/messaging/services.py b/vibes_auth/messaging/services.py index 82cd970c..7184fb8b 100644 --- a/vibes_auth/messaging/services.py +++ b/vibes_auth/messaging/services.py @@ -54,7 +54,6 @@ def create_anon_thread(email: str) -> ChatThread: def send_message(thread: ChatThread, *, sender_user: Optional[User], sender_type: SenderType, text: str) -> ChatMessage: if not text or len(text) > 1028: raise ValidationError({"text": _("Message must be 1..1028 characters.")}) - # Permission rules: non-staff may only write to their own thread or anon thread they initiated if sender_user and not sender_user.is_staff: if thread.user_id != sender_user.pk: raise PermissionDenied @@ -88,7 +87,6 @@ def send_message(thread: ChatThread, *, sender_user: Optional[User], sender_type def auto_reply(thread: ChatThread) -> None: - # Localizable text, do not translate here text = _("We're searching for the operator to answer you already, hold by!") msg = ChatMessage.objects.create( thread=thread, @@ -119,7 +117,6 @@ def claim_thread(thread: ChatThread, staff_user: User) -> ChatThread: if not staff_user.is_staff: raise PermissionDenied if thread.assigned_to_id and not staff_user.is_superuser: - # already assigned, cannot reassign/unassign raise PermissionDenied thread.assigned_to = staff_user thread.save(update_fields=["assigned_to", "modified"]) @@ -144,9 +141,8 @@ def reassign_thread(thread: ChatThread, superuser: User, new_staff: User) -> Cha def close_thread(thread: ChatThread, actor: User | None) -> ChatThread: if actor and actor.is_staff: - pass # allowed + pass elif actor and not actor.is_staff: - # non-staff allowed to close own thread? Keep simple: allowed only for staff for now raise PermissionDenied thread.status = ThreadStatus.CLOSED thread.save(update_fields=["status", "modified"])