diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 715b1ddedd..8320bd84ef 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -10,8 +10,7 @@ from core.tools.tool.builtin_tool import BuiltinTool class SageMakerReRankTool(BuiltinTool): sagemaker_client: Any = None - sagemaker_endpoint: str | None = None - topk: int | None = None + sagemaker_endpoint: str = None def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint: str): inputs = [query_input] * len(docs) @@ -47,8 +46,7 @@ class SageMakerReRankTool(BuiltinTool): self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") line = 2 - if not self.topk: - self.topk = tool_parameters.get("topk", 5) + topk = tool_parameters.get("topk", 5) line = 3 query = tool_parameters.get("query", "") @@ -75,7 +73,7 @@ class SageMakerReRankTool(BuiltinTool): sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) line = 9 - return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] + return [self.create_json_message(res) for res in sorted_candidate_docs[:topk]] except Exception as e: return self.create_text_message(f"Exception {str(e)}, line : {line}")