From 703aefbd17033a981074155417d00a982490a832 Mon Sep 17 00:00:00 2001 From: jyong Date: Wed, 6 Mar 2024 13:50:26 +0800 Subject: [PATCH] add rag test --- .../rag/__mock/milvus_function.py | 103 ++++++++-------- .../rag/__mock/milvus_mock.py | 18 +-- .../test_paragraph_index_processor.py | 113 ++++++++++++++++++ ...{test_vector_factory.py => test_qdrant.py} | 0 4 files changed, 169 insertions(+), 65 deletions(-) create mode 100644 api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py rename api/tests/integration_tests/rag/vector/{test_vector_factory.py => test_qdrant.py} (100%) diff --git a/api/tests/integration_tests/rag/__mock/milvus_function.py b/api/tests/integration_tests/rag/__mock/milvus_function.py index ac8f88518e..1cdb9cf88a 100644 --- a/api/tests/integration_tests/rag/__mock/milvus_function.py +++ b/api/tests/integration_tests/rag/__mock/milvus_function.py @@ -1,73 +1,64 @@ from ctypes import Union -from typing import List, Optional, Tuple -from qdrant_client.conversions import common_types as types +from typing import List class MockMilvusClass(object): - @staticmethod - def get_collections() -> types.CollectionsResponse: - collections_response = types.CollectionsResponse( - collections=["test"] - ) - return collections_response - - @staticmethod - def recreate_collection() -> bool: - return True - - @staticmethod - def create_payload_index() -> types.UpdateResult: - update_result = types.UpdateResult( - updated=1 - ) - return update_result - - @staticmethod - def upsert() -> types.UpdateResult: - update_result = types.UpdateResult( - updated=1 - ) - return update_result - @staticmethod def insert() -> List[Union[str, int]]: - result = ['d48632d7-c972-484a-8ed9-262490919c79'] + result = [447829498067199697] return result @staticmethod def delete() -> List[Union[str, int]]: - result = ['d48632d7-c972-484a-8ed9-262490919c79'] + result = [447829498067199697] return result @staticmethod - def scroll() -> Tuple[List[types.Record], Optional[types.PointId]]: - - record = types.Record( - id='d48632d7-c972-484a-8ed9-262490919c79', - payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d', - 'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d', - 'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436', - 'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79', - 'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'}, - 'page_content': 'Dify is a company that provides a platform for the development of AI models.'}, - vector=[0.23333 for _ in range(233)] - ) - return [record], 'd48632d7-c972-484a-8ed9-262490919c79' + def search() -> List[dict]: + result = [ + { + 'id': 447829498067199697, + 'distance': 0.8776655793190002, + 'entity': { + 'page_content': 'Dify is a company that provides a platform for the development of AI models.', + 'metadata': + { + 'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace', + 'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319', + 'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c', + 'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454' + } + } + } + ] + return result @staticmethod - def search() -> List[types.ScoredPoint]: - result = types.ScoredPoint( - id='d48632d7-c972-484a-8ed9-262490919c79', - payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d', - 'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d', - 'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436', - 'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79', - 'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'}, - 'page_content': 'Dify is a company that provides a platform for the development of AI models.'}, - vision=999, - vector=[0.23333 for _ in range(233)], - score=0.99 - ) - return [result] + def query() -> List[dict]: + result = [ + { + 'id': 447829498067199697, + 'distance': 0.8776655793190002, + 'entity': { + 'page_content': 'Dify is a company that provides a platform for the development of AI models.', + 'metadata': + { + 'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace', + 'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319', + 'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c', + 'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454' + } + } + } + ] + return result + + @staticmethod + def create_collection_with_schema(): + pass + + @staticmethod + def has_collection() -> bool: + return True diff --git a/api/tests/integration_tests/rag/__mock/milvus_mock.py b/api/tests/integration_tests/rag/__mock/milvus_mock.py index db564795ee..b5967da888 100644 --- a/api/tests/integration_tests/rag/__mock/milvus_mock.py +++ b/api/tests/integration_tests/rag/__mock/milvus_mock.py @@ -27,18 +27,18 @@ def mock_milvus(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections if "connect" in methods: monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete()) - if "get_collections" in methods: - monkeypatch.setattr(utility, "has_collection", MockMilvusClass.get_collections()) + if "has_collection" in methods: + monkeypatch.setattr(utility, "has_collection", MockMilvusClass.has_collection()) if "insert" in methods: monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert()) - if "create_payload_index" in methods: - monkeypatch.setattr(QdrantClient, "create_payload_index", MockMilvusClass.create_payload_index()) - if "upsert" in methods: - monkeypatch.setattr(QdrantClient, "upsert", MockMilvusClass.upsert()) - if "scroll" in methods: - monkeypatch.setattr(QdrantClient, "scroll", MockMilvusClass.scroll()) + if "query" in methods: + monkeypatch.setattr(MilvusClient, "query", MockMilvusClass.query()) + if "delete" in methods: + monkeypatch.setattr(MilvusClient, "delete", MockMilvusClass.delete()) if "search" in methods: - monkeypatch.setattr(QdrantClient, "search", MockMilvusClass.search()) + monkeypatch.setattr(MilvusClient, "search", MockMilvusClass.search()) + if "create_collection_with_schema" in methods: + monkeypatch.setattr(MilvusClient, "create_collection_with_schema", MockMilvusClass.create_collection_with_schema()) return unpatch diff --git a/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py b/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py new file mode 100644 index 0000000000..35781f4b10 --- /dev/null +++ b/api/tests/integration_tests/rag/index_processor/test_paragraph_index_processor.py @@ -0,0 +1,113 @@ +"""test paragraph index processor.""" +import datetime +import uuid +from typing import Optional + +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.keyword.keyword_factory import Keyword +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import Document +from libs import helper +from models.dataset import Dataset +from models.model import UploadFile + + +class ParagraphIndexProcessor(BaseIndexProcessor): + + def extract(self) -> list[Document]: + file_detail = UploadFile( + tenant_id='test', + storage_type='local', + key='test.txt', + name='test.txt', + size=1024, + extension='txt', + mime_type='text/plain', + created_by='test', + created_at=datetime.datetime.utcnow(), + used=True, + used_by='d48632d7-c972-484a-8ed9-262490919c79', + used_at=datetime.datetime.utcnow() + ) + extract_setting = ExtractSetting( + datasource_type="upload_file", + upload_file=file_detail, + document_model='text_model' + ) + + text_docs = ExtractProcessor.extract(extract_setting=extract_setting, + is_automatic=False) + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + # Split the text documents into nodes. + splitter = self._get_splitter(processing_rule=kwargs.get('process_rule'), + embedding_model_instance=kwargs.get('embedding_model_instance')) + all_documents = [] + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, kwargs.get('process_rule')) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata['doc_id'] = doc_id + document_node.metadata['doc_hash'] = hash + # delete Spliter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:] + else: + page_content = page_content + document_node.page_content = page_content + split_documents.append(document_node) + all_documents.extend(split_documents) + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + vector.create(documents) + if with_keywords: + keyword = Keyword(dataset) + keyword.create(documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + if dataset.indexing_technique == 'high_quality': + vector = Vector(dataset) + if node_ids: + vector.delete_by_ids(node_ids) + else: + vector.delete() + if with_keywords: + keyword = Keyword(dataset) + if node_ids: + keyword.delete_by_ids(node_ids) + else: + keyword.delete() + + def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int, + score_threshold: float, reranking_model: dict) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query, + top_k=top_k, score_threshold=score_threshold, + reranking_model=reranking_model) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata['score'] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs diff --git a/api/tests/integration_tests/rag/vector/test_vector_factory.py b/api/tests/integration_tests/rag/vector/test_qdrant.py similarity index 100% rename from api/tests/integration_tests/rag/vector/test_vector_factory.py rename to api/tests/integration_tests/rag/vector/test_qdrant.py