External knowledge api
This commit is contained in:
parent
ed92c90a40
commit
089da063d4
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import Field, PositiveInt
|
from pydantic import Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
@ -8,6 +8,7 @@ class BedrockConfig(BaseSettings):
|
|||||||
"""
|
"""
|
||||||
bedrock configs
|
bedrock configs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
|
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
|
||||||
description="AWS secret access key",
|
description="AWS secret access key",
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -37,7 +37,17 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
|
|||||||
from .billing import billing
|
from .billing import billing
|
||||||
|
|
||||||
# Import datasets controllers
|
# Import datasets controllers
|
||||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external
|
from .datasets import (
|
||||||
|
data_source,
|
||||||
|
datasets,
|
||||||
|
datasets_document,
|
||||||
|
datasets_segments,
|
||||||
|
external,
|
||||||
|
file,
|
||||||
|
hit_testing,
|
||||||
|
test_external,
|
||||||
|
website,
|
||||||
|
)
|
||||||
|
|
||||||
# Import explore controllers
|
# Import explore controllers
|
||||||
from .explore import (
|
from .explore import (
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal, reqparse
|
from flask_restful import Resource, marshal, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound, InternalServerError
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
@ -234,7 +234,6 @@ class ExternalDatasetCreateApi(Resource):
|
|||||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
@ -1,18 +1,12 @@
|
|||||||
from flask import request
|
from flask_restful import Resource, reqparse
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restful import Resource, marshal, reqparse
|
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
|
||||||
|
|
||||||
import services
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
|
||||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
|
|
||||||
|
|
||||||
class TestExternalApi(Resource):
|
class TestExternalApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -50,5 +44,4 @@ class TestExternalApi(Resource):
|
|||||||
return result, 200
|
return result, 200
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")
|
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")
|
||||||
|
@ -23,19 +23,18 @@ default_retrieval_model = {
|
|||||||
|
|
||||||
class RetrievalService:
|
class RetrievalService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def retrieve(cls,
|
def retrieve(
|
||||||
|
cls,
|
||||||
retrieval_method: str,
|
retrieval_method: str,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
score_threshold: Optional[float] = .0,
|
score_threshold: Optional[float] = 0.0,
|
||||||
reranking_model: Optional[dict] = None,
|
reranking_model: Optional[dict] = None,
|
||||||
reranking_mode: Optional[str] = 'reranking_model',
|
reranking_mode: Optional[str] = "reranking_model",
|
||||||
weights: Optional[dict] = None
|
weights: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
dataset = db.session.query(Dataset).filter(
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -45,46 +44,55 @@ class RetrievalService:
|
|||||||
threads = []
|
threads = []
|
||||||
exceptions = []
|
exceptions = []
|
||||||
# retrieval_model source with keyword
|
# retrieval_model source with keyword
|
||||||
if retrieval_method == 'keyword_search':
|
if retrieval_method == "keyword_search":
|
||||||
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
|
keyword_thread = threading.Thread(
|
||||||
'flask_app': current_app._get_current_object(),
|
target=RetrievalService.keyword_search,
|
||||||
'dataset_id': dataset_id,
|
kwargs={
|
||||||
'query': query,
|
"flask_app": current_app._get_current_object(),
|
||||||
'top_k': top_k,
|
"dataset_id": dataset_id,
|
||||||
'all_documents': all_documents,
|
"query": query,
|
||||||
'exceptions': exceptions,
|
"top_k": top_k,
|
||||||
})
|
"all_documents": all_documents,
|
||||||
|
"exceptions": exceptions,
|
||||||
|
},
|
||||||
|
)
|
||||||
threads.append(keyword_thread)
|
threads.append(keyword_thread)
|
||||||
keyword_thread.start()
|
keyword_thread.start()
|
||||||
# retrieval_model source with semantic
|
# retrieval_model source with semantic
|
||||||
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
if RetrievalMethod.is_support_semantic_search(retrieval_method):
|
||||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
embedding_thread = threading.Thread(
|
||||||
'flask_app': current_app._get_current_object(),
|
target=RetrievalService.embedding_search,
|
||||||
'dataset_id': dataset_id,
|
kwargs={
|
||||||
'query': query,
|
"flask_app": current_app._get_current_object(),
|
||||||
'top_k': top_k,
|
"dataset_id": dataset_id,
|
||||||
'score_threshold': score_threshold,
|
"query": query,
|
||||||
'reranking_model': reranking_model,
|
"top_k": top_k,
|
||||||
'all_documents': all_documents,
|
"score_threshold": score_threshold,
|
||||||
'retrieval_method': retrieval_method,
|
"reranking_model": reranking_model,
|
||||||
'exceptions': exceptions,
|
"all_documents": all_documents,
|
||||||
})
|
"retrieval_method": retrieval_method,
|
||||||
|
"exceptions": exceptions,
|
||||||
|
},
|
||||||
|
)
|
||||||
threads.append(embedding_thread)
|
threads.append(embedding_thread)
|
||||||
embedding_thread.start()
|
embedding_thread.start()
|
||||||
|
|
||||||
# retrieval source with full text
|
# retrieval source with full text
|
||||||
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
|
||||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
full_text_index_thread = threading.Thread(
|
||||||
'flask_app': current_app._get_current_object(),
|
target=RetrievalService.full_text_index_search,
|
||||||
'dataset_id': dataset_id,
|
kwargs={
|
||||||
'query': query,
|
"flask_app": current_app._get_current_object(),
|
||||||
'retrieval_method': retrieval_method,
|
"dataset_id": dataset_id,
|
||||||
'score_threshold': score_threshold,
|
"query": query,
|
||||||
'top_k': top_k,
|
"retrieval_method": retrieval_method,
|
||||||
'reranking_model': reranking_model,
|
"score_threshold": score_threshold,
|
||||||
'all_documents': all_documents,
|
"top_k": top_k,
|
||||||
'exceptions': exceptions,
|
"reranking_model": reranking_model,
|
||||||
})
|
"all_documents": all_documents,
|
||||||
|
"exceptions": exceptions,
|
||||||
|
},
|
||||||
|
)
|
||||||
threads.append(full_text_index_thread)
|
threads.append(full_text_index_thread)
|
||||||
full_text_index_thread.start()
|
full_text_index_thread.start()
|
||||||
|
|
||||||
@ -92,35 +100,25 @@ class RetrievalService:
|
|||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
if exceptions:
|
if exceptions:
|
||||||
exception_message = ';\n'.join(exceptions)
|
exception_message = ";\n".join(exceptions)
|
||||||
raise Exception(exception_message)
|
raise Exception(exception_message)
|
||||||
|
|
||||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
|
data_post_processor = DataPostProcessor(
|
||||||
reranking_model, weights, False)
|
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
|
||||||
|
)
|
||||||
all_documents = data_post_processor.invoke(
|
all_documents = data_post_processor.invoke(
|
||||||
query=query,
|
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
|
||||||
documents=all_documents,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
top_n=top_k
|
|
||||||
)
|
)
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def external_retrieve(cls,
|
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
||||||
dataset_id: str,
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||||
query: str,
|
|
||||||
external_retrieval_model: Optional[dict] = None):
|
|
||||||
dataset = db.session.query(Dataset).filter(
|
|
||||||
Dataset.id == dataset_id
|
|
||||||
).first()
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
dataset.tenant_id,
|
dataset.tenant_id, dataset_id, query, external_retrieval_model
|
||||||
dataset_id,
|
|
||||||
query,
|
|
||||||
external_retrieval_model
|
|
||||||
)
|
)
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class Document(BaseModel):
|
|||||||
"""
|
"""
|
||||||
metadata: Optional[dict] = Field(default_factory=dict)
|
metadata: Optional[dict] = Field(default_factory=dict)
|
||||||
|
|
||||||
provider: Optional[str] = 'dify'
|
provider: Optional[str] = "dify"
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentTransformer(ABC):
|
class BaseDocumentTransformer(ABC):
|
||||||
|
@ -112,7 +112,12 @@ class DatasetRetrieval:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# pass if dataset is not available
|
# pass if dataset is not available
|
||||||
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0 and dataset.provider != "external":
|
if (
|
||||||
|
dataset
|
||||||
|
and dataset.available_document_count == 0
|
||||||
|
and dataset.available_document_count == 0
|
||||||
|
and dataset.provider != "external"
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
available_datasets.append(dataset)
|
available_datasets.append(dataset)
|
||||||
@ -172,7 +177,6 @@ class DatasetRetrieval:
|
|||||||
if item.metadata.get("score"):
|
if item.metadata.get("score"):
|
||||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||||
|
|
||||||
|
|
||||||
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
|
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
|
||||||
segments = DocumentSegment.query.filter(
|
segments = DocumentSegment.query.filter(
|
||||||
DocumentSegment.dataset_id.in_(dataset_ids),
|
DocumentSegment.dataset_id.in_(dataset_ids),
|
||||||
@ -188,9 +192,19 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
document_context_list.append(DocumentContext(content=f"question:{segment.get_sign_content()} answer:{segment.answer}", score=document_score_list.get(segment.index_node_id, None)))
|
document_context_list.append(
|
||||||
|
DocumentContext(
|
||||||
|
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
|
||||||
|
score=document_score_list.get(segment.index_node_id, None),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
document_context_list.append(DocumentContext(content=segment.get_sign_content(), score=document_score_list.get(segment.index_node_id, None)))
|
document_context_list.append(
|
||||||
|
DocumentContext(
|
||||||
|
content=segment.get_sign_content(),
|
||||||
|
score=document_score_list.get(segment.index_node_id, None),
|
||||||
|
)
|
||||||
|
)
|
||||||
if show_retrieve_source:
|
if show_retrieve_source:
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
|
||||||
@ -279,7 +293,7 @@ class DatasetRetrieval:
|
|||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
query=query,
|
query=query,
|
||||||
external_retrieval_parameters=dataset.retrieval_model
|
external_retrieval_parameters=dataset.retrieval_model,
|
||||||
)
|
)
|
||||||
for external_document in external_documents:
|
for external_document in external_documents:
|
||||||
document = Document(
|
document = Document(
|
||||||
@ -304,7 +318,9 @@ class DatasetRetrieval:
|
|||||||
retrieval_method = retrieval_model_config["search_method"]
|
retrieval_method = retrieval_model_config["search_method"]
|
||||||
# get reranking model
|
# get reranking model
|
||||||
reranking_model = (
|
reranking_model = (
|
||||||
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
|
retrieval_model_config["reranking_model"]
|
||||||
|
if retrieval_model_config["reranking_enable"]
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
# get score threshold
|
# get score threshold
|
||||||
score_threshold = 0.0
|
score_threshold = 0.0
|
||||||
@ -452,7 +468,7 @@ class DatasetRetrieval:
|
|||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
query=query,
|
query=query,
|
||||||
external_retrieval_parameters=dataset.retrieval_model
|
external_retrieval_parameters=dataset.retrieval_model,
|
||||||
)
|
)
|
||||||
for external_document in external_documents:
|
for external_document in external_documents:
|
||||||
document = Document(
|
document = Document(
|
||||||
|
@ -168,7 +168,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
"dataset_name": item.metadata.get("dataset_name"),
|
"dataset_name": item.metadata.get("dataset_name"),
|
||||||
"document_name": item.metadata.get("title"),
|
"document_name": item.metadata.get("title"),
|
||||||
"data_source_type": "external",
|
"data_source_type": "external",
|
||||||
"retriever_from": 'workflow',
|
"retriever_from": "workflow",
|
||||||
"score": item.metadata.get("score"),
|
"score": item.metadata.get("score"),
|
||||||
},
|
},
|
||||||
"title": item.metadata.get("title"),
|
"title": item.metadata.get("title"),
|
||||||
|
@ -37,8 +37,8 @@ class Dataset(db.Model):
|
|||||||
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
|
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
|
||||||
PROVIDER_LIST = ['vendor', 'external', None]
|
PROVIDER_LIST = ["vendor", "external", None]
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id = db.Column(StringUUID, nullable=False)
|
||||||
@ -74,10 +74,9 @@ class Dataset(db.Model):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def external_retrieval_model(self):
|
def external_retrieval_model(self):
|
||||||
|
|
||||||
default_retrieval_model = {
|
default_retrieval_model = {
|
||||||
"top_k": 2,
|
"top_k": 2,
|
||||||
"score_threshold": .0,
|
"score_threshold": 0.0,
|
||||||
}
|
}
|
||||||
return self.retrieval_model or default_retrieval_model
|
return self.retrieval_model or default_retrieval_model
|
||||||
|
|
||||||
@ -700,35 +699,32 @@ class DatasetPermission(db.Model):
|
|||||||
|
|
||||||
|
|
||||||
class ExternalApiTemplates(db.Model):
|
class ExternalApiTemplates(db.Model):
|
||||||
__tablename__ = 'external_api_templates'
|
__tablename__ = "external_api_templates"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint('id', name='external_api_template_pkey'),
|
db.PrimaryKeyConstraint("id", name="external_api_template_pkey"),
|
||||||
db.Index('external_api_templates_tenant_idx', 'tenant_id'),
|
db.Index("external_api_templates_tenant_idx", "tenant_id"),
|
||||||
db.Index('external_api_templates_name_idx', 'name'),
|
db.Index("external_api_templates_name_idx", "name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, nullable=False,
|
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||||
server_default=db.text('uuid_generate_v4()'))
|
|
||||||
name = db.Column(db.String(255), nullable=False)
|
name = db.Column(db.String(255), nullable=False)
|
||||||
description = db.Column(db.String(255), nullable=False)
|
description = db.Column(db.String(255), nullable=False)
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id = db.Column(StringUUID, nullable=False)
|
||||||
settings = db.Column(db.Text, nullable=True)
|
settings = db.Column(db.Text, nullable=True)
|
||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False,
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
|
||||||
updated_by = db.Column(StringUUID, nullable=True)
|
updated_by = db.Column(StringUUID, nullable=True)
|
||||||
updated_at = db.Column(db.DateTime, nullable=False,
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
'id': self.id,
|
"id": self.id,
|
||||||
'tenant_id': self.tenant_id,
|
"tenant_id": self.tenant_id,
|
||||||
'name': self.name,
|
"name": self.name,
|
||||||
'description': self.description,
|
"description": self.description,
|
||||||
'settings': self.settings_dict,
|
"settings": self.settings_dict,
|
||||||
'created_by': self.created_by,
|
"created_by": self.created_by,
|
||||||
'created_at': self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -740,24 +736,21 @@ class ExternalApiTemplates(db.Model):
|
|||||||
|
|
||||||
|
|
||||||
class ExternalKnowledgeBindings(db.Model):
|
class ExternalKnowledgeBindings(db.Model):
|
||||||
__tablename__ = 'external_knowledge_bindings'
|
__tablename__ = "external_knowledge_bindings"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey'),
|
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
|
||||||
db.Index('external_knowledge_bindings_tenant_idx', 'tenant_id'),
|
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
|
||||||
db.Index('external_knowledge_bindings_dataset_idx', 'dataset_id'),
|
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
|
||||||
db.Index('external_knowledge_bindings_external_knowledge_idx', 'external_knowledge_id'),
|
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
|
||||||
db.Index('external_knowledge_bindings_external_api_template_idx', 'external_api_template_id'),
|
db.Index("external_knowledge_bindings_external_api_template_idx", "external_api_template_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, nullable=False,
|
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
|
||||||
server_default=db.text('uuid_generate_v4()'))
|
|
||||||
tenant_id = db.Column(StringUUID, nullable=False)
|
tenant_id = db.Column(StringUUID, nullable=False)
|
||||||
external_api_template_id = db.Column(StringUUID, nullable=False)
|
external_api_template_id = db.Column(StringUUID, nullable=False)
|
||||||
dataset_id = db.Column(StringUUID, nullable=False)
|
dataset_id = db.Column(StringUUID, nullable=False)
|
||||||
external_knowledge_id = db.Column(db.Text, nullable=False)
|
external_knowledge_id = db.Column(db.Text, nullable=False)
|
||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False,
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
|
||||||
updated_by = db.Column(StringUUID, nullable=True)
|
updated_by = db.Column(StringUUID, nullable=True)
|
||||||
updated_at = db.Column(db.DateTime, nullable=False,
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
|
||||||
|
@ -59,9 +59,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
|
|||||||
class DatasetService:
|
class DatasetService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
|
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
|
||||||
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(
|
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
|
||||||
Dataset.created_at.desc()
|
|
||||||
)
|
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
# get permitted dataset ids
|
# get permitted dataset ids
|
||||||
|
@ -5,8 +5,10 @@ from copy import deepcopy
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import boto3
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
# from tasks.external_document_indexing_task import external_document_indexing_task
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -16,13 +18,9 @@ from models.dataset import (
|
|||||||
ExternalApiTemplates,
|
ExternalApiTemplates,
|
||||||
ExternalKnowledgeBindings,
|
ExternalKnowledgeBindings,
|
||||||
)
|
)
|
||||||
from core.rag.models.document import Document as RetrievalDocument
|
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
|
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
|
||||||
from services.errors.dataset import DatasetNameDuplicateError
|
from services.errors.dataset import DatasetNameDuplicateError
|
||||||
# from tasks.external_document_indexing_task import external_document_indexing_task
|
|
||||||
import requests
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalDatasetService:
|
class ExternalDatasetService:
|
||||||
@ -281,9 +279,7 @@ class ExternalDatasetService:
|
|||||||
raise ValueError("external api template not found")
|
raise ValueError("external api template not found")
|
||||||
|
|
||||||
settings = json.loads(external_api_template.settings)
|
settings = json.loads(external_api_template.settings)
|
||||||
headers = {
|
headers = {"Content-Type": "application/json"}
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
if settings.get("api_key"):
|
if settings.get("api_key"):
|
||||||
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
|
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
|
||||||
|
|
||||||
@ -302,26 +298,19 @@ class ExternalDatasetService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_external_knowledge_retrieval(
|
def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str):
|
||||||
top_k: int, score_threshold: float, query: str, external_knowledge_id: str
|
|
||||||
):
|
|
||||||
client = boto3.client(
|
client = boto3.client(
|
||||||
"bedrock-agent-runtime",
|
"bedrock-agent-runtime",
|
||||||
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
|
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
|
||||||
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
|
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
|
||||||
region_name='us-east-1',
|
region_name="us-east-1",
|
||||||
)
|
)
|
||||||
response = client.retrieve(
|
response = client.retrieve(
|
||||||
knowledgeBaseId=external_knowledge_id,
|
knowledgeBaseId=external_knowledge_id,
|
||||||
retrievalConfiguration={
|
retrievalConfiguration={
|
||||||
'vectorSearchConfiguration': {
|
"vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"}
|
||||||
'numberOfResults': top_k,
|
|
||||||
'overrideSearchType': 'HYBRID'
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
retrievalQuery={
|
retrievalQuery={"text": query},
|
||||||
'text': query
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
|
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
|
||||||
|
@ -28,7 +28,7 @@ class HitTestingService:
|
|||||||
external_retrieval_model: dict,
|
external_retrieval_model: dict,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if (dataset.available_document_count == 0 or dataset.available_segment_count == 0):
|
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||||
return {
|
return {
|
||||||
"query": {
|
"query": {
|
||||||
"content": query,
|
"content": query,
|
||||||
@ -80,8 +80,7 @@ class HitTestingService:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
if dataset.provider != "external":
|
if dataset.provider != "external":
|
||||||
return {
|
return {
|
||||||
"query": {
|
"query": {"content": query},
|
||||||
"content": query},
|
|
||||||
"records": [],
|
"records": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user