diff --git a/api/.env.example b/api/.env.example index 3749010244..2b0a7be216 100644 --- a/api/.env.example +++ b/api/.env.example @@ -63,6 +63,13 @@ WEAVIATE_BATCH_SIZE=100 QDRANT_URL=http://localhost:6333 QDRANT_API_KEY=difyai123456 +# Milvus configuration +MILVUS_HOST=127.0.0.1 +MILVUS_PORT=19530 +MILVUS_USER=root +MILVUS_PASSWORD=Milvus +MILVUS_SECURE=false + # Mail configuration, support: resend MAIL_TYPE= MAIL_DEFAULT_SEND_FROM=no-reply diff --git a/api/config.py b/api/config.py index 5c875b02f4..4c28f00c5a 100644 --- a/api/config.py +++ b/api/config.py @@ -135,6 +135,14 @@ class Config: self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') + # milvus setting + self.MILVUS_HOST = get_env('MILVUS_HOST') + self.MILVUS_PORT = get_env('MILVUS_PORT') + self.MILVUS_USER = get_env('MILVUS_USER') + self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD') + self.MILVUS_SECURE = get_env('MILVUS_SECURE') + + # cors settings self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) diff --git a/api/core/index/vector_index/milvus.py b/api/core/index/vector_index/milvus.py new file mode 100644 index 0000000000..067d48b5a6 --- /dev/null +++ b/api/core/index/vector_index/milvus.py @@ -0,0 +1,860 @@ +"""Wrapper around the Milvus vector database.""" +from __future__ import annotations + +import logging +from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence +from uuid import uuid4 + +import numpy as np + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance + +logger = logging.getLogger(__name__) + +DEFAULT_MILVUS_CONNECTION = { + "host": "localhost", + "port": "19530", + "user": "", + "password": "", + "secure": False, +} + + +class Milvus(VectorStore): + """Initialize wrapper around the milvus vector database. + + In order to use this you need to have `pymilvus` installed and a + running Milvus + + See the following documentation for how to run a Milvus instance: + https://milvus.io/docs/install_standalone-docker.md + + If looking for a hosted Milvus, take a look at this documentation: + https://zilliz.com/cloud and make use of the Zilliz vectorstore found in + this project, + + IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. + + Args: + embedding_function (Embeddings): Function used to embed the text. + collection_name (str): Which Milvus collection to use. Defaults to + "LangChainCollection". + connection_args (Optional[dict[str, any]]): The connection args used for + this class comes in the form of a dict. + consistency_level (str): The consistency level to use for a collection. + Defaults to "Session". + index_params (Optional[dict]): Which index params to use. Defaults to + HNSW/AUTOINDEX depending on service. + search_params (Optional[dict]): Which search params to use. Defaults to + default of index. + drop_old (Optional[bool]): Whether to drop the current collection. Defaults + to False. + + The connection args used for this class comes in the form of a dict, + here are a few of the options: + address (str): The actual address of Milvus + instance. Example address: "localhost:19530" + uri (str): The uri of Milvus instance. Example uri: + "http://randomwebsite:19530", + "tcp:foobarsite:19530", + "https://ok.s3.south.com:19530". + host (str): The host of Milvus instance. Default at "localhost", + PyMilvus will fill in the default host if only port is provided. + port (str/int): The port of Milvus instance. Default at 19530, PyMilvus + will fill in the default port if only host is provided. + user (str): Use which user to connect to Milvus instance. If user and + password are provided, we will add related header in every RPC call. + password (str): Required when user is provided. The password + corresponding to the user. + secure (bool): Default is false. If set to true, tls will be enabled. + client_key_path (str): If use tls two-way authentication, need to + write the client.key path. + client_pem_path (str): If use tls two-way authentication, need to + write the client.pem path. + ca_pem_path (str): If use tls two-way authentication, need to write + the ca.pem path. + server_pem_path (str): If use tls one-way authentication, need to + write the server.pem path. + server_name (str): If use tls, need to write the common name. + + Example: + .. code-block:: python + + from langchain import Milvus + from langchain.embeddings import OpenAIEmbeddings + + embedding = OpenAIEmbeddings() + # Connect to a milvus instance on localhost + milvus_store = Milvus( + embedding_function = Embeddings, + collection_name = "LangChainCollection", + drop_old = True, + ) + + Raises: + ValueError: If the pymilvus python package is not installed. + """ + + def __init__( + self, + embedding_function: Embeddings, + collection_name: str = "LangChainCollection", + connection_args: Optional[dict[str, Any]] = None, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: Optional[bool] = False, + ): + """Initialize the Milvus vector store.""" + try: + from pymilvus import Collection, utility + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) + + # Default search params when one is not provided. + self.default_search_params = { + "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, + "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, + "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, + "AUTOINDEX": {"metric_type": "L2", "params": {}}, + } + + self.embedding_func = embedding_function + self.collection_name = collection_name + self.index_params = index_params + self.search_params = search_params + self.consistency_level = consistency_level + + # In order for a collection to be compatible, pk needs to be auto'id and int + self._primary_field = "id" + # In order for compatibility, the text field will need to be called "text" + self._text_field = "page_content" + # In order for compatibility, the vector field needs to be called "vector" + self._vector_field = "vectors" + # In order for compatibility, the metadata field will need to be called "metadata" + self._metadata_field = "metadata" + self.fields: list[str] = [] + # Create the connection to the server + if connection_args is None: + connection_args = DEFAULT_MILVUS_CONNECTION + self.alias = self._create_connection_alias(connection_args) + self.col: Optional[Collection] = None + + # Grab the existing collection if it exists + if utility.has_collection(self.collection_name, using=self.alias): + self.col = Collection( + self.collection_name, + using=self.alias, + ) + # If need to drop old, drop it + if drop_old and isinstance(self.col, Collection): + self.col.drop() + self.col = None + + # Initialize the vector store + self._init() + + @property + + + def embeddings(self) -> Embeddings: + return self.embedding_func + + def _create_connection_alias(self, connection_args: dict) -> str: + """Create the connection to the Milvus server.""" + from pymilvus import MilvusException, connections + + # Grab the connection arguments that are used for checking existing connection + host: str = connection_args.get("host", None) + port: Union[str, int] = connection_args.get("port", None) + address: str = connection_args.get("address", None) + uri: str = connection_args.get("uri", None) + user = connection_args.get("user", None) + + # Order of use is host/port, uri, address + if host is not None and port is not None: + given_address = str(host) + ":" + str(port) + elif uri is not None: + given_address = uri.split("https://")[1] + elif address is not None: + given_address = address + else: + given_address = None + logger.debug("Missing standard address type for reuse atttempt") + + # User defaults to empty string when getting connection info + if user is not None: + tmp_user = user + else: + tmp_user = "" + + # If a valid address was given, then check if a connection exists + if given_address is not None: + for con in connections.list_connections(): + addr = connections.get_connection_addr(con[0]) + if ( + con[1] + and ("address" in addr) + and (addr["address"] == given_address) + and ("user" in addr) + and (addr["user"] == tmp_user) + ): + logger.debug("Using previous connection: %s", con[0]) + return con[0] + + # Generate a new connection if one doesn't exist + alias = uuid4().hex + try: + connections.connect(alias=alias, **connection_args) + logger.debug("Created new connection using: %s", alias) + return alias + except MilvusException as e: + logger.error("Failed to create new connection using: %s", alias) + raise e + + def _init( + self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None + ) -> None: + if embeddings is not None: + self._create_collection(embeddings, metadatas) + self._extract_fields() + self._create_index() + self._create_search_params() + self._load() + + def _create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None + ) -> None: + from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusException, + ) + from pymilvus.orm.types import infer_dtype_bydata + + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + # Determine metadata schema + # if metadatas: + # # Create FieldSchema for each entry in metadata. + # for key, value in metadatas[0].items(): + # # Infer the corresponding datatype of the metadata + # dtype = infer_dtype_bydata(value) + # # Datatype isn't compatible + # if dtype == DataType.UNKNOWN or dtype == DataType.NONE: + # logger.error( + # "Failure to create collection, unrecognized dtype for key: %s", + # key, + # ) + # raise ValueError(f"Unrecognized datatype for {key}.") + # # Dataype is a string/varchar equivalent + # elif dtype == DataType.VARCHAR: + # fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) + # else: + # fields.append(FieldSchema(key, dtype)) + if metadatas: + fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535)) + + # Create the text field + fields.append( + FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + self._primary_field, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) + ) + + # Create the schema for the collection + schema = CollectionSchema(fields) + + # Create the collection + try: + self.col = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + using=self.alias, + ) + except MilvusException as e: + logger.error( + "Failed to create collection: %s error: %s", self.collection_name, e + ) + raise e + + def _extract_fields(self) -> None: + """Grab the existing fields from the Collection""" + from pymilvus import Collection + + if isinstance(self.col, Collection): + schema = self.col.schema + for x in schema.fields: + self.fields.append(x.name) + # Since primary field is auto-id, no need to track it + self.fields.remove(self._primary_field) + + def _get_index(self) -> Optional[dict[str, Any]]: + """Return the vector index information if it exists""" + from pymilvus import Collection + + if isinstance(self.col, Collection): + for x in self.col.indexes: + if x.field_name == self._vector_field: + return x.to_dict() + return None + + def _create_index(self) -> None: + """Create a index on the collection""" + from pymilvus import Collection, MilvusException + + if isinstance(self.col, Collection) and self._get_index() is None: + try: + # If no index params, use a default HNSW based one + if self.index_params is None: + self.index_params = { + "metric_type": "IP", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + + try: + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + + # If default did not work, most likely on Zilliz Cloud + except MilvusException: + # Use AUTOINDEX based index + self.index_params = { + "metric_type": "L2", + "index_type": "AUTOINDEX", + "params": {}, + } + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + logger.debug( + "Successfully created an index on collection: %s", + self.collection_name, + ) + + except MilvusException as e: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise e + + def _create_search_params(self) -> None: + """Generate search params based on the current index type""" + from pymilvus import Collection + + if isinstance(self.col, Collection) and self.search_params is None: + index = self._get_index() + if index is not None: + index_type: str = index["index_param"]["index_type"] + metric_type: str = index["index_param"]["metric_type"] + self.search_params = self.default_search_params[index_type] + self.search_params["metric_type"] = metric_type + + def _load(self) -> None: + """Load the collection if available.""" + from pymilvus import Collection + + if isinstance(self.col, Collection) and self._get_index() is not None: + self.col.load() + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + timeout: Optional[int] = None, + batch_size: int = 1000, + **kwargs: Any, + ) -> List[str]: + """Insert text data into Milvus. + + Inserting data when the collection has not be made yet will result + in creating a new Collection. The data of the first entity decides + the schema of the new collection, the dim is extracted from the first + embedding and the columns are decided by the first metadata dict. + Metada keys will need to be present for all inserted values. At + the moment there is no None equivalent in Milvus. + + Args: + texts (Iterable[str]): The texts to embed, it is assumed + that they all fit in memory. + metadatas (Optional[List[dict]]): Metadata dicts attached to each of + the texts. Defaults to None. + timeout (Optional[int]): Timeout for each batch insert. Defaults + to None. + batch_size (int, optional): Batch size to use for insertion. + Defaults to 1000. + + Raises: + MilvusException: Failure to add texts + + Returns: + List[str]: The resulting keys for each inserted element. + """ + from pymilvus import Collection, MilvusException + + texts = list(texts) + + try: + embeddings = self.embedding_func.embed_documents(texts) + except NotImplementedError: + embeddings = [self.embedding_func.embed_query(x) for x in texts] + + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + + # If the collection hasn't been initialized yet, perform all steps to do so + if not isinstance(self.col, Collection): + self._init(embeddings, metadatas) + + # Dict to hold all insert columns + insert_dict: dict[str, list] = { + self._text_field: texts, + self._vector_field: embeddings, + } + + # Collect the metadata into the insert dict. + # if metadatas is not None: + # for d in metadatas: + # for key, value in d.items(): + # if key in self.fields: + # insert_dict.setdefault(key, []).append(value) + if metadatas is not None: + for d in metadatas: + insert_dict.setdefault(self._metadata_field, []).append(d) + + # Total insert count + vectors: list = insert_dict[self._vector_field] + total_count = len(vectors) + + pks: list[str] = [] + + assert isinstance(self.col, Collection) + for i in range(0, total_count, batch_size): + # Grab end index + end = min(i + batch_size, total_count) + # Convert dict to list of lists batch for insertion + insert_list = [insert_dict[x][i:end] for x in self.fields] + # Insert into the collection. + try: + res: Collection + res = self.col.insert(insert_list, timeout=timeout, **kwargs) + pks.extend(res.primary_keys) + except MilvusException as e: + logger.error( + "Failed to insert batch starting at entity: %s/%s", i, total_count + ) + raise e + return pks + + def similarity_search( + self, + query: str, + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + query (str): The text to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score( + query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + embedding (List[float]): The embedding vector to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + query (str): The text being searched. + k (int, optional): The amount of results to return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[float], List[Tuple[Document, any, any]]: + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + # Embed the query text. + embedding = self.embedding_func.embed_query(query) + + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return res + + def _similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and relevance scores in the range [0, 1]. + + 0 is dissimilar, 1 is most similar. + + Args: + query: input text + k: Number of Documents to return. Defaults to 4. + **kwargs: kwargs to be passed to similarity search. Should include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of Tuples of (doc, similarity_score) + """ + return self.similarity_search_with_score(query, k, **kwargs) + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + embedding (List[float]): The embedding vector being searched. + k (int, optional): The amount of results to return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Tuple[Document, float]]: Result doc and score. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + # Organize results. + ret = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata')) + pair = (doc, result.score) + ret.append(pair) + + return ret + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. + + Args: + query (str): The text being searched. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5 + param (dict, optional): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + embedding = self.embedding_func.embed_query(query) + + return self.max_marginal_relevance_search_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + param=param, + expr=expr, + timeout=timeout, + **kwargs, + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. + + Args: + embedding (str): The embedding vector being searched. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5 + param (dict, optional): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=fetch_k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + # Organize results. + ids = [] + documents = [] + scores = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + documents.append(doc) + scores.append(result.score) + ids.append(result.id) + + vectors = self.col.query( + expr=f"{self._primary_field} in {ids}", + output_fields=[self._primary_field, self._vector_field], + timeout=timeout, + ) + # Reorganize the results from query to match search order. + vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} + + ordered_result_embeddings = [vectors[x] for x in ids] + + # Get the new order of results. + new_ordering = maximal_marginal_relevance( + np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult + ) + + # Reorder the values and return. + ret = [] + for x in new_ordering: + # Function can return -1 index + if x == -1: + break + else: + ret.append(documents[x]) + return ret + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = "LangChainCollection", + connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: bool = False, + batch_size: int = 100, + ids: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> Milvus: + """Create a Milvus collection, indexes it with HNSW, and insert data. + + Args: + texts (List[str]): Text data. + embedding (Embeddings): Embedding function. + metadatas (Optional[List[dict]]): Metadata for each text if it exists. + Defaults to None. + collection_name (str, optional): Collection name to use. Defaults to + "LangChainCollection". + connection_args (dict[str, Any], optional): Connection args to use. Defaults + to DEFAULT_MILVUS_CONNECTION. + consistency_level (str, optional): Which consistency level to use. Defaults + to "Session". + index_params (Optional[dict], optional): Which index_params to use. Defaults + to None. + search_params (Optional[dict], optional): Which search params to use. + Defaults to None. + drop_old (Optional[bool], optional): Whether to drop the collection with + that name if it exists. Defaults to False. + batch_size: + How many vectors upload per-request. + Default: 100 + ids: Optional[Sequence[str]] = None, + + Returns: + Milvus: Milvus Vector Store + """ + vector_db = cls( + embedding_function=embedding, + collection_name=collection_name, + connection_args=connection_args, + consistency_level=consistency_level, + index_params=index_params, + search_params=search_params, + drop_old=drop_old, + **kwargs, + ) + vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size) + return vector_db diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index abf57f5529..27109b602c 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -9,30 +9,46 @@ from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.milvus_vector_store import MilvusVectorStore from core.vector_store.weaviate_vector_store import WeaviateVectorStore -from models.dataset import Dataset +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding class MilvusConfig(BaseModel): - endpoint: str + host: str + port: int user: str password: str + secure: bool batch_size: int = 100 @root_validator() def validate_config(cls, values: dict) -> dict: - if not values['endpoint']: - raise ValueError("config MILVUS_ENDPOINT is required") + if not values['host']: + raise ValueError("config MILVUS_HOST is required") + if not values['port']: + raise ValueError("config MILVUS_PORT is required") + if not values['secure']: + raise ValueError("config MILVUS_SECURE is required") if not values['user']: raise ValueError("config MILVUS_USER is required") if not values['password']: raise ValueError("config MILVUS_PASSWORD is required") return values + def to_milvus_params(self): + return { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'password': self.password, + 'secure': self.secure + } + class MilvusVectorIndex(BaseVectorIndex): def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings): super().__init__(dataset, embeddings) - self._client = self._init_client(config) + self._client_config = config def get_type(self) -> str: return 'milvus' @@ -49,7 +65,6 @@ class MilvusVectorIndex(BaseVectorIndex): dataset_id = dataset.id return "Vector_index_" + dataset_id.replace("-", "_") + '_Node' - def to_index_struct(self) -> dict: return { "type": self.get_type(), @@ -58,26 +73,29 @@ class MilvusVectorIndex(BaseVectorIndex): def create(self, texts: list[Document], **kwargs) -> BaseIndex: uuids = self._get_uuids(texts) - self._vector_store = WeaviateVectorStore.from_documents( + index_params = { + 'metric_type': 'IP', + 'index_type': "HNSW", + 'params': {"M": 8, "efConstruction": 64} + } + self._vector_store = MilvusVectorStore.from_documents( texts, self._embeddings, - client=self._client, - index_name=self.get_index_name(self.dataset), - uuids=uuids, - by_text=False + collection_name=self.get_index_name(self.dataset), + connection_args=self._client_config.to_milvus_params(), + index_params=index_params ) return self def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex: uuids = self._get_uuids(texts) - self._vector_store = WeaviateVectorStore.from_documents( + self._vector_store = MilvusVectorStore.from_documents( texts, self._embeddings, - client=self._client, - index_name=collection_name, - uuids=uuids, - by_text=False + collection_name=collection_name, + ids=uuids, + content_payload_key='page_content' ) return self @@ -86,42 +104,53 @@ class MilvusVectorIndex(BaseVectorIndex): """Only for created index.""" if self._vector_store: return self._vector_store - attributes = ['doc_id', 'dataset_id', 'document_id'] - if self._is_origin(): - attributes = ['doc_id'] - return WeaviateVectorStore( - client=self._client, - index_name=self.get_index_name(self.dataset), - text_key='text', - embedding=self._embeddings, - attributes=attributes, - by_text=False + return MilvusVectorStore( + collection_name=self.get_index_name(self.dataset), + embedding_function=self._embeddings, + connection_args=self._client_config.to_milvus_params() ) def _get_vector_store_class(self) -> type: return MilvusVectorStore def delete_by_document_id(self, document_id: str): - if self._is_origin(): - self.recreate_dataset(self.dataset) - return + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + ids = vector_store.get_ids_by_document_id(document_id) + if ids: + vector_store.del_texts({ + 'filter': f'id in {ids}' + }) + + def delete_by_ids(self, doc_ids: list[str]) -> None: + + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) + ids = vector_store.get_ids_by_doc_ids(doc_ids) + vector_store.del_texts({ + 'filter': f' id in {ids}' + }) + + def delete_by_group_id(self, group_id: str) -> None: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) - vector_store.del_texts({ - "operator": "Equal", - "path": ["document_id"], - "valueText": document_id - }) + vector_store.delete() - def _is_origin(self): - if self.dataset.index_struct_dict: - class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] - if not class_prefix.endswith('_Node'): - # original class_prefix - return True + def delete(self) -> None: + vector_store = self._get_vector_store() + vector_store = cast(self._get_vector_store_class(), vector_store) - return False + from qdrant_client.http import models + vector_store.del_texts(models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self.dataset.id), + ), + ], + )) diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py index ffc7aa17b6..dd3ab272e0 100644 --- a/api/core/index/vector_index/vector_index.py +++ b/api/core/index/vector_index/vector_index.py @@ -47,6 +47,20 @@ class VectorIndex: ), embeddings=embeddings ) + elif vector_type == "milvus": + from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig + + return MilvusVectorIndex( + dataset=dataset, + config=MilvusConfig( + host=config.get('MILVUS_HOST'), + port=config.get('MILVUS_PORT'), + user=config.get('MILVUS_USER'), + password=config.get('MILVUS_PASSWORD'), + secure=config.get('MILVUS_SECURE'), + ), + embeddings=embeddings + ) else: raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") diff --git a/api/core/vector_store/milvus_vector_store.py b/api/core/vector_store/milvus_vector_store.py index cc84459c12..a70445dd4c 100644 --- a/api/core/vector_store/milvus_vector_store.py +++ b/api/core/vector_store/milvus_vector_store.py @@ -1,4 +1,4 @@ -from langchain.vectorstores import Milvus +from core.index.vector_index.milvus import Milvus class MilvusVectorStore(Milvus): @@ -6,33 +6,41 @@ class MilvusVectorStore(Milvus): if not where_filter: raise ValueError('where_filter must not be empty') - self._client.batch.delete_objects( - class_name=self._index_name, - where=where_filter, - output='minimal' - ) + self.col.delete(where_filter.get('filter')) def del_text(self, uuid: str) -> None: - self._client.data_object.delete( - uuid, - class_name=self._index_name - ) + expr = f"id == {uuid}" + self.col.delete(expr) def text_exists(self, uuid: str) -> bool: - result = self._client.query.get(self._index_name).with_additional(["id"]).with_where({ - "path": ["doc_id"], - "operator": "Equal", - "valueText": uuid, - }).with_limit(1).do() + result = self.col.query( + expr=f'metadata["doc_id"] == "{uuid}"', + output_fields=["id"] + ) - if "errors" in result: - raise ValueError(f"Error during query: {result['errors']}") + return len(result) > 0 - entries = result["data"]["Get"][self._index_name] - if len(entries) == 0: - return False + def get_ids_by_document_id(self, document_id: str): + result = self.col.query( + expr=f'metadata["document_id"] == "{document_id}"', + output_fields=["id"] + ) + if result: + return [item["id"] for item in result] + else: + return None - return True + def get_ids_by_doc_ids(self, doc_ids: list): + result = self.col.query( + expr=f'metadata["doc_id"] in {doc_ids}', + output_fields=["id"] + ) + if result: + return [item["id"] for item in result] + else: + return None def delete(self): - self._client.schema.delete_class(self._index_name) + from pymilvus import utility + utility.drop_collection(self.collection_name, None, self.alias) + diff --git a/api/requirements.txt b/api/requirements.txt index a35a1b172b..a1ca193ff8 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -52,4 +52,5 @@ pandas==1.5.3 xinference==0.5.2 safetensors==0.3.2 zhipuai==1.0.7 -werkzeug==2.3.7 \ No newline at end of file +werkzeug==2.3.7 +pymilvus==2.3.0 \ No newline at end of file diff --git a/docker/milvus-standalone-docker-compose.yml b/docker/milvus-standalone-docker-compose.yml new file mode 100644 index 0000000000..ae2846c817 --- /dev/null +++ b/docker/milvus-standalone-docker-compose.yml @@ -0,0 +1,64 @@ +version: '3.5' + +services: + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.5 + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + healthcheck: + test: ["CMD", "etcdctl", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + ports: + - "9001:9001" + - "9000:9000" + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.3.1 + command: ["milvus", "run", "standalone"] + environment: + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 + common.security.authorizationEnabled: true + volumes: + - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/milvus:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + ports: + - "19530:19530" + - "9091:9091" + depends_on: + - "etcd" + - "minio" + +networks: + default: + name: milvus