fix rerank mode is none

This commit is contained in:
jyong 2024-08-22 15:36:47 +08:00
parent cb70e12827
commit 0724640bbb
4 changed files with 0 additions and 242 deletions

Binary file not shown.

View File

@ -1,118 +0,0 @@
"""Abstract interface for document loader implementations."""
import os
import tempfile
from urllib.parse import urlparse
import requests
from docx import Document as DocxDocument
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class WordExtractorTest(BaseExtractor):
"""Load docx files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str):
"""Initialize with file path."""
self.file_path = file_path
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
r = requests.get(self.file_path)
if r.status_code != 200:
raise ValueError(
f"Check the url of your file; returned status code {r.status_code}"
)
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def __del__(self) -> None:
if hasattr(self, "temp_file"):
self.temp_file.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
from docx import Document as docx_Document
document = docx_Document(self.file_path)
doc_texts = [paragraph.text for paragraph in document.paragraphs]
content = '\n'.join(doc_texts)
return [Document(
page_content=content,
metadata={"source": self.file_path},
)]
@staticmethod
def _is_valid_url(url: str) -> bool:
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder):
image_count = 0
image_paths = []
for rel in doc.part.rels.values():
if "image" in rel.target_ref:
image_count += 1
image_ext = rel.target_ref.split('.')[-1]
image_name = f"image{image_count}.{image_ext}"
image_path = os.path.join(image_folder, image_name)
with open(image_path, "wb") as img_file:
img_file.write(rel.target_part.blob)
image_paths.append(f"![](/api/system/img/{image_name})")
return image_paths
def _table_to_html(self, table):
html = "<table border='1'>"
for row in table.rows:
html += "<tr>"
for cell in row.cells:
html += f"<td>{cell.text}</td>"
html += "</tr>"
html += "</table>"
return html
def parse_docx(self, docx_path, image_folder):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []
image_index = 0
image_paths = self._extract_images_from_docx(doc, image_folder)
for element in doc.element.body:
if element.tag.endswith('p'): # paragraph
paragraph = element.text.strip()
if paragraph:
content.append(paragraph)
elif element.tag.endswith('tbl'): # table
table = doc.tables[image_index]
content.append(self._table_to_html(table))
image_index += 1
# 替换图片占位符
content_with_images = []
for item in content:
if '!' in item and '[]' in item:
item = image_paths.pop(0)
content_with_images.append(item)
return content_with_images

View File

@ -1,63 +0,0 @@
import os
import unittest
from unittest.mock import mock_open, patch
from extensions.storage.local_storage import LocalStorage
class TestLocalStorage(unittest.TestCase):
def setUp(self):
# Configuration for each test
self.app_config = {'root': '/test'}
self.folder = 'test_folder/'
self.storage = LocalStorage(self.app_config, self.folder)
@patch('os.makedirs')
def test_save(self, mock_makedirs):
# Test the save functionality
test_data = b"test data"
with patch('builtins.open', mock_open()) as mocked_file:
self.storage.save('file.txt', test_data)
mocked_file.assert_called_with(os.path.join(os.getcwd(), 'test_folder/file.txt'), "wb")
handle = mocked_file()
handle.write.assert_called_once_with(test_data)
@patch('os.path.exists', return_value=True)
@patch('builtins.open', new_callable=mock_open, read_data=b"test data")
def test_load_once(self, mock_open, mock_exists):
# Test the load_once method
data = self.storage.load_once('file.txt')
self.assertEqual(data, b"test data")
@patch('os.path.exists', return_value=True)
def test_load_stream(self, mock_exists):
# Test the load_stream method
with patch('builtins.open', mock_open(read_data=b"test data")) as mocked_file:
generator = self.storage.load_stream('file.txt')
output = list(generator)
self.assertEqual(output, [b'test data'])
@patch('shutil.copyfile')
@patch('os.path.exists', return_value=True)
def test_download(self, mock_exists, mock_copyfile):
# Test the download method
self.storage.download('file.txt', 'target.txt')
mock_copyfile.assert_called_once_with('test_folder/file.txt', 'target.txt')
@patch('os.path.exists', return_value=True)
def test_exists(self, mock_exists):
# Test the exists method
self.assertTrue(self.storage.exists('file.txt'))
@patch('os.path.exists', return_value=True)
@patch('os.remove')
def test_delete(self, mock_remove, mock_exists):
# Test the delete method
self.storage.delete('file.txt')
mock_remove.assert_called_once_with('test_folder/file.txt')
@patch('os.path.exists', return_value=False)
def test_delete_file_not_found(self, mock_exists):
# Test deleting a file that does not exist
with self.assertRaises(FileNotFoundError):
self.storage.delete('file.txt')

View File

@ -1,61 +0,0 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from extensions import ext_redis
def get_example_filename() -> str:
return 'test_text.txt'
def get_example_file_data() -> bytes:
return b'test_text'
@pytest.fixture
def setup_mock_redis() -> None:
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractOssTest:
def __init__(self):
self.client = None
self.filename = get_example_filename()
self.data = get_example_file_data()
def save(self):
raise NotImplementedError
def load_once(self) -> bytes:
raise NotImplementedError
def load_stream(self) -> Generator:
raise NotImplementedError
def download(self):
raise NotImplementedError
def exists(self):
raise NotImplementedError
def delete(self):
raise NotImplementedError
def run_all_tests(self):
self.save()
self.load_once()
self.load_stream()
self.exists()
self.delete()