Fix: The topk parameter doesn't work in sagemaker rerank tool (#12150)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
parent
dae1b5a619
commit
0fdb39f1c3
@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
class SageMakerReRankTool(BuiltinTool):
|
||||
sagemaker_client: Any = None
|
||||
sagemaker_endpoint: str | None = None
|
||||
topk: int | None = None
|
||||
sagemaker_endpoint: str = None
|
||||
|
||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||
inputs = [query_input] * len(docs)
|
||||
@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool):
|
||||
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 9
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
||||
return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
Loading…
Reference in New Issue
Block a user