From 0620fa3094578e93093adf6ff41bd3bedd7863bd Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 26 Feb 2024 19:47:29 +0800 Subject: [PATCH] Feat/vdb migrate command (#2562) Co-authored-by: jyong --- api/commands.py | 159 ++++++++++++------ api/core/indexing_runner.py | 1 + .../datasource/vdb/milvus/milvus_vector.py | 8 +- api/core/rag/datasource/vdb/vector_factory.py | 18 ++ .../vdb/weaviate/weaviate_vector.py | 5 +- 5 files changed, 134 insertions(+), 57 deletions(-) diff --git a/api/commands.py b/api/commands.py index 91b50445e6..62b3552dff 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,20 +1,21 @@ import base64 import json import secrets +from typing import cast import click from flask import current_app from werkzeug.exceptions import NotFound -from core.embedding.cached_embedding import CacheEmbedding -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from extensions.ext_database import db from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair from models.account import Tenant -from models.dataset import Dataset +from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment +from models.dataset import Document as DatasetDocument from models.model import Account from models.provider import Provider, ProviderModel @@ -124,14 +125,15 @@ def reset_encrypt_key_pair(): 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) -@click.command('create-qdrant-indexes', help='Create qdrant indexes.') -def create_qdrant_indexes(): +@click.command('vdb-migrate', help='migrate vector db.') +def vdb_migrate(): """ - Migrate other vector database datas to Qdrant. + Migrate vector database datas to target vector database . """ - click.echo(click.style('Start create qdrant indexes.', fg='green')) + click.echo(click.style('Start migrate vector db.', fg='green')) create_count = 0 - + config = cast(dict, current_app.config) + vector_type = config.get('VECTOR_STORE') page = 1 while True: try: @@ -140,54 +142,101 @@ def create_qdrant_indexes(): except NotFound: break - model_manager = ModelManager() - page += 1 for dataset in datasets: - if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] != 'qdrant': - try: - click.echo('Create dataset qdrant index: {}'.format(dataset.id)) - try: - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - - ) - except Exception: - continue - embeddings = CacheEmbedding(embedding_model) - - from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex - - index = QdrantVectorIndex( - dataset=dataset, - config=QdrantConfig( - endpoint=current_app.config.get('QDRANT_URL'), - api_key=current_app.config.get('QDRANT_API_KEY'), - root_path=current_app.root_path - ), - embeddings=embeddings - ) - if index: - index.create_qdrant_dataset(dataset) - index_struct = { - "type": 'qdrant', - "vector_store": { - "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} - } - dataset.index_struct = json.dumps(index_struct) - db.session.commit() - create_count += 1 - else: - click.echo('passed.') - except Exception as e: - click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + try: + click.echo('Create dataset vdb index: {}'.format(dataset.id)) + if dataset.index_struct_dict: + if dataset.index_struct_dict['type'] == vector_type: continue + if vector_type == "weaviate": + dataset_id = dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'weaviate', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == "qdrant": + if dataset.collection_binding_id: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ + one_or_none() + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError('Dataset Collection Bindings is not exist!') + else: + dataset_id = dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'qdrant', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) + + elif vector_type == "milvus": + dataset_id = dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'milvus', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + vector = Vector(dataset) + click.echo(f"vdb_migrate {dataset.id}") + + try: + vector.delete() + except Exception as e: + raise e + + dataset_documents = db.session.query(DatasetDocument).filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == 'completed', + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).all() + + documents = [] + for dataset_document in dataset_documents: + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True + ).all() + + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + documents.append(document) + + if documents: + try: + vector.create(documents) + except Exception as e: + raise e + click.echo(f"Dataset {dataset.id} create successfully.") + db.session.add(dataset) + db.session.commit() + create_count += 1 + except Exception as e: + db.session.rollback() + click.echo( + click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) + continue click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) @@ -196,4 +245,4 @@ def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) - app.cli.add_command(create_qdrant_indexes) + app.cli.add_command(vdb_migrate) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index d2d04c984b..68bb294a18 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -664,6 +664,7 @@ class IndexingRunner: ) # load index index_processor.load(dataset, chunk_documents) + db.session.add(dataset) document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 9a251ede97..bb12ef1b56 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -127,9 +127,15 @@ class MilvusVector(BaseVector): self._client.delete(collection_name=self._collection_name, pks=doc_ids) def delete(self) -> None: + alias = uuid4().hex + if self._client_config.secure: + uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) + else: + uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) from pymilvus import utility - utility.drop_collection(self._collection_name, None) + utility.drop_collection(self._collection_name, None, using=alias) def text_exists(self, id: str) -> bool: diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index dd8fc93041..619f7d6487 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,4 @@ +import json from typing import Any, cast from flask import current_app @@ -39,6 +40,11 @@ class Vector: else: dataset_id = self._dataset.id collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'weaviate', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( @@ -66,6 +72,13 @@ class Vector: dataset_id = self._dataset.id collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + if not self._dataset.index_struct_dict: + index_struct_dict = { + "type": 'qdrant', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) + return QdrantVector( collection_name=collection_name, group_id=self._dataset.id, @@ -84,6 +97,11 @@ class Vector: else: dataset_id = self._dataset.id collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'milvus', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) return MilvusVector( collection_name=collection_name, config=MilvusConfig( diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5c3a810fbf..78033379d6 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -127,7 +127,10 @@ class WeaviateVector(BaseVector): ) def delete(self): - self._client.schema.delete_class(self._collection_name) + # check whether the index already exists + schema = self._default_schema(self._collection_name) + if self._client.schema.contains(schema): + self._client.schema.delete_class(self._collection_name) def text_exists(self, id: str) -> bool: collection_name = self._collection_name