merge error

This commit is contained in:
jyong 2024-09-13 09:49:24 +08:00
parent 9ca0e56a8a
commit 89e81873c4
5 changed files with 132 additions and 76 deletions

View File

@ -47,6 +47,7 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@ -57,6 +58,7 @@ class HitTestingApi(Resource):
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrival_model"],
limit=10,
)

View File

@ -10,6 +10,7 @@ from core.rag.rerank.constants.rerank_mode import RerankMode
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
@ -29,10 +30,21 @@ class RetrievalService:
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0,
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
weights: Optional[dict] = None):
weights: Optional[dict] = None, provider: Optional[str] = None,
external_retrieval_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset:
return []
if provider == 'external':
external_knowledge_binding = ExternalDatasetService.fetch_external_knowledge_retrival(
dataset.tenant_id,
dataset_id,
query,
external_retrieval_model
)
else:
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []

View File

@ -23,7 +23,6 @@ class ApiTemplateSetting(BaseModel):
method: str
url: str
request_method: str
authorization: Authorization
api_token: str
headers: Optional[dict] = None
params: Optional[dict] = None
callback_setting: Optional[ProcessStatusSetting] = None

View File

@ -117,6 +117,16 @@ class ExternalDatasetService:
return True
return False
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id,
tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError('external knowledge binding not found')
return external_knowledge_binding
@staticmethod
def document_create_args_validate(tenant_id: str, api_template_id: str, process_parameter: dict):
api_template = ExternalApiTemplates.query.filter_by(
@ -196,8 +206,6 @@ class ExternalDatasetService:
@staticmethod
def process_external_api(settings: ApiTemplateSetting,
headers: Union[None, dict[str, Any]],
parameter: Union[None, dict[str, Any]],
files: Union[None, dict[str, Any]]) -> httpx.Response:
"""
do http request depending on api bundle
@ -205,14 +213,12 @@ class ExternalDatasetService:
kwargs = {
'url': settings.url,
'headers': headers,
'headers': settings.headers,
'follow_redirects': True,
}
if settings.request_method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
response = getattr(ssrf_proxy, settings.request_method)(data=parameter, files=files, **kwargs)
else:
raise ValueError(f'Invalid http method {settings.request_method}')
response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs)
return response
@staticmethod
@ -246,7 +252,7 @@ class ExternalDatasetService:
return ApiTemplateSetting.parse_obj(settings)
@staticmethod
def create_external_dataset(tenant_id, user_id, args):
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists
if Dataset.query.filter_by(name=args.get('name'), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
@ -254,6 +260,7 @@ class ExternalDatasetService:
id=args.get('api_template_id'),
tenant_id=tenant_id
).first()
if api_template is None:
raise ValueError('api template not found')
@ -281,4 +288,37 @@ class ExternalDatasetService:
return dataset
@staticmethod
def fetch_external_knowledge_retrival(tenant_id: str,
dataset_id: str,
query: str,
external_retrival_parameters: dict):
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
dataset_id=dataset_id,
tenant_id=tenant_id
).first()
if not external_knowledge_binding:
raise ValueError('external knowledge binding not found')
external_api_template = ExternalApiTemplates.query.filter_by(
id=external_knowledge_binding.external_api_template_id
).first()
if not external_api_template:
raise ValueError('external api template not found')
settings = json.loads(external_api_template.settings)
headers = {}
if settings.get('api_token'):
headers['Authorization'] = f"Bearer {settings.get('api_token')}"
external_retrival_parameters['query'] = query
api_template_setting = {
'url': f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
'request_method': 'post',
'headers': settings.get('headers'),
'params': external_retrival_parameters
}
response = ExternalDatasetService.process_external_api(
ApiTemplateSetting(**api_template_setting), None
)

View File

@ -19,7 +19,8 @@ default_retrieval_model = {
class HitTestingService:
@classmethod
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
def retrieve(cls, dataset: Dataset, query: str, account: Account,
retrieval_model: dict, external_retrieval_model: dict, limit: int = 10) -> dict:
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return {
"query": {
@ -50,6 +51,8 @@ class HitTestingService:
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None),
provider=dataset.provider,
external_retrieval_model=external_retrieval_model,
)
end = time.perf_counter()