diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 0b4a7be986..cf2e8af1a2 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -47,6 +47,7 @@ class HitTestingApi(Resource): parser = reqparse.RequestParser() parser.add_argument("query", type=str, location="json") parser.add_argument("retrieval_model", type=dict, required=False, location="json") + parser.add_argument("external_retrival_model", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -57,6 +58,7 @@ class HitTestingApi(Resource): query=args["query"], account=current_user, retrieval_model=args["retrieval_model"], + external_retrieval_model=args["external_retrival_model"], limit=10, ) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 9f45771794..26a728bcb6 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -10,6 +10,7 @@ from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.retrieval.retrival_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, @@ -29,76 +30,87 @@ class RetrievalService: def retrieve(cls, retrival_method: str, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None): + weights: Optional[dict] = None, provider: Optional[str] = None, + external_retrieval_model: Optional[dict] = None): dataset = db.session.query(Dataset).filter( Dataset.id == dataset_id ).first() - if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: + if not dataset: return [] - all_documents = [] - threads = [] - exceptions = [] - # retrieval_model source with keyword - if retrival_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) - threads.append(keyword_thread) - keyword_thread.start() - # retrieval_model source with semantic - if RetrievalMethod.is_support_semantic_search(retrival_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrival_method': retrival_method, - 'exceptions': exceptions, - }) - threads.append(embedding_thread) - embedding_thread.start() - - # retrieval source with full text - if RetrievalMethod.is_support_fulltext_search(retrival_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrival_method': retrival_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) - threads.append(full_text_index_thread) - full_text_index_thread.start() - - for thread in threads: - thread.join() - - if exceptions: - exception_message = ';\n'.join(exceptions) - raise Exception(exception_message) - - if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + if provider == 'external': + external_knowledge_binding = ExternalDatasetService.fetch_external_knowledge_retrival( + dataset.tenant_id, + dataset_id, + query, + external_retrieval_model ) - return all_documents + else: + if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: + return [] + all_documents = [] + threads = [] + exceptions = [] + # retrieval_model source with keyword + if retrival_method == 'keyword_search': + keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k, + 'all_documents': all_documents, + 'exceptions': exceptions, + }) + threads.append(keyword_thread) + keyword_thread.start() + # retrieval_model source with semantic + if RetrievalMethod.is_support_semantic_search(retrival_method): + embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'top_k': top_k, + 'score_threshold': score_threshold, + 'reranking_model': reranking_model, + 'all_documents': all_documents, + 'retrival_method': retrival_method, + 'exceptions': exceptions, + }) + threads.append(embedding_thread) + embedding_thread.start() + + # retrieval source with full text + if RetrievalMethod.is_support_fulltext_search(retrival_method): + full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ + 'flask_app': current_app._get_current_object(), + 'dataset_id': dataset_id, + 'query': query, + 'retrival_method': retrival_method, + 'score_threshold': score_threshold, + 'top_k': top_k, + 'reranking_model': reranking_model, + 'all_documents': all_documents, + 'exceptions': exceptions, + }) + threads.append(full_text_index_thread) + full_text_index_thread.start() + + for thread in threads: + thread.join() + + if exceptions: + exception_message = ';\n'.join(exceptions) + raise Exception(exception_message) + + if retrival_method == RetrievalMethod.HYBRID_SEARCH.value: + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, + reranking_model, weights, False) + all_documents = data_post_processor.invoke( + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=top_k + ) + return all_documents @classmethod def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, diff --git a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py index 068dbd96ad..bca86e3d7e 100644 --- a/api/services/entities/external_knowledge_entities/external_knowledge_entities.py +++ b/api/services/entities/external_knowledge_entities/external_knowledge_entities.py @@ -23,7 +23,6 @@ class ApiTemplateSetting(BaseModel): method: str url: str request_method: str - authorization: Authorization + api_token: str headers: Optional[dict] = None params: Optional[dict] = None - callback_setting: Optional[ProcessStatusSetting] = None diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index d221c20744..364335a058 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -117,6 +117,16 @@ class ExternalDatasetService: return True return False + @staticmethod + def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + dataset_id=dataset_id, + tenant_id=tenant_id + ).first() + if not external_knowledge_binding: + raise ValueError('external knowledge binding not found') + return external_knowledge_binding + @staticmethod def document_create_args_validate(tenant_id: str, api_template_id: str, process_parameter: dict): api_template = ExternalApiTemplates.query.filter_by( @@ -196,8 +206,6 @@ class ExternalDatasetService: @staticmethod def process_external_api(settings: ApiTemplateSetting, - headers: Union[None, dict[str, Any]], - parameter: Union[None, dict[str, Any]], files: Union[None, dict[str, Any]]) -> httpx.Response: """ do http request depending on api bundle @@ -205,14 +213,12 @@ class ExternalDatasetService: kwargs = { 'url': settings.url, - 'headers': headers, + 'headers': settings.headers, 'follow_redirects': True, } - if settings.request_method in ('get', 'head', 'post', 'put', 'delete', 'patch'): - response = getattr(ssrf_proxy, settings.request_method)(data=parameter, files=files, **kwargs) - else: - raise ValueError(f'Invalid http method {settings.request_method}') + response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs) + return response @staticmethod @@ -246,7 +252,7 @@ class ExternalDatasetService: return ApiTemplateSetting.parse_obj(settings) @staticmethod - def create_external_dataset(tenant_id, user_id, args): + def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: # check if dataset name already exists if Dataset.query.filter_by(name=args.get('name'), tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") @@ -254,6 +260,7 @@ class ExternalDatasetService: id=args.get('api_template_id'), tenant_id=tenant_id ).first() + if api_template is None: raise ValueError('api template not found') @@ -281,4 +288,37 @@ class ExternalDatasetService: return dataset + @staticmethod + def fetch_external_knowledge_retrival(tenant_id: str, + dataset_id: str, + query: str, + external_retrival_parameters: dict): + external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( + dataset_id=dataset_id, + tenant_id=tenant_id + ).first() + if not external_knowledge_binding: + raise ValueError('external knowledge binding not found') + external_api_template = ExternalApiTemplates.query.filter_by( + id=external_knowledge_binding.external_api_template_id + ).first() + if not external_api_template: + raise ValueError('external api template not found') + + settings = json.loads(external_api_template.settings) + headers = {} + if settings.get('api_token'): + headers['Authorization'] = f"Bearer {settings.get('api_token')}" + + external_retrival_parameters['query'] = query + + api_template_setting = { + 'url': f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents", + 'request_method': 'post', + 'headers': settings.get('headers'), + 'params': external_retrival_parameters + } + response = ExternalDatasetService.process_external_api( + ApiTemplateSetting(**api_template_setting), None + ) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index db99064814..5e00f740df 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -19,7 +19,8 @@ default_retrieval_model = { class HitTestingService: @classmethod - def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict: + def retrieve(cls, dataset: Dataset, query: str, account: Account, + retrieval_model: dict, external_retrieval_model: dict, limit: int = 10) -> dict: if dataset.available_document_count == 0 or dataset.available_segment_count == 0: return { "query": { @@ -50,6 +51,8 @@ class HitTestingService: if retrieval_model.get("reranking_mode") else "reranking_model", weights=retrieval_model.get("weights", None), + provider=dataset.provider, + external_retrieval_model=external_retrieval_model, ) end = time.perf_counter()