From 0fdb39f1c3d76143360d9a87d205e8746b68a988 Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:42:25 +0800 Subject: [PATCH] Fix: The topk parameter doesn't work in sagemaker rerank tool (#12150) Co-authored-by: Yuanbo Li --- .../provider/builtin/aws/tools/sagemaker_text_rerank.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 715b1ddedd..8320bd84ef 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str | None = None - topk: int | None = None + sagemaker_endpoint: str = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) @@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool): self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 - if not self.topk: - self.topk = tool_parameters.get("topk", 5) + topk = tool_parameters.get("topk", 5) line = 3 query = tool_parameters.get("query", "") @@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool): sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 - return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]] except Exception as e: return self.create_text_message(f"Exception {str(e)}, line : {line}")