diff --git a/api/tests/integration_tests/vdb/__mock/tcvectordb.py b/api/tests/integration_tests/vdb/__mock/tcvectordb.py index 68a1e290ad..c1d53296fa 100644 --- a/api/tests/integration_tests/vdb/__mock/tcvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/tcvectordb.py @@ -1,58 +1,66 @@ import os -from typing import Optional +from typing import Optional, List, Dict, Union import pytest from _pytest.monkeypatch import MonkeyPatch from requests.adapters import HTTPAdapter -from tcvectordb import VectorDBClient # type: ignore +from tcvectordb import RPCVectorDBClient # type: ignore +from tcvectordb.model.collection import FilterIndexConfig from tcvectordb.model.database import Collection, Database # type: ignore from tcvectordb.model.document import Document, Filter # type: ignore from tcvectordb.model.enum import ReadConsistency # type: ignore -from tcvectordb.model.index import Index # type: ignore +from tcvectordb.model.index import Index, IndexField # type: ignore +from tcvectordb.rpc.model.collection import RPCCollection +from tcvectordb.rpc.model.database import RPCDatabase from xinference_client.types import Embedding # type: ignore class MockTcvectordbClass: def mock_vector_db_client( self, - url=None, - username="", - key="", + url: str, + username='', + key='', read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, - timeout=5, + timeout=10, adapter: HTTPAdapter = None, + pool_size: int = 2, + proxies: Optional[dict] = None, + password: Optional[str] = None, + **kwargs ): self._conn = None self._read_consistency = read_consistency - def list_databases(self) -> list[Database]: - return [ - Database( - conn=self._conn, - read_consistency=self._read_consistency, - name="dify", - ) - ] + def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase: + return RPCDatabase( + name="dify", + read_consistency=self._read_consistency, + ) - def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: - return [] - - def drop_collection(self, name: str, timeout: Optional[float] = None): - return {"code": 0, "msg": "operation success"} + def exists_collection(self, database_name: str, collection_name: str) -> bool: + return True def create_collection( self, - name: str, + database_name: str, + collection_name: str, shard: int, replicas: int, - description: str, - index: Index, + description: str = None, + index: Index = None, embedding: Embedding = None, - timeout: Optional[float] = None, - ) -> Collection: - return Collection( - self, - name, + timeout: float = None, + ttl_config: dict = None, + filter_index_config: FilterIndexConfig = None, + indexes: List[IndexField] = None, + ) -> RPCCollection: + return RPCCollection( + RPCDatabase( + name="dify", + read_consistency=self._read_consistency, + ), + collection_name, shard, replicas, description, @@ -60,19 +68,26 @@ class MockTcvectordbClass: embedding=embedding, read_consistency=self._read_consistency, timeout=timeout, + ttl_config=ttl_config, + filter_index_config=filter_index_config, + indexes=indexes, ) - def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection: - collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout) - return collection - def collection_upsert( - self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs + self, + database_name: str, + collection_name: str, + documents: List[Union[Document, Dict]], + timeout: Optional[float] = None, + build_index: bool = True, + **kwargs ): return {"code": 0, "msg": "operation success"} def collection_search( self, + database_name: str, + collection_name: str, vectors: list[list[float]], filter: Filter = None, params=None, @@ -81,10 +96,12 @@ class MockTcvectordbClass: output_fields: Optional[list[str]] = None, timeout: Optional[float] = None, ) -> list[list[dict]]: - return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] + return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]] def collection_query( self, + database_name: str, + collection_name: str, document_ids: Optional[list] = None, retrieve_vector: bool = False, limit: Optional[int] = None, @@ -97,12 +114,20 @@ class MockTcvectordbClass: def collection_delete( self, + database_name: str, + collection_name: str, document_ids: Optional[list[str]] = None, filter: Filter = None, timeout: Optional[float] = None, ): return {"code": 0, "msg": "operation success"} + def drop_collection(self, + database_name: str, + collection_name: str, + timeout: Optional[float] = None) -> Dict: + return {"code": 0, "msg": "operation success"} + MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @@ -110,16 +135,17 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @pytest.fixture def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): if MOCK: - monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) - monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) - monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) - monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) - monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) - monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) - monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) - monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) - monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) - monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) + monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) + monkeypatch.setattr(RPCVectorDBClient, + "create_database_if_not_exists", + MockTcvectordbClass.create_database_if_not_exists) + monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection) + monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection) + monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert) + monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search) + monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query) + monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete) + monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection) yield