diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index f4f08bc8b2..dc1f1ada11 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -436,7 +436,8 @@ class DatasetRetrieval: if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode', 'reranking_model'), + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index f978eccd7c..7cb7c033bb 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -181,7 +181,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode', 'reranking_model'), + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index c1443cd09f..de8ff7ad38 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -79,7 +79,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model.get('reranking_model', None), - reranking_mode=retrieval_model.get('reranking_mode', 'reranking_model'), + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) else: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 59ffa38cea..0e072a3e21 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -45,7 +45,8 @@ class HitTestingService: score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, reranking_model=retrieval_model.get('reranking_model', None), - reranking_mode=retrieval_model.get('reranking_mode', 'reranking_model'), + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), )