milvus 2.4
This commit is contained in:
parent
c255a20d7c
commit
51f5796908
@ -8,3 +8,4 @@ class Field(Enum):
|
||||
VECTOR = "vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = "id"
|
||||
BM25_KEY = "bm25_ef"
|
||||
|
@ -2,8 +2,12 @@ import logging
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from langdetect import detect, LangDetectException
|
||||
from milvus_model.sparse import BM25EmbeddingFunction
|
||||
from milvus_model.sparse.bm25.tokenizers import build_default_analyzer
|
||||
from pydantic import BaseModel, root_validator
|
||||
from pymilvus import MilvusClient, MilvusException, connections
|
||||
from pymilvus.milvus_client import IndexParams
|
||||
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@ -12,6 +16,19 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LANGUAGE_CODE_MAP = {
|
||||
'en': 'en', # English
|
||||
'de': 'de', # German
|
||||
'fr': 'fr', # French
|
||||
'ru': 'ru', # Russian
|
||||
'es': 'sp', # Spanish (langdetect returns 'es' for Spanish)
|
||||
'it': 'it', # Italian
|
||||
'pt': 'pt', # Portuguese
|
||||
'zh-cn': 'zh', # Chinese (Simplified)
|
||||
'zh-tw': 'zh', # Chinese (Traditional)
|
||||
'ja': 'jp', # Japanese
|
||||
'ko': 'kr', # Korean
|
||||
}
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
host: str
|
||||
@ -70,9 +87,11 @@ class MilvusVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
insert_dict_list = []
|
||||
for i in range(len(documents)):
|
||||
bm25_ef = self._bm25_document_encode(documents[i].page_content)
|
||||
insert_dict = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i],
|
||||
Field.BM25_KEY.value: bm25_ef,
|
||||
Field.METADATA_KEY.value: documents[i].metadata
|
||||
}
|
||||
insert_dict_list.append(insert_dict)
|
||||
@ -171,7 +190,6 @@ class MilvusVector(BaseVector):
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] == "{id}"',
|
||||
output_fields=["id"])
|
||||
|
||||
return len(result) > 0
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
@ -180,6 +198,7 @@ class MilvusVector(BaseVector):
|
||||
results = self._client.search(collection_name=self._collection_name,
|
||||
data=[query_vector],
|
||||
limit=kwargs.get('top_k', 4),
|
||||
anns_field=Field.VECTOR.value,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
)
|
||||
# Organize results.
|
||||
@ -195,7 +214,18 @@ class MilvusVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
bm25_ef = self._bm25_query_encode(query)
|
||||
search_params = {
|
||||
"metric_type": "IP",
|
||||
"params": {"drop_ratio_search": 0.2}, # the ratio of small vector values to be dropped during search.
|
||||
}
|
||||
results = self._client.search(collection_name=self._collection_name,
|
||||
data=bm25_ef,
|
||||
limit=kwargs.get('top_k', 4),
|
||||
anns_field=Field.BM25_KEY.value,
|
||||
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
||||
search_params=search_params
|
||||
)
|
||||
return []
|
||||
|
||||
def create_collection(
|
||||
@ -229,6 +259,10 @@ class MilvusVector(BaseVector):
|
||||
fields.append(
|
||||
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
# Create the bm25 field
|
||||
fields.append(
|
||||
FieldSchema(Field.BM25_KEY.value, DataType.SPARSE_FLOAT_VECTOR)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
@ -247,11 +281,22 @@ class MilvusVector(BaseVector):
|
||||
self._fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self._fields.remove(Field.PRIMARY_KEY.value)
|
||||
|
||||
# Create Index params for the collection
|
||||
index_params_obj = IndexParams()
|
||||
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
||||
index_params_obj.add_index(
|
||||
field_name=Field.BM25_KEY.value,
|
||||
index_name=f'{self._collection_name}_sparse_index',
|
||||
index_type="SPARSE_INVERTED_INDEX",
|
||||
# the type of index to be created. set to `SPARSE_INVERTED_INDEX` or `SPARSE_WAND`.
|
||||
metric_type="IP",
|
||||
# the metric type to be used for the index. Currently, only `IP` (Inner Product) is supported.
|
||||
params={"drop_ratio_build": 0.2}, # the ratio of small vector values to be dropped during indexing.
|
||||
)
|
||||
# Create the collection
|
||||
collection_name = self._collection_name
|
||||
self._client.create_collection_with_schema(collection_name=collection_name,
|
||||
schema=schema, index_param=index_params,
|
||||
self._client.create_collection(collection_name=collection_name,
|
||||
schema=schema, index_params=index_params_obj,
|
||||
consistency_level=self._consistency_level)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
def _init_client(self, config) -> MilvusClient:
|
||||
@ -261,3 +306,25 @@ class MilvusVector(BaseVector):
|
||||
uri = "http://" + str(config.host) + ":" + str(config.port)
|
||||
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
|
||||
return client
|
||||
|
||||
def _bm25_document_encode(self, text: str):
|
||||
language = self._get_text_language(text)
|
||||
analyzer = build_default_analyzer(language=language)
|
||||
bm25_ef = BM25EmbeddingFunction(analyzer)
|
||||
bm25_ef.fit([text])
|
||||
docs_embeddings = bm25_ef.encode_documents([text])
|
||||
return docs_embeddings
|
||||
|
||||
def _bm25_query_encode(self, text: str):
|
||||
language = self._get_text_language(text)
|
||||
analyzer = build_default_analyzer(language=language)
|
||||
bm25_ef = BM25EmbeddingFunction(analyzer)
|
||||
docs_embeddings = bm25_ef.encode_queries([text])
|
||||
return docs_embeddings
|
||||
|
||||
def _get_text_language(self, text: str) -> str:
|
||||
try:
|
||||
detected_language = detect(text)
|
||||
return LANGUAGE_CODE_MAP.get(detected_language, "en")
|
||||
except LangDetectException:
|
||||
return "en"
|
||||
|
@ -56,7 +56,7 @@ xinference-client==0.9.4
|
||||
safetensors~=0.4.3
|
||||
zhipuai==1.0.7
|
||||
werkzeug~=3.0.1
|
||||
pymilvus==2.3.1
|
||||
pymilvus==2.4.1
|
||||
qdrant-client==1.7.3
|
||||
cohere~=5.2.4
|
||||
pyyaml~=6.0.1
|
||||
|
@ -38,7 +38,7 @@ services:
|
||||
|
||||
milvus-standalone:
|
||||
container_name: milvus-standalone
|
||||
image: milvusdb/milvus:v2.3.1
|
||||
image: milvusdb/milvus:v2.4.1
|
||||
command: ["milvus", "run", "standalone"]
|
||||
environment:
|
||||
ETCD_ENDPOINTS: etcd:2379
|
||||
|
Loading…
Reference in New Issue
Block a user