Features: 1) Add bulk_prefetch_filterable_attributes for optimized attribute fetching; 2) Introduce FilterableAttribute typing for standardizing attribute data across serializers and models; 3) Enhance CategoryDetailSerializer with custom list_serializer_class for bulk operations.

Fixes: 1) Replace redundant `get_filterable_attributes` logic with `filterable_attributes` property to avoid duplications; 2) Remove unnecessary imports and redundant comments in various modules.

Extra: Refactor admin queryset methods for optimization; remove unused `categories` field in `Attribute` model; improve clarity and maintainability of Graphene resolvers and related logic.
This commit is contained in:
Egor Pavlovich Gorbunov 2025-10-26 16:21:34 +03:00
parent fa46e3ad9c
commit 2114c8bb76
7 changed files with 157 additions and 95 deletions

View file

@ -145,6 +145,13 @@ class AttributeValueInline(TabularInline): # type: ignore [type-arg]
verbose_name_plural = _("attribute values")
icon = "fa-solid fa-list-ul"
def get_queryset(self, request):
return (
super()
.get_queryset(request)
.select_related("attribute", "product")
)
class ProductImageInline(TabularInline): # type: ignore [type-arg]
model = ProductImage
@ -154,6 +161,9 @@ class ProductImageInline(TabularInline): # type: ignore [type-arg]
verbose_name_plural = _("images")
icon = "fa-regular fa-images"
def get_queryset(self, request):
return super().get_queryset(request).select_related("product")
class StockInline(TabularInline): # type: ignore [type-arg]
model = Stock
@ -163,6 +173,13 @@ class StockInline(TabularInline): # type: ignore [type-arg]
verbose_name_plural = _("stocks")
icon = "fa-solid fa-boxes-stacked"
def get_queryset(self, request):
return (
super()
.get_queryset(request)
.select_related("vendor", "product")
)
class OrderProductInline(TabularInline): # type: ignore [type-arg]
model = OrderProduct
@ -439,6 +456,20 @@ class ProductAdmin(FieldsetsMixin, ActivationActionsMixin, ModelAdmin): # type:
"tags",
]
def get_queryset(self, request):
# Optimize product change page to avoid N+1 queries
return (
super()
.get_queryset(request)
.select_related("category", "brand")
.prefetch_related(
"tags",
"images",
"stocks__vendor",
"attributes__attribute",
)
)
@register(ProductTag)
class ProductTagAdmin(FieldsetsMixin, ActivationActionsMixin, ModelAdmin): # type: ignore [misc, type-arg]

View file

