External knowledge api

This commit is contained in:
jyong 2024-09-24 18:00:45 +08:00
parent ed92c90a40
commit 089da063d4
12 changed files with 179 additions and 183 deletions

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import Field, PositiveInt from pydantic import Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -8,6 +8,7 @@ class BedrockConfig(BaseSettings):
""" """
bedrock configs bedrock configs
""" """
AWS_SECRET_ACCESS_KEY: Optional[str] = Field( AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="AWS secret access key", description="AWS secret access key",
default=None, default=None,

View File

@ -37,7 +37,17 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
from .billing import billing from .billing import billing
# Import datasets controllers # Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
external,
file,
hit_testing,
test_external,
website,
)
# Import explore controllers # Import explore controllers
from .explore import ( from .explore import (

View File

@ -1,7 +1,7 @@
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal, reqparse from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound, InternalServerError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import api
@ -234,7 +234,6 @@ class ExternalDatasetCreateApi(Resource):
parser.add_argument("description", type=str, required=False, nullable=True, location="json") parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator

View File

@ -1,18 +1,12 @@
from flask import request from flask_restful import Resource, reqparse
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
class TestExternalApi(Resource): class TestExternalApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -50,5 +44,4 @@ class TestExternalApi(Resource):
return result, 200 return result, 200
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents") api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")

View File

@ -23,19 +23,18 @@ default_retrieval_model = {
class RetrievalService: class RetrievalService:
@classmethod @classmethod
def retrieve(cls, def retrieve(
cls,
retrieval_method: str, retrieval_method: str,
dataset_id: str, dataset_id: str,
query: str, query: str,
top_k: int, top_k: int,
score_threshold: Optional[float] = .0, score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None, reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = 'reranking_model', reranking_mode: Optional[str] = "reranking_model",
weights: Optional[dict] = None weights: Optional[dict] = None,
): ):
dataset = db.session.query(Dataset).filter( dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
Dataset.id == dataset_id
).first()
if not dataset: if not dataset:
return [] return []
@ -45,46 +44,55 @@ class RetrievalService:
threads = [] threads = []
exceptions = [] exceptions = []
# retrieval_model source with keyword # retrieval_model source with keyword
if retrieval_method == 'keyword_search': if retrieval_method == "keyword_search":
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ keyword_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=RetrievalService.keyword_search,
'dataset_id': dataset_id, kwargs={
'query': query, "flask_app": current_app._get_current_object(),
'top_k': top_k, "dataset_id": dataset_id,
'all_documents': all_documents, "query": query,
'exceptions': exceptions, "top_k": top_k,
}) "all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(keyword_thread) threads.append(keyword_thread)
keyword_thread.start() keyword_thread.start()
# retrieval_model source with semantic # retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrieval_method): if RetrievalMethod.is_support_semantic_search(retrieval_method):
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=RetrievalService.embedding_search,
'dataset_id': dataset_id, kwargs={
'query': query, "flask_app": current_app._get_current_object(),
'top_k': top_k, "dataset_id": dataset_id,
'score_threshold': score_threshold, "query": query,
'reranking_model': reranking_model, "top_k": top_k,
'all_documents': all_documents, "score_threshold": score_threshold,
'retrieval_method': retrieval_method, "reranking_model": reranking_model,
'exceptions': exceptions, "all_documents": all_documents,
}) "retrieval_method": retrieval_method,
"exceptions": exceptions,
},
)
threads.append(embedding_thread) threads.append(embedding_thread)
embedding_thread.start() embedding_thread.start()
# retrieval source with full text # retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrieval_method): if RetrievalMethod.is_support_fulltext_search(retrieval_method):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(
'flask_app': current_app._get_current_object(), target=RetrievalService.full_text_index_search,
'dataset_id': dataset_id, kwargs={
'query': query, "flask_app": current_app._get_current_object(),
'retrieval_method': retrieval_method, "dataset_id": dataset_id,
'score_threshold': score_threshold, "query": query,
'top_k': top_k, "retrieval_method": retrieval_method,
'reranking_model': reranking_model, "score_threshold": score_threshold,
'all_documents': all_documents, "top_k": top_k,
'exceptions': exceptions, "reranking_model": reranking_model,
}) "all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(full_text_index_thread) threads.append(full_text_index_thread)
full_text_index_thread.start() full_text_index_thread.start()
@ -92,35 +100,25 @@ class RetrievalService:
thread.join() thread.join()
if exceptions: if exceptions:
exception_message = ';\n'.join(exceptions) exception_message = ";\n".join(exceptions)
raise Exception(exception_message) raise Exception(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode, data_post_processor = DataPostProcessor(
reranking_model, weights, False) str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke( all_documents = data_post_processor.invoke(
query=query, query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
) )
return all_documents return all_documents
@classmethod @classmethod
def external_retrieve(cls, def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
dataset_id: str, dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
query: str,
external_retrieval_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset: if not dataset:
return [] return []
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset.tenant_id, dataset_id, query, external_retrieval_model
dataset_id,
query,
external_retrieval_model
) )
return all_documents return all_documents

