From cb7be3767c4ce4bcbc5367fa5a8927a4f3b93a63 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 12 Jan 2024 21:15:07 +0800 Subject: [PATCH] feat: huggingface llm add new params. (#2014) --- .../huggingface_hub/llm/llm.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index f3d5a853d7..e0701dff59 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -134,7 +134,55 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel precision=0, ) - return [temperature_rule, top_k_rule, top_p_rule] + max_new_tokens = ParameterRule( + name='max_new_tokens', + label={ + 'en_US': 'Max New Tokens', + 'zh_Hans': '最大新标记', + }, + type='int', + help={ + 'en_US': 'Maximum number of generated tokens.', + 'zh_Hans': '生成的标记的最大数量。', + }, + required=False, + default=20, + min=1, + max=4096, + precision=0, + ) + + seed = ParameterRule( + name='seed', + label={ + 'en_US': 'Random sampling seed', + 'zh_Hans': '随机采样种子', + }, + type='int', + help={ + 'en_US': 'Random sampling seed.', + 'zh_Hans': '随机采样种子。', + }, + required=False, + precision=0, + ) + + repetition_penalty = ParameterRule( + name='repetition_penalty', + label={ + 'en_US': 'Repetition Penalty', + 'zh_Hans': '重复惩罚', + }, + type='float', + help={ + 'en_US': 'The parameter for repetition penalty. 1.0 means no penalty.', + 'zh_Hans': '重复惩罚的参数。1.0 表示没有惩罚。', + }, + required=False, + precision=1, + ) + + return [temperature_rule, top_k_rule, top_p_rule, max_new_tokens, seed, repetition_penalty] def _handle_generate_stream_response(self, model: str,