Compare commits
6 Commits
main
...
feat/add-r
Author | SHA1 | Date | |
---|---|---|---|
![]() |
3622691f38 | ||
![]() |
52e6f458be | ||
![]() |
703aefbd17 | ||
![]() |
cc84d07765 | ||
![]() |
4ea468b52a | ||
![]() |
796f7d4d29 |
@ -277,7 +277,7 @@ class QdrantVector(BaseVector):
|
||||
from qdrant_client.http import models
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self._group_id),
|
||||
),
|
||||
|
9
api/core/rag/test.py
Normal file
9
api/core/rag/test.py
Normal file
@ -0,0 +1,9 @@
|
||||
from llama_index.core import DocumentSummaryIndex
|
||||
|
||||
doc_summary_index = DocumentSummaryIndex.from_documents(
|
||||
city_docs,
|
||||
llm=chatgpt,
|
||||
transformations=[splitter],
|
||||
response_synthesizer=response_synthesizer,
|
||||
show_progress=True,
|
||||
)
|
64
api/tests/integration_tests/rag/__mock/milvus_function.py
Normal file
64
api/tests/integration_tests/rag/__mock/milvus_function.py
Normal file
@ -0,0 +1,64 @@
|
||||
from ctypes import Union
|
||||
from typing import List
|
||||
|
||||
|
||||
class MockMilvusClass(object):
|
||||
|
||||
@staticmethod
|
||||
def insert() -> List[Union[str, int]]:
|
||||
result = [447829498067199697]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def delete() -> List[Union[str, int]]:
|
||||
result = [447829498067199697]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def search() -> List[dict]:
|
||||
result = [
|
||||
{
|
||||
'id': 447829498067199697,
|
||||
'distance': 0.8776655793190002,
|
||||
'entity': {
|
||||
'page_content': 'Dify is a company that provides a platform for the development of AI models.',
|
||||
'metadata':
|
||||
{
|
||||
'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace',
|
||||
'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319',
|
||||
'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c',
|
||||
'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def query() -> List[dict]:
|
||||
result = [
|
||||
{
|
||||
'id': 447829498067199697,
|
||||
'distance': 0.8776655793190002,
|
||||
'entity': {
|
||||
'page_content': 'Dify is a company that provides a platform for the development of AI models.',
|
||||
'metadata':
|
||||
{
|
||||
'doc_id': '327d1cb8-15ce-4934-bede-936a13c19ace',
|
||||
'doc_hash': '7ee3cf010e606bb768c3bca7b1397ff651fd008ef10e56a646c537d2c8afb319',
|
||||
'document_id': '6c4619dd-2169-4879-b05a-b8937c98c80c',
|
||||
'dataset_id': 'a2f4f4eb-75eb-4432-8c5f-788100533454'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_collection_with_schema():
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def has_collection() -> bool:
|
||||
return True
|
||||
|
58
api/tests/integration_tests/rag/__mock/milvus_mock.py
Normal file
58
api/tests/integration_tests/rag/__mock/milvus_mock.py
Normal file
@ -0,0 +1,58 @@
|
||||
import os
|
||||
from typing import Callable, List, Literal
|
||||
|
||||
import pytest
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from pymilvus import Connections, MilvusClient
|
||||
from pymilvus.orm import utility
|
||||
from qdrant_client import QdrantClient
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
from unstructured.partition.md import partition_md
|
||||
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
|
||||
from tests.integration_tests.rag.__mock.milvus_function import MockMilvusClass
|
||||
from tests.integration_tests.rag.__mock.qdrant_function import MockQdrantClass
|
||||
from tests.integration_tests.rag.__mock.unstructured_function import MockUnstructuredClass
|
||||
|
||||
|
||||
def mock_milvus(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections", "delete", "recreate_collection", "create_payload_index", "upsert", "scroll", "search"]]) -> Callable[[], None]:
|
||||
"""
|
||||
mock unstructured module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "connect" in methods:
|
||||
monkeypatch.setattr(Connections, "connect", MockMilvusClass.delete())
|
||||
if "has_collection" in methods:
|
||||
monkeypatch.setattr(utility, "has_collection", MockMilvusClass.has_collection())
|
||||
if "insert" in methods:
|
||||
monkeypatch.setattr(MilvusClient, "insert", MockMilvusClass.insert())
|
||||
if "query" in methods:
|
||||
monkeypatch.setattr(MilvusClient, "query", MockMilvusClass.query())
|
||||
if "delete" in methods:
|
||||
monkeypatch.setattr(MilvusClient, "delete", MockMilvusClass.delete())
|
||||
if "search" in methods:
|
||||
monkeypatch.setattr(MilvusClient, "search", MockMilvusClass.search())
|
||||
if "create_collection_with_schema" in methods:
|
||||
monkeypatch.setattr(MilvusClient, "create_collection_with_schema", MockMilvusClass.create_collection_with_schema())
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_milvus_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
if MOCK:
|
||||
unpatch = mock_milvus(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
|
68
api/tests/integration_tests/rag/__mock/qdrant_function.py
Normal file
68
api/tests/integration_tests/rag/__mock/qdrant_function.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
class MockQdrantClass(object):
|
||||
|
||||
@staticmethod
|
||||
def get_collections() -> types.CollectionsResponse:
|
||||
collections_response = types.CollectionsResponse(
|
||||
collections=["test"]
|
||||
)
|
||||
return collections_response
|
||||
|
||||
@staticmethod
|
||||
def recreate_collection() -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def create_payload_index() -> types.UpdateResult:
|
||||
update_result = types.UpdateResult(
|
||||
updated=1
|
||||
)
|
||||
return update_result
|
||||
|
||||
@staticmethod
|
||||
def upsert() -> types.UpdateResult:
|
||||
update_result = types.UpdateResult(
|
||||
updated=1
|
||||
)
|
||||
return update_result
|
||||
|
||||
@staticmethod
|
||||
def delete() -> types.UpdateResult:
|
||||
update_result = types.UpdateResult(
|
||||
updated=1
|
||||
)
|
||||
return update_result
|
||||
|
||||
@staticmethod
|
||||
def scroll() -> Tuple[List[types.Record], Optional[types.PointId]]:
|
||||
|
||||
record = types.Record(
|
||||
id='d48632d7-c972-484a-8ed9-262490919c79',
|
||||
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
||||
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
||||
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
|
||||
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
|
||||
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
|
||||
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
|
||||
vector=[0.23333 for _ in range(233)]
|
||||
)
|
||||
return [record], 'd48632d7-c972-484a-8ed9-262490919c79'
|
||||
|
||||
@staticmethod
|
||||
def search() -> List[types.ScoredPoint]:
|
||||
result = types.ScoredPoint(
|
||||
id='d48632d7-c972-484a-8ed9-262490919c79',
|
||||
payload={'group_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
||||
'metadata': {'dataset_id': '06798db6-1f99-489a-b599-dd386a043f2d',
|
||||
'doc_hash': '85197672a2c2b05d2c8690cb7f1eedc78fe5f0ca7b8ae8a301f64eb8d959b436',
|
||||
'doc_id': 'd48632d7-c972-484a-8ed9-262490919c79',
|
||||
'document_id': '1518a57d-9049-426e-99ae-5a6d479175c0'},
|
||||
'page_content': 'Dify is a company that provides a platform for the development of AI models.'},
|
||||
vision=999,
|
||||
vector=[0.23333 for _ in range(233)],
|
||||
score=0.99
|
||||
)
|
||||
return [result]
|
55
api/tests/integration_tests/rag/__mock/qdrant_mock.py
Normal file
55
api/tests/integration_tests/rag/__mock/qdrant_mock.py
Normal file
@ -0,0 +1,55 @@
|
||||
import os
|
||||
from typing import Callable, List, Literal
|
||||
|
||||
import pytest
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from qdrant_client import QdrantClient
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
from unstructured.partition.md import partition_md
|
||||
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
|
||||
from tests.integration_tests.rag.__mock.qdrant_function import MockQdrantClass
|
||||
from tests.integration_tests.rag.__mock.unstructured_function import MockUnstructuredClass
|
||||
|
||||
|
||||
def mock_qdrant(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections", "delete", "recreate_collection", "create_payload_index", "upsert", "scroll", "search"]]) -> Callable[[], None]:
|
||||
"""
|
||||
mock unstructured module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "delete" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "delete", MockQdrantClass.delete())
|
||||
if "get_collections" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "get_collections", MockQdrantClass.get_collections())
|
||||
if "recreate_collection" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "recreate_collection", MockQdrantClass.recreate_collection())
|
||||
if "create_payload_index" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "create_payload_index", MockQdrantClass.create_payload_index())
|
||||
if "upsert" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "upsert", MockQdrantClass.upsert())
|
||||
if "scroll" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "scroll", MockQdrantClass.scroll())
|
||||
if "search" in methods:
|
||||
monkeypatch.setattr(QdrantClient, "search", MockQdrantClass.search())
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_qdrant_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
if MOCK:
|
||||
unpatch = mock_qdrant(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
|
@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from unstructured.documents.elements import Element
|
||||
|
||||
|
||||
class MockUnstructuredClass(object):
|
||||
@staticmethod
|
||||
def partition_md() -> List[Element]:
|
||||
element = Element(
|
||||
category="title",
|
||||
embeddings=[],
|
||||
id="test",
|
||||
metadata={},
|
||||
text="test"
|
||||
)
|
||||
return [element]
|
||||
|
||||
@staticmethod
|
||||
def partition_text() -> List[Element]:
|
||||
element = Element(
|
||||
category="title",
|
||||
embeddings=[],
|
||||
id="test",
|
||||
metadata={},
|
||||
text="test"
|
||||
)
|
||||
return [element]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def chunk_by_title() -> List[Element]:
|
||||
element = Element(
|
||||
category="title",
|
||||
embeddings=[],
|
||||
id="test",
|
||||
metadata={},
|
||||
text="test"
|
||||
)
|
||||
return [element]
|
45
api/tests/integration_tests/rag/__mock/unstructured_mock.py
Normal file
45
api/tests/integration_tests/rag/__mock/unstructured_mock.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import Callable, List, Literal
|
||||
|
||||
import pytest
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from unstructured.chunking import title
|
||||
from unstructured.partition import md, text
|
||||
from tests.integration_tests.rag.__mock.unstructured_function import MockUnstructuredClass
|
||||
|
||||
|
||||
def mock_unstructured(monkeypatch: MonkeyPatch, methods: List[Literal["partition_md", "chunk_by_title"]]) -> Callable[[], None]:
|
||||
"""
|
||||
mock unstructured module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "partition_md" in methods:
|
||||
monkeypatch.setattr(md, "partition_md", MockUnstructuredClass.partition_md())
|
||||
if "partition_text" in methods:
|
||||
monkeypatch.setattr(text, "partition_text", MockUnstructuredClass.partition_text())
|
||||
if "chunk_by_title" in methods:
|
||||
monkeypatch.setattr(title, "chunk_by_title", MockUnstructuredClass.chunk_by_title())
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_unstructured_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
if MOCK:
|
||||
unpatch = mock_unstructured(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
|
39
api/tests/integration_tests/rag/__mock/weaviate_function.py
Normal file
39
api/tests/integration_tests/rag/__mock/weaviate_function.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from qdrant_client.conversions import common_types as types
|
||||
|
||||
|
||||
class MockWeaviateClass(object):
|
||||
|
||||
@staticmethod
|
||||
def contains() -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def add_data_object() -> str:
|
||||
return 'd48632d7-c972-484a-8ed9-262490919c79'
|
||||
|
||||
@staticmethod
|
||||
def delete_class() -> None:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def do() -> dict:
|
||||
record = {
|
||||
'Get': {
|
||||
'Vector_index_a5f66ab4_cc83_4061_85a5_cb775933d52a_Node': [
|
||||
{
|
||||
'_additional': {
|
||||
'distance': 0.10660946,
|
||||
'vector': [0.23333 for _ in range(233)]
|
||||
},
|
||||
'dataset_id': 'a5f66ab4-cc83-4061-85a5-cb775933d52a',
|
||||
'doc_hash': '52c3c8889c34d2d7b50bb04ca4d77081b1b4b625bc69c82294abfbdf7e918c21',
|
||||
'doc_id': 'b3fdec03-99ad-4a7c-a565-94d02dcde05e',
|
||||
'document_id': '71ec7e68-c45a-4d8b-886b-6077730a83ee',
|
||||
'text': '1、你知道孙悟空是从哪里生出来的吗?'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return record
|
||||
|
53
api/tests/integration_tests/rag/__mock/weaviate_mock.py
Normal file
53
api/tests/integration_tests/rag/__mock/weaviate_mock.py
Normal file
@ -0,0 +1,53 @@
|
||||
import os
|
||||
from typing import Callable, List, Literal
|
||||
|
||||
import pytest
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from qdrant_client import QdrantClient
|
||||
from unstructured.chunking.title import chunk_by_title
|
||||
from unstructured.partition.md import partition_md
|
||||
from weaviate.batch import Batch
|
||||
from weaviate.gql.get import GetBuilder
|
||||
from weaviate.schema import Schema
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
|
||||
from tests.integration_tests.rag.__mock.qdrant_function import MockQdrantClass
|
||||
from tests.integration_tests.rag.__mock.weaviate_function import MockWeaviateClass
|
||||
|
||||
|
||||
def mock_weaviate(monkeypatch: MonkeyPatch, methods: List[Literal["get_collections", "delete", "recreate_collection", "create_payload_index", "upsert", "scroll", "search"]]) -> Callable[[], None]:
|
||||
"""
|
||||
mock unstructured module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "delete" in methods:
|
||||
monkeypatch.setattr(Schema, "delete", MockWeaviateClass.delete_class())
|
||||
if "contains" in methods:
|
||||
monkeypatch.setattr(Schema, "contains", MockWeaviateClass.contains())
|
||||
if "add_data_object" in methods:
|
||||
monkeypatch.setattr(Batch, "add_data_object", MockWeaviateClass.add_data_object())
|
||||
if "do" in methods:
|
||||
monkeypatch.setattr(GetBuilder, "do", MockWeaviateClass.do())
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_weaviate_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
if MOCK:
|
||||
unpatch = mock_weaviate(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
|
@ -0,0 +1,20 @@
|
||||
import os
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_xlsx():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the xlsx file
|
||||
test_file_path = os.path.join(assets_dir, 'test.xlsx')
|
||||
|
||||
extractor = ExcelExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
22
api/tests/integration_tests/rag/extractor/test_test_csv.py
Normal file
22
api/tests/integration_tests/rag/extractor/test_test_csv.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_csv():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the txt file
|
||||
test_file_path = os.path.join(assets_dir, 'test.csv')
|
||||
|
||||
extractor = CSVExtractor(test_file_path, autodetect_encoding=True)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
25
api/tests/integration_tests/rag/extractor/test_test_docx.py
Normal file
25
api/tests/integration_tests/rag/extractor/test_test_docx.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_docx():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.docx')
|
||||
|
||||
extractor = WordExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
22
api/tests/integration_tests/rag/extractor/test_test_html.py
Normal file
22
api/tests/integration_tests/rag/extractor/test_test_html.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_html():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.html')
|
||||
|
||||
extractor = HtmlExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_markdown():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.md')
|
||||
|
||||
extractor = MarkdownExtractor(test_file_path, autodetect_encoding=True)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
24
api/tests/integration_tests/rag/extractor/test_test_pdf.py
Normal file
24
api/tests/integration_tests/rag/extractor/test_test_pdf.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_pdf():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.pdf')
|
||||
|
||||
extractor = PdfExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
21
api/tests/integration_tests/rag/extractor/test_test_text.py
Normal file
21
api/tests/integration_tests/rag/extractor/test_test_text.py
Normal file
@ -0,0 +1,21 @@
|
||||
import os
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_text():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the txt file
|
||||
test_file_path = os.path.join(assets_dir, 'test.txt')
|
||||
|
||||
extractor = TextExtractor(test_file_path, autodetect_encoding=True)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
@ -0,0 +1,27 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_unstructured_docx():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the docx file
|
||||
test_file_path = os.path.join(assets_dir, 'test.docx')
|
||||
|
||||
unstructured_api_url = os.getenv('UNSTRUCTURED_API_URL')
|
||||
|
||||
extractor = UnstructuredWordExtractor(test_file_path, unstructured_api_url)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
@ -0,0 +1,25 @@
|
||||
import os
|
||||
import pytest
|
||||
from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.rag.__mock.unstructured_mock import setup_unstructured_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
|
||||
def test_extract_unstructured_markdown(setup_unstructured_mock):
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.md')
|
||||
|
||||
unstructured_api_url = os.getenv('UNSTRUCTURED_API_URL')
|
||||
|
||||
extractor = UnstructuredMarkdownExtractor(test_file_path, unstructured_api_url)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
@ -0,0 +1,102 @@
|
||||
"""test paragraph index processor."""
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
import pytest
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from libs import helper
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_unstructured_mock', [['partition_md', 'chunk_by_title']], indirect=True)
|
||||
def extract():
|
||||
|
||||
index_processor = IndexProcessorFactory('text_model').init_index_processor()
|
||||
|
||||
# extract
|
||||
file_detail = UploadFile(
|
||||
tenant_id='test',
|
||||
storage_type='local',
|
||||
key='test.txt',
|
||||
name='test.txt',
|
||||
size=1024,
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
created_by='test',
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=True,
|
||||
used_by='d48632d7-c972-484a-8ed9-262490919c79',
|
||||
used_at=datetime.datetime.utcnow()
|
||||
)
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="upload_file",
|
||||
upload_file=file_detail,
|
||||
document_model='text_model'
|
||||
)
|
||||
|
||||
text_docs = ExtractProcessor.extract(extract_setting=extract_setting,
|
||||
is_automatic=True)
|
||||
assert isinstance(text_docs, list)
|
||||
for text_doc in text_docs:
|
||||
assert isinstance(text_doc, Document)
|
||||
|
||||
# transform
|
||||
process_rule = {
|
||||
'pre_processing_rules': [
|
||||
{'id': 'remove_extra_spaces', 'enabled': True},
|
||||
{'id': 'remove_urls_emails', 'enabled': False}
|
||||
],
|
||||
'segmentation': {
|
||||
'delimiter': '\n',
|
||||
'max_tokens': 500,
|
||||
'chunk_overlap': 50
|
||||
}
|
||||
}
|
||||
documents = index_processor.transform(text_docs, embedding_model_instance=None,
|
||||
process_rule=process_rule)
|
||||
for document in documents:
|
||||
assert isinstance(document, Document)
|
||||
|
||||
# load
|
||||
vector = Vector(dataset)
|
||||
vector.create(documents)
|
||||
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
else:
|
||||
vector.delete()
|
||||
if with_keywords:
|
||||
keyword = Keyword(dataset)
|
||||
if node_ids:
|
||||
keyword.delete_by_ids(node_ids)
|
||||
else:
|
||||
keyword.delete()
|
||||
|
||||
|
||||
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
|
||||
score_threshold: float, reranking_model: dict) -> list[Document]:
|
||||
# Set search parameters.
|
||||
results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id, query=query,
|
||||
top_k=top_k, score_threshold=score_threshold,
|
||||
reranking_model=reranking_model)
|
||||
# Organize results.
|
||||
docs = []
|
||||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata['score'] = result.score
|
||||
if result.score > score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
32
api/tests/integration_tests/rag/vector/test_qdrant.py
Normal file
32
api/tests/integration_tests/rag/vector/test_qdrant.py
Normal file
@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVector, QdrantConfig
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_qdrant_mock',
|
||||
[['get_collections', 'recreate_collection',
|
||||
'create_payload_index', 'upsert', 'scroll',
|
||||
'search']],
|
||||
indirect=True)
|
||||
def test_qdrant(setup_qdrant_mock):
|
||||
document = Document(page_content="test", metadata={"test": "test"})
|
||||
qdrant_vector = QdrantVector(
|
||||
collection_name="test",
|
||||
group_id='test',
|
||||
config=QdrantConfig(
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="test",
|
||||
root_path="test",
|
||||
timeout=10
|
||||
)
|
||||
)
|
||||
# create
|
||||
qdrant_vector.create(texts=[document], embeddings=[[0.23333 for _ in range(233)]])
|
||||
# search
|
||||
result = qdrant_vector.search_by_vector(query_vector=[0.23333 for _ in range(233)])
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
||||
# delete
|
||||
qdrant_vector.delete()
|
||||
|
20
api/tests/unit_tests/rag/extractor/test_extract_xlsx.py
Normal file
20
api/tests/unit_tests/rag/extractor/test_extract_xlsx.py
Normal file
@ -0,0 +1,20 @@
|
||||
import os
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_xlsx():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the xlsx file
|
||||
test_file_path = os.path.join(assets_dir, 'test.xlsx')
|
||||
|
||||
extractor = ExcelExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
22
api/tests/unit_tests/rag/extractor/test_test_csv.py
Normal file
22
api/tests/unit_tests/rag/extractor/test_test_csv.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_csv():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the txt file
|
||||
test_file_path = os.path.join(assets_dir, 'test.csv')
|
||||
|
||||
extractor = CSVExtractor(test_file_path, autodetect_encoding=True)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
25
api/tests/unit_tests/rag/extractor/test_test_docx.py
Normal file
25
api/tests/unit_tests/rag/extractor/test_test_docx.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_docx():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.docx')
|
||||
|
||||
extractor = WordExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
22
api/tests/unit_tests/rag/extractor/test_test_html.py
Normal file
22
api/tests/unit_tests/rag/extractor/test_test_html.py
Normal file
@ -0,0 +1,22 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_html():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.html')
|
||||
|
||||
extractor = HtmlExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
23
api/tests/unit_tests/rag/extractor/test_test_markdown.py
Normal file
23
api/tests/unit_tests/rag/extractor/test_test_markdown.py
Normal file
@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_markdown():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.md')
|
||||
|
||||
extractor = MarkdownExtractor(test_file_path, autodetect_encoding=True)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
24
api/tests/unit_tests/rag/extractor/test_test_pdf.py
Normal file
24
api/tests/unit_tests/rag/extractor/test_test_pdf.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.pdf_extractor import PdfExtractor
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_pdf():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the markdown file
|
||||
test_file_path = os.path.join(assets_dir, 'test.pdf')
|
||||
|
||||
extractor = PdfExtractor(test_file_path)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
21
api/tests/unit_tests/rag/extractor/test_test_text.py
Normal file
21
api/tests/unit_tests/rag/extractor/test_test_text.py
Normal file
@ -0,0 +1,21 @@
|
||||
import os
|
||||
from core.rag.extractor.text_extractor import TextExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
def test_extract_text():
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the txt file
|
||||
test_file_path = os.path.join(assets_dir, 'test.txt')
|
||||
|
||||
extractor = TextExtractor(test_file_path, autodetect_encoding=True)
|
||||
result = extractor.extract()
|
||||
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, Document)
|
Loading…
Reference in New Issue
Block a user