Feat/add milvus vector db (#1302)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
875dfbbf0e
commit
07aab5e868
@ -63,6 +63,13 @@ WEAVIATE_BATCH_SIZE=100
|
|||||||
QDRANT_URL=http://localhost:6333
|
QDRANT_URL=http://localhost:6333
|
||||||
QDRANT_API_KEY=difyai123456
|
QDRANT_API_KEY=difyai123456
|
||||||
|
|
||||||
|
# Milvus configuration
|
||||||
|
MILVUS_HOST=127.0.0.1
|
||||||
|
MILVUS_PORT=19530
|
||||||
|
MILVUS_USER=root
|
||||||
|
MILVUS_PASSWORD=Milvus
|
||||||
|
MILVUS_SECURE=false
|
||||||
|
|
||||||
# Mail configuration, support: resend
|
# Mail configuration, support: resend
|
||||||
MAIL_TYPE=
|
MAIL_TYPE=
|
||||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||||
|
@ -135,6 +135,14 @@ class Config:
|
|||||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||||
|
|
||||||
|
# milvus setting
|
||||||
|
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||||
|
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||||
|
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||||
|
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||||
|
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||||
|
|
||||||
|
|
||||||
# cors settings
|
# cors settings
|
||||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||||
|
860
api/core/index/vector_index/milvus.py
Normal file
860
api/core/index/vector_index/milvus.py
Normal file
@ -0,0 +1,860 @@
|
|||||||
|
"""Wrapper around the Milvus vector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MILVUS_CONNECTION = {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "19530",
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"secure": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Milvus(VectorStore):
|
||||||
|
"""Initialize wrapper around the milvus vector database.
|
||||||
|
|
||||||
|
In order to use this you need to have `pymilvus` installed and a
|
||||||
|
running Milvus
|
||||||
|
|
||||||
|
See the following documentation for how to run a Milvus instance:
|
||||||
|
https://milvus.io/docs/install_standalone-docker.md
|
||||||
|
|
||||||
|
If looking for a hosted Milvus, take a look at this documentation:
|
||||||
|
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
|
||||||
|
this project,
|
||||||
|
|
||||||
|
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_function (Embeddings): Function used to embed the text.
|
||||||
|
collection_name (str): Which Milvus collection to use. Defaults to
|
||||||
|
"LangChainCollection".
|
||||||
|
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||||
|
this class comes in the form of a dict.
|
||||||
|
consistency_level (str): The consistency level to use for a collection.
|
||||||
|
Defaults to "Session".
|
||||||
|
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||||
|
HNSW/AUTOINDEX depending on service.
|
||||||
|
search_params (Optional[dict]): Which search params to use. Defaults to
|
||||||
|
default of index.
|
||||||
|
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||||
|
to False.
|
||||||
|
|
||||||
|
The connection args used for this class comes in the form of a dict,
|
||||||
|
here are a few of the options:
|
||||||
|
address (str): The actual address of Milvus
|
||||||
|
instance. Example address: "localhost:19530"
|
||||||
|
uri (str): The uri of Milvus instance. Example uri:
|
||||||
|
"http://randomwebsite:19530",
|
||||||
|
"tcp:foobarsite:19530",
|
||||||
|
"https://ok.s3.south.com:19530".
|
||||||
|
host (str): The host of Milvus instance. Default at "localhost",
|
||||||
|
PyMilvus will fill in the default host if only port is provided.
|
||||||
|
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
|
||||||
|
will fill in the default port if only host is provided.
|
||||||
|
user (str): Use which user to connect to Milvus instance. If user and
|
||||||
|
password are provided, we will add related header in every RPC call.
|
||||||
|
password (str): Required when user is provided. The password
|
||||||
|
corresponding to the user.
|
||||||
|
secure (bool): Default is false. If set to true, tls will be enabled.
|
||||||
|
client_key_path (str): If use tls two-way authentication, need to
|
||||||
|
write the client.key path.
|
||||||
|
client_pem_path (str): If use tls two-way authentication, need to
|
||||||
|
write the client.pem path.
|
||||||
|
ca_pem_path (str): If use tls two-way authentication, need to write
|
||||||
|
the ca.pem path.
|
||||||
|
server_pem_path (str): If use tls one-way authentication, need to
|
||||||
|
write the server.pem path.
|
||||||
|
server_name (str): If use tls, need to write the common name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import Milvus
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embedding = OpenAIEmbeddings()
|
||||||
|
# Connect to a milvus instance on localhost
|
||||||
|
milvus_store = Milvus(
|
||||||
|
embedding_function = Embeddings,
|
||||||
|
collection_name = "LangChainCollection",
|
||||||
|
drop_old = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the pymilvus python package is not installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_function: Embeddings,
|
||||||
|
collection_name: str = "LangChainCollection",
|
||||||
|
connection_args: Optional[dict[str, Any]] = None,
|
||||||
|
consistency_level: str = "Session",
|
||||||
|
index_params: Optional[dict] = None,
|
||||||
|
search_params: Optional[dict] = None,
|
||||||
|
drop_old: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
"""Initialize the Milvus vector store."""
|
||||||
|
try:
|
||||||
|
from pymilvus import Collection, utility
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import pymilvus python package. "
|
||||||
|
"Please install it with `pip install pymilvus`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default search params when one is not provided.
|
||||||
|
self.default_search_params = {
|
||||||
|
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||||
|
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||||
|
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||||
|
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
|
||||||
|
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
|
||||||
|
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||||
|
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||||
|
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
|
||||||
|
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
|
||||||
|
"AUTOINDEX": {"metric_type": "L2", "params": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
self.embedding_func = embedding_function
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.index_params = index_params
|
||||||
|
self.search_params = search_params
|
||||||
|
self.consistency_level = consistency_level
|
||||||
|
|
||||||
|
# In order for a collection to be compatible, pk needs to be auto'id and int
|
||||||
|
self._primary_field = "id"
|
||||||
|
# In order for compatibility, the text field will need to be called "text"
|
||||||
|
self._text_field = "page_content"
|
||||||
|
# In order for compatibility, the vector field needs to be called "vector"
|
||||||
|
self._vector_field = "vectors"
|
||||||
|
# In order for compatibility, the metadata field will need to be called "metadata"
|
||||||
|
self._metadata_field = "metadata"
|
||||||
|
self.fields: list[str] = []
|
||||||
|
# Create the connection to the server
|
||||||
|
if connection_args is None:
|
||||||
|
connection_args = DEFAULT_MILVUS_CONNECTION
|
||||||
|
self.alias = self._create_connection_alias(connection_args)
|
||||||
|
self.col: Optional[Collection] = None
|
||||||
|
|
||||||
|
# Grab the existing collection if it exists
|
||||||
|
if utility.has_collection(self.collection_name, using=self.alias):
|
||||||
|
self.col = Collection(
|
||||||
|
self.collection_name,
|
||||||
|
using=self.alias,
|
||||||
|
)
|
||||||
|
# If need to drop old, drop it
|
||||||
|
if drop_old and isinstance(self.col, Collection):
|
||||||
|
self.col.drop()
|
||||||
|
self.col = None
|
||||||
|
|
||||||
|
# Initialize the vector store
|
||||||
|
self._init()
|
||||||
|
|
||||||
|
@property
|
||||||
|
|
||||||
|
|
||||||
|
def embeddings(self) -> Embeddings:
|
||||||
|
return self.embedding_func
|
||||||
|
|
||||||
|
def _create_connection_alias(self, connection_args: dict) -> str:
|
||||||
|
"""Create the connection to the Milvus server."""
|
||||||
|
from pymilvus import MilvusException, connections
|
||||||
|
|
||||||
|
# Grab the connection arguments that are used for checking existing connection
|
||||||
|
host: str = connection_args.get("host", None)
|
||||||
|
port: Union[str, int] = connection_args.get("port", None)
|
||||||
|
address: str = connection_args.get("address", None)
|
||||||
|
uri: str = connection_args.get("uri", None)
|
||||||
|
user = connection_args.get("user", None)
|
||||||
|
|
||||||
|
# Order of use is host/port, uri, address
|
||||||
|
if host is not None and port is not None:
|
||||||
|
given_address = str(host) + ":" + str(port)
|
||||||
|
elif uri is not None:
|
||||||
|
given_address = uri.split("https://")[1]
|
||||||
|
elif address is not None:
|
||||||
|
given_address = address
|
||||||
|
else:
|
||||||
|
given_address = None
|
||||||
|
logger.debug("Missing standard address type for reuse atttempt")
|
||||||
|
|
||||||
|
# User defaults to empty string when getting connection info
|
||||||
|
if user is not None:
|
||||||
|
tmp_user = user
|
||||||
|
else:
|
||||||
|
tmp_user = ""
|
||||||
|
|
||||||
|
# If a valid address was given, then check if a connection exists
|
||||||
|
if given_address is not None:
|
||||||
|
for con in connections.list_connections():
|
||||||
|
addr = connections.get_connection_addr(con[0])
|
||||||
|
if (
|
||||||
|
con[1]
|
||||||
|
and ("address" in addr)
|
||||||
|
and (addr["address"] == given_address)
|
||||||
|
and ("user" in addr)
|
||||||
|
and (addr["user"] == tmp_user)
|
||||||
|
):
|
||||||
|
logger.debug("Using previous connection: %s", con[0])
|
||||||
|
return con[0]
|
||||||
|
|
||||||
|
# Generate a new connection if one doesn't exist
|
||||||
|
alias = uuid4().hex
|
||||||
|
try:
|
||||||
|
connections.connect(alias=alias, **connection_args)
|
||||||
|
logger.debug("Created new connection using: %s", alias)
|
||||||
|
return alias
|
||||||
|
except MilvusException as e:
|
||||||
|
logger.error("Failed to create new connection using: %s", alias)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _init(
|
||||||
|
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
|
||||||
|
) -> None:
|
||||||
|
if embeddings is not None:
|
||||||
|
self._create_collection(embeddings, metadatas)
|
||||||
|
self._extract_fields()
|
||||||
|
self._create_index()
|
||||||
|
self._create_search_params()
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _create_collection(
|
||||||
|
self, embeddings: list, metadatas: Optional[list[dict]] = None
|
||||||
|
) -> None:
|
||||||
|
from pymilvus import (
|
||||||
|
Collection,
|
||||||
|
CollectionSchema,
|
||||||
|
DataType,
|
||||||
|
FieldSchema,
|
||||||
|
MilvusException,
|
||||||
|
)
|
||||||
|
from pymilvus.orm.types import infer_dtype_bydata
|
||||||
|
|
||||||
|
# Determine embedding dim
|
||||||
|
dim = len(embeddings[0])
|
||||||
|
fields = []
|
||||||
|
# Determine metadata schema
|
||||||
|
# if metadatas:
|
||||||
|
# # Create FieldSchema for each entry in metadata.
|
||||||
|
# for key, value in metadatas[0].items():
|
||||||
|
# # Infer the corresponding datatype of the metadata
|
||||||
|
# dtype = infer_dtype_bydata(value)
|
||||||
|
# # Datatype isn't compatible
|
||||||
|
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
|
||||||
|
# logger.error(
|
||||||
|
# "Failure to create collection, unrecognized dtype for key: %s",
|
||||||
|
# key,
|
||||||
|
# )
|
||||||
|
# raise ValueError(f"Unrecognized datatype for {key}.")
|
||||||
|
# # Dataype is a string/varchar equivalent
|
||||||
|
# elif dtype == DataType.VARCHAR:
|
||||||
|
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
|
||||||
|
# else:
|
||||||
|
# fields.append(FieldSchema(key, dtype))
|
||||||
|
if metadatas:
|
||||||
|
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
|
||||||
|
|
||||||
|
# Create the text field
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
|
||||||
|
)
|
||||||
|
# Create the primary key field
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(
|
||||||
|
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Create the vector field, supports binary or float vectors
|
||||||
|
fields.append(
|
||||||
|
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the schema for the collection
|
||||||
|
schema = CollectionSchema(fields)
|
||||||
|
|
||||||
|
# Create the collection
|
||||||
|
try:
|
||||||
|
self.col = Collection(
|
||||||
|
name=self.collection_name,
|
||||||
|
schema=schema,
|
||||||
|
consistency_level=self.consistency_level,
|
||||||
|
using=self.alias,
|
||||||
|
)
|
||||||
|
except MilvusException as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to create collection: %s error: %s", self.collection_name, e
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _extract_fields(self) -> None:
|
||||||
|
"""Grab the existing fields from the Collection"""
|
||||||
|
from pymilvus import Collection
|
||||||
|
|
||||||
|
if isinstance(self.col, Collection):
|
||||||
|
schema = self.col.schema
|
||||||
|
for x in schema.fields:
|
||||||
|
self.fields.append(x.name)
|
||||||
|
# Since primary field is auto-id, no need to track it
|
||||||
|
self.fields.remove(self._primary_field)
|
||||||
|
|
||||||
|
def _get_index(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Return the vector index information if it exists"""
|
||||||
|
from pymilvus import Collection
|
||||||
|
|
||||||
|
if isinstance(self.col, Collection):
|
||||||
|
for x in self.col.indexes:
|
||||||
|
if x.field_name == self._vector_field:
|
||||||
|
return x.to_dict()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_index(self) -> None:
|
||||||
|
"""Create a index on the collection"""
|
||||||
|
from pymilvus import Collection, MilvusException
|
||||||
|
|
||||||
|
if isinstance(self.col, Collection) and self._get_index() is None:
|
||||||
|
try:
|
||||||
|
# If no index params, use a default HNSW based one
|
||||||
|
if self.index_params is None:
|
||||||
|
self.index_params = {
|
||||||
|
"metric_type": "IP",
|
||||||
|
"index_type": "HNSW",
|
||||||
|
"params": {"M": 8, "efConstruction": 64},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
index_params=self.index_params,
|
||||||
|
using=self.alias,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If default did not work, most likely on Zilliz Cloud
|
||||||
|
except MilvusException:
|
||||||
|
# Use AUTOINDEX based index
|
||||||
|
self.index_params = {
|
||||||
|
"metric_type": "L2",
|
||||||
|
"index_type": "AUTOINDEX",
|
||||||
|
"params": {},
|
||||||
|
}
|
||||||
|
self.col.create_index(
|
||||||
|
self._vector_field,
|
||||||
|
index_params=self.index_params,
|
||||||
|
using=self.alias,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Successfully created an index on collection: %s",
|
||||||
|
self.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
except MilvusException as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to create an index on collection: %s", self.collection_name
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _create_search_params(self) -> None:
|
||||||
|
"""Generate search params based on the current index type"""
|
||||||
|
from pymilvus import Collection
|
||||||
|
|
||||||
|
if isinstance(self.col, Collection) and self.search_params is None:
|
||||||
|
index = self._get_index()
|
||||||
|
if index is not None:
|
||||||
|
index_type: str = index["index_param"]["index_type"]
|
||||||
|
metric_type: str = index["index_param"]["metric_type"]
|
||||||
|
self.search_params = self.default_search_params[index_type]
|
||||||
|
self.search_params["metric_type"] = metric_type
|
||||||
|
|
||||||
|
def _load(self) -> None:
|
||||||
|
"""Load the collection if available."""
|
||||||
|
from pymilvus import Collection
|
||||||
|
|
||||||
|
if isinstance(self.col, Collection) and self._get_index() is not None:
|
||||||
|
self.col.load()
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
batch_size: int = 1000,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Insert text data into Milvus.
|
||||||
|
|
||||||
|
Inserting data when the collection has not be made yet will result
|
||||||
|
in creating a new Collection. The data of the first entity decides
|
||||||
|
the schema of the new collection, the dim is extracted from the first
|
||||||
|
embedding and the columns are decided by the first metadata dict.
|
||||||
|
Metada keys will need to be present for all inserted values. At
|
||||||
|
the moment there is no None equivalent in Milvus.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (Iterable[str]): The texts to embed, it is assumed
|
||||||
|
that they all fit in memory.
|
||||||
|
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
|
||||||
|
the texts. Defaults to None.
|
||||||
|
timeout (Optional[int]): Timeout for each batch insert. Defaults
|
||||||
|
to None.
|
||||||
|
batch_size (int, optional): Batch size to use for insertion.
|
||||||
|
Defaults to 1000.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MilvusException: Failure to add texts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: The resulting keys for each inserted element.
|
||||||
|
"""
|
||||||
|
from pymilvus import Collection, MilvusException
|
||||||
|
|
||||||
|
texts = list(texts)
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = self.embedding_func.embed_documents(texts)
|
||||||
|
except NotImplementedError:
|
||||||
|
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||||
|
|
||||||
|
if len(embeddings) == 0:
|
||||||
|
logger.debug("Nothing to insert, skipping.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# If the collection hasn't been initialized yet, perform all steps to do so
|
||||||
|
if not isinstance(self.col, Collection):
|
||||||
|
self._init(embeddings, metadatas)
|
||||||
|
|
||||||
|
# Dict to hold all insert columns
|
||||||
|
insert_dict: dict[str, list] = {
|
||||||
|
self._text_field: texts,
|
||||||
|
self._vector_field: embeddings,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Collect the metadata into the insert dict.
|
||||||
|
# if metadatas is not None:
|
||||||
|
# for d in metadatas:
|
||||||
|
# for key, value in d.items():
|
||||||
|
# if key in self.fields:
|
||||||
|
# insert_dict.setdefault(key, []).append(value)
|
||||||
|
if metadatas is not None:
|
||||||
|
for d in metadatas:
|
||||||
|
insert_dict.setdefault(self._metadata_field, []).append(d)
|
||||||
|
|
||||||
|
# Total insert count
|
||||||
|
vectors: list = insert_dict[self._vector_field]
|
||||||
|
total_count = len(vectors)
|
||||||
|
|
||||||
|
pks: list[str] = []
|
||||||
|
|
||||||
|
assert isinstance(self.col, Collection)
|
||||||
|
for i in range(0, total_count, batch_size):
|
||||||
|
# Grab end index
|
||||||
|
end = min(i + batch_size, total_count)
|
||||||
|
# Convert dict to list of lists batch for insertion
|
||||||
|
insert_list = [insert_dict[x][i:end] for x in self.fields]
|
||||||
|
# Insert into the collection.
|
||||||
|
try:
|
||||||
|
res: Collection
|
||||||
|
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
|
||||||
|
pks.extend(res.primary_keys)
|
||||||
|
except MilvusException as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
return pks
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Perform a similarity search against the query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The text to search.
|
||||||
|
k (int, optional): How many results to return. Defaults to 4.
|
||||||
|
param (dict, optional): The search params for the index type.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): How long to wait before timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Collection.search() keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: Document results for search.
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
res = self.similarity_search_with_score(
|
||||||
|
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in res]
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Perform a similarity search against the query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (List[float]): The embedding vector to search.
|
||||||
|
k (int, optional): How many results to return. Defaults to 4.
|
||||||
|
param (dict, optional): The search params for the index type.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): How long to wait before timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Collection.search() keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: Document results for search.
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
res = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in res]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Perform a search on a query string and return results with score.
|
||||||
|
|
||||||
|
For more information about the search parameters, take a look at the pymilvus
|
||||||
|
documentation found here:
|
||||||
|
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The text being searched.
|
||||||
|
k (int, optional): The amount of results to return. Defaults to 4.
|
||||||
|
param (dict): The search params for the specified index.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): How long to wait before timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Collection.search() keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float], List[Tuple[Document, any, any]]:
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Embed the query text.
|
||||||
|
embedding = self.embedding_func.embed_query(query)
|
||||||
|
|
||||||
|
res = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _similarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs and relevance scores in the range [0, 1].
|
||||||
|
|
||||||
|
0 is dissimilar, 1 is most similar.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: input text
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||||
|
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||||
|
filter the resulting set of retrieved docs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Tuples of (doc, similarity_score)
|
||||||
|
"""
|
||||||
|
return self.similarity_search_with_score(query, k, **kwargs)
|
||||||
|
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Perform a search on a query string and return results with score.
|
||||||
|
|
||||||
|
For more information about the search parameters, take a look at the pymilvus
|
||||||
|
documentation found here:
|
||||||
|
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (List[float]): The embedding vector being searched.
|
||||||
|
k (int, optional): The amount of results to return. Defaults to 4.
|
||||||
|
param (dict): The search params for the specified index.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): How long to wait before timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Collection.search() keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[Document, float]]: Result doc and score.
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if param is None:
|
||||||
|
param = self.search_params
|
||||||
|
|
||||||
|
# Determine result metadata fields.
|
||||||
|
output_fields = self.fields[:]
|
||||||
|
output_fields.remove(self._vector_field)
|
||||||
|
|
||||||
|
# Perform the search.
|
||||||
|
res = self.col.search(
|
||||||
|
data=[embedding],
|
||||||
|
anns_field=self._vector_field,
|
||||||
|
param=param,
|
||||||
|
limit=k,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=output_fields,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Organize results.
|
||||||
|
ret = []
|
||||||
|
for result in res[0]:
|
||||||
|
meta = {x: result.entity.get(x) for x in output_fields}
|
||||||
|
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
|
||||||
|
pair = (doc, result.score)
|
||||||
|
ret.append(pair)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Perform a search and return results that are reordered by MMR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The text being searched.
|
||||||
|
k (int, optional): How many results to give. Defaults to 4.
|
||||||
|
fetch_k (int, optional): Total results to select k from.
|
||||||
|
Defaults to 20.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5
|
||||||
|
param (dict, optional): The search params for the specified index.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): How long to wait before timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Collection.search() keyword arguments.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: Document results for search.
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
embedding = self.embedding_func.embed_query(query)
|
||||||
|
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding=embedding,
|
||||||
|
k=k,
|
||||||
|
fetch_k=fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
param=param,
|
||||||
|
expr=expr,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
param: Optional[dict] = None,
|
||||||
|
expr: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Perform a search and return results that are reordered by MMR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (str): The embedding vector being searched.
|
||||||
|
k (int, optional): How many results to give. Defaults to 4.
|
||||||
|
fetch_k (int, optional): Total results to select k from.
|
||||||
|
Defaults to 20.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5
|
||||||
|
param (dict, optional): The search params for the specified index.
|
||||||
|
Defaults to None.
|
||||||
|
expr (str, optional): Filtering expression. Defaults to None.
|
||||||
|
timeout (int, optional): How long to wait before timeout error.
|
||||||
|
Defaults to None.
|
||||||
|
kwargs: Collection.search() keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: Document results for search.
|
||||||
|
"""
|
||||||
|
if self.col is None:
|
||||||
|
logger.debug("No existing collection to search.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if param is None:
|
||||||
|
param = self.search_params
|
||||||
|
|
||||||
|
# Determine result metadata fields.
|
||||||
|
output_fields = self.fields[:]
|
||||||
|
output_fields.remove(self._vector_field)
|
||||||
|
|
||||||
|
# Perform the search.
|
||||||
|
res = self.col.search(
|
||||||
|
data=[embedding],
|
||||||
|
anns_field=self._vector_field,
|
||||||
|
param=param,
|
||||||
|
limit=fetch_k,
|
||||||
|
expr=expr,
|
||||||
|
output_fields=output_fields,
|
||||||
|
timeout=timeout,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Organize results.
|
||||||
|
ids = []
|
||||||
|
documents = []
|
||||||
|
scores = []
|
||||||
|
for result in res[0]:
|
||||||
|
meta = {x: result.entity.get(x) for x in output_fields}
|
||||||
|
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
|
||||||
|
documents.append(doc)
|
||||||
|
scores.append(result.score)
|
||||||
|
ids.append(result.id)
|
||||||
|
|
||||||
|
vectors = self.col.query(
|
||||||
|
expr=f"{self._primary_field} in {ids}",
|
||||||
|
output_fields=[self._primary_field, self._vector_field],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
# Reorganize the results from query to match search order.
|
||||||
|
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
|
||||||
|
|
||||||
|
ordered_result_embeddings = [vectors[x] for x in ids]
|
||||||
|
|
||||||
|
# Get the new order of results.
|
||||||
|
new_ordering = maximal_marginal_relevance(
|
||||||
|
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reorder the values and return.
|
||||||
|
ret = []
|
||||||
|
for x in new_ordering:
|
||||||
|
# Function can return -1 index
|
||||||
|
if x == -1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
ret.append(documents[x])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
collection_name: str = "LangChainCollection",
|
||||||
|
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
|
||||||
|
consistency_level: str = "Session",
|
||||||
|
index_params: Optional[dict] = None,
|
||||||
|
search_params: Optional[dict] = None,
|
||||||
|
drop_old: bool = False,
|
||||||
|
batch_size: int = 100,
|
||||||
|
ids: Optional[Sequence[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Milvus:
|
||||||
|
"""Create a Milvus collection, indexes it with HNSW, and insert data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): Text data.
|
||||||
|
embedding (Embeddings): Embedding function.
|
||||||
|
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||||
|
Defaults to None.
|
||||||
|
collection_name (str, optional): Collection name to use. Defaults to
|
||||||
|
"LangChainCollection".
|
||||||
|
connection_args (dict[str, Any], optional): Connection args to use. Defaults
|
||||||
|
to DEFAULT_MILVUS_CONNECTION.
|
||||||
|
consistency_level (str, optional): Which consistency level to use. Defaults
|
||||||
|
to "Session".
|
||||||
|
index_params (Optional[dict], optional): Which index_params to use. Defaults
|
||||||
|
to None.
|
||||||
|
search_params (Optional[dict], optional): Which search params to use.
|
||||||
|
Defaults to None.
|
||||||
|
drop_old (Optional[bool], optional): Whether to drop the collection with
|
||||||
|
that name if it exists. Defaults to False.
|
||||||
|
batch_size:
|
||||||
|
How many vectors upload per-request.
|
||||||
|
Default: 100
|
||||||
|
ids: Optional[Sequence[str]] = None,
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Milvus: Milvus Vector Store
|
||||||
|
"""
|
||||||
|
vector_db = cls(
|
||||||
|
embedding_function=embedding,
|
||||||
|
collection_name=collection_name,
|
||||||
|
connection_args=connection_args,
|
||||||
|
consistency_level=consistency_level,
|
||||||
|
index_params=index_params,
|
||||||
|
search_params=search_params,
|
||||||
|
drop_old=drop_old,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
|
||||||
|
return vector_db
|
@ -9,30 +9,46 @@ from core.index.base import BaseIndex
|
|||||||
from core.index.vector_index.base import BaseVectorIndex
|
from core.index.vector_index.base import BaseVectorIndex
|
||||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||||
from models.dataset import Dataset
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, DatasetCollectionBinding
|
||||||
|
|
||||||
|
|
||||||
class MilvusConfig(BaseModel):
|
class MilvusConfig(BaseModel):
|
||||||
endpoint: str
|
host: str
|
||||||
|
port: int
|
||||||
user: str
|
user: str
|
||||||
password: str
|
password: str
|
||||||
|
secure: bool
|
||||||
batch_size: int = 100
|
batch_size: int = 100
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_config(cls, values: dict) -> dict:
|
def validate_config(cls, values: dict) -> dict:
|
||||||
if not values['endpoint']:
|
if not values['host']:
|
||||||
raise ValueError("config MILVUS_ENDPOINT is required")
|
raise ValueError("config MILVUS_HOST is required")
|
||||||
|
if not values['port']:
|
||||||
|
raise ValueError("config MILVUS_PORT is required")
|
||||||
|
if not values['secure']:
|
||||||
|
raise ValueError("config MILVUS_SECURE is required")
|
||||||
if not values['user']:
|
if not values['user']:
|
||||||
raise ValueError("config MILVUS_USER is required")
|
raise ValueError("config MILVUS_USER is required")
|
||||||
if not values['password']:
|
if not values['password']:
|
||||||
raise ValueError("config MILVUS_PASSWORD is required")
|
raise ValueError("config MILVUS_PASSWORD is required")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
def to_milvus_params(self):
|
||||||
|
return {
|
||||||
|
'host': self.host,
|
||||||
|
'port': self.port,
|
||||||
|
'user': self.user,
|
||||||
|
'password': self.password,
|
||||||
|
'secure': self.secure
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIndex(BaseVectorIndex):
|
class MilvusVectorIndex(BaseVectorIndex):
|
||||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||||
super().__init__(dataset, embeddings)
|
super().__init__(dataset, embeddings)
|
||||||
self._client = self._init_client(config)
|
self._client_config = config
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
return 'milvus'
|
return 'milvus'
|
||||||
@ -49,7 +65,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
|||||||
dataset_id = dataset.id
|
dataset_id = dataset.id
|
||||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||||
|
|
||||||
|
|
||||||
def to_index_struct(self) -> dict:
|
def to_index_struct(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"type": self.get_type(),
|
"type": self.get_type(),
|
||||||
@ -58,26 +73,29 @@ class MilvusVectorIndex(BaseVectorIndex):
|
|||||||
|
|
||||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||||
uuids = self._get_uuids(texts)
|
uuids = self._get_uuids(texts)
|
||||||
self._vector_store = WeaviateVectorStore.from_documents(
|
index_params = {
|
||||||
|
'metric_type': 'IP',
|
||||||
|
'index_type': "HNSW",
|
||||||
|
'params': {"M": 8, "efConstruction": 64}
|
||||||
|
}
|
||||||
|
self._vector_store = MilvusVectorStore.from_documents(
|
||||||
texts,
|
texts,
|
||||||
self._embeddings,
|
self._embeddings,
|
||||||
client=self._client,
|
collection_name=self.get_index_name(self.dataset),
|
||||||
index_name=self.get_index_name(self.dataset),
|
connection_args=self._client_config.to_milvus_params(),
|
||||||
uuids=uuids,
|
index_params=index_params
|
||||||
by_text=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||||
uuids = self._get_uuids(texts)
|
uuids = self._get_uuids(texts)
|
||||||
self._vector_store = WeaviateVectorStore.from_documents(
|
self._vector_store = MilvusVectorStore.from_documents(
|
||||||
texts,
|
texts,
|
||||||
self._embeddings,
|
self._embeddings,
|
||||||
client=self._client,
|
collection_name=collection_name,
|
||||||
index_name=collection_name,
|
ids=uuids,
|
||||||
uuids=uuids,
|
content_payload_key='page_content'
|
||||||
by_text=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
@ -86,42 +104,53 @@ class MilvusVectorIndex(BaseVectorIndex):
|
|||||||
"""Only for created index."""
|
"""Only for created index."""
|
||||||
if self._vector_store:
|
if self._vector_store:
|
||||||
return self._vector_store
|
return self._vector_store
|
||||||
|
|
||||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||||
if self._is_origin():
|
|
||||||
attributes = ['doc_id']
|
|
||||||
|
|
||||||
return WeaviateVectorStore(
|
return MilvusVectorStore(
|
||||||
client=self._client,
|
collection_name=self.get_index_name(self.dataset),
|
||||||
index_name=self.get_index_name(self.dataset),
|
embedding_function=self._embeddings,
|
||||||
text_key='text',
|
connection_args=self._client_config.to_milvus_params()
|
||||||
embedding=self._embeddings,
|
|
||||||
attributes=attributes,
|
|
||||||
by_text=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_vector_store_class(self) -> type:
|
def _get_vector_store_class(self) -> type:
|
||||||
return MilvusVectorStore
|
return MilvusVectorStore
|
||||||
|
|
||||||
def delete_by_document_id(self, document_id: str):
|
def delete_by_document_id(self, document_id: str):
|
||||||
if self._is_origin():
|
|
||||||
self.recreate_dataset(self.dataset)
|
vector_store = self._get_vector_store()
|
||||||
return
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
ids = vector_store.get_ids_by_document_id(document_id)
|
||||||
|
if ids:
|
||||||
|
vector_store.del_texts({
|
||||||
|
'filter': f'id in {ids}'
|
||||||
|
})
|
||||||
|
|
||||||
|
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||||
|
|
||||||
|
vector_store = self._get_vector_store()
|
||||||
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
ids = vector_store.get_ids_by_doc_ids(doc_ids)
|
||||||
|
vector_store.del_texts({
|
||||||
|
'filter': f' id in {ids}'
|
||||||
|
})
|
||||||
|
|
||||||
|
def delete_by_group_id(self, group_id: str) -> None:
|
||||||
|
|
||||||
vector_store = self._get_vector_store()
|
vector_store = self._get_vector_store()
|
||||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
|
|
||||||
vector_store.del_texts({
|
vector_store.delete()
|
||||||
"operator": "Equal",
|
|
||||||
"path": ["document_id"],
|
|
||||||
"valueText": document_id
|
|
||||||
})
|
|
||||||
|
|
||||||
def _is_origin(self):
|
def delete(self) -> None:
|
||||||
if self.dataset.index_struct_dict:
|
vector_store = self._get_vector_store()
|
||||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||||
if not class_prefix.endswith('_Node'):
|
|
||||||
# original class_prefix
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
from qdrant_client.http import models
|
||||||
|
vector_store.del_texts(models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="group_id",
|
||||||
|
match=models.MatchValue(value=self.dataset.id),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
))
|
||||||
|
@ -47,6 +47,20 @@ class VectorIndex:
|
|||||||
),
|
),
|
||||||
embeddings=embeddings
|
embeddings=embeddings
|
||||||
)
|
)
|
||||||
|
elif vector_type == "milvus":
|
||||||
|
from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
|
||||||
|
|
||||||
|
return MilvusVectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=MilvusConfig(
|
||||||
|
host=config.get('MILVUS_HOST'),
|
||||||
|
port=config.get('MILVUS_PORT'),
|
||||||
|
user=config.get('MILVUS_USER'),
|
||||||
|
password=config.get('MILVUS_PASSWORD'),
|
||||||
|
secure=config.get('MILVUS_SECURE'),
|
||||||
|
),
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from langchain.vectorstores import Milvus
|
from core.index.vector_index.milvus import Milvus
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorStore(Milvus):
|
class MilvusVectorStore(Milvus):
|
||||||
@ -6,33 +6,41 @@ class MilvusVectorStore(Milvus):
|
|||||||
if not where_filter:
|
if not where_filter:
|
||||||
raise ValueError('where_filter must not be empty')
|
raise ValueError('where_filter must not be empty')
|
||||||
|
|
||||||
self._client.batch.delete_objects(
|
self.col.delete(where_filter.get('filter'))
|
||||||
class_name=self._index_name,
|
|
||||||
where=where_filter,
|
|
||||||
output='minimal'
|
|
||||||
)
|
|
||||||
|
|
||||||
def del_text(self, uuid: str) -> None:
|
def del_text(self, uuid: str) -> None:
|
||||||
self._client.data_object.delete(
|
expr = f"id == {uuid}"
|
||||||
uuid,
|
self.col.delete(expr)
|
||||||
class_name=self._index_name
|
|
||||||
)
|
|
||||||
|
|
||||||
def text_exists(self, uuid: str) -> bool:
|
def text_exists(self, uuid: str) -> bool:
|
||||||
result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({
|
result = self.col.query(
|
||||||
"path": ["doc_id"],
|
expr=f'metadata["doc_id"] == "{uuid}"',
|
||||||
"operator": "Equal",
|
output_fields=["id"]
|
||||||
"valueText": uuid,
|
)
|
||||||
}).with_limit(1).do()
|
|
||||||
|
|
||||||
if "errors" in result:
|
return len(result) > 0
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
|
||||||
|
|
||||||
entries = result["data"]["Get"][self._index_name]
|
def get_ids_by_document_id(self, document_id: str):
|
||||||
if len(entries) == 0:
|
result = self.col.query(
|
||||||
return False
|
expr=f'metadata["document_id"] == "{document_id}"',
|
||||||
|
output_fields=["id"]
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return [item["id"] for item in result]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
return True
|
def get_ids_by_doc_ids(self, doc_ids: list):
|
||||||
|
result = self.col.query(
|
||||||
|
expr=f'metadata["doc_id"] in {doc_ids}',
|
||||||
|
output_fields=["id"]
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return [item["id"] for item in result]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def delete(self):
|
def delete(self):
|
||||||
self._client.schema.delete_class(self._index_name)
|
from pymilvus import utility
|
||||||
|
utility.drop_collection(self.collection_name, None, self.alias)
|
||||||
|
|
||||||
|
@ -52,4 +52,5 @@ pandas==1.5.3
|
|||||||
xinference==0.5.2
|
xinference==0.5.2
|
||||||
safetensors==0.3.2
|
safetensors==0.3.2
|
||||||
zhipuai==1.0.7
|
zhipuai==1.0.7
|
||||||
werkzeug==2.3.7
|
werkzeug==2.3.7
|
||||||
|
pymilvus==2.3.0
|
64
docker/milvus-standalone-docker-compose.yml
Normal file
64
docker/milvus-standalone-docker-compose.yml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
version: '3.5'
|
||||||
|
|
||||||
|
services:
|
||||||
|
etcd:
|
||||||
|
container_name: milvus-etcd
|
||||||
|
image: quay.io/coreos/etcd:v3.5.5
|
||||||
|
environment:
|
||||||
|
- ETCD_AUTO_COMPACTION_MODE=revision
|
||||||
|
- ETCD_AUTO_COMPACTION_RETENTION=1000
|
||||||
|
- ETCD_QUOTA_BACKEND_BYTES=4294967296
|
||||||
|
- ETCD_SNAPSHOT_COUNT=50000
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd
|
||||||
|
command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "etcdctl", "endpoint", "health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
minio:
|
||||||
|
container_name: milvus-minio
|
||||||
|
image: minio/minio:RELEASE.2023-03-20T20-16-18Z
|
||||||
|
environment:
|
||||||
|
MINIO_ACCESS_KEY: minioadmin
|
||||||
|
MINIO_SECRET_KEY: minioadmin
|
||||||
|
ports:
|
||||||
|
- "9001:9001"
|
||||||
|
- "9000:9000"
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data
|
||||||
|
command: minio server /minio_data --console-address ":9001"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
standalone:
|
||||||
|
container_name: milvus-standalone
|
||||||
|
image: milvusdb/milvus:v2.3.1
|
||||||
|
command: ["milvus", "run", "standalone"]
|
||||||
|
environment:
|
||||||
|
ETCD_ENDPOINTS: etcd:2379
|
||||||
|
MINIO_ADDRESS: minio:9000
|
||||||
|
common.security.authorizationEnabled: true
|
||||||
|
volumes:
|
||||||
|
- ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
|
||||||
|
interval: 30s
|
||||||
|
start_period: 90s
|
||||||
|
timeout: 20s
|
||||||
|
retries: 3
|
||||||
|
ports:
|
||||||
|
- "19530:19530"
|
||||||
|
- "9091:9091"
|
||||||
|
depends_on:
|
||||||
|
- "etcd"
|
||||||
|
- "minio"
|
||||||
|
|
||||||
|
networks:
|
||||||
|
default:
|
||||||
|
name: milvus
|
Loading…
Reference in New Issue
Block a user