External knowledge api

This commit is contained in:
jyong 2024-09-24 18:00:45 +08:00
parent ed92c90a40
commit 089da063d4
12 changed files with 179 additions and 183 deletions

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic import Field
from pydantic_settings import BaseSettings
@ -8,6 +8,7 @@ class BedrockConfig(BaseSettings):
"""
bedrock configs
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="AWS secret access key",
default=None,

View File

@ -37,7 +37,17 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website, test_external
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
external,
file,
hit_testing,
test_external,
website,
)
# Import explore controllers
from .explore import (

View File

@ -1,7 +1,7 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound, InternalServerError
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
@ -234,7 +234,6 @@ class ExternalDatasetCreateApi(Resource):
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator

View File

@ -1,18 +1,12 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from flask_restful import Resource, reqparse
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService
class TestExternalApi(Resource):
@setup_required
@login_required
@ -50,5 +44,4 @@ class TestExternalApi(Resource):
return result, 200
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrieval-documents")

View File

@ -23,19 +23,18 @@ default_retrieval_model = {
class RetrievalService:
@classmethod
def retrieve(cls,
retrieval_method: str,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = 'reranking_model',
weights: Optional[dict] = None
):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
def retrieve(
cls,
retrieval_method: str,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = "reranking_model",
weights: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
@ -45,46 +44,55 @@ class RetrievalService:
threads = []
exceptions = []
# retrieval_model source with keyword
if retrieval_method == 'keyword_search':
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
'exceptions': exceptions,
})
if retrieval_method == "keyword_search":
keyword_thread = threading.Thread(
target=RetrievalService.keyword_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(keyword_thread)
keyword_thread.start()
# retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrieval_method):
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'top_k': top_k,
'score_threshold': score_threshold,
'reranking_model': reranking_model,
'all_documents': all_documents,
'retrieval_method': retrieval_method,
'exceptions': exceptions,
})
embedding_thread = threading.Thread(
target=RetrievalService.embedding_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"score_threshold": score_threshold,
"reranking_model": reranking_model,
"all_documents": all_documents,
"retrieval_method": retrieval_method,
"exceptions": exceptions,
},
)
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrieval_method):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset_id,
'query': query,
'retrieval_method': retrieval_method,
'score_threshold': score_threshold,
'top_k': top_k,
'reranking_model': reranking_model,
'all_documents': all_documents,
'exceptions': exceptions,
})
full_text_index_thread = threading.Thread(
target=RetrievalService.full_text_index_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"retrieval_method": retrieval_method,
"score_threshold": score_threshold,
"top_k": top_k,
"reranking_model": reranking_model,
"all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(full_text_index_thread)
full_text_index_thread.start()
@ -92,41 +100,31 @@ class RetrievalService:
thread.join()
if exceptions:
exception_message = ';\n'.join(exceptions)
exception_message = ";\n".join(exceptions)
raise Exception(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
reranking_model, weights, False)
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
)
return all_documents
@classmethod
def external_retrieve(cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model
dataset.tenant_id, dataset_id, query, external_retrieval_model
)
return all_documents
@classmethod
def keyword_search(
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list
):
with flask_app.app_context():
try:
@ -141,16 +139,16 @@ class RetrievalService:
@classmethod
def embedding_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
@ -168,10 +166,10 @@ class RetrievalService:
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
@ -188,16 +186,16 @@ class RetrievalService:
@classmethod
def full_text_index_search(
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
cls,
flask_app: Flask,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float],
reranking_model: Optional[dict],
all_documents: list,
retrieval_method: str,
exceptions: list,
):
with flask_app.app_context():
try:
@ -210,10 +208,10 @@ class RetrievalService:
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
):
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False

View File

@ -17,7 +17,7 @@ class Document(BaseModel):
"""
metadata: Optional[dict] = Field(default_factory=dict)
provider: Optional[str] = 'dify'
provider: Optional[str] = "dify"
class BaseDocumentTransformer(ABC):

View File

@ -112,7 +112,12 @@ class DatasetRetrieval:
continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0 and dataset.provider != "external":
if (
dataset
and dataset.available_document_count == 0
and dataset.available_document_count == 0
and dataset.provider != "external"
):
continue
available_datasets.append(dataset)
@ -172,7 +177,6 @@ class DatasetRetrieval:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
@ -188,9 +192,19 @@ class DatasetRetrieval:
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(DocumentContext(content=f"question:{segment.get_sign_content()} answer:{segment.answer}", score=document_score_list.get(segment.index_node_id, None)))
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None),
)
)
else:
document_context_list.append(DocumentContext(content=segment.get_sign_content(), score=document_score_list.get(segment.index_node_id, None)))
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None),
)
)
if show_retrieve_source:
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
@ -279,7 +293,7 @@ class DatasetRetrieval:
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model
external_retrieval_parameters=dataset.retrieval_model,
)
for external_document in external_documents:
document = Document(
@ -304,7 +318,9 @@ class DatasetRetrieval:
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
reranking_model = (
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
retrieval_model_config["reranking_model"]
if retrieval_model_config["reranking_enable"]
else None
)
# get score threshold
score_threshold = 0.0
@ -452,7 +468,7 @@ class DatasetRetrieval:
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
query=query,
external_retrieval_parameters=dataset.retrieval_model
external_retrieval_parameters=dataset.retrieval_model,
)
for external_document in external_documents:
document = Document(

View File

@ -168,7 +168,7 @@ class KnowledgeRetrievalNode(BaseNode):
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": 'workflow',
"retriever_from": "workflow",
"score": item.metadata.get("score"),
},
"title": item.metadata.get("title"),

