diff --git a/api/configs/middleware/external/bedrock_config.py b/api/configs/middleware/external/bedrock_config.py index be10da1432..2f5e8536ff 100644 --- a/api/configs/middleware/external/bedrock_config.py +++ b/api/configs/middleware/external/bedrock_config.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import Field, PositiveInt +from pydantic import Field from pydantic_settings import BaseSettings @@ -8,6 +8,7 @@ class BedrockConfig(BaseSettings): """ bedrock configs """ + AWS_SECRET_ACCESS_KEY: Optional[str] = Field( description="AWS secret access key", default=None, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index a7dd97f51e..7bdcb47a61 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -37,7 +37,17 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p from .billing import billing # Import datasets controllers -from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external +from .datasets import ( + data_source, + datasets, + datasets_document, + datasets_segments, + external, + file, + hit_testing, + test_external, + website, +) # Import explore controllers from .explore import ( diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index c0e07d1bae..4b93ecbe10 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,7 +1,7 @@ from flask import request from flask_login import current_user from flask_restful import Resource, marshal, reqparse -from werkzeug.exceptions import Forbidden, NotFound, InternalServerError +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.console import api @@ -234,7 +234,6 @@ class ExternalDatasetCreateApi(Resource): parser.add_argument("description", type=str, required=False, nullable=True, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") - args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator diff --git a/api/controllers/console/datasets/test_external.py b/api/controllers/console/datasets/test_external.py index 7c46be6533..3f3b760b9c 100644 --- a/api/controllers/console/datasets/test_external.py +++ b/api/controllers/console/datasets/test_external.py @@ -1,18 +1,12 @@ -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse -from werkzeug.exceptions import Forbidden, NotFound +from flask_restful import Resource, reqparse -import services from controllers.console import api -from controllers.console.app.error import ProviderNotInitializeError -from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from fields.dataset_fields import dataset_detail_fields from libs.login import login_required from services.external_knowledge_service import ExternalDatasetService + class TestExternalApi(Resource): @setup_required @login_required @@ -50,5 +44,4 @@ class TestExternalApi(Resource): return result, 200 - api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents") diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 496c3e2678..d3fd0c672a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -23,19 +23,18 @@ default_retrieval_model = { class RetrievalService: @classmethod - def retrieve(cls, - retrieval_method: str, - dataset_id: str, - query: str, - top_k: int, - score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, - reranking_mode: Optional[str] = 'reranking_model', - weights: Optional[dict] = None - ): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def retrieve( + cls, + retrieval_method: str, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float] = 0.0, + reranking_model: Optional[dict] = None, + reranking_mode: Optional[str] = "reranking_model", + weights: Optional[dict] = None, + ): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] @@ -45,46 +44,55 @@ class RetrievalService: threads = [] exceptions = [] # retrieval_model source with keyword - if retrieval_method == 'keyword_search': - keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + if retrieval_method == "keyword_search": + keyword_thread = threading.Thread( + target=RetrievalService.keyword_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(keyword_thread) keyword_thread.start() # retrieval_model source with semantic if RetrievalMethod.is_support_semantic_search(retrieval_method): - embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'top_k': top_k, - 'score_threshold': score_threshold, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'retrieval_method': retrieval_method, - 'exceptions': exceptions, - }) + embedding_thread = threading.Thread( + target=RetrievalService.embedding_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "top_k": top_k, + "score_threshold": score_threshold, + "reranking_model": reranking_model, + "all_documents": all_documents, + "retrieval_method": retrieval_method, + "exceptions": exceptions, + }, + ) threads.append(embedding_thread) embedding_thread.start() # retrieval source with full text if RetrievalMethod.is_support_fulltext_search(retrieval_method): - full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ - 'flask_app': current_app._get_current_object(), - 'dataset_id': dataset_id, - 'query': query, - 'retrieval_method': retrieval_method, - 'score_threshold': score_threshold, - 'top_k': top_k, - 'reranking_model': reranking_model, - 'all_documents': all_documents, - 'exceptions': exceptions, - }) + full_text_index_thread = threading.Thread( + target=RetrievalService.full_text_index_search, + kwargs={ + "flask_app": current_app._get_current_object(), + "dataset_id": dataset_id, + "query": query, + "retrieval_method": retrieval_method, + "score_threshold": score_threshold, + "top_k": top_k, + "reranking_model": reranking_model, + "all_documents": all_documents, + "exceptions": exceptions, + }, + ) threads.append(full_text_index_thread) full_text_index_thread.start() @@ -92,41 +100,31 @@ class RetrievalService: thread.join() if exceptions: - exception_message = ';\n'.join(exceptions) + exception_message = ";\n".join(exceptions) raise Exception(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: - data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, - reranking_model, weights, False) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k + query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k ) return all_documents @classmethod - def external_retrieve(cls, - dataset_id: str, - query: str, - external_retrieval_model: Optional[dict] = None): - dataset = db.session.query(Dataset).filter( - Dataset.id == dataset_id - ).first() + def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, - dataset_id, - query, - external_retrieval_model + dataset.tenant_id, dataset_id, query, external_retrieval_model ) return all_documents @classmethod def keyword_search( - cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list + cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list ): with flask_app.app_context(): try: @@ -141,16 +139,16 @@ class RetrievalService: @classmethod def embedding_search( - cls, - flask_app: Flask, - dataset_id: str, - query: str, - top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], - all_documents: list, - retrieval_method: str, - exceptions: list, + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, ): with flask_app.app_context(): try: @@ -168,10 +166,10 @@ class RetrievalService: if documents: if ( - reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value ): data_post_processor = DataPostProcessor( str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False @@ -188,16 +186,16 @@ class RetrievalService: @classmethod def full_text_index_search( - cls, - flask_app: Flask, - dataset_id: str, - query: str, - top_k: int, - score_threshold: Optional[float], - reranking_model: Optional[dict], - all_documents: list, - retrieval_method: str, - exceptions: list, + cls, + flask_app: Flask, + dataset_id: str, + query: str, + top_k: int, + score_threshold: Optional[float], + reranking_model: Optional[dict], + all_documents: list, + retrieval_method: str, + exceptions: list, ): with flask_app.app_context(): try: @@ -210,10 +208,10 @@ class RetrievalService: documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) if documents: if ( - reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") - and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value + reranking_model + and reranking_model.get("reranking_model_name") + and reranking_model.get("reranking_provider_name") + and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value ): data_post_processor = DataPostProcessor( str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 02b4bc82b0..1e9aaa24f0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -17,7 +17,7 @@ class Document(BaseModel): """ metadata: Optional[dict] = Field(default_factory=dict) - provider: Optional[str] = 'dify' + provider: Optional[str] = "dify" class BaseDocumentTransformer(ABC): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 8c404fb12c..966df573c3 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -112,7 +112,12 @@ class DatasetRetrieval: continue # pass if dataset is not available - if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0 and dataset.provider != "external": + if ( + dataset + and dataset.available_document_count == 0 + and dataset.available_document_count == 0 + and dataset.provider != "external" + ): continue available_datasets.append(dataset) @@ -172,7 +177,6 @@ class DatasetRetrieval: if item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] segments = DocumentSegment.query.filter( DocumentSegment.dataset_id.in_(dataset_ids), @@ -188,9 +192,19 @@ class DatasetRetrieval: ) for segment in sorted_segments: if segment.answer: - document_context_list.append(DocumentContext(content=f"question:{segment.get_sign_content()} answer:{segment.answer}", score=document_score_list.get(segment.index_node_id, None))) + document_context_list.append( + DocumentContext( + content=f"question:{segment.get_sign_content()} answer:{segment.answer}", + score=document_score_list.get(segment.index_node_id, None), + ) + ) else: - document_context_list.append(DocumentContext(content=segment.get_sign_content(), score=document_score_list.get(segment.index_node_id, None))) + document_context_list.append( + DocumentContext( + content=segment.get_sign_content(), + score=document_score_list.get(segment.index_node_id, None), + ) + ) if show_retrieve_source: for segment in sorted_segments: dataset = Dataset.query.filter_by(id=segment.dataset_id).first() @@ -279,7 +293,7 @@ class DatasetRetrieval: tenant_id=dataset.tenant_id, dataset_id=dataset_id, query=query, - external_retrieval_parameters=dataset.retrieval_model + external_retrieval_parameters=dataset.retrieval_model, ) for external_document in external_documents: document = Document( @@ -304,7 +318,9 @@ class DatasetRetrieval: retrieval_method = retrieval_model_config["search_method"] # get reranking model reranking_model = ( - retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None + retrieval_model_config["reranking_model"] + if retrieval_model_config["reranking_enable"] + else None ) # get score threshold score_threshold = 0.0 @@ -452,7 +468,7 @@ class DatasetRetrieval: tenant_id=dataset.tenant_id, dataset_id=dataset_id, query=query, - external_retrieval_parameters=dataset.retrieval_model + external_retrieval_parameters=dataset.retrieval_model, ) for external_document in external_documents: document = Document( diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 121c96e619..c08dc143a6 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -168,7 +168,7 @@ class KnowledgeRetrievalNode(BaseNode): "dataset_name": item.metadata.get("dataset_name"), "document_name": item.metadata.get("title"), "data_source_type": "external", - "retriever_from": 'workflow', + "retriever_from": "workflow", "score": item.metadata.get("score"), }, "title": item.metadata.get("title"), diff --git a/api/models/dataset.py b/api/models/dataset.py index 585c83e24a..ecf2c244e6 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -37,8 +37,8 @@ class Dataset(db.Model): db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), ) - INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] - PROVIDER_LIST = ['vendor', 'external', None] + INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] + PROVIDER_LIST = ["vendor", "external", None] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) @@ -74,10 +74,9 @@ class Dataset(db.Model): @property def external_retrieval_model(self): - default_retrieval_model = { "top_k": 2, - "score_threshold": .0, + "score_threshold": 0.0, } return self.retrieval_model or default_retrieval_model @@ -700,35 +699,32 @@ class DatasetPermission(db.Model): class ExternalApiTemplates(db.Model): - __tablename__ = 'external_api_templates' + __tablename__ = "external_api_templates" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='external_api_template_pkey'), - db.Index('external_api_templates_tenant_idx', 'tenant_id'), - db.Index('external_api_templates_name_idx', 'name'), + db.PrimaryKeyConstraint("id", name="external_api_template_pkey"), + db.Index("external_api_templates_tenant_idx", "tenant_id"), + db.Index("external_api_templates_name_idx", "name"), ) - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) description = db.Column(db.String(255), nullable=False) tenant_id = db.Column(StringUUID, nullable=False) settings = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) def to_dict(self): return { - 'id': self.id, - 'tenant_id': self.tenant_id, - 'name': self.name, - 'description': self.description, - 'settings': self.settings_dict, - 'created_by': self.created_by, - 'created_at': self.created_at.isoformat(), + "id": self.id, + "tenant_id": self.tenant_id, + "name": self.name, + "description": self.description, + "settings": self.settings_dict, + "created_by": self.created_by, + "created_at": self.created_at.isoformat(), } @property @@ -740,24 +736,21 @@ class ExternalApiTemplates(db.Model): class ExternalKnowledgeBindings(db.Model): - __tablename__ = 'external_knowledge_bindings' + __tablename__ = "external_knowledge_bindings" __table_args__ = ( - db.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey'), - db.Index('external_knowledge_bindings_tenant_idx', 'tenant_id'), - db.Index('external_knowledge_bindings_dataset_idx', 'dataset_id'), - db.Index('external_knowledge_bindings_external_knowledge_idx', 'external_knowledge_id'), - db.Index('external_knowledge_bindings_external_api_template_idx', 'external_api_template_id'), + db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), + db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), + db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), + db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), + db.Index("external_knowledge_bindings_external_api_template_idx", "external_api_template_id"), ) - id = db.Column(StringUUID, nullable=False, - server_default=db.text('uuid_generate_v4()')) + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) external_api_template_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, - server_default=db.text('CURRENT_TIMESTAMP(0)')) \ No newline at end of file + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 12e0418093..11e11d6974 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -59,9 +59,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None): - query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by( - Dataset.created_at.desc() - ) + query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index e1acac5c58..0365a03990 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -5,8 +5,10 @@ from copy import deepcopy from datetime import datetime, timezone from typing import Any, Optional, Union +import boto3 import httpx +# from tasks.external_document_indexing_task import external_document_indexing_task from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db @@ -16,13 +18,9 @@ from models.dataset import ( ExternalApiTemplates, ExternalKnowledgeBindings, ) -from core.rag.models.document import Document as RetrievalDocument from models.model import UploadFile from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization from services.errors.dataset import DatasetNameDuplicateError -# from tasks.external_document_indexing_task import external_document_indexing_task -import requests -import boto3 class ExternalDatasetService: @@ -266,7 +264,7 @@ class ExternalDatasetService: @staticmethod def fetch_external_knowledge_retrieval( - tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict + tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict ) -> list: external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( dataset_id=dataset_id, tenant_id=tenant_id @@ -281,9 +279,7 @@ class ExternalDatasetService: raise ValueError("external api template not found") settings = json.loads(external_api_template.settings) - headers = { - "Content-Type": "application/json" - } + headers = {"Content-Type": "application/json"} if settings.get("api_key"): headers["Authorization"] = f"Bearer {settings.get('api_key')}" @@ -302,26 +298,19 @@ class ExternalDatasetService: return [] @staticmethod - def test_external_knowledge_retrieval( - top_k: int, score_threshold: float, query: str, external_knowledge_id: str - ): + def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str): client = boto3.client( "bedrock-agent-runtime", aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, - region_name='us-east-1', + region_name="us-east-1", ) response = client.retrieve( knowledgeBaseId=external_knowledge_id, retrievalConfiguration={ - 'vectorSearchConfiguration': { - 'numberOfResults': top_k, - 'overrideSearchType': 'HYBRID' - } + "vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"} }, - retrievalQuery={ - 'text': query - } + retrievalQuery={"text": query}, ) results = [] if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 196720882a..7957b4dc82 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -20,15 +20,15 @@ default_retrieval_model = { class HitTestingService: @classmethod def retrieve( - cls, - dataset: Dataset, - query: str, - account: Account, - retrieval_model: dict, - external_retrieval_model: dict, - limit: int = 10, + cls, + dataset: Dataset, + query: str, + account: Account, + retrieval_model: dict, + external_retrieval_model: dict, + limit: int = 10, ) -> dict: - if (dataset.available_document_count == 0 or dataset.available_segment_count == 0): + if dataset.available_document_count == 0 or dataset.available_segment_count == 0: return { "query": { "content": query, @@ -72,16 +72,15 @@ class HitTestingService: @classmethod def external_retrieve( - cls, - dataset: Dataset, - query: str, - account: Account, - external_retrieval_model: dict, + cls, + dataset: Dataset, + query: str, + account: Account, + external_retrieval_model: dict, ) -> dict: if dataset.provider != "external": return { - "query": { - "content": query}, + "query": {"content": query}, "records": [], }