merge error
This commit is contained in:
parent
9ca0e56a8a
commit
89e81873c4
@ -47,6 +47,7 @@ class HitTestingApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@ -57,6 +58,7 @@ class HitTestingApi(Resource):
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrival_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,7 @@ from core.rag.rerank.constants.rerank_mode import RerankMode
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
@ -29,10 +30,21 @@ class RetrievalService:
|
||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0,
|
||||
reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model',
|
||||
weights: Optional[dict] = None):
|
||||
weights: Optional[dict] = None, provider: Optional[str] = None,
|
||||
external_retrieval_model: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
return []
|
||||
if provider == 'external':
|
||||
external_knowledge_binding = ExternalDatasetService.fetch_external_knowledge_retrival(
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model
|
||||
)
|
||||
else:
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
|
@ -23,7 +23,6 @@ class ApiTemplateSetting(BaseModel):
|
||||
method: str
|
||||
url: str
|
||||
request_method: str
|
||||
authorization: Authorization
|
||||
api_token: str
|
||||
headers: Optional[dict] = None
|
||||
params: Optional[dict] = None
|
||||
callback_setting: Optional[ProcessStatusSetting] = None
|
||||
|
@ -117,6 +117,16 @@ class ExternalDatasetService:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id
|
||||
).first()
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError('external knowledge binding not found')
|
||||
return external_knowledge_binding
|
||||
|
||||
@staticmethod
|
||||
def document_create_args_validate(tenant_id: str, api_template_id: str, process_parameter: dict):
|
||||
api_template = ExternalApiTemplates.query.filter_by(
|
||||
@ -196,8 +206,6 @@ class ExternalDatasetService:
|
||||
|
||||
@staticmethod
|
||||
def process_external_api(settings: ApiTemplateSetting,
|
||||
headers: Union[None, dict[str, Any]],
|
||||
parameter: Union[None, dict[str, Any]],
|
||||
files: Union[None, dict[str, Any]]) -> httpx.Response:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
@ -205,14 +213,12 @@ class ExternalDatasetService:
|
||||
|
||||
kwargs = {
|
||||
'url': settings.url,
|
||||
'headers': headers,
|
||||
'headers': settings.headers,
|
||||
'follow_redirects': True,
|
||||
}
|
||||
|
||||
if settings.request_method in ('get', 'head', 'post', 'put', 'delete', 'patch'):
|
||||
response = getattr(ssrf_proxy, settings.request_method)(data=parameter, files=files, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid http method {settings.request_method}')
|
||||
response = getattr(ssrf_proxy, settings.request_method)(data=settings.params, files=files, **kwargs)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
@ -246,7 +252,7 @@ class ExternalDatasetService:
|
||||
return ApiTemplateSetting.parse_obj(settings)
|
||||
|
||||
@staticmethod
|
||||
def create_external_dataset(tenant_id, user_id, args):
|
||||
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
|
||||
# check if dataset name already exists
|
||||
if Dataset.query.filter_by(name=args.get('name'), tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
|
||||
@ -254,6 +260,7 @@ class ExternalDatasetService:
|
||||
id=args.get('api_template_id'),
|
||||
tenant_id=tenant_id
|
||||
).first()
|
||||
|
||||
if api_template is None:
|
||||
raise ValueError('api template not found')
|
||||
|
||||
@ -281,4 +288,37 @@ class ExternalDatasetService:
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def fetch_external_knowledge_retrival(tenant_id: str,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrival_parameters: dict):
|
||||
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=tenant_id
|
||||
).first()
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError('external knowledge binding not found')
|
||||
|
||||
external_api_template = ExternalApiTemplates.query.filter_by(
|
||||
id=external_knowledge_binding.external_api_template_id
|
||||
).first()
|
||||
if not external_api_template:
|
||||
raise ValueError('external api template not found')
|
||||
|
||||
settings = json.loads(external_api_template.settings)
|
||||
headers = {}
|
||||
if settings.get('api_token'):
|
||||
headers['Authorization'] = f"Bearer {settings.get('api_token')}"
|
||||
|
||||
external_retrival_parameters['query'] = query
|
||||
|
||||
api_template_setting = {
|
||||
'url': f"{settings.get('endpoint')}/dify/external-knowledge/retrival-documents",
|
||||
'request_method': 'post',
|
||||
'headers': settings.get('headers'),
|
||||
'params': external_retrival_parameters
|
||||
}
|
||||
response = ExternalDatasetService.process_external_api(
|
||||
ApiTemplateSetting(**api_template_setting), None
|
||||
)
|
||||
|
@ -19,7 +19,8 @@ default_retrieval_model = {
|
||||
|
||||
class HitTestingService:
|
||||
@classmethod
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:
|
||||
def retrieve(cls, dataset: Dataset, query: str, account: Account,
|
||||
retrieval_model: dict, external_retrieval_model: dict, limit: int = 10) -> dict:
|
||||
if dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return {
|
||||
"query": {
|
||||
@ -50,6 +51,8 @@ class HitTestingService:
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
provider=dataset.provider,
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
Loading…
Reference in New Issue
Block a user