View File

@ -17,7 +17,7 @@ class Document(BaseModel):
""" """
metadata: Optional[dict] = Field(default_factory=dict) metadata: Optional[dict] = Field(default_factory=dict)
provider: Optional[str] = 'dify' provider: Optional[str] = "dify"
class BaseDocumentTransformer(ABC): class BaseDocumentTransformer(ABC):

View File

@ -112,7 +112,12 @@ class DatasetRetrieval:
continue continue
# pass if dataset is not available # pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0 and dataset.provider != "external": if (
dataset
and dataset.available_document_count == 0
and dataset.available_document_count == 0
and dataset.provider != "external"
):
continue continue
available_datasets.append(dataset) available_datasets.append(dataset)
@ -172,7 +177,6 @@ class DatasetRetrieval:
if item.metadata.get("score"): if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata["doc_id"] for document in dify_documents] index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter( segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids), DocumentSegment.dataset_id.in_(dataset_ids),
@ -188,9 +192,19 @@ class DatasetRetrieval:
) )
for segment in sorted_segments: for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append(DocumentContext(content=f"question:{segment.get_sign_content()} answer:{segment.answer}", score=document_score_list.get(segment.index_node_id, None))) document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None),
)
)
else: else:
document_context_list.append(DocumentContext(content=segment.get_sign_content(), score=document_score_list.get(segment.index_node_id, None))) document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None),
)
)
if show_retrieve_source: if show_retrieve_source:
for segment in sorted_segments: for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
@ -279,7 +293,7 @@ class DatasetRetrieval:
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset_id, dataset_id=dataset_id,
query=query, query=query,
external_retrieval_parameters=dataset.retrieval_model external_retrieval_parameters=dataset.retrieval_model,
) )
for external_document in external_documents: for external_document in external_documents:
document = Document( document = Document(
@ -304,7 +318,9 @@ class DatasetRetrieval:
retrieval_method = retrieval_model_config["search_method"] retrieval_method = retrieval_model_config["search_method"]
# get reranking model # get reranking model
reranking_model = ( reranking_model = (
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None retrieval_model_config["reranking_model"]
if retrieval_model_config["reranking_enable"]
else None
) )
# get score threshold # get score threshold
score_threshold = 0.0 score_threshold = 0.0
@ -452,7 +468,7 @@ class DatasetRetrieval:
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset_id, dataset_id=dataset_id,
query=query, query=query,
external_retrieval_parameters=dataset.retrieval_model external_retrieval_parameters=dataset.retrieval_model,
) )
for external_document in external_documents: for external_document in external_documents:
document = Document( document = Document(

View File

@ -168,7 +168,7 @@ class KnowledgeRetrievalNode(BaseNode):
"dataset_name": item.metadata.get("dataset_name"), "dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"), "document_name": item.metadata.get("title"),
"data_source_type": "external", "data_source_type": "external",
"retriever_from": 'workflow', "retriever_from": "workflow",
"score": item.metadata.get("score"), "score": item.metadata.get("score"),
}, },
"title": item.metadata.get("title"), "title": item.metadata.get("title"),

View File

