Compare commits

...

1 Commits

Author SHA1 Message Date
jyong
51f5796908 milvus 2.4 2024-05-21 21:47:27 +08:00
4 changed files with 76 additions and 8 deletions

View File

@ -8,3 +8,4 @@ class Field(Enum):
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
BM25_KEY = "bm25_ef"

View File

@ -2,8 +2,12 @@ import logging
from typing import Any, Optional
from uuid import uuid4
from langdetect import detect, LangDetectException
from milvus_model.sparse import BM25EmbeddingFunction
from milvus_model.sparse.bm25.tokenizers import build_default_analyzer
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from pymilvus.milvus_client import IndexParams
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
@ -12,6 +16,19 @@ from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
LANGUAGE_CODE_MAP = {
'en': 'en', # English
'de': 'de', # German
'fr': 'fr', # French
'ru': 'ru', # Russian
'es': 'sp', # Spanish (langdetect returns 'es' for Spanish)
'it': 'it', # Italian
'pt': 'pt', # Portuguese
'zh-cn': 'zh', # Chinese (Simplified)
'zh-tw': 'zh', # Chinese (Traditional)
'ja': 'jp', # Japanese
'ko': 'kr', # Korean
}
class MilvusConfig(BaseModel):
host: str
@ -70,9 +87,11 @@ class MilvusVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
insert_dict_list = []
for i in range(len(documents)):
bm25_ef = self._bm25_document_encode(documents[i].page_content)
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.BM25_KEY.value: bm25_ef,
Field.METADATA_KEY.value: documents[i].metadata
}
insert_dict_list.append(insert_dict)
@ -171,7 +190,6 @@ class MilvusVector(BaseVector):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
@ -180,6 +198,7 @@ class MilvusVector(BaseVector):
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
anns_field=Field.VECTOR.value,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
@ -195,7 +214,18 @@ class MilvusVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
bm25_ef = self._bm25_query_encode(query)
search_params = {
"metric_type": "IP",
"params": {"drop_ratio_search": 0.2}, # the ratio of small vector values to be dropped during search.
}
results = self._client.search(collection_name=self._collection_name,
data=bm25_ef,
limit=kwargs.get('top_k', 4),
anns_field=Field.BM25_KEY.value,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
search_params=search_params
)
return []
def create_collection(
@ -229,6 +259,10 @@ class MilvusVector(BaseVector):
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
# Create the bm25 field
fields.append(
FieldSchema(Field.BM25_KEY.value, DataType.SPARSE_FLOAT_VECTOR)
)
# Create the primary key field
fields.append(
FieldSchema(
@ -247,12 +281,23 @@ class MilvusVector(BaseVector):
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create Index params for the collection
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
index_params_obj.add_index(
field_name=Field.BM25_KEY.value,
index_name=f'{self._collection_name}_sparse_index',
index_type="SPARSE_INVERTED_INDEX",
# the type of index to be created. set to `SPARSE_INVERTED_INDEX` or `SPARSE_WAND`.
metric_type="IP",
# the metric type to be used for the index. Currently, only `IP` (Inner Product) is supported.
params={"drop_ratio_build": 0.2}, # the ratio of small vector values to be dropped during indexing.
)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
self._client.create_collection(collection_name=collection_name,
schema=schema, index_params=index_params_obj,
consistency_level=self._consistency_level)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
if config.secure:
@ -261,3 +306,25 @@ class MilvusVector(BaseVector):
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
return client
def _bm25_document_encode(self, text: str):
language = self._get_text_language(text)
analyzer = build_default_analyzer(language=language)
bm25_ef = BM25EmbeddingFunction(analyzer)
bm25_ef.fit([text])
docs_embeddings = bm25_ef.encode_documents([text])
return docs_embeddings
def _bm25_query_encode(self, text: str):
language = self._get_text_language(text)
analyzer = build_default_analyzer(language=language)
bm25_ef = BM25EmbeddingFunction(analyzer)
docs_embeddings = bm25_ef.encode_queries([text])
return docs_embeddings
def _get_text_language(self, text: str) -> str:
try:
detected_language = detect(text)
return LANGUAGE_CODE_MAP.get(detected_language, "en")
except LangDetectException:
return "en"

View File

@ -56,7 +56,7 @@ xinference-client==0.9.4
safetensors~=0.4.3
zhipuai==1.0.7
werkzeug~=3.0.1
pymilvus==2.3.1
pymilvus==2.4.1
qdrant-client==1.7.3
cohere~=5.2.4
pyyaml~=6.0.1

View File

@ -38,7 +38,7 @@ services:
milvus-standalone:
container_name: milvus-standalone
image: milvusdb/milvus:v2.3.1
image: milvusdb/milvus:v2.4.1
command: ["milvus", "run", "standalone"]
environment:
ETCD_ENDPOINTS: etcd:2379