fix rerank mode is none

This commit is contained in:
jyong 2024-08-22 15:33:43 +08:00
parent 067b956b2c
commit cb70e12827
9 changed files with 366 additions and 25 deletions

BIN
api/celerybeat-schedule.db Normal file

Binary file not shown.

View File

@ -43,12 +43,10 @@ class OAuthDataSource(Resource):
if not internal_secret:
return {'error': 'Internal secret is not set'},
oauth_provider.save_internal_access_token(internal_secret)
return { 'data': '' }
return {'data': ''}
else:
auth_url = oauth_provider.get_authorization_url()
return { 'data': auth_url }, 200
return {'data': auth_url}, 200
class OAuthDataSourceCallback(Resource):
@ -68,7 +66,7 @@ class OAuthDataSourceCallback(Resource):
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
else:
return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource):
def get(self, provider: str):

View File

@ -0,0 +1,118 @@
"""Abstract interface for document loader implementations."""
import os
import tempfile
from urllib.parse import urlparse
import requests
from docx import Document as DocxDocument
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
class WordExtractorTest(BaseExtractor):
"""Load docx files.
Args:
file_path: Path to the file to load.
"""
def __init__(self, file_path: str):
"""Initialize with file path."""
self.file_path = file_path
if "~" in self.file_path:
self.file_path = os.path.expanduser(self.file_path)
# If the file is a web path, download it to a temporary file, and use that
if not os.path.isfile(self.file_path) and self._is_valid_url(self.file_path):
r = requests.get(self.file_path)
if r.status_code != 200:
raise ValueError(
f"Check the url of your file; returned status code {r.status_code}"
)
self.web_path = self.file_path
self.temp_file = tempfile.NamedTemporaryFile()
self.temp_file.write(r.content)
self.file_path = self.temp_file.name
elif not os.path.isfile(self.file_path):
raise ValueError(f"File path {self.file_path} is not a valid file or url")
def __del__(self) -> None:
if hasattr(self, "temp_file"):
self.temp_file.close()
def extract(self) -> list[Document]:
"""Load given path as single page."""
from docx import Document as docx_Document
document = docx_Document(self.file_path)
doc_texts = [paragraph.text for paragraph in document.paragraphs]
content = '\n'.join(doc_texts)
return [Document(
page_content=content,
metadata={"source": self.file_path},
)]
@staticmethod
def _is_valid_url(url: str) -> bool:
"""Check if the url is valid."""
parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder):
image_count = 0
image_paths = []
for rel in doc.part.rels.values():
if "image" in rel.target_ref:
image_count += 1
image_ext = rel.target_ref.split('.')[-1]
image_name = f"image{image_count}.{image_ext}"
image_path = os.path.join(image_folder, image_name)
with open(image_path, "wb") as img_file:
img_file.write(rel.target_part.blob)
image_paths.append(f"![](/api/system/img/{image_name})")
return image_paths
def _table_to_html(self, table):
html = "<table border='1'>"
for row in table.rows:
html += "<tr>"
for cell in row.cells:
html += f"<td>{cell.text}</td>"
html += "</tr>"
html += "</table>"
return html
def parse_docx(self, docx_path, image_folder):
doc = DocxDocument(docx_path)
os.makedirs(image_folder, exist_ok=True)
content = []
image_index = 0
image_paths = self._extract_images_from_docx(doc, image_folder)
for element in doc.element.body:
if element.tag.endswith('p'): # paragraph
paragraph = element.text.strip()
if paragraph:
content.append(paragraph)
elif element.tag.endswith('tbl'): # table
table = doc.tables[image_index]
content.append(self._table_to_html(table))
image_index += 1
# 替换图片占位符
content_with_images = []
for item in content:
if '!' in item and '[]' in item:
item = image_paths.pop(0)
content_with_images.append(item)
return content_with_images

View File

@ -695,6 +695,7 @@ class ExternalApiTemplates(db.Model):
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.String(255), nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
settings = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
@ -709,6 +710,7 @@ class ExternalApiTemplates(db.Model):
'id': self.id,
'tenant_id': self.tenant_id,
'name': self.name,
'description': self.description,
'settings': self.settings_dict,
'created_by': self.created_by,
'created_at': self.created_at,

View File