@ -37,8 +37,8 @@ class Dataset(db.Model):
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
) )
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ['vendor', 'external', None] PROVIDER_LIST = ["vendor", "external", None]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
@ -74,10 +74,9 @@ class Dataset(db.Model):
@property @property
def external_retrieval_model(self): def external_retrieval_model(self):
default_retrieval_model = { default_retrieval_model = {
"top_k": 2, "top_k": 2,
"score_threshold": .0, "score_threshold": 0.0,
} }
return self.retrieval_model or default_retrieval_model return self.retrieval_model or default_retrieval_model
@ -700,35 +699,32 @@ class DatasetPermission(db.Model):
class ExternalApiTemplates(db.Model): class ExternalApiTemplates(db.Model):
__tablename__ = 'external_api_templates' __tablename__ = "external_api_templates"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='external_api_template_pkey'), db.PrimaryKeyConstraint("id", name="external_api_template_pkey"),
db.Index('external_api_templates_tenant_idx', 'tenant_id'), db.Index("external_api_templates_tenant_idx", "tenant_id"),
db.Index('external_api_templates_name_idx', 'name'), db.Index("external_api_templates_name_idx", "name"),
) )
id = db.Column(StringUUID, nullable=False, id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
server_default=db.text('uuid_generate_v4()'))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
description = db.Column(db.String(255), nullable=False) description = db.Column(db.String(255), nullable=False)
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
settings = db.Column(db.Text, nullable=True) settings = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
server_default=db.text('CURRENT_TIMESTAMP(0)'))
def to_dict(self): def to_dict(self):
return { return {
'id': self.id, "id": self.id,
'tenant_id': self.tenant_id, "tenant_id": self.tenant_id,
'name': self.name, "name": self.name,
'description': self.description, "description": self.description,
'settings': self.settings_dict, "settings": self.settings_dict,
'created_by': self.created_by, "created_by": self.created_by,
'created_at': self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
} }
@property @property
@ -740,24 +736,21 @@ class ExternalApiTemplates(db.Model):
class ExternalKnowledgeBindings(db.Model): class ExternalKnowledgeBindings(db.Model):
__tablename__ = 'external_knowledge_bindings' __tablename__ = "external_knowledge_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey'), db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
db.Index('external_knowledge_bindings_tenant_idx', 'tenant_id'), db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
db.Index('external_knowledge_bindings_dataset_idx', 'dataset_id'), db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
db.Index('external_knowledge_bindings_external_knowledge_idx', 'external_knowledge_id'), db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
db.Index('external_knowledge_bindings_external_api_template_idx', 'external_api_template_id'), db.Index("external_knowledge_bindings_external_api_template_idx", "external_api_template_id"),
) )
id = db.Column(StringUUID, nullable=False, id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
external_api_template_id = db.Column(StringUUID, nullable=False) external_api_template_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
external_knowledge_id = db.Column(db.Text, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
server_default=db.text('CURRENT_TIMESTAMP(0)'))

View File

@ -59,9 +59,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService: class DatasetService:
@staticmethod @staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None): def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by( query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
Dataset.created_at.desc()
)
if user: if user:
# get permitted dataset ids # get permitted dataset ids

View File

@ -5,8 +5,10 @@ from copy import deepcopy
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional, Union from typing import Any, Optional, Union
import boto3
import httpx import httpx
# from tasks.external_document_indexing_task import external_document_indexing_task
from configs import dify_config from configs import dify_config
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
@ -16,13 +18,9 @@ from models.dataset import (
ExternalApiTemplates, ExternalApiTemplates,
ExternalKnowledgeBindings, ExternalKnowledgeBindings,
) )
from core.rag.models.document import Document as RetrievalDocument
from models.model import UploadFile from models.model import UploadFile
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
from services.errors.dataset import DatasetNameDuplicateError from services.errors.dataset import DatasetNameDuplicateError
# from tasks.external_document_indexing_task import external_document_indexing_task
import requests
import boto3
class ExternalDatasetService: class ExternalDatasetService:
@ -281,9 +279,7 @@ class ExternalDatasetService:
raise ValueError("external api template not found") raise ValueError("external api template not found")
settings = json.loads(external_api_template.settings) settings = json.loads(external_api_template.settings)
headers = { headers = {"Content-Type": "application/json"}
"Content-Type": "application/json"
}
if settings.get("api_key"): if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}" headers["Authorization"] = f"Bearer {settings.get('api_key')}"
@ -302,26 +298,19 @@ class ExternalDatasetService:
return [] return []
@staticmethod @staticmethod
def test_external_knowledge_retrieval( def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str):
top_k: int, score_threshold: float, query: str, external_knowledge_id: str
):
client = boto3.client( client = boto3.client(
"bedrock-agent-runtime", "bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
region_name='us-east-1', region_name="us-east-1",
) )
response = client.retrieve( response = client.retrieve(
knowledgeBaseId=external_knowledge_id, knowledgeBaseId=external_knowledge_id,
retrievalConfiguration={ retrievalConfiguration={
'vectorSearchConfiguration': { "vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"}
'numberOfResults': top_k,
'overrideSearchType': 'HYBRID'
}
}, },
retrievalQuery={ retrievalQuery={"text": query},
'text': query
}
) )
results = [] results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:

View File

@ -28,7 +28,7 @@ class HitTestingService:
external_retrieval_model: dict, external_retrieval_model: dict,
limit: int = 10, limit: int = 10,
) -> dict: ) -> dict:
if (dataset.available_document_count == 0 or dataset.available_segment_count == 0): if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return { return {
"query": { "query": {
"content": query, "content": query,
@ -80,8 +80,7 @@ class HitTestingService:
) -> dict: ) -> dict:
if dataset.provider != "external": if dataset.provider != "external":
return { return {
"query": { "query": {"content": query},
"content": query},
"records": [], "records": [],
} }