Fix: The topk parameter doesn't work in sagemaker rerank tool (#12150)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
parent
dae1b5a619
commit
0fdb39f1c3
@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|||||||
|
|
||||||
class SageMakerReRankTool(BuiltinTool):
|
class SageMakerReRankTool(BuiltinTool):
|
||||||
sagemaker_client: Any = None
|
sagemaker_client: Any = None
|
||||||
sagemaker_endpoint: str | None = None
|
sagemaker_endpoint: str = None
|
||||||
topk: int | None = None
|
|
||||||
|
|
||||||
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str):
|
||||||
inputs = [query_input] * len(docs)
|
inputs = [query_input] * len(docs)
|
||||||
@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool):
|
|||||||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint")
|
||||||
|
|
||||||
line = 2
|
line = 2
|
||||||
if not self.topk:
|
topk = tool_parameters.get("topk", 5)
|
||||||
self.topk = tool_parameters.get("topk", 5)
|
|
||||||
|
|
||||||
line = 3
|
line = 3
|
||||||
query = tool_parameters.get("query", "")
|
query = tool_parameters.get("query", "")
|
||||||
@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool):
|
|||||||
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)
|
||||||
|
|
||||||
line = 9
|
line = 9
|
||||||
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]
|
return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||||
|
Loading…
Reference in New Issue
Block a user