External knowledge api
This commit is contained in:
parent
ed92c90a40
commit
089da063d4
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ class BedrockConfig(BaseSettings):
|
||||
"""
|
||||
bedrock configs
|
||||
"""
|
||||
|
||||
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
|
||||
description="AWS secret access key",
|
||||
default=None,
|
||||
|
@ -37,7 +37,17 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
|
||||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external
|
||||
from .datasets import (
|
||||
data_source,
|
||||
datasets,
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
external,
|
||||
file,
|
||||
hit_testing,
|
||||
test_external,
|
||||
website,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound, InternalServerError
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
@ -234,7 +234,6 @@ class ExternalDatasetCreateApi(Resource):
|
||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
|
@ -1,18 +1,12 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
class TestExternalApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -50,5 +44,4 @@ class TestExternalApi(Resource):
|
||||
return result, 200
|
||||
|
||||
|
||||
|
||||
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")
|
||||
|
@ -23,19 +23,18 @@ default_retrieval_model = {
|
||||
|
||||
class RetrievalService:
|
||||
@classmethod
|
||||
def retrieve(cls,
|
||||
retrieval_method: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = .0,
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: Optional[str] = 'reranking_model',
|
||||
weights: Optional[dict] = None
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
def retrieve(
|
||||
cls,
|
||||
retrieval_method: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: Optional[str] = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
|
||||
@ -45,46 +44,55 @@ class RetrievalService:
|
||||
threads = []
|
||||
exceptions = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == 'keyword_search':
|
||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
if retrieval_method == "keyword_search":
|
||||
keyword_thread = threading.Thread(
|
||||
target=RetrievalService.keyword_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
# retrieval_model source with semantic
|
||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k,
|
||||
'score_threshold': score_threshold,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'retrieval_method': retrieval_method,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
embedding_thread = threading.Thread(
|
||||
target=RetrievalService.embedding_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"retrieval_method": retrieval_method,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(embedding_thread)
|
||||
embedding_thread.start()
|
||||
|
||||
# retrieval source with full text
|
||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'retrieval_method': retrieval_method,
|
||||
'score_threshold': score_threshold,
|
||||
'top_k': top_k,
|
||||
'reranking_model': reranking_model,
|
||||
'all_documents': all_documents,
|
||||
'exceptions': exceptions,
|
||||
})
|
||||
full_text_index_thread = threading.Thread(
|
||||
target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"retrieval_method": retrieval_method,
|
||||
"score_threshold": score_threshold,
|
||||
"top_k": top_k,
|
||||
"reranking_model": reranking_model,
|
||||
"all_documents": all_documents,
|
||||
"exceptions": exceptions,
|
||||
},
|
||||
)
|
||||
threads.append(full_text_index_thread)
|
||||
full_text_index_thread.start()
|
||||
|
||||
@ -92,41 +100,31 @@ class RetrievalService:
|
||||
thread.join()
|
||||
|
||||
if exceptions:
|
||||
exception_message = ';\n'.join(exceptions)
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
|
||||
reranking_model, weights, False)
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||
)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_k
|
||||
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def keyword_search(
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -141,16 +139,16 @@ class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def embedding_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -168,10 +166,10 @@ class RetrievalService:
|
||||
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
@ -188,16 +186,16 @@ class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
cls,
|
||||
flask_app: Flask,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: Optional[float],
|
||||
reranking_model: Optional[dict],
|
||||
all_documents: list,
|
||||
retrieval_method: str,
|
||||
exceptions: list,
|
||||
):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
@ -210,10 +208,10 @@ class RetrievalService:
|
||||
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
|
||||
if documents:
|
||||
if (
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
reranking_model
|
||||
and reranking_model.get("reranking_model_name")
|
||||
and reranking_model.get("reranking_provider_name")
|
||||
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
|
||||
):
|
||||
data_post_processor = DataPostProcessor(
|
||||
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
|
||||
|
@ -17,7 +17,7 @@ class Document(BaseModel):
|
||||
"""
|
||||
metadata: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
provider: Optional[str] = 'dify'
|
||||
provider: Optional[str] = "dify"
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
|
@ -112,7 +112,12 @@ class DatasetRetrieval:
|
||||
continue
|
||||
|
||||
# pass if dataset is not available
|
||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
if (
|
||||
dataset
|
||||
and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0
|
||||
and dataset.provider != "external"
|
||||
):
|
||||
continue
|
||||
|
||||
available_datasets.append(dataset)
|
||||
@ -172,7 +177,6 @@ class DatasetRetrieval:
|
||||
if item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
|
||||
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||
@ -188,9 +192,19 @@ class DatasetRetrieval:
|
||||
)
|
||||
for segment in sorted_segments:
|
||||
if segment.answer:
|
||||
document_context_list.append(DocumentContext(content=f"question:{segment.get_sign_content()} answer:{segment.answer}", score=document_score_list.get(segment.index_node_id, None)))
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
)
|
||||
)
|
||||
else:
|
||||
document_context_list.append(DocumentContext(content=segment.get_sign_content(), score=document_score_list.get(segment.index_node_id, None)))
|
||||
document_context_list.append(
|
||||
DocumentContext(
|
||||
content=segment.get_sign_content(),
|
||||
score=document_score_list.get(segment.index_node_id, None),
|
||||
)
|
||||
)
|
||||
if show_retrieve_source:
|
||||
for segment in sorted_segments:
|
||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||
@ -279,7 +293,7 @@ class DatasetRetrieval:
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = Document(
|
||||
@ -304,7 +318,9 @@ class DatasetRetrieval:
|
||||
retrieval_method = retrieval_model_config["search_method"]
|
||||
# get reranking model
|
||||
reranking_model = (
|
||||
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
|
||||
retrieval_model_config["reranking_model"]
|
||||
if retrieval_model_config["reranking_enable"]
|
||||
else None
|
||||
)
|
||||
# get score threshold
|
||||
score_threshold = 0.0
|
||||
@ -452,7 +468,7 @@ class DatasetRetrieval:
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = Document(
|
||||
|
@ -168,7 +168,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": 'workflow',
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
|
@ -37,8 +37,8 @@ class Dataset(db.Model):
|
||||
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||
)
|
||||
|
||||
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
|
||||
PROVIDER_LIST = ['vendor', 'external', None]
|
||||
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||
PROVIDER_LIST = ["vendor", "external", None]
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
@ -74,10 +74,9 @@ class Dataset(db.Model):
|
||||
|
||||
@property
|
||||
def external_retrieval_model(self):
|
||||
|
||||
default_retrieval_model = {
|
||||
"top_k": 2,
|
||||
"score_threshold": .0,
|
||||
"score_threshold": 0.0,
|
||||
}
|
||||
return self.retrieval_model or default_retrieval_model
|
||||
|
||||
@ -700,35 +699,32 @@ class DatasetPermission(db.Model):
|
||||
|
||||
|
||||
class ExternalApiTemplates(db.Model):
|
||||
__tablename__ = 'external_api_templates'
|
||||
__tablename__ = "external_api_templates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='external_api_template_pkey'),
|
||||
db.Index('external_api_templates_tenant_idx', 'tenant_id'),
|
||||
db.Index('external_api_templates_name_idx', 'name'),
|
||||
db.PrimaryKeyConstraint("id", name="external_api_template_pkey"),
|
||||
db.Index("external_api_templates_tenant_idx", "tenant_id"),
|
||||
db.Index("external_api_templates_name_idx", "name"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, nullable=False,
|
||||
server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.String(255), nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
settings = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'settings': self.settings_dict,
|
||||
'created_by': self.created_by,
|
||||
'created_at': self.created_at.isoformat(),
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"settings": self.settings_dict,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@property
|
||||
@ -740,24 +736,21 @@ class ExternalApiTemplates(db.Model):
|
||||
|
||||
|
||||
class ExternalKnowledgeBindings(db.Model):
|
||||
__tablename__ = 'external_knowledge_bindings'
|
||||
__tablename__ = "external_knowledge_bindings"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey'),
|
||||
db.Index('external_knowledge_bindings_tenant_idx', 'tenant_id'),
|
||||
db.Index('external_knowledge_bindings_dataset_idx', 'dataset_id'),
|
||||
db.Index('external_knowledge_bindings_external_knowledge_idx', 'external_knowledge_id'),
|
||||
db.Index('external_knowledge_bindings_external_api_template_idx', 'external_api_template_id'),
|
||||
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
|
||||
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
|
||||
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
|
||||
db.Index("external_knowledge_bindings_external_api_template_idx", "external_api_template_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, nullable=False,
|
||||
server_default=db.text('uuid_generate_v4()'))
|
||||
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
external_api_template_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
external_knowledge_id = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
@ -59,9 +59,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
||||
class DatasetService:
|
||||
@staticmethod
|
||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
|
||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(
|
||||
Dataset.created_at.desc()
|
||||
)
|
||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||
|
||||
if user:
|
||||
# get permitted dataset ids
|
||||
|
@ -5,8 +5,10 @@ from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
import httpx
|
||||
|
||||
# from tasks.external_document_indexing_task import external_document_indexing_task
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
@ -16,13 +18,9 @@ from models.dataset import (
|
||||
ExternalApiTemplates,
|
||||
ExternalKnowledgeBindings,
|
||||
)
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from models.model import UploadFile
|
||||
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
# from tasks.external_document_indexing_task import external_document_indexing_task
|
||||
import requests
|
||||
import boto3
|
||||
|
||||
|
||||
class ExternalDatasetService:
|
||||
@ -266,7 +264,7 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def fetch_external_knowledge_retrieval(
|
||||
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
|
||||
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
|
||||
) -> list:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id, tenant_id=tenant_id
|
||||
@ -281,9 +279,7 @@ class ExternalDatasetService:
|
||||
raise ValueError("external api template not found")
|
||||
|
||||
settings = json.loads(external_api_template.settings)
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if settings.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
|
||||
|
||||
@ -302,26 +298,19 @@ class ExternalDatasetService:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def test_external_knowledge_retrieval(
|
||||
top_k: int, score_threshold: float, query: str, external_knowledge_id: str
|
||||
):
|
||||
def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str):
|
||||
client = boto3.client(
|
||||
"bedrock-agent-runtime",
|
||||
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
|
||||
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
|
||||
region_name='us-east-1',
|
||||
region_name="us-east-1",
|
||||
)
|
||||
response = client.retrieve(
|
||||
knowledgeBaseId=external_knowledge_id,
|
||||
retrievalConfiguration={
|
||||
'vectorSearchConfiguration': {
|
||||
'numberOfResults': top_k,
|
||||
'overrideSearchType': 'HYBRID'
|
||||
}
|
||||
"vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"}
|
||||
},
|
||||
retrievalQuery={
|
||||
'text': query
|
||||
}
|
||||
retrievalQuery={"text": query},
|
||||
)
|
||||
results = []
|
||||
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
|
||||
|
@ -20,15 +20,15 @@ default_retrieval_model = {
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict,
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
retrieval_model: dict,
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
if (dataset.available_document_count == 0 or dataset.available_segment_count == 0):
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
@ -72,16 +72,15 @@ class HitTestingService:
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
cls,
|
||||
dataset: Dataset,
|
||||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
) -> dict:
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
"query": {
|
||||
"content": query},
|
||||
"query": {"content": query},
|
||||
"records": [],
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user