mock tcvectordb's RPCVectorDBClient

This commit is contained in:
wlleiiwang 2025-03-21 14:03:54 +08:00
parent 3440d6cde7
commit 4e2e8b8476

View File

@ -1,58 +1,66 @@
import os import os
from typing import Optional from typing import Optional, List, Dict, Union
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from tcvectordb import VectorDBClient # type: ignore from tcvectordb import RPCVectorDBClient # type: ignore
from tcvectordb.model.collection import FilterIndexConfig
from tcvectordb.model.database import Collection, Database # type: ignore from tcvectordb.model.database import Collection, Database # type: ignore
from tcvectordb.model.document import Document, Filter # type: ignore from tcvectordb.model.document import Document, Filter # type: ignore
from tcvectordb.model.enum import ReadConsistency # type: ignore from tcvectordb.model.enum import ReadConsistency # type: ignore
from tcvectordb.model.index import Index # type: ignore from tcvectordb.model.index import Index, IndexField # type: ignore
from tcvectordb.rpc.model.collection import RPCCollection
from tcvectordb.rpc.model.database import RPCDatabase
from xinference_client.types import Embedding # type: ignore from xinference_client.types import Embedding # type: ignore
class MockTcvectordbClass: class MockTcvectordbClass:
def mock_vector_db_client( def mock_vector_db_client(
self, self,
url=None, url: str,
username="", username='',
key="", key='',
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY, read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5, timeout=10,
adapter: HTTPAdapter = None, adapter: HTTPAdapter = None,
pool_size: int = 2,
proxies: Optional[dict] = None,
password: Optional[str] = None,
**kwargs
): ):
self._conn = None self._conn = None
self._read_consistency = read_consistency self._read_consistency = read_consistency
def list_databases(self) -> list[Database]: def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
return [ return RPCDatabase(
Database( name="dify",
conn=self._conn, read_consistency=self._read_consistency,
read_consistency=self._read_consistency, )
name="dify",
)
]
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]: def exists_collection(self, database_name: str, collection_name: str) -> bool:
return [] return True
def drop_collection(self, name: str, timeout: Optional[float] = None):
return {"code": 0, "msg": "operation success"}
def create_collection( def create_collection(
self, self,
name: str, database_name: str,
collection_name: str,
shard: int, shard: int,
replicas: int, replicas: int,
description: str, description: str = None,
index: Index, index: Index = None,
embedding: Embedding = None, embedding: Embedding = None,
timeout: Optional[float] = None, timeout: float = None,
) -> Collection: ttl_config: dict = None,
return Collection( filter_index_config: FilterIndexConfig = None,
self, indexes: List[IndexField] = None,
name, ) -> RPCCollection:
return RPCCollection(
RPCDatabase(
name="dify",
read_consistency=self._read_consistency,
),
collection_name,
shard, shard,
replicas, replicas,
description, description,
@ -60,19 +68,26 @@ class MockTcvectordbClass:
embedding=embedding, embedding=embedding,
read_consistency=self._read_consistency, read_consistency=self._read_consistency,
timeout=timeout, timeout=timeout,
ttl_config=ttl_config,
filter_index_config=filter_index_config,
indexes=indexes,
) )
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout)
return collection
def collection_upsert( def collection_upsert(
self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs self,
database_name: str,
collection_name: str,
documents: List[Union[Document, Dict]],
timeout: Optional[float] = None,
build_index: bool = True,
**kwargs
): ):
return {"code": 0, "msg": "operation success"} return {"code": 0, "msg": "operation success"}
def collection_search( def collection_search(
self, self,
database_name: str,
collection_name: str,
vectors: list[list[float]], vectors: list[list[float]],
filter: Filter = None, filter: Filter = None,
params=None, params=None,
@ -81,10 +96,12 @@ class MockTcvectordbClass:
output_fields: Optional[list[str]] = None, output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
) -> list[list[dict]]: ) -> list[list[dict]]:
return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]] return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_query( def collection_query(
self, self,
database_name: str,
collection_name: str,
document_ids: Optional[list] = None, document_ids: Optional[list] = None,
retrieve_vector: bool = False, retrieve_vector: bool = False,
limit: Optional[int] = None, limit: Optional[int] = None,
@ -97,12 +114,20 @@ class MockTcvectordbClass:
def collection_delete( def collection_delete(
self, self,
database_name: str,
collection_name: str,
document_ids: Optional[list[str]] = None, document_ids: Optional[list[str]] = None,
filter: Filter = None, filter: Filter = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
): ):
return {"code": 0, "msg": "operation success"} return {"code": 0, "msg": "operation success"}
def drop_collection(self,
database_name: str,
collection_name: str,
timeout: Optional[float] = None) -> Dict:
return {"code": 0, "msg": "operation success"}
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@ -110,16 +135,17 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture @pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK: if MOCK:
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client) monkeypatch.setattr(RPCVectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) monkeypatch.setattr(RPCVectorDBClient,
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) "create_database_if_not_exists",
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) MockTcvectordbClass.create_database_if_not_exists)
monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection) monkeypatch.setattr(RPCVectorDBClient, "exists_collection", MockTcvectordbClass.exists_collection)
monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection) monkeypatch.setattr(RPCVectorDBClient, "create_collection", MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert) monkeypatch.setattr(RPCVectorDBClient, "upsert", MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search) monkeypatch.setattr(RPCVectorDBClient, "search", MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query) monkeypatch.setattr(RPCVectorDBClient, "query", MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete) monkeypatch.setattr(RPCVectorDBClient, "delete", MockTcvectordbClass.collection_delete)
monkeypatch.setattr(RPCVectorDBClient, "drop_collection", MockTcvectordbClass.drop_collection)
yield yield