@ -0,0 +1,90 @@
import datetime
import time
import click
from sqlalchemy import func
from werkzeug.exceptions import NotFound
import app
from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, DatasetQuery, Document
@app.celery.task(queue='dataset')
def clean_unused_message_task():
click.echo(click.style('Start clean unused messages .', fg='green'))
clean_days = int(dify_config.CLEAN_DAY_SETTING)
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
page = 1
while True:
try:
# Subquery for counting new documents
document_subquery_new = db.session.query(
Document.dataset_id,
func.count(Document.id).label('document_count')
).filter(
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False,
Document.updated_at > thirty_days_ago
).group_by(Document.dataset_id).subquery()
# Subquery for counting old documents
document_subquery_old = db.session.query(
Document.dataset_id,
func.count(Document.id).label('document_count')
).filter(
Document.indexing_status == 'completed',
Document.enabled == True,
Document.archived == False,
Document.updated_at < thirty_days_ago
).group_by(Document.dataset_id).subquery()
# Main query with join and filter
datasets = (db.session.query(Dataset)
.outerjoin(
document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id
).outerjoin(
document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id
).filter(
Dataset.created_at < thirty_days_ago,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0
).order_by(
Dataset.created_at.desc()
).paginate(page=page, per_page=50))
except NotFound:
break
if datasets.items is None or len(datasets.items) == 0:
break
page += 1
for dataset in datasets:
dataset_query = db.session.query(DatasetQuery).filter(
DatasetQuery.created_at > thirty_days_ago,
DatasetQuery.dataset_id == dataset.id
).all()
if not dataset_query or len(dataset_query) == 0:
try:
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
# update document
update_params = {
Document.enabled: False
}
Document.query.filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit()
click.echo(click.style('Cleaned unused dataset {} from db success!'.format(dataset.id),
fg='green'))
except Exception as e:
click.echo(
click.style('clean dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
end_at = time.perf_counter()
click.echo(click.style('Cleaned unused dataset from db success latency: {}'.format(end_at - start_at), fg='green'))

View File

@ -48,29 +48,14 @@ class ExternalDatasetService:
return api_templates.items, api_templates.total
@classmethod
def validate_api_list(cls, api_settings: list[dict]):
def validate_api_list(cls, api_settings: dict):
if not api_settings:
raise ValueError('api list is empty')
for api_settings_dict in api_settings:
if not api_settings_dict.get('method'):
raise ValueError('api name is required')
if 'endpoint' not in api_settings and not api_settings['endpoint']:
raise ValueError('endpoint is required')
if 'api_key' not in api_settings and not api_settings['api_key']:
raise ValueError('api_key is required')
if not api_settings_dict.get('url'):
raise ValueError('api url is required')
if api_settings_dict.get('authorization'):
if not api_settings_dict.get('authorization').get('type'):
raise ValueError('authorization type is required')
if api_settings_dict.get('authorization').get('type') == 'bearer':
if not api_settings_dict.get('authorization').get('api_key'):
raise ValueError('authorization token is required')
if api_settings_dict.get('authorization').get('type') == 'custom':
if not api_settings_dict.get('authorization').get('header'):
raise ValueError('authorization header is required')
if api_settings_dict.get('method') in ['create', 'update']:
if not api_settings_dict.get('callback_setting'):
raise ValueError('callback_setting is required for create and update method')
@staticmethod
def create_api_template(tenant_id: str, user_id: str, args: dict) -> ExternalApiTemplates:

View File

@ -0,0 +1,63 @@
import os
import unittest
from unittest.mock import mock_open, patch
from extensions.storage.local_storage import LocalStorage
class TestLocalStorage(unittest.TestCase):
def setUp(self):
# Configuration for each test
self.app_config = {'root': '/test'}
self.folder = 'test_folder/'
self.storage = LocalStorage(self.app_config, self.folder)
@patch('os.makedirs')
def test_save(self, mock_makedirs):
# Test the save functionality
test_data = b"test data"
with patch('builtins.open', mock_open()) as mocked_file:
self.storage.save('file.txt', test_data)
mocked_file.assert_called_with(os.path.join(os.getcwd(), 'test_folder/file.txt'), "wb")
handle = mocked_file()
handle.write.assert_called_once_with(test_data)
@patch('os.path.exists', return_value=True)
@patch('builtins.open', new_callable=mock_open, read_data=b"test data")
def test_load_once(self, mock_open, mock_exists):
# Test the load_once method
data = self.storage.load_once('file.txt')
self.assertEqual(data, b"test data")
@patch('os.path.exists', return_value=True)
def test_load_stream(self, mock_exists):
# Test the load_stream method
with patch('builtins.open', mock_open(read_data=b"test data")) as mocked_file:
generator = self.storage.load_stream('file.txt')
output = list(generator)
self.assertEqual(output, [b'test data'])
@patch('shutil.copyfile')
@patch('os.path.exists', return_value=True)
def test_download(self, mock_exists, mock_copyfile):
# Test the download method
self.storage.download('file.txt', 'target.txt')
mock_copyfile.assert_called_once_with('test_folder/file.txt', 'target.txt')
@patch('os.path.exists', return_value=True)
def test_exists(self, mock_exists):
# Test the exists method
self.assertTrue(self.storage.exists('file.txt'))
@patch('os.path.exists', return_value=True)
@patch('os.remove')
def test_delete(self, mock_remove, mock_exists):
# Test the delete method
self.storage.delete('file.txt')
mock_remove.assert_called_once_with('test_folder/file.txt')
@patch('os.path.exists', return_value=False)
def test_delete_file_not_found(self, mock_exists):
# Test deleting a file that does not exist
with self.assertRaises(FileNotFoundError):
self.storage.delete('file.txt')

View File

@ -0,0 +1,61 @@
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
from extensions import ext_redis
def get_example_filename() -> str:
return 'test_text.txt'
def get_example_file_data() -> bytes:
return b'test_text'
@pytest.fixture
def setup_mock_redis() -> None:
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractOssTest:
def __init__(self):
self.client = None
self.filename = get_example_filename()
self.data = get_example_file_data()
def save(self):
raise NotImplementedError
def load_once(self) -> bytes:
raise NotImplementedError
def load_stream(self) -> Generator:
raise NotImplementedError
def download(self):
raise NotImplementedError
def exists(self):
raise NotImplementedError
def delete(self):
raise NotImplementedError
def run_all_tests(self):
self.save()
self.load_once()
self.load_stream()
self.exists()
self.delete()

24
docker/unstructured.yaml Normal file
View File

@ -0,0 +1,24 @@
# unstructured .
# (if used, you need to set ETL_TYPE to Unstructured in the api & worker service.)
unstructured:
image: downloads.unstructured.io/unstructured-io/unstructured-api:latest
profiles:
- unstructured
restart: always
volumes:
- ./volumes/unstructured:/app/data
networks:
# create a network between sandbox, api and ssrf_proxy, and can not access outside.
ssrf_proxy_network:
driver: bridge
internal: true
milvus:
driver: bridge
opensearch-net:
driver: bridge
internal: true
volumes:
oradata: