Fix: The topk parameter doesn't work in sagemaker rerank tool (#12150)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001 2024-12-27 14:42:25 +08:00 committed by GitHub
parent dae1b5a619
commit 0fdb39f1c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
class SageMakerReRankTool(BuiltinTool):
sagemaker_client: Any = None
sagemaker_endpoint: str | None = None
topk: int | None = None
sagemaker_endpoint: str = None
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
inputs = [query_input] * len(docs)
@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool):
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
line = 2
if not self.topk:
self.topk = tool_parameters.get("topk", 5)
topk = tool_parameters.get("topk", 5)
line = 3
query = tool_parameters.get("query", "")
@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool):
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
line = 9
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]]
except Exception as e:
return self.create_text_message(f"Exception {str(e)}, line : {line}")