schon/core/elasticsearch/__init__.py
Egor fureunoir Gorbunov c263182414 Features: 1) Add seen_keys mechanism to avoid duplicate hits in Elasticsearch query results; 2) Introduce _collect_hits helper function for processing and storing hits; 3) Add exact-match queries for categories, brands, and products to improve search accuracy.
Fixes: 1) Prevent duplicate entries in hit processing by checking `seen_keys`.

Extra: Refactor query-building logic for consistency and readability; minor performance optimizations in query execution.
2025-10-16 15:19:13 +03:00

597 lines
20 KiB
Python

import re
from typing import Any
from blib2to3.pgen2.parse import Callable
from django.conf import settings
from django.db.models import QuerySet
from django.http import Http404
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from django_elasticsearch_dsl import fields
from django_elasticsearch_dsl.registries import registry
from elasticsearch import NotFoundError
from elasticsearch.dsl import Q, Search
from rest_framework.request import Request
from core.models import Brand, Category, Product
SMART_FIELDS = [
"name^6",
"name.ngram^6",
"name.phonetic^4",
"name.translit^5",
"title^4",
"title.ngram^5",
"title.phonetic^3",
"title.translit^4",
"description^2",
"description.ngram^3",
"description.phonetic^2",
"description.translit^3",
"brand_name^4",
"brand_name.ngram^3",
"brand_name.auto^4",
"brand_name.translit^4",
"category_name^3",
"category_name.ngram^3",
"category_name.auto^3",
"category_name.translit^3",
"sku^7",
"sku.ngram^5",
"sku.auto^6",
"partnumber^8",
"partnumber.ngram^6",
"partnumber.auto^7",
]
functions = [
{
"filter": Q("term", **{"_index": "products"}),
"field_value_factor": {
"field": "brand_priority",
"modifier": "log1p",
"factor": 0.15,
"missing": 0,
},
"weight": 0.35,
},
{
"filter": Q("term", **{"_index": "products"}),
"field_value_factor": {
"field": "rating",
"modifier": "log1p",
"factor": 0.10,
"missing": 0,
},
"weight": 0.3,
},
{
"filter": Q("term", **{"_index": "products"}),
"field_value_factor": {
"field": "total_orders",
"modifier": "log1p",
"factor": 0.18,
"missing": 0,
},
"weight": 0.4,
},
{
"filter": Q("term", **{"_index": "products"}),
"field_value_factor": {
"field": "category_priority",
"modifier": "log1p",
"factor": 0.15,
"missing": 0,
},
"weight": 0.35,
},
{
"filter": Q("term", **{"_index": "categories"}),
"field_value_factor": {
"field": "priority",
"modifier": "log1p",
"factor": 0.18,
"missing": 0,
},
"weight": 0.45,
},
{
"filter": Q("term", **{"_index": "brands"}),
"field_value_factor": {
"field": "priority",
"modifier": "log1p",
"factor": 0.18,
"missing": 0,
},
"weight": 0.45,
},
]
def process_query(
query: str = "",
request: Request | None = None,
indexes: tuple[str, ...] = ("categories", "brands", "products"),
use_transliteration: bool = True,
) -> dict[str, list[dict[str, Any]]] | None:
if not query:
raise ValueError(_("no search term provided."))
query = query.strip()
try:
exact_shoulds = [
Q("term", **{"name.raw": {"value": query, "boost": 2.0}}),
Q("term", **{"slug": {"value": slugify(query), "boost": 1.5}}),
Q("term", **{"sku.raw": {"value": query.lower(), "boost": 6.0}}),
Q("term", **{"partnumber.raw": {"value": query.lower(), "boost": 7.0}}),
]
lang = ""
if request and hasattr(request, "LANGUAGE_CODE") and request.LANGUAGE_CODE:
lang = request.LANGUAGE_CODE.lower()
base = lang.split("-")[0] if lang else ""
is_cjk = base in {"ja", "zh"}
is_rtl_or_indic = base in {"ar", "hi"}
fields_all = SMART_FIELDS[:]
if not use_transliteration:
fields_all = [f for f in fields_all if ".translit" not in f]
if is_cjk or is_rtl_or_indic:
fields_all = [f for f in fields_all if ".phonetic" not in f]
fields_all = [
f.replace("name.ngram^6", "name.ngram^8")
.replace("title.ngram^5", "title.ngram^7")
.replace("description.ngram^3", "description.ngram^4")
for f in fields_all
]
fuzzy = None if (is_cjk or is_rtl_or_indic) else "AUTO:5,8"
is_code_like = bool(re.search(r"[0-9]", query)) and " " not in query
text_shoulds = [
Q(
"multi_match",
query=query,
fields=fields_all,
operator="and",
type="most_fields",
tie_breaker=0.2,
**({"fuzziness": fuzzy} if fuzzy else {}),
),
Q(
"multi_match",
query=query,
fields=[f for f in fields_all if f.endswith(".auto")],
type="bool_prefix",
),
]
if is_code_like:
text_shoulds.extend(
[
Q("term", **{"sku.raw": {"value": query.lower(), "boost": 10.0}}),
Q("term", **{"partnumber.raw": {"value": query.lower(), "boost": 12.0}}),
Q("prefix", **{"sku.raw": {"value": query.lower(), "boost": 5.0}}),
Q("prefix", **{"partnumber.raw": {"value": query.lower(), "boost": 6.0}}),
]
)
query_base = Q(
"bool",
should=exact_shoulds + text_shoulds,
minimum_should_match=1,
)
def build_search(idxs: list[str], size: int) -> Search:
return (
Search(index=idxs)
.query(query_base)
.extra(
rescore={
"window_size": 200,
"query": {
"rescore_query": Q(
"function_score",
query=Q("match_all"),
functions=functions,
boost_mode="sum",
score_mode="sum",
max_boost=1.2,
).to_dict(),
"query_weight": 1.0,
"rescore_query_weight": 0.6,
},
}
)
.extra(size=size, track_total_hits=True)
)
resp_cats = None
if "categories" in indexes:
search_cats = build_search(["categories"], size=22)
resp_cats = search_cats.execute()
resp_brands = None
if "brands" in indexes:
search_brands = build_search(["brands"], size=22)
resp_brands = search_brands.execute()
resp_products = None
if "products" in indexes:
search_products = build_search(["products"], size=44)
resp_products = search_products.execute()
results: dict[str, list[dict[str, Any]]] = {"products": [], "categories": [], "brands": [], "posts": []}
uuids_by_index: dict[str, list[dict[str, Any]]] = {"products": [], "categories": [], "brands": []}
hit_cache: list[Any] = []
seen_keys: set[tuple[str, str]] = set()
def _hit_key(hittee: Any) -> tuple[str, str]:
return hittee.meta.index, str(getattr(hittee, "uuid", None) or hittee.meta.id)
def _collect_hits(hits: list[Any]) -> None:
for hh in hits:
key = _hit_key(hh)
if key in seen_keys:
continue
hit_cache.append(hh)
seen_keys.add(key)
if getattr(hh, "uuid", None):
uuids_by_index.setdefault(hh.meta.index, []).append({"uuid": str(hh.uuid)})
exact_queries_by_index: dict[str, list[Any]] = {
"categories": [
Q("term", **{"name.raw": {"value": query}}),
Q("term", **{"slug": {"value": slugify(query)}}),
],
"brands": [
Q("term", **{"name.raw": {"value": query}}),
Q("term", **{"slug": {"value": slugify(query)}}),
],
"products": [
Q("term", **{"name.raw": {"value": query}}),
Q("term", **{"slug": {"value": slugify(query)}}),
Q("term", **{"sku.raw": {"value": query.lower()}}),
Q("term", **{"partnumber.raw": {"value": query.lower()}}),
],
}
for idx_name in ("categories", "brands", "products"):
if idx_name in indexes:
shoulds = exact_queries_by_index[idx_name]
s_exact = (
Search(index=[idx_name])
.query(Q("bool", should=shoulds, minimum_should_match=1))
.extra(size=5, track_total_hits=False)
)
try:
resp_exact = s_exact.execute()
except NotFoundError:
resp_exact = None
if resp_exact is not None and getattr(resp_exact, "hits", None):
_collect_hits(list(resp_exact.hits))
for h in (
list(resp_cats.hits[:12] if resp_cats else [])
+ list(resp_brands.hits[:12] if resp_brands else [])
+ list(resp_products.hits[:26] if resp_products else [])
):
k = _hit_key(h)
if k in seen_keys:
continue
hit_cache.append(h)
seen_keys.add(k)
if getattr(h, "uuid", None):
uuids_by_index.setdefault(h.meta.index, []).append({"uuid": str(h.uuid)})
products_by_uuid = {}
brands_by_uuid = {}
cats_by_uuid = {}
if request:
if uuids_by_index.get("products"):
products_by_uuid = {
str(p.uuid): p
for p in Product.objects.filter(uuid__in=uuids_by_index["products"])
.select_related("brand", "category")
.prefetch_related("images")
}
if uuids_by_index.get("brands"):
brands_by_uuid = {str(b.uuid): b for b in Brand.objects.filter(uuid__in=uuids_by_index["brands"])}
if uuids_by_index.get("categories"):
cats_by_uuid = {str(c.uuid): c for c in Category.objects.filter(uuid__in=uuids_by_index["categories"])}
for hit in hit_cache:
obj_uuid = getattr(hit, "uuid", None) or hit.meta.id
obj_name = getattr(hit, "name", None) or getattr(hit, "title", None) or "N/A"
obj_slug = getattr(hit, "slug", "") or (
slugify(obj_name) if hit.meta.index in {"brands", "categories"} else ""
)
image_url = None
idx = hit.meta.index
if idx == "products" and request:
prod = products_by_uuid.get(str(obj_uuid))
if prod:
first = prod.images.order_by("priority").first()
if first and first.image:
image_url = request.build_absolute_uri(first.image.url)
elif idx == "brands" and request:
brand = brands_by_uuid.get(str(obj_uuid))
if brand and brand.small_logo:
image_url = request.build_absolute_uri(brand.small_logo.url)
elif idx == "categories" and request:
cat = cats_by_uuid.get(str(obj_uuid))
if cat and cat.image:
image_url = request.build_absolute_uri(cat.image.url)
hit_result = {
"uuid": str(obj_uuid),
"name": obj_name,
"slug": obj_slug,
"image": image_url,
}
if settings.DEBUG:
if idx == "products":
hit_result["rating_debug"] = getattr(hit, "rating", 0)
hit_result["total_orders_debug"] = getattr(hit, "total_orders", 0)
hit_result["brand_priority_debug"] = getattr(hit, "brand_priority", 0)
hit_result["category_priority_debug"] = getattr(hit, "category_priority", 0)
if idx in ("brands", "categories"):
hit_result["priority_debug"] = getattr(hit, "priority", 0)
results[idx].append(hit_result)
return results
except NotFoundError as nfe:
raise Http404 from nfe
LANGUAGE_ANALYZER_MAP = {
"cs": "czech",
"da": "danish",
"de": "german",
"en": "english",
"es": "spanish",
"fr": "french",
"it": "italian",
"nl": "dutch",
"pt": "portuguese",
"ro": "romanian",
"ja": "cjk_search",
"zh": "cjk_search",
"ar": "arabic_search",
"hi": "indic_search",
"ru": "russian",
"pl": "standard",
"kk": "standard",
}
def _lang_analyzer(lang_code: str) -> str:
base = lang_code.split("-")[0].lower()
return LANGUAGE_ANALYZER_MAP.get(base, "icu_query")
class ActiveOnlyMixin:
def get_queryset(self) -> QuerySet[Any]:
return super().get_queryset().filter(is_active=True) # type: ignore [no-any-return, misc]
def should_index_object(self, obj) -> bool: # type: ignore [no-untyped-def]
return getattr(obj, "is_active", False)
COMMON_ANALYSIS = {
"char_filter": {
"icu_nfkc_cf": {"type": "icu_normalizer", "name": "nfkc_cf"},
},
"filter": {
"edge_ngram_filter": {"type": "edge_ngram", "min_gram": 1, "max_gram": 20},
"ngram_filter": {"type": "ngram", "min_gram": 2, "max_gram": 20},
"cjk_bigram": {"type": "cjk_bigram"},
"icu_folding": {"type": "icu_folding"},
"double_metaphone": {"type": "phonetic", "encoder": "double_metaphone", "replace": False},
"arabic_norm": {"type": "arabic_normalization"},
"indic_norm": {"type": "indic_normalization"},
"icu_any_latin": {"type": "icu_transform", "id": "Any-Latin"},
"icu_latin_ascii": {"type": "icu_transform", "id": "Latin-ASCII"},
"icu_ru_latin_bgn": {"type": "icu_transform", "id": "Russian-Latin/BGN"},
},
"analyzer": {
"icu_query": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding"],
},
"autocomplete": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding", "edge_ngram_filter"],
},
"autocomplete_search": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding"],
},
"name_ngram": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding", "ngram_filter"],
},
"name_phonetic": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding", "double_metaphone"],
},
"cjk_search": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding", "cjk_bigram"],
},
"arabic_search": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding", "arabic_norm"],
},
"indic_search": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": ["lowercase", "icu_folding", "indic_norm"],
},
"translit_index": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": [
"icu_any_latin",
"icu_ru_latin_bgn",
"icu_latin_ascii",
"lowercase",
"icu_folding",
"double_metaphone",
],
},
"translit_query": {
"type": "custom",
"char_filter": ["icu_nfkc_cf"],
"tokenizer": "icu_tokenizer",
"filter": [
"icu_any_latin",
"icu_ru_latin_bgn",
"icu_latin_ascii",
"lowercase",
"icu_folding",
"double_metaphone",
],
},
},
"normalizer": {
"lc_norm": {
"type": "custom",
"filter": ["lowercase", "icu_folding"],
}
},
}
def add_multilang_fields(cls: Any) -> None:
for code, _lang in settings.LANGUAGES:
lc = code.replace("-", "_").lower()
name_field = f"name_{lc}"
setattr(
cls,
name_field,
fields.TextField(
attr=name_field,
analyzer=_lang_analyzer(code),
copy_to="name",
fields={
"raw": fields.KeywordField(ignore_above=256),
"ngram": fields.TextField(analyzer="name_ngram", search_analyzer="icu_query"),
"phonetic": fields.TextField(analyzer="name_phonetic"),
"translit": fields.TextField(analyzer="translit_index", search_analyzer="translit_query"),
},
),
)
def make_prepare(attr: str) -> Callable[[Any, Any], str]:
return lambda self, instance: getattr(instance, attr, "") or ""
setattr(cls, f"prepare_{name_field}", make_prepare(name_field))
desc_field = f"description_{lc}"
setattr(
cls,
desc_field,
fields.TextField(
attr=desc_field,
analyzer=_lang_analyzer(code),
copy_to="description",
fields={
"raw": fields.KeywordField(ignore_above=256),
"ngram": fields.TextField(analyzer="name_ngram", search_analyzer="icu_query"),
"phonetic": fields.TextField(analyzer="name_phonetic"),
"translit": fields.TextField(analyzer="translit_index", search_analyzer="translit_query"),
},
),
)
setattr(cls, f"prepare_{desc_field}", make_prepare(desc_field))
def populate_index() -> None:
for doc in registry.get_documents(set(registry.get_models())):
qs = doc().get_indexing_queryset()
doc().update(qs, parallel=True, refresh=True)
return None
def process_system_query(
query: str,
*,
indexes: tuple[str, ...] = ("categories", "brands", "products"),
size_per_index: int = 25,
language_code: str | None = None,
use_transliteration: bool = True,
) -> dict[str, list[dict[str, Any]]]:
if not query:
raise ValueError(_("no search term provided."))
q = query.strip()
base = (language_code or "").split("-")[0].lower() if language_code else ""
is_cjk = base in {"ja", "zh"}
is_rtl_or_indic = base in {"ar", "hi"}
fields_all = [f for f in SMART_FIELDS if not f.startswith(("sku", "partnumber"))]
if not use_transliteration:
fields_all = [f for f in fields_all if ".translit" not in f]
if is_cjk or is_rtl_or_indic:
fields_all = [f for f in fields_all if ".phonetic" not in f]
fields_all = [
f.replace("ngram^6", "ngram^8").replace("ngram^5", "ngram^7").replace("ngram^3", "ngram^4")
for f in fields_all
]
fuzzy = None if (is_cjk or is_rtl_or_indic) else "AUTO:5,8"
mm = Q(
"multi_match",
query=q,
fields=fields_all,
operator="and",
type="most_fields",
tie_breaker=0.2,
**({"fuzziness": fuzzy} if fuzzy else {}),
)
results: dict[str, list[dict[str, Any]]] = {idx: [] for idx in indexes}
for idx in indexes:
s = Search(index=[idx]).query(mm).extra(size=size_per_index, track_total_hits=False)
resp = s.execute()
for h in resp.hits:
name = getattr(h, "name", None) or getattr(h, "title", None) or "N/A"
results[idx].append(
{
"id": getattr(h, "uuid", None) or h.meta.id,
"name": name,
"slug": getattr(h, "slug", ""),
"score": getattr(h.meta, "score", None),
}
)
return results