Chore: optimize the code of PromptTransform (#16143)

This commit is contained in:
Yongtao Huang 2025-03-19 11:24:57 +08:00 committed by GitHub
parent e0cf55f5e9
commit d339403e89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 11 deletions

View File

@ -93,7 +93,7 @@ class SimplePromptTransform(PromptTransform):
return prompt_messages, stops
def get_prompt_str_and_rules(
def _get_prompt_str_and_rules(
self,
app_mode: AppMode,
model_config: ModelConfigWithCredentialsEntity,
@ -184,7 +184,7 @@ class SimplePromptTransform(PromptTransform):
prompt_messages: list[PromptMessage] = []
# get prompt
prompt, _ = self.get_prompt_str_and_rules(
prompt, _ = self._get_prompt_str_and_rules(
app_mode=app_mode,
model_config=model_config,
pre_prompt=pre_prompt,
@ -209,9 +209,9 @@ class SimplePromptTransform(PromptTransform):
)
if query:
prompt_messages.append(self.get_last_user_message(query, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
else:
prompt_messages.append(self.get_last_user_message(prompt, files, image_detail_config))
prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
return prompt_messages, None
@ -228,7 +228,7 @@ class SimplePromptTransform(PromptTransform):
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
prompt, prompt_rules = self._get_prompt_str_and_rules(
app_mode=app_mode,
model_config=model_config,
pre_prompt=pre_prompt,
@ -254,7 +254,7 @@ class SimplePromptTransform(PromptTransform):
)
# get prompt
prompt, prompt_rules = self.get_prompt_str_and_rules(
prompt, prompt_rules = self._get_prompt_str_and_rules(
app_mode=app_mode,
model_config=model_config,
pre_prompt=pre_prompt,
@ -268,9 +268,9 @@ class SimplePromptTransform(PromptTransform):
if stops is not None and len(stops) == 0:
stops = None
return [self.get_last_user_message(prompt, files, image_detail_config)], stops
return [self._get_last_user_message(prompt, files, image_detail_config)], stops
def get_last_user_message(
def _get_last_user_message(
self,
prompt: str,
files: Sequence["File"],

View File

@ -64,12 +64,10 @@ def test_get_prompt():
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()
assert len(result) <= max_token_limit
assert len(result) == 4
max_token_limit = 20
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
result = transform.get_prompt()
assert len(result) <= max_token_limit
assert len(result) == 12

View File

@ -84,7 +84,6 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
query_in_prompt=True,
with_memory_prompt=False,
)
print(prompt_template["prompt_template"].template)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]