fix rerank mode is none
This commit is contained in:
parent
cb70e12827
commit
0724640bbb
Binary file not shown.
@ -1,118 +0,0 @@
|
|||||||
"""Abstract interface for document loader implementations."""
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from docx import Document as DocxDocument
|
|
||||||
|
|
||||||
from core.rag.extractor.extractor_base import BaseExtractor
|
|
||||||
from core.rag.models.document import Document
|
|
||||||
|
|
||||||
|
|
||||||
class WordExtractorTest(BaseExtractor):
|
|
||||||
"""Load docx files.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file to load.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, file_path: str):
|
|
||||||
"""Initialize with file path."""
|
|
||||||
self.file_path = file_path
|
|
||||||
if "~" in self.file_path:
|
|
||||||
self.file_path = os.path.expanduser(self.file_path)
|
|
||||||
|
|
||||||
# If the file is a web path, download it to a temporary file, and use that
|
|
||||||
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
|
|
||||||
r = requests.get(self.file_path)
|
|
||||||
|
|
||||||
if r.status_code != 200:
|
|
||||||
raise ValueError(
|
|
||||||
f"Check the url of your file; returned status code {r.status_code}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.web_path = self.file_path
|
|
||||||
self.temp_file = tempfile.NamedTemporaryFile()
|
|
||||||
self.temp_file.write(r.content)
|
|
||||||
self.file_path = self.temp_file.name
|
|
||||||
elif not os.path.isfile(self.file_path):
|
|
||||||
raise ValueError(f"File path {self.file_path} is not a valid file or url")
|
|
||||||
|
|
||||||
def __del__(self) -> None:
|
|
||||||
if hasattr(self, "temp_file"):
|
|
||||||
self.temp_file.close()
|
|
||||||
|
|
||||||
def extract(self) -> list[Document]:
|
|
||||||
"""Load given path as single page."""
|
|
||||||
from docx import Document as docx_Document
|
|
||||||
|
|
||||||
document = docx_Document(self.file_path)
|
|
||||||
doc_texts = [paragraph.text for paragraph in document.paragraphs]
|
|
||||||
content = '\n'.join(doc_texts)
|
|
||||||
|
|
||||||
return [Document(
|
|
||||||
page_content=content,
|
|
||||||
metadata={"source": self.file_path},
|
|
||||||
)]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_valid_url(url: str) -> bool:
|
|
||||||
"""Check if the url is valid."""
|
|
||||||
parsed = urlparse(url)
|
|
||||||
return bool(parsed.netloc) and bool(parsed.scheme)
|
|
||||||
|
|
||||||
def _extract_images_from_docx(self, doc, image_folder):
|
|
||||||
image_count = 0
|
|
||||||
image_paths = []
|
|
||||||
|
|
||||||
for rel in doc.part.rels.values():
|
|
||||||
if "image" in rel.target_ref:
|
|
||||||
image_count += 1
|
|
||||||
image_ext = rel.target_ref.split('.')[-1]
|
|
||||||
image_name = f"image{image_count}.{image_ext}"
|
|
||||||
image_path = os.path.join(image_folder, image_name)
|
|
||||||
with open(image_path, "wb") as img_file:
|
|
||||||
img_file.write(rel.target_part.blob)
|
|
||||||
image_paths.append(f"")
|
|
||||||
|
|
||||||
return image_paths
|
|
||||||
|
|
||||||
def _table_to_html(self, table):
|
|
||||||
html = "<table border='1'>"
|
|
||||||
for row in table.rows:
|
|
||||||
html += "<tr>"
|
|
||||||
for cell in row.cells:
|
|
||||||
html += f"<td>{cell.text}</td>"
|
|
||||||
html += "</tr>"
|
|
||||||
html += "</table>"
|
|
||||||
return html
|
|
||||||
|
|
||||||
def parse_docx(self, docx_path, image_folder):
|
|
||||||
doc = DocxDocument(docx_path)
|
|
||||||
os.makedirs(image_folder, exist_ok=True)
|
|
||||||
|
|
||||||
content = []
|
|
||||||
|
|
||||||
image_index = 0
|
|
||||||
image_paths = self._extract_images_from_docx(doc, image_folder)
|
|
||||||
|
|
||||||
for element in doc.element.body:
|
|
||||||
if element.tag.endswith('p'): # paragraph
|
|
||||||
paragraph = element.text.strip()
|
|
||||||
if paragraph:
|
|
||||||
content.append(paragraph)
|
|
||||||
elif element.tag.endswith('tbl'): # table
|
|
||||||
table = doc.tables[image_index]
|
|
||||||
content.append(self._table_to_html(table))
|
|
||||||
image_index += 1
|
|
||||||
|
|
||||||
# 替换图片占位符
|
|
||||||
content_with_images = []
|
|
||||||
for item in content:
|
|
||||||
if '!' in item and '[]' in item:
|
|
||||||
item = image_paths.pop(0)
|
|
||||||
content_with_images.append(item)
|
|
||||||
|
|
||||||
return content_with_images
|
|
@ -1,63 +0,0 @@
|
|||||||
import os
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import mock_open, patch
|
|
||||||
|
|
||||||
from extensions.storage.local_storage import LocalStorage
|
|
||||||
|
|
||||||
|
|
||||||
class TestLocalStorage(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
# Configuration for each test
|
|
||||||
self.app_config = {'root': '/test'}
|
|
||||||
self.folder = 'test_folder/'
|
|
||||||
self.storage = LocalStorage(self.app_config, self.folder)
|
|
||||||
|
|
||||||
@patch('os.makedirs')
|
|
||||||
def test_save(self, mock_makedirs):
|
|
||||||
# Test the save functionality
|
|
||||||
test_data = b"test data"
|
|
||||||
with patch('builtins.open', mock_open()) as mocked_file:
|
|
||||||
self.storage.save('file.txt', test_data)
|
|
||||||
mocked_file.assert_called_with(os.path.join(os.getcwd(), 'test_folder/file.txt'), "wb")
|
|
||||||
handle = mocked_file()
|
|
||||||
handle.write.assert_called_once_with(test_data)
|
|
||||||
|
|
||||||
@patch('os.path.exists', return_value=True)
|
|
||||||
@patch('builtins.open', new_callable=mock_open, read_data=b"test data")
|
|
||||||
def test_load_once(self, mock_open, mock_exists):
|
|
||||||
# Test the load_once method
|
|
||||||
data = self.storage.load_once('file.txt')
|
|
||||||
self.assertEqual(data, b"test data")
|
|
||||||
|
|
||||||
@patch('os.path.exists', return_value=True)
|
|
||||||
def test_load_stream(self, mock_exists):
|
|
||||||
# Test the load_stream method
|
|
||||||
with patch('builtins.open', mock_open(read_data=b"test data")) as mocked_file:
|
|
||||||
generator = self.storage.load_stream('file.txt')
|
|
||||||
output = list(generator)
|
|
||||||
self.assertEqual(output, [b'test data'])
|
|
||||||
|
|
||||||
@patch('shutil.copyfile')
|
|
||||||
@patch('os.path.exists', return_value=True)
|
|
||||||
def test_download(self, mock_exists, mock_copyfile):
|
|
||||||
# Test the download method
|
|
||||||
self.storage.download('file.txt', 'target.txt')
|
|
||||||
mock_copyfile.assert_called_once_with('test_folder/file.txt', 'target.txt')
|
|
||||||
|
|
||||||
@patch('os.path.exists', return_value=True)
|
|
||||||
def test_exists(self, mock_exists):
|
|
||||||
# Test the exists method
|
|
||||||
self.assertTrue(self.storage.exists('file.txt'))
|
|
||||||
|
|
||||||
@patch('os.path.exists', return_value=True)
|
|
||||||
@patch('os.remove')
|
|
||||||
def test_delete(self, mock_remove, mock_exists):
|
|
||||||
# Test the delete method
|
|
||||||
self.storage.delete('file.txt')
|
|
||||||
mock_remove.assert_called_once_with('test_folder/file.txt')
|
|
||||||
|
|
||||||
@patch('os.path.exists', return_value=False)
|
|
||||||
def test_delete_file_not_found(self, mock_exists):
|
|
||||||
# Test deleting a file that does not exist
|
|
||||||
with self.assertRaises(FileNotFoundError):
|
|
||||||
self.storage.delete('file.txt')
|
|
@ -1,61 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from extensions import ext_redis
|
|
||||||
|
|
||||||
|
|
||||||
def get_example_filename() -> str:
|
|
||||||
return 'test_text.txt'
|
|
||||||
|
|
||||||
|
|
||||||
def get_example_file_data() -> bytes:
|
|
||||||
return b'test_text'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def setup_mock_redis() -> None:
|
|
||||||
# get
|
|
||||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
|
||||||
|
|
||||||
# set
|
|
||||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
|
||||||
|
|
||||||
# lock
|
|
||||||
mock_redis_lock = MagicMock()
|
|
||||||
mock_redis_lock.__enter__ = MagicMock()
|
|
||||||
mock_redis_lock.__exit__ = MagicMock()
|
|
||||||
ext_redis.redis_client.lock = mock_redis_lock
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractOssTest:
|
|
||||||
def __init__(self):
|
|
||||||
self.client = None
|
|
||||||
self.filename = get_example_filename()
|
|
||||||
self.data = get_example_file_data()
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def load_once(self) -> bytes:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def load_stream(self) -> Generator:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def download(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def exists(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def delete(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def run_all_tests(self):
|
|
||||||
self.save()
|
|
||||||
self.load_once()
|
|
||||||
self.load_stream()
|
|
||||||
self.exists()
|
|
||||||
self.delete()
|
|
Loading…
Reference in New Issue
Block a user