View File

@ -37,8 +37,8 @@ class Dataset(db.Model):
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
)
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
PROVIDER_LIST = ['vendor', 'external', None]
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
@ -74,10 +74,9 @@ class Dataset(db.Model):
@property
def external_retrieval_model(self):
default_retrieval_model = {
"top_k": 2,
"score_threshold": .0,
"score_threshold": 0.0,
}
return self.retrieval_model or default_retrieval_model
@ -700,35 +699,32 @@ class DatasetPermission(db.Model):
class ExternalApiTemplates(db.Model):
__tablename__ = 'external_api_templates'
__tablename__ = "external_api_templates"
__table_args__ = (
db.PrimaryKeyConstraint('id', name='external_api_template_pkey'),
db.Index('external_api_templates_tenant_idx', 'tenant_id'),
db.Index('external_api_templates_name_idx', 'name'),
db.PrimaryKeyConstraint("id", name="external_api_template_pkey"),
db.Index("external_api_templates_tenant_idx", "tenant_id"),
db.Index("external_api_templates_name_idx", "name"),
)
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.String(255), nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
settings = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
def to_dict(self):
return {
'id': self.id,
'tenant_id': self.tenant_id,
'name': self.name,
'description': self.description,
'settings': self.settings_dict,
'created_by': self.created_by,
'created_at': self.created_at.isoformat(),
"id": self.id,
"tenant_id": self.tenant_id,
"name": self.name,
"description": self.description,
"settings": self.settings_dict,
"created_by": self.created_by,
"created_at": self.created_at.isoformat(),
}
@property
@ -740,24 +736,21 @@ class ExternalApiTemplates(db.Model):
class ExternalKnowledgeBindings(db.Model):
__tablename__ = 'external_knowledge_bindings'
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint('id', name='external_knowledge_bindings_pkey'),
db.Index('external_knowledge_bindings_tenant_idx', 'tenant_id'),
db.Index('external_knowledge_bindings_dataset_idx', 'dataset_id'),
db.Index('external_knowledge_bindings_external_knowledge_idx', 'external_knowledge_id'),
db.Index('external_knowledge_bindings_external_api_template_idx', 'external_api_template_id'),
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
db.Index("external_knowledge_bindings_external_api_template_idx", "external_api_template_id"),
)
id = db.Column(StringUUID, nullable=False,
server_default=db.text('uuid_generate_v4()'))
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
external_api_template_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
external_knowledge_id = db.Column(db.Text, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False,
server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

View File

@ -59,9 +59,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(
Dataset.created_at.desc()
)
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids

View File

@ -5,8 +5,10 @@ from copy import deepcopy
from datetime import datetime, timezone
from typing import Any, Optional, Union
import boto3
import httpx
# from tasks.external_document_indexing_task import external_document_indexing_task
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
@ -16,13 +18,9 @@ from models.dataset import (
ExternalApiTemplates,
ExternalKnowledgeBindings,
)
from core.rag.models.document import Document as RetrievalDocument
from models.model import UploadFile
from services.entities.external_knowledge_entities.external_knowledge_entities import ApiTemplateSetting, Authorization
from services.errors.dataset import DatasetNameDuplicateError
# from tasks.external_document_indexing_task import external_document_indexing_task
import requests
import boto3
class ExternalDatasetService:
@ -266,7 +264,7 @@ class ExternalDatasetService:
@staticmethod
def fetch_external_knowledge_retrieval(
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
tenant_id: str, dataset_id: str, query: str, external_retrieval_parameters: dict
) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id, tenant_id=tenant_id
@ -281,9 +279,7 @@ class ExternalDatasetService:
raise ValueError("external api template not found")
settings = json.loads(external_api_template.settings)
headers = {
"Content-Type": "application/json"
}
headers = {"Content-Type": "application/json"}
if settings.get("api_key"):
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
@ -302,26 +298,19 @@ class ExternalDatasetService:
return []
@staticmethod
def test_external_knowledge_retrieval(
top_k: int, score_threshold: float, query: str, external_knowledge_id: str
):
def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str):
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
region_name='us-east-1',
region_name="us-east-1",
)
response = client.retrieve(
knowledgeBaseId=external_knowledge_id,
retrievalConfiguration={
'vectorSearchConfiguration': {
'numberOfResults': top_k,
'overrideSearchType': 'HYBRID'
}
"vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"}
},
retrievalQuery={
'text': query
}
retrievalQuery={"text": query},
)
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:

View File

@ -20,15 +20,15 @@ default_retrieval_model = {
class HitTestingService:
@classmethod
def retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict,
external_retrieval_model: dict,
limit: int = 10,
cls,
dataset: Dataset,
query: str,
account: Account,
retrieval_model: dict,
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
if (dataset.available_document_count == 0 or dataset.available_segment_count == 0):
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
"content": query,
@ -72,16 +72,15 @@ class HitTestingService:
@classmethod
def external_retrieve(
cls,
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict,
cls,
dataset: Dataset,
query: str,
account: Account,
external_retrieval_model: dict,
) -> dict:
if dataset.provider != "external":
return {
"query": {
"content": query},
"query": {"content": query},
"records": [],
}