@ -1,10 +1,10 @@
import logging
from contextlib import suppress
from typing import Any
from constance import config
from django.core.cache import cache
from django.db.models import Max, Min, QuerySet
from django.db.models.functions import Length
from django.utils.translation import gettext_lazy as _
from graphene import (
UUID,
@ -212,7 +212,6 @@ class CategoryType(DjangoObjectType): # type: ignore [misc]
fields = (
"uuid",
"markup_percent",
"attributes",
"children",
"name",
"slug",
@ -223,11 +222,16 @@ class CategoryType(DjangoObjectType): # type: ignore [misc]
filter_fields = ["uuid"]
description = _("categories")
def resolve_children(self, info) -> TreeQuerySet:
def resolve_children(self, info) -> TreeQuerySet | list[Category]:
categories = Category.objects.filter(parent=self)
if info.context.user.has_perm("core.view_category"):
return categories
return categories.filter(is_active=True)
if not info.context.user.has_perm("core.view_category"):
categories = categories.filter(is_active=True)
result = categories
with suppress(Exception):
items = list(categories)
Category.bulk_prefetch_filterable_attributes(items)
result = items
return result
def resolve_image(self: Category, info) -> str:
return info.context.build_absolute_uri(self.image.url) if self.image else ""
@ -238,38 +242,7 @@ class CategoryType(DjangoObjectType): # type: ignore [misc]
return 0.0
def resolve_filterable_attributes(self: Category, info):
filterable_results = cache.get(f"{self.uuid}_filterable_results", [])
if len(filterable_results) > 0:
return filterable_results
for attr in (
self.attributes.all()
if info.context.user.has_perm("view_attribute")
else self.attributes.filter(is_active=True)
):
distinct_vals = (
AttributeValue.objects.annotate(value_length=Length("value"))
.filter(attribute=attr, attribute__is_filterable=True, product__category=self, value_length__lte=30)
.values_list("value", flat=True)
.distinct()
)
distinct_vals_list = list(distinct_vals)
if len(distinct_vals_list) <= 128:
filterable_results.append(
{
"attribute_name": attr.name,
"possible_values": distinct_vals_list,
}
)
else:
pass
cache.set(f"{self.uuid}_filterable_results", filterable_results, 86400)
return filterable_results
return self.filterable_attributes
def resolve_min_max_prices(self: Category, _info):
min_max_prices = cache.get(key=f"{self.name}_min_max_prices", default={})

View file

@ -3,7 +3,7 @@ import json
import logging
import traceback
from contextlib import suppress
from typing import Any, Optional, Self
from typing import Any, Optional, Self, Iterable
from constance import config
from django.contrib.gis.db.models import PointField
@ -37,6 +37,7 @@ from django.db.models import (
URLField,
)
from django.db.models.indexes import Index
from django.db.models.functions import Length
from django.http import Http404
from django.utils import timezone
from django.utils.encoding import force_bytes
@ -52,6 +53,7 @@ from core.abstract import NiceModel
from core.choices import ORDER_PRODUCT_STATUS_CHOICES, ORDER_STATUS_CHOICES
from core.errors import DisabledCommerceError, NotEnoughMoneyError
from core.managers import AddressManager, ProductManager
from core.typing import FilterableAttribute
from core.utils import (
generate_human_readable_id,
generate_human_readable_token,
@ -340,6 +342,81 @@ class Category(ExportModelOperationsMixin("category"), NiceModel, MPTTModel): #
return 0
return self.get_descendants().aggregate(max_depth=Max("level"))["max_depth"] - self.get_level()
@classmethod
def bulk_prefetch_filterable_attributes(cls, categories: Iterable["Category"]) -> None:
cat_list = [c for c in categories]
if not cat_list:
return
cat_ids = [c.id for c in cat_list if c.id]
if not cat_ids:
return
rows = (
AttributeValue.objects.annotate(value_length=Length("value"))
.filter(
product__category_id__in=cat_ids,
attribute__is_filterable=True,
value_length__lte=30,
)
.values_list(
"product__category_id",
"attribute_id",
"attribute__name",
"attribute__value_type",
"value",
)
.distinct()
)
per_cat: dict[int, dict[int, dict]] = {}
for cat_id, attr_id, attr_name, value_type, value in rows:
cat_bucket = per_cat.get(cat_id)
if cat_bucket is None:
cat_bucket = {}
per_cat[cat_id] = cat_bucket
bucket = cat_bucket.get(attr_id)
if bucket is None:
bucket = {
"attribute_name": attr_name,
"possible_values": [],
"value_type": value_type,
}
cat_bucket[attr_id] = bucket
if len(bucket["possible_values"]) < 128 and value not in bucket["possible_values"]:
bucket["possible_values"].append(value)
for c in cat_list:
data = list(per_cat.get(c.id, {}).values())
c.__dict__["filterable_attributes"] = data
@cached_property
def filterable_attributes(self) -> list[FilterableAttribute]:
rows = (
AttributeValue.objects.annotate(value_length=Length("value"))
.filter(
product__category=self,
attribute__is_filterable=True,
value_length__lte=30,
)
.values_list("attribute_id", "attribute__name", "attribute__value_type", "value")
.distinct()
)
by_attr: dict[int, dict] = {}
for attr_id, attr_name, value_type, value in rows:
bucket = by_attr.get(attr_id)
if bucket is None:
bucket = {
"attribute_name": attr_name,
"possible_values": [],
"value_type": value_type,
}
by_attr[attr_id] = bucket
if len(bucket["possible_values"]) < 128 and value not in bucket["possible_values"]:
bucket["possible_values"].append(value)
return list(by_attr.values()) # type: ignore [arg-type]
class Meta:
verbose_name = _("category")
verbose_name_plural = _("categories")
@ -627,13 +704,6 @@ class Attribute(ExportModelOperationsMixin("attribute"), NiceModel): # type: ig
)
is_publicly_visible = True
categories = ManyToManyField(
"core.Category",
related_name="attributes",
help_text=_("category of this attribute"),
verbose_name=_("categories"),
)
group = ForeignKey(
"core.AttributeGroup",
on_delete=CASCADE,

View file

@ -1,12 +1,9 @@
import logging
from collections import defaultdict
from contextlib import suppress
from typing import Collection, Any
from django.core.cache import cache
from django.db.models.functions import Length
from rest_framework.fields import JSONField, SerializerMethodField
from rest_framework.serializers import ModelSerializer
from rest_framework.serializers import ModelSerializer, ListSerializer
from rest_framework_recursive.fields import RecursiveField
from core.models import (
@ -28,6 +25,7 @@ from core.models import (
Wishlist,
)
from core.serializers.simple import CategorySimpleSerializer, ProductSimpleSerializer
from core.typing import FilterableAttribute
from core.serializers.utility import AddressSerializer
logger = logging.getLogger("django")
@ -47,6 +45,15 @@ class AttributeGroupDetailSerializer(ModelSerializer):
]
class CategoryDetailListSerializer(ListSerializer):
def to_representation(self, data): # type: ignore[override]
items = list(data)
with suppress(Exception):
Category.bulk_prefetch_filterable_attributes(items)
return super().to_representation(items)
class CategoryDetailSerializer(ModelSerializer):
children = SerializerMethodField()
image = SerializerMethodField()
@ -54,6 +61,7 @@ class CategoryDetailSerializer(ModelSerializer):
class Meta:
model = Category
list_serializer_class = CategoryDetailListSerializer
fields = [
"uuid",
"name",
@ -72,46 +80,8 @@ class CategoryDetailSerializer(ModelSerializer):
return obj.image.url
return None
def get_filterable_attributes(self, obj: Category) -> list[dict]:
cache_key = f"{obj.uuid}_filterable_results"
filterable_results = cache.get(cache_key)
if filterable_results is not None:
return filterable_results
attrs_qs = obj.attributes.filter(is_active=True, is_filterable=True)
attributes = list(attrs_qs)
attr_ids = [a.id for a in attributes]
raw_vals = (
AttributeValue.objects.annotate(value_length=Length("value"))
.filter(
attribute_id__in=attr_ids,
product__category=obj,
value_length__lte=30,
)
.values_list("attribute_id", "value")
.distinct()
)
grouped = defaultdict(list)
for attr_id, val in raw_vals:
grouped[attr_id].append(val)
filterable_results = []
for attr in attributes:
vals = grouped.get(attr.id, []) # type: ignore
slice_vals = vals[:128] if len(vals) > 128 else vals
filterable_results.append(
{
"attribute_name": attr.name,
"possible_values": slice_vals,
"value_type": attr.value_type,
}
)
cache.set(cache_key, filterable_results, 3600)
return filterable_results
def get_filterable_attributes(self, obj: Category) -> list[FilterableAttribute]:
return obj.filterable_attributes
def get_children(self, obj) -> Collection[Any]:
request = self.context.get("request")

6
core/typing/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from .models import AttributeValueTypeLiteral, FilterableAttribute
__all__ = [
"AttributeValueTypeLiteral",
"FilterableAttribute",
]

16
core/typing/models.py Normal file
View file

@ -0,0 +1,16 @@
from typing import Literal, TypedDict
AttributeValueTypeLiteral = Literal[
"string",
"integer",
"float",
"boolean",
"array",
"object",
]
class FilterableAttribute(TypedDict):
attribute_name: str
possible_values: list[str]
value_type: AttributeValueTypeLiteral

View file

@ -54,7 +54,6 @@ def create_anon_thread(email: str) -> ChatThread:
def send_message(thread: ChatThread, *, sender_user: Optional[User], sender_type: SenderType, text: str) -> ChatMessage:
if not text or len(text) > 1028:
raise ValidationError({"text": _("Message must be 1..1028 characters.")})
# Permission rules: non-staff may only write to their own thread or anon thread they initiated
if sender_user and not sender_user.is_staff:
if thread.user_id != sender_user.pk:
raise PermissionDenied
@ -88,7 +87,6 @@ def send_message(thread: ChatThread, *, sender_user: Optional[User], sender_type
def auto_reply(thread: ChatThread) -> None:
# Localizable text, do not translate here
text = _("We're searching for the operator to answer you already, hold by!")
msg = ChatMessage.objects.create(
thread=thread,
@ -119,7 +117,6 @@ def claim_thread(thread: ChatThread, staff_user: User) -> ChatThread:
if not staff_user.is_staff:
raise PermissionDenied
if thread.assigned_to_id and not staff_user.is_superuser:
# already assigned, cannot reassign/unassign
raise PermissionDenied
thread.assigned_to = staff_user
thread.save(update_fields=["assigned_to", "modified"])
@ -144,9 +141,8 @@ def reassign_thread(thread: ChatThread, superuser: User, new_staff: User) -> Cha
def close_thread(thread: ChatThread, actor: User | None) -> ChatThread:
if actor and actor.is_staff:
pass # allowed
pass
elif actor and not actor.is_staff:
# non-staff allowed to close own thread? Keep simple: allowed only for staff for now
raise PermissionDenied
thread.status = ThreadStatus.CLOSED
thread.save(update_fields=["status", "modified"])