From 8efd12b199eb5d5c978b536d55a4ccac81863a52 Mon Sep 17 00:00:00 2001 From: elvisliu <719880851@qq.com> Date: Wed, 19 Mar 2025 23:44:40 +0800 Subject: [PATCH] feat: support huawei cloud vector database --- .../vdb/huawei/huawei_cloud_vector.py | 25 +++--- .../vdb/__mock/huaweicloudvectordb.py | 88 +++++++++++++++++++ .../integration_tests/vdb/huawei/__init__.py | 0 .../vdb/huawei/test_huawei_cloud.py | 28 ++++++ dev/pytest/pytest_vdb.sh | 1 + 5 files changed, 128 insertions(+), 14 deletions(-) create mode 100644 api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py create mode 100644 api/tests/integration_tests/vdb/huawei/__init__.py create mode 100644 api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index b8b1aebb96..00d651150a 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -3,6 +3,9 @@ import logging import ssl from typing import Any, Optional +from elasticsearch import Elasticsearch +from pydantic import BaseModel, model_validator + from configs import dify_config from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -10,10 +13,8 @@ from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document -from elasticsearch import Elasticsearch from extensions.ext_redis import redis_client from models.dataset import Dataset -from pydantic import BaseModel, model_validator logger = logging.getLogger(__name__) @@ -49,10 +50,9 @@ class HuaweiCloudVectorConfig(BaseModel): class HuaweiCloudVector(BaseVector): - def __init__(self, index_name: str, config: HuaweiCloudVectorConfig, attributes: list): + def __init__(self, index_name: str, config: HuaweiCloudVectorConfig): super().__init__(index_name.lower()) self._client = Elasticsearch(**config.to_elasticsearch_params()) - self._attributes = attributes def get_type(self) -> str: return VectorType.HUAWEI_CLOUD @@ -103,7 +103,7 @@ class HuaweiCloudVector(BaseVector): "topk": top_k, } } - } + }, } results = self._client.search(index=self._collection_name, body=query, request_timeout=120) @@ -152,10 +152,10 @@ class HuaweiCloudVector(BaseVector): self.add_texts(texts, embeddings, **kwargs) def create_collection( - self, - embeddings: list[list[float]], - metadatas: Optional[list[dict[Any, Any]]] = None, - index_params: Optional[dict] = None, + self, + embeddings: list[list[float]], + metadatas: Optional[list[dict[Any, Any]]] = None, + index_params: Optional[dict] = None, ): lock_name = f"vector_indexing_lock_{self._collection_name}" with redis_client.lock(lock_name, timeout=20): @@ -176,7 +176,7 @@ class HuaweiCloudVector(BaseVector): "algorithm": "GRAPH", "metric": "cosine", "neighbors": 32, - "efc": 128 + "efc": 128, }, Field.METADATA_KEY.value: { "type": "object", @@ -186,9 +186,7 @@ class HuaweiCloudVector(BaseVector): }, } } - settings = { - "index.vector": True - } + settings = {"index.vector": True} self._client.indices.create(index=self._collection_name, mappings=mappings, settings=settings) redis_client.set(collection_exist_cache_key, 1, ex=3600) @@ -211,5 +209,4 @@ class HuaweiCloudVectorFactory(AbstractVectorFactory): username=dify_config.HUAWEI_CLOUD_USER, password=dify_config.HUAWEI_CLOUD_PASSWORD, ), - attributes=[], ) diff --git a/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py b/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py new file mode 100644 index 0000000000..e1aba4e2c1 --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/huaweicloudvectordb.py @@ -0,0 +1,88 @@ +import os + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from api.core.rag.datasource.vdb.field import Field +from elasticsearch import Elasticsearch + + +class MockIndicesClient: + def __init__(self): + pass + + def create(self, index, mappings, settings): + return {"acknowledge": True} + + def refresh(self, index): + return {"acknowledge": True} + + def delete(self, index): + return {"acknowledge": True} + + def exists(self, index): + return True + + +class MockClient: + def __init__(self, **kwargs): + self.indices = MockIndicesClient() + + def index(self, **kwargs): + return {"acknowledge": True} + + def exists(self, **kwargs): + return True + + def delete(self, **kwargs): + return {"acknowledge": True} + + def search(self, **kwargs): + return { + "took": 1, + "hits": { + "hits": [ + { + "_source": { + Field.CONTENT_KEY.value: "abcdef", + Field.VECTOR.value: [1, 2], + Field.METADATA_KEY.value: {}, + }, + "_score": 1.0, + }, + { + "_source": { + Field.CONTENT_KEY.value: "123456", + Field.VECTOR.value: [2, 2], + Field.METADATA_KEY.value: {}, + }, + "_score": 0.9, + }, + { + "_source": { + Field.CONTENT_KEY.value: "a1b2c3", + Field.VECTOR.value: [3, 2], + Field.METADATA_KEY.value: {}, + }, + "_score": 0.8, + }, + ] + }, + } + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_client_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__) + monkeypatch.setattr(Elasticsearch, "index", MockClient.index) + monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists) + monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete) + monkeypatch.setattr(Elasticsearch, "search", MockClient.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/huawei/__init__.py b/api/tests/integration_tests/vdb/huawei/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py b/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py new file mode 100644 index 0000000000..943b2bc877 --- /dev/null +++ b/api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py @@ -0,0 +1,28 @@ +from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig +from tests.integration_tests.vdb.__mock.huaweicloudvectordb import setup_client_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + + +class HuaweiCloudVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = HuaweiCloudVector( + "dify", + HuaweiCloudVectorConfig( + hosts="https://127.0.0.1:9200", + username="dify", + password="dify", + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 3 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 3 + + +def test_huawei_cloud_vector(setup_mock_redis, setup_client_mock): + HuaweiCloudVectorTest().run_all_tests() diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index c68a94c79b..dd03ca3514 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -15,3 +15,4 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/couchbase \ api/tests/integration_tests/vdb/oceanbase \ api/tests/integration_tests/vdb/tidb_vector \ + api/tests/integration_tests/vdb/huawei \