feat: support huawei cloud vector database
This commit is contained in:
parent
fa275b9fdc
commit
8efd12b199
@ -3,6 +3,9 @@ import logging
|
|||||||
import ssl
|
import ssl
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.vdb.field import Field
|
from core.rag.datasource.vdb.field import Field
|
||||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||||
@ -10,10 +13,8 @@ from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
|||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.embedding.embedding_base import Embeddings
|
from core.rag.embedding.embedding_base import Embeddings
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from pydantic import BaseModel, model_validator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -49,10 +50,9 @@ class HuaweiCloudVectorConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class HuaweiCloudVector(BaseVector):
|
class HuaweiCloudVector(BaseVector):
|
||||||
def __init__(self, index_name: str, config: HuaweiCloudVectorConfig, attributes: list):
|
def __init__(self, index_name: str, config: HuaweiCloudVectorConfig):
|
||||||
super().__init__(index_name.lower())
|
super().__init__(index_name.lower())
|
||||||
self._client = Elasticsearch(**config.to_elasticsearch_params())
|
self._client = Elasticsearch(**config.to_elasticsearch_params())
|
||||||
self._attributes = attributes
|
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return VectorType.HUAWEI_CLOUD
|
return VectorType.HUAWEI_CLOUD
|
||||||
@ -103,7 +103,7 @@ class HuaweiCloudVector(BaseVector):
|
|||||||
"topk": top_k,
|
"topk": top_k,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
results = self._client.search(index=self._collection_name, body=query, request_timeout=120)
|
results = self._client.search(index=self._collection_name, body=query, request_timeout=120)
|
||||||
@ -152,10 +152,10 @@ class HuaweiCloudVector(BaseVector):
|
|||||||
self.add_texts(texts, embeddings, **kwargs)
|
self.add_texts(texts, embeddings, **kwargs)
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self,
|
self,
|
||||||
embeddings: list[list[float]],
|
embeddings: list[list[float]],
|
||||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||||
index_params: Optional[dict] = None,
|
index_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||||
with redis_client.lock(lock_name, timeout=20):
|
with redis_client.lock(lock_name, timeout=20):
|
||||||
@ -176,7 +176,7 @@ class HuaweiCloudVector(BaseVector):
|
|||||||
"algorithm": "GRAPH",
|
"algorithm": "GRAPH",
|
||||||
"metric": "cosine",
|
"metric": "cosine",
|
||||||
"neighbors": 32,
|
"neighbors": 32,
|
||||||
"efc": 128
|
"efc": 128,
|
||||||
},
|
},
|
||||||
Field.METADATA_KEY.value: {
|
Field.METADATA_KEY.value: {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -186,9 +186,7 @@ class HuaweiCloudVector(BaseVector):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
settings = {
|
settings = {"index.vector": True}
|
||||||
"index.vector": True
|
|
||||||
}
|
|
||||||
self._client.indices.create(index=self._collection_name, mappings=mappings, settings=settings)
|
self._client.indices.create(index=self._collection_name, mappings=mappings, settings=settings)
|
||||||
|
|
||||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||||
@ -211,5 +209,4 @@ class HuaweiCloudVectorFactory(AbstractVectorFactory):
|
|||||||
username=dify_config.HUAWEI_CLOUD_USER,
|
username=dify_config.HUAWEI_CLOUD_USER,
|
||||||
password=dify_config.HUAWEI_CLOUD_PASSWORD,
|
password=dify_config.HUAWEI_CLOUD_PASSWORD,
|
||||||
),
|
),
|
||||||
attributes=[],
|
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
from api.core.rag.datasource.vdb.field import Field
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
|
|
||||||
|
class MockIndicesClient:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create(self, index, mappings, settings):
|
||||||
|
return {"acknowledge": True}
|
||||||
|
|
||||||
|
def refresh(self, index):
|
||||||
|
return {"acknowledge": True}
|
||||||
|
|
||||||
|
def delete(self, index):
|
||||||
|
return {"acknowledge": True}
|
||||||
|
|
||||||
|
def exists(self, index):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class MockClient:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.indices = MockIndicesClient()
|
||||||
|
|
||||||
|
def index(self, **kwargs):
|
||||||
|
return {"acknowledge": True}
|
||||||
|
|
||||||
|
def exists(self, **kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete(self, **kwargs):
|
||||||
|
return {"acknowledge": True}
|
||||||
|
|
||||||
|
def search(self, **kwargs):
|
||||||
|
return {
|
||||||
|
"took": 1,
|
||||||
|
"hits": {
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"_source": {
|
||||||
|
Field.CONTENT_KEY.value: "abcdef",
|
||||||
|
Field.VECTOR.value: [1, 2],
|
||||||
|
Field.METADATA_KEY.value: {},
|
||||||
|
},
|
||||||
|
"_score": 1.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_source": {
|
||||||
|
Field.CONTENT_KEY.value: "123456",
|
||||||
|
Field.VECTOR.value: [2, 2],
|
||||||
|
Field.METADATA_KEY.value: {},
|
||||||
|
},
|
||||||
|
"_score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_source": {
|
||||||
|
Field.CONTENT_KEY.value: "a1b2c3",
|
||||||
|
Field.VECTOR.value: [3, 2],
|
||||||
|
Field.METADATA_KEY.value: {},
|
||||||
|
},
|
||||||
|
"_score": 0.8,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_client_mock(request, monkeypatch: MonkeyPatch):
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.setattr(Elasticsearch, "__init__", MockClient.__init__)
|
||||||
|
monkeypatch.setattr(Elasticsearch, "index", MockClient.index)
|
||||||
|
monkeypatch.setattr(Elasticsearch, "exists", MockClient.exists)
|
||||||
|
monkeypatch.setattr(Elasticsearch, "delete", MockClient.delete)
|
||||||
|
monkeypatch.setattr(Elasticsearch, "search", MockClient.search)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if MOCK:
|
||||||
|
monkeypatch.undo()
|
0
api/tests/integration_tests/vdb/huawei/__init__.py
Normal file
0
api/tests/integration_tests/vdb/huawei/__init__.py
Normal file
28
api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py
Normal file
28
api/tests/integration_tests/vdb/huawei/test_huawei_cloud.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||||
|
from tests.integration_tests.vdb.__mock.huaweicloudvectordb import setup_client_mock
|
||||||
|
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||||
|
|
||||||
|
|
||||||
|
class HuaweiCloudVectorTest(AbstractVectorTest):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.vector = HuaweiCloudVector(
|
||||||
|
"dify",
|
||||||
|
HuaweiCloudVectorConfig(
|
||||||
|
hosts="https://127.0.0.1:9200",
|
||||||
|
username="dify",
|
||||||
|
password="dify",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def search_by_vector(self):
|
||||||
|
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||||
|
assert len(hits_by_vector) == 3
|
||||||
|
|
||||||
|
def search_by_full_text(self):
|
||||||
|
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||||
|
assert len(hits_by_full_text) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_huawei_cloud_vector(setup_mock_redis, setup_client_mock):
|
||||||
|
HuaweiCloudVectorTest().run_all_tests()
|
@ -15,3 +15,4 @@ pytest api/tests/integration_tests/vdb/chroma \
|
|||||||
api/tests/integration_tests/vdb/couchbase \
|
api/tests/integration_tests/vdb/couchbase \
|
||||||
api/tests/integration_tests/vdb/oceanbase \
|
api/tests/integration_tests/vdb/oceanbase \
|
||||||
api/tests/integration_tests/vdb/tidb_vector \
|
api/tests/integration_tests/vdb/tidb_vector \
|
||||||
|
api/tests/integration_tests/vdb/huawei \
|
||||||
|
Loading…
Reference in New Issue
Block a user