update to external knowledge api
This commit is contained in:
parent
5fa86074ed
commit
611f0fb3f6
@ -14,16 +14,11 @@ class TestExternalApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"top_k",
|
||||
"retrieval_setting",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"score_threshold",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=float,
|
||||
type=dict,
|
||||
location="json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"query",
|
||||
@ -32,14 +27,14 @@ class TestExternalApi(Resource):
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
"knowledge_id",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
result = ExternalDatasetService.test_external_knowledge_retrieval(
|
||||
args["top_k"], args["score_threshold"], args["query"], args["external_knowledge_id"]
|
||||
args["retrieval_setting"], args["query"], args["knowledge_id"]
|
||||
)
|
||||
return result, 200
|
||||
|
||||
|
@ -283,22 +283,28 @@ class ExternalDatasetService:
|
||||
if settings.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {settings.get('api_key')}"
|
||||
|
||||
external_retrieval_parameters["query"] = query
|
||||
external_retrieval_parameters["external_knowledge_id"] = external_knowledge_binding.external_knowledge_id
|
||||
request_params = {
|
||||
"retrieval_setting": {
|
||||
"top_k": external_retrieval_parameters.get("top_k"),
|
||||
"score_threshold": external_retrieval_parameters.get("score_threshold"),
|
||||
},
|
||||
"query": query,
|
||||
"knowledge_id": external_knowledge_binding.external_knowledge_id,
|
||||
}
|
||||
|
||||
external_knowledge_api_setting = {
|
||||
"url": f"{settings.get('endpoint')}/dify/external-knowledge/retrieval-documents",
|
||||
"request_method": "post",
|
||||
"headers": headers,
|
||||
"params": external_retrieval_parameters,
|
||||
"params": request_params,
|
||||
}
|
||||
response = ExternalDatasetService.process_external_api(ExternalKnowledgeApiSetting(**external_knowledge_api_setting), None)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
return response.json().get("records", [])
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def test_external_knowledge_retrieval(top_k: int, score_threshold: float, query: str, external_knowledge_id: str):
|
||||
def test_external_knowledge_retrieval(retrieval_setting: dict, query: str, external_knowledge_id: str):
|
||||
client = boto3.client(
|
||||
"bedrock-agent-runtime",
|
||||
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
|
||||
@ -308,7 +314,7 @@ class ExternalDatasetService:
|
||||
response = client.retrieve(
|
||||
knowledgeBaseId=external_knowledge_id,
|
||||
retrievalConfiguration={
|
||||
"vectorSearchConfiguration": {"numberOfResults": top_k, "overrideSearchType": "HYBRID"}
|
||||
"vectorSearchConfiguration": {"numberOfResults": retrieval_setting.get("top_k"), "overrideSearchType": "HYBRID"}
|
||||
},
|
||||
retrievalQuery={"text": query},
|
||||
)
|
||||
@ -317,7 +323,7 @@ class ExternalDatasetService:
|
||||
if response.get("retrievalResults"):
|
||||
retrieval_results = response.get("retrievalResults")
|
||||
for retrieval_result in retrieval_results:
|
||||
if retrieval_result.get("score") < score_threshold:
|
||||
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", .0):
|
||||
continue
|
||||
result = {
|
||||
"metadata": retrieval_result.get("metadata"),
|
||||
@ -326,4 +332,6 @@ class ExternalDatasetService:
|
||||
"content": retrieval_result.get("content").get("text"),
|
||||
}
|
||||
results.append(result)
|
||||
return results
|
||||
return {
|
||||
"records": results
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user