|
|
@ -2,8 +2,12 @@ import logging
|
|
|
|
from typing import Any, Optional
|
|
|
|
from typing import Any, Optional
|
|
|
|
from uuid import uuid4
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langdetect import detect, LangDetectException
|
|
|
|
|
|
|
|
from milvus_model.sparse import BM25EmbeddingFunction
|
|
|
|
|
|
|
|
from milvus_model.sparse.bm25.tokenizers import build_default_analyzer
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
from pymilvus import MilvusClient, MilvusException, connections
|
|
|
|
from pymilvus import MilvusClient, MilvusException, connections
|
|
|
|
|
|
|
|
from pymilvus.milvus_client import IndexParams
|
|
|
|
|
|
|
|
|
|
|
|
from core.rag.datasource.vdb.field import Field
|
|
|
|
from core.rag.datasource.vdb.field import Field
|
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector
|
|
|
@ -12,6 +16,19 @@ from extensions.ext_redis import redis_client
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LANGUAGE_CODE_MAP = {
|
|
|
|
|
|
|
|
'en': 'en', # English
|
|
|
|
|
|
|
|
'de': 'de', # German
|
|
|
|
|
|
|
|
'fr': 'fr', # French
|
|
|
|
|
|
|
|
'ru': 'ru', # Russian
|
|
|
|
|
|
|
|
'es': 'sp', # Spanish (langdetect returns 'es' for Spanish)
|
|
|
|
|
|
|
|
'it': 'it', # Italian
|
|
|
|
|
|
|
|
'pt': 'pt', # Portuguese
|
|
|
|
|
|
|
|
'zh-cn': 'zh', # Chinese (Simplified)
|
|
|
|
|
|
|
|
'zh-tw': 'zh', # Chinese (Traditional)
|
|
|
|
|
|
|
|
'ja': 'jp', # Japanese
|
|
|
|
|
|
|
|
'ko': 'kr', # Korean
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class MilvusConfig(BaseModel):
|
|
|
|
class MilvusConfig(BaseModel):
|
|
|
|
host: str
|
|
|
|
host: str
|
|
|
@ -70,9 +87,11 @@ class MilvusVector(BaseVector):
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
|
|
|
insert_dict_list = []
|
|
|
|
insert_dict_list = []
|
|
|
|
for i in range(len(documents)):
|
|
|
|
for i in range(len(documents)):
|
|
|
|
|
|
|
|
bm25_ef = self._bm25_document_encode(documents[i].page_content)
|
|
|
|
insert_dict = {
|
|
|
|
insert_dict = {
|
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content,
|
|
|
|
Field.VECTOR.value: embeddings[i],
|
|
|
|
Field.VECTOR.value: embeddings[i],
|
|
|
|
|
|
|
|
Field.BM25_KEY.value: bm25_ef,
|
|
|
|
Field.METADATA_KEY.value: documents[i].metadata
|
|
|
|
Field.METADATA_KEY.value: documents[i].metadata
|
|
|
|
}
|
|
|
|
}
|
|
|
|
insert_dict_list.append(insert_dict)
|
|
|
|
insert_dict_list.append(insert_dict)
|
|
|
@ -171,7 +190,6 @@ class MilvusVector(BaseVector):
|
|
|
|
result = self._client.query(collection_name=self._collection_name,
|
|
|
|
result = self._client.query(collection_name=self._collection_name,
|
|
|
|
filter=f'metadata["doc_id"] == "{id}"',
|
|
|
|
filter=f'metadata["doc_id"] == "{id}"',
|
|
|
|
output_fields=["id"])
|
|
|
|
output_fields=["id"])
|
|
|
|
|
|
|
|
|
|
|
|
return len(result) > 0
|
|
|
|
return len(result) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
|
|
@ -180,6 +198,7 @@ class MilvusVector(BaseVector):
|
|
|
|
results = self._client.search(collection_name=self._collection_name,
|
|
|
|
results = self._client.search(collection_name=self._collection_name,
|
|
|
|
data=[query_vector],
|
|
|
|
data=[query_vector],
|
|
|
|
limit=kwargs.get('top_k', 4),
|
|
|
|
limit=kwargs.get('top_k', 4),
|
|
|
|
|
|
|
|
anns_field=Field.VECTOR.value,
|
|
|
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
|
|
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# Organize results.
|
|
|
|
# Organize results.
|
|
|
@ -195,7 +214,18 @@ class MilvusVector(BaseVector):
|
|
|
|
return docs
|
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
|
|
|
# milvus/zilliz doesn't support bm25 search
|
|
|
|
bm25_ef = self._bm25_query_encode(query)
|
|
|
|
|
|
|
|
search_params = {
|
|
|
|
|
|
|
|
"metric_type": "IP",
|
|
|
|
|
|
|
|
"params": {"drop_ratio_search": 0.2}, # the ratio of small vector values to be dropped during search.
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
results = self._client.search(collection_name=self._collection_name,
|
|
|
|
|
|
|
|
data=bm25_ef,
|
|
|
|
|
|
|
|
limit=kwargs.get('top_k', 4),
|
|
|
|
|
|
|
|
anns_field=Field.BM25_KEY.value,
|
|
|
|
|
|
|
|
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
|
|
|
|
|
|
|
|
search_params=search_params
|
|
|
|
|
|
|
|
)
|
|
|
|
return []
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
def create_collection(
|
|
|
|
def create_collection(
|
|
|
@ -229,6 +259,10 @@ class MilvusVector(BaseVector):
|
|
|
|
fields.append(
|
|
|
|
fields.append(
|
|
|
|
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
|
|
|
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create the bm25 field
|
|
|
|
|
|
|
|
fields.append(
|
|
|
|
|
|
|
|
FieldSchema(Field.BM25_KEY.value, DataType.SPARSE_FLOAT_VECTOR)
|
|
|
|
|
|
|
|
)
|
|
|
|
# Create the primary key field
|
|
|
|
# Create the primary key field
|
|
|
|
fields.append(
|
|
|
|
fields.append(
|
|
|
|
FieldSchema(
|
|
|
|
FieldSchema(
|
|
|
@ -247,12 +281,23 @@ class MilvusVector(BaseVector):
|
|
|
|
self._fields.append(x.name)
|
|
|
|
self._fields.append(x.name)
|
|
|
|
# Since primary field is auto-id, no need to track it
|
|
|
|
# Since primary field is auto-id, no need to track it
|
|
|
|
self._fields.remove(Field.PRIMARY_KEY.value)
|
|
|
|
self._fields.remove(Field.PRIMARY_KEY.value)
|
|
|
|
|
|
|
|
# Create Index params for the collection
|
|
|
|
|
|
|
|
index_params_obj = IndexParams()
|
|
|
|
|
|
|
|
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)
|
|
|
|
|
|
|
|
index_params_obj.add_index(
|
|
|
|
|
|
|
|
field_name=Field.BM25_KEY.value,
|
|
|
|
|
|
|
|
index_name=f'{self._collection_name}_sparse_index',
|
|
|
|
|
|
|
|
index_type="SPARSE_INVERTED_INDEX",
|
|
|
|
|
|
|
|
# the type of index to be created. set to `SPARSE_INVERTED_INDEX` or `SPARSE_WAND`.
|
|
|
|
|
|
|
|
metric_type="IP",
|
|
|
|
|
|
|
|
# the metric type to be used for the index. Currently, only `IP` (Inner Product) is supported.
|
|
|
|
|
|
|
|
params={"drop_ratio_build": 0.2}, # the ratio of small vector values to be dropped during indexing.
|
|
|
|
|
|
|
|
)
|
|
|
|
# Create the collection
|
|
|
|
# Create the collection
|
|
|
|
collection_name = self._collection_name
|
|
|
|
collection_name = self._collection_name
|
|
|
|
self._client.create_collection_with_schema(collection_name=collection_name,
|
|
|
|
self._client.create_collection(collection_name=collection_name,
|
|
|
|
schema=schema, index_param=index_params,
|
|
|
|
schema=schema, index_params=index_params_obj,
|
|
|
|
consistency_level=self._consistency_level)
|
|
|
|
consistency_level=self._consistency_level)
|
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
|
|
|
def _init_client(self, config) -> MilvusClient:
|
|
|
|
def _init_client(self, config) -> MilvusClient:
|
|
|
|
if config.secure:
|
|
|
|
if config.secure:
|
|
|
@ -261,3 +306,25 @@ class MilvusVector(BaseVector):
|
|
|
|
uri = "http://" + str(config.host) + ":" + str(config.port)
|
|
|
|
uri = "http://" + str(config.host) + ":" + str(config.port)
|
|
|
|
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
|
|
|
|
client = MilvusClient(uri=uri, user=config.user, password=config.password,db_name=config.database)
|
|
|
|
return client
|
|
|
|
return client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bm25_document_encode(self, text: str):
|
|
|
|
|
|
|
|
language = self._get_text_language(text)
|
|
|
|
|
|
|
|
analyzer = build_default_analyzer(language=language)
|
|
|
|
|
|
|
|
bm25_ef = BM25EmbeddingFunction(analyzer)
|
|
|
|
|
|
|
|
bm25_ef.fit([text])
|
|
|
|
|
|
|
|
docs_embeddings = bm25_ef.encode_documents([text])
|
|
|
|
|
|
|
|
return docs_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bm25_query_encode(self, text: str):
|
|
|
|
|
|
|
|
language = self._get_text_language(text)
|
|
|
|
|
|
|
|
analyzer = build_default_analyzer(language=language)
|
|
|
|
|
|
|
|
bm25_ef = BM25EmbeddingFunction(analyzer)
|
|
|
|
|
|
|
|
docs_embeddings = bm25_ef.encode_queries([text])
|
|
|
|
|
|
|
|
return docs_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_text_language(self, text: str) -> str:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
detected_language = detect(text)
|
|
|
|
|
|
|
|
return LANGUAGE_CODE_MAP.get(detected_language, "en")
|
|
|
|
|
|
|
|
except LangDetectException:
|
|
|
|
|
|
|
|
return "en"
|
|
|
|