diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index dc400dafbb..a81e7dedb9 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -8,3 +8,4 @@ class Field(Enum): VECTOR = "vector" TEXT_KEY = "text" PRIMARY_KEY = "id" + BM25_KEY = "bm25_ef" diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 0586e279d3..a958b91199 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -2,8 +2,12 @@ import logging from typing import Any, Optional from uuid import uuid4 +from langdetect import detect, LangDetectException +from milvus_model.sparse import BM25EmbeddingFunction +from milvus_model.sparse.bm25.tokenizers import build_default_analyzer from pydantic import BaseModel, root_validator from pymilvus import MilvusClient, MilvusException, connections +from pymilvus.milvus_client import IndexParams from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector @@ -12,6 +16,19 @@ from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) +LANGUAGE_CODE_MAP = { + 'en': 'en', # English + 'de': 'de', # German + 'fr': 'fr', # French + 'ru': 'ru', # Russian + 'es': 'sp', # Spanish (langdetect returns 'es' for Spanish) + 'it': 'it', # Italian + 'pt': 'pt', # Portuguese + 'zh-cn': 'zh', # Chinese (Simplified) + 'zh-tw': 'zh', # Chinese (Traditional) + 'ja': 'jp', # Japanese + 'ko': 'kr', # Korean +} class MilvusConfig(BaseModel): host: str @@ -70,9 +87,11 @@ class MilvusVector(BaseVector): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): insert_dict_list = [] for i in range(len(documents)): + bm25_ef = self._bm25_document_encode(documents[i].page_content) insert_dict = { Field.CONTENT_KEY.value: documents[i].page_content, Field.VECTOR.value: embeddings[i], + Field.BM25_KEY.value: bm25_ef, Field.METADATA_KEY.value: documents[i].metadata } insert_dict_list.append(insert_dict) @@ -171,7 +190,6 @@ class MilvusVector(BaseVector): result = self._client.query(collection_name=self._collection_name, filter=f'metadata["doc_id"] == "{id}"', output_fields=["id"]) - return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: @@ -180,6 +198,7 @@ class MilvusVector(BaseVector): results = self._client.search(collection_name=self._collection_name, data=[query_vector], limit=kwargs.get('top_k', 4), + anns_field=Field.VECTOR.value, output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], ) # Organize results. @@ -195,7 +214,18 @@ class MilvusVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - # milvus/zilliz doesn't support bm25 search + bm25_ef = self._bm25_query_encode(query) + search_params = { + "metric_type": "IP", + "params": {"drop_ratio_search": 0.2}, # the ratio of small vector values to be dropped during search. + } + results = self._client.search(collection_name=self._collection_name, + data=bm25_ef, + limit=kwargs.get('top_k', 4), + anns_field=Field.BM25_KEY.value, + output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value], + search_params=search_params + ) return [] def create_collection( @@ -229,6 +259,10 @@ class MilvusVector(BaseVector): fields.append( FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) ) + # Create the bm25 field + fields.append( + FieldSchema(Field.BM25_KEY.value, DataType.SPARSE_FLOAT_VECTOR) + ) # Create the primary key field fields.append( FieldSchema( @@ -247,12 +281,23 @@ class MilvusVector(BaseVector): self._fields.append(x.name) # Since primary field is auto-id, no need to track it self._fields.remove(Field.PRIMARY_KEY.value) - + # Create Index params for the collection + index_params_obj = IndexParams() + index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params) + index_params_obj.add_index( + field_name=Field.BM25_KEY.value, + index_name=f'{self._collection_name}_sparse_index', + index_type="SPARSE_INVERTED_INDEX", + # the type of index to be created. set to `SPARSE_INVERTED_INDEX` or `SPARSE_WAND`. + metric_type="IP", + # the metric type to be used for the index. Currently, only `IP` (Inner Product) is supported. + params={"drop_ratio_build": 0.2}, # the ratio of small vector values to be dropped during indexing. + ) # Create the collection collection_name = self._collection_name - self._client.create_collection_with_schema(collection_name=collection_name, - schema=schema, index_param=index_params, - consistency_level=self._consistency_level) + self._client.create_collection(collection_name=collection_name, + schema=schema, index_params=index_params_obj, + consistency_level=self._consistency_level) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _init_client(self, config) -> MilvusClient: if config.secure: @@ -261,3 +306,25 @@ class MilvusVector(BaseVector): uri = "http://" + str(config.host) + ":" + str(config.port) client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database) return client + + def _bm25_document_encode(self, text: str): + language = self._get_text_language(text) + analyzer = build_default_analyzer(language=language) + bm25_ef = BM25EmbeddingFunction(analyzer) + bm25_ef.fit([text]) + docs_embeddings = bm25_ef.encode_documents([text]) + return docs_embeddings + + def _bm25_query_encode(self, text: str): + language = self._get_text_language(text) + analyzer = build_default_analyzer(language=language) + bm25_ef = BM25EmbeddingFunction(analyzer) + docs_embeddings = bm25_ef.encode_queries([text]) + return docs_embeddings + + def _get_text_language(self, text: str) -> str: + try: + detected_language = detect(text) + return LANGUAGE_CODE_MAP.get(detected_language, "en") + except LangDetectException: + return "en" diff --git a/api/requirements.txt b/api/requirements.txt index 39cbfaad99..194809b5d9 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -56,7 +56,7 @@ xinference-client==0.9.4 safetensors~=0.4.3 zhipuai==1.0.7 werkzeug~=3.0.1 -pymilvus==2.3.1 +pymilvus==2.4.1 qdrant-client==1.7.3 cohere~=5.2.4 pyyaml~=6.0.1 diff --git a/docker/docker-compose.milvus.yaml b/docker/docker-compose.milvus.yaml index c422efbf4b..7b4af8f965 100644 --- a/docker/docker-compose.milvus.yaml +++ b/docker/docker-compose.milvus.yaml @@ -38,7 +38,7 @@ services: milvus-standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.3.1 + image: milvusdb/milvus:v2.4.1 command: ["milvus", "run", "standalone"] environment: ETCD_ENDPOINTS: etcd:2379