From d97d3ff5fc43a066c605a04c4ac5d522155527dd Mon Sep 17 00:00:00 2001 From: zhuhao <37029601+hwzhuhao@users.noreply.github.com> Date: Sat, 12 Oct 2024 23:58:41 +0800 Subject: [PATCH] chore: add abstract decorator and output log when query embedding fails (#9264) --- api/core/embedding/cached_embedding.py | 7 ++++++- api/core/rag/datasource/keyword/keyword_base.py | 2 ++ .../datasource/vdb/elasticsearch/elasticsearch_vector.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 75219051cd..31d2171e72 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -5,6 +5,7 @@ from typing import Optional, cast import numpy as np from sqlalchemy.exc import IntegrityError +from configs import dify_config from core.embedding.embedding_constant import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey @@ -110,6 +111,8 @@ class CacheEmbedding(Embeddings): embedding_results = embedding_result.embeddings[0] embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() except Exception as ex: + if dify_config.DEBUG: + logging.exception(f"Failed to embed query text: {ex}") raise ex try: @@ -122,6 +125,8 @@ class CacheEmbedding(Embeddings): encoded_str = encoded_vector.decode("utf-8") redis_client.setex(embedding_cache_key, 600, encoded_str) except Exception as ex: - logging.exception("Failed to add embedding to redis %s", ex) + if dify_config.DEBUG: + logging.exception("Failed to add embedding to redis %s", ex) + raise ex return embedding_results diff --git a/api/core/rag/datasource/keyword/keyword_base.py b/api/core/rag/datasource/keyword/keyword_base.py index 4b9ec460e6..be00687abd 100644 --- a/api/core/rag/datasource/keyword/keyword_base.py +++ b/api/core/rag/datasource/keyword/keyword_base.py @@ -27,9 +27,11 @@ class BaseKeyword(ABC): def delete_by_ids(self, ids: list[str]) -> None: raise NotImplementedError + @abstractmethod def delete(self) -> None: raise NotImplementedError + @abstractmethod def search(self, query: str, **kwargs: Any) -> list[Document]: raise NotImplementedError diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f585e12b2e..66bc31a4bf 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -77,7 +77,7 @@ class ElasticSearchVector(BaseVector): raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") def get_type(self) -> str: - return "elasticsearch" + return VectorType.ELASTICSEARCH def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): uuids = self._get_uuids(documents)