refactor: remove max token limit from history prompt methods and related calculations
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
c53a0db4e0
commit
4f27d5d987
@ -53,20 +53,6 @@ class AgentChatAppRunner(AppRunner):
|
|||||||
query = application_generate_entity.query
|
query = application_generate_entity.query
|
||||||
files = application_generate_entity.files
|
files = application_generate_entity.files
|
||||||
|
|
||||||
# Pre-calculate the number of tokens of the prompt messages,
|
|
||||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
|
||||||
# If the rest number of tokens is not enough, raise exception.
|
|
||||||
# Include: prompt template, inputs, query(optional), files(optional)
|
|
||||||
# Not Include: memory, external data, dataset context
|
|
||||||
self.get_pre_calculate_rest_tokens(
|
|
||||||
app_record=app_record,
|
|
||||||
model_config=application_generate_entity.model_conf,
|
|
||||||
prompt_template_entity=app_config.prompt_template,
|
|
||||||
inputs=dict(inputs),
|
|
||||||
files=list(files),
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
|
|
||||||
memory = None
|
memory = None
|
||||||
if application_generate_entity.conversation_id:
|
if application_generate_entity.conversation_id:
|
||||||
# get memory of conversation (read-only)
|
# get memory of conversation (read-only)
|
||||||
|
@ -35,69 +35,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class AppRunner:
|
class AppRunner:
|
||||||
def get_pre_calculate_rest_tokens(
|
|
||||||
self,
|
|
||||||
app_record: App,
|
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
|
||||||
inputs: Mapping[str, str],
|
|
||||||
files: Sequence["File"],
|
|
||||||
query: Optional[str] = None,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Get pre calculate rest tokens
|
|
||||||
:param app_record: app record
|
|
||||||
:param model_config: model config entity
|
|
||||||
:param prompt_template_entity: prompt template entity
|
|
||||||
:param inputs: inputs
|
|
||||||
:param files: files
|
|
||||||
:param query: query
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# Invoke model
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
|
||||||
)
|
|
||||||
|
|
||||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
||||||
if parameter_rule.name == "max_tokens" or (
|
|
||||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
|
||||||
):
|
|
||||||
max_tokens = (
|
|
||||||
model_config.parameters.get(parameter_rule.name)
|
|
||||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
|
||||||
) or 0
|
|
||||||
|
|
||||||
if model_context_tokens is None:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
if max_tokens is None:
|
|
||||||
max_tokens = 0
|
|
||||||
|
|
||||||
# get prompt messages without memory and context
|
|
||||||
prompt_messages, stop = self.organize_prompt_messages(
|
|
||||||
app_record=app_record,
|
|
||||||
model_config=model_config,
|
|
||||||
prompt_template_entity=prompt_template_entity,
|
|
||||||
inputs=inputs,
|
|
||||||
files=files,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
|
|
||||||
if rest_tokens < 0:
|
|
||||||
raise InvokeBadRequestError(
|
|
||||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
|
||||||
"or shrink the max token, or switch to a llm with a larger token limit size."
|
|
||||||
)
|
|
||||||
|
|
||||||
return rest_tokens
|
|
||||||
|
|
||||||
def recalc_llm_max_tokens(
|
def recalc_llm_max_tokens(
|
||||||
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
||||||
):
|
):
|
||||||
|
@ -54,20 +54,6 @@ class CompletionAppRunner(AppRunner):
|
|||||||
)
|
)
|
||||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
|
||||||
# Pre-calculate the number of tokens of the prompt messages,
|
|
||||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
|
||||||
# If the rest number of tokens is not enough, raise exception.
|
|
||||||
# Include: prompt template, inputs, query(optional), files(optional)
|
|
||||||
# Not Include: memory, external data, dataset context
|
|
||||||
self.get_pre_calculate_rest_tokens(
|
|
||||||
app_record=app_record,
|
|
||||||
model_config=application_generate_entity.model_conf,
|
|
||||||
prompt_template_entity=app_config.prompt_template,
|
|
||||||
inputs=inputs,
|
|
||||||
files=files,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
|
|
||||||
# organize all inputs and template to prompt messages
|
# organize all inputs and template to prompt messages
|
||||||
# Include: prompt template, inputs, query(optional), files(optional)
|
# Include: prompt template, inputs, query(optional), files(optional)
|
||||||
prompt_messages, stop = self.organize_prompt_messages(
|
prompt_messages, stop = self.organize_prompt_messages(
|
||||||
|
@ -25,9 +25,7 @@ class TokenBufferMemory:
|
|||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
self.model_instance = model_instance
|
self.model_instance = model_instance
|
||||||
|
|
||||||
def get_history_prompt_messages(
|
def get_history_prompt_messages(self, message_limit: Optional[int] = None) -> Sequence[PromptMessage]:
|
||||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
|
||||||
) -> Sequence[PromptMessage]:
|
|
||||||
"""
|
"""
|
||||||
Get history prompt messages.
|
Get history prompt messages.
|
||||||
:param max_token_limit: max token limit
|
:param max_token_limit: max token limit
|
||||||
@ -118,33 +116,23 @@ class TokenBufferMemory:
|
|||||||
if not prompt_messages:
|
if not prompt_messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# prune the chat message if it exceeds the max token limit
|
|
||||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
if curr_message_tokens > max_token_limit:
|
|
||||||
pruned_memory = []
|
|
||||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
|
||||||
pruned_memory.append(prompt_messages.pop(0))
|
|
||||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
def get_history_prompt_text(
|
def get_history_prompt_text(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
human_prefix: str = "Human",
|
human_prefix: str = "Human",
|
||||||
ai_prefix: str = "Assistant",
|
ai_prefix: str = "Assistant",
|
||||||
max_token_limit: int = 2000,
|
|
||||||
message_limit: Optional[int] = None,
|
message_limit: Optional[int] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get history prompt text.
|
Get history prompt text.
|
||||||
:param human_prefix: human prefix
|
:param human_prefix: human prefix
|
||||||
:param ai_prefix: ai prefix
|
:param ai_prefix: ai prefix
|
||||||
:param max_token_limit: max token limit
|
|
||||||
:param message_limit: message limit
|
:param message_limit: message limit
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
prompt_messages = self.get_history_prompt_messages(message_limit=message_limit)
|
||||||
|
|
||||||
string_messages = []
|
string_messages = []
|
||||||
for m in prompt_messages:
|
for m in prompt_messages:
|
||||||
|
@ -202,9 +202,6 @@ class PluginModelManager(BasePluginManager):
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
|
||||||
Get number of tokens for llm
|
|
||||||
"""
|
|
||||||
response = self._request_with_plugin_daemon_response_stream(
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
method="POST",
|
method="POST",
|
||||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||||
|
@ -284,14 +284,10 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
inputs = {"#histories#": "", **prompt_inputs}
|
inputs = {"#histories#": "", **prompt_inputs}
|
||||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
|
||||||
|
|
||||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
|
||||||
|
|
||||||
histories = self._get_history_messages_from_memory(
|
histories = self._get_history_messages_from_memory(
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
max_token_limit=rest_tokens,
|
|
||||||
human_prefix=role_prefix.user,
|
human_prefix=role_prefix.user,
|
||||||
ai_prefix=role_prefix.assistant,
|
ai_prefix=role_prefix.assistant,
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
@ -16,8 +16,7 @@ class PromptTransform:
|
|||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
histories = self._get_history_messages_list_from_memory(memory=memory, memory_config=memory_config)
|
||||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
|
||||||
prompt_messages.extend(histories)
|
prompt_messages.extend(histories)
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
@ -54,31 +53,29 @@ class PromptTransform:
|
|||||||
self,
|
self,
|
||||||
memory: TokenBufferMemory,
|
memory: TokenBufferMemory,
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
max_token_limit: int,
|
|
||||||
human_prefix: Optional[str] = None,
|
human_prefix: Optional[str] = None,
|
||||||
ai_prefix: Optional[str] = None,
|
ai_prefix: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get memory messages."""
|
human_prefix = human_prefix or "Human"
|
||||||
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
|
ai_prefix = ai_prefix or "Assistant"
|
||||||
|
|
||||||
if human_prefix:
|
|
||||||
kwargs["human_prefix"] = human_prefix
|
|
||||||
|
|
||||||
if ai_prefix:
|
|
||||||
kwargs["ai_prefix"] = ai_prefix
|
|
||||||
|
|
||||||
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
||||||
kwargs["message_limit"] = memory_config.window.size
|
message_limit = memory_config.window.size
|
||||||
|
else:
|
||||||
return memory.get_history_prompt_text(**kwargs)
|
message_limit = None
|
||||||
|
return memory.get_history_prompt_text(
|
||||||
|
human_prefix=human_prefix,
|
||||||
|
ai_prefix=ai_prefix,
|
||||||
|
message_limit=message_limit,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_history_messages_list_from_memory(
|
def _get_history_messages_list_from_memory(
|
||||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
self,
|
||||||
|
memory: TokenBufferMemory,
|
||||||
|
memory_config: MemoryConfig,
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""Get memory messages."""
|
"""Get memory messages."""
|
||||||
return list(
|
return list(
|
||||||
memory.get_history_prompt_messages(
|
memory.get_history_prompt_messages(
|
||||||
max_token_limit=max_token_limit,
|
|
||||||
message_limit=memory_config.window.size
|
message_limit=memory_config.window.size
|
||||||
if (
|
if (
|
||||||
memory_config.window.enabled
|
memory_config.window.enabled
|
||||||
|
@ -238,9 +238,6 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if memory:
|
if memory:
|
||||||
tmp_human_message = UserPromptMessage(content=prompt)
|
|
||||||
|
|
||||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
|
||||||
histories = self._get_history_messages_from_memory(
|
histories = self._get_history_messages_from_memory(
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_config=MemoryConfig(
|
memory_config=MemoryConfig(
|
||||||
@ -248,7 +245,6 @@ class SimplePromptTransform(PromptTransform):
|
|||||||
enabled=False,
|
enabled=False,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
max_token_limit=rest_tokens,
|
|
||||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||||
)
|
)
|
||||||
|
@ -27,7 +27,7 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
@ -91,7 +91,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
_node_type = NodeType.LLM
|
_node_type = NodeType.LLM
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None] | NodeRunResult:
|
||||||
node_inputs: Optional[dict[str, Any]] = None
|
node_inputs: Optional[dict[str, Any]] = None
|
||||||
process_data = None
|
process_data = None
|
||||||
result_text = ""
|
result_text = ""
|
||||||
@ -624,7 +624,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
memory_text = _handle_memory_completion_mode(
|
memory_text = _handle_memory_completion_mode(
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
model_config=model_config,
|
|
||||||
)
|
)
|
||||||
# Insert histories into the prompt
|
# Insert histories into the prompt
|
||||||
prompt_content = prompt_messages[0].content
|
prompt_content = prompt_messages[0].content
|
||||||
@ -960,36 +959,6 @@ def _render_jinja2_message(
|
|||||||
return result_text
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
def _calculate_rest_token(
|
|
||||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
|
||||||
) -> int:
|
|
||||||
rest_tokens = 2000
|
|
||||||
|
|
||||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
||||||
if model_context_tokens:
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
||||||
if parameter_rule.name == "max_tokens" or (
|
|
||||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
|
||||||
):
|
|
||||||
max_tokens = (
|
|
||||||
model_config.parameters.get(parameter_rule.name)
|
|
||||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
|
||||||
rest_tokens = max(rest_tokens, 0)
|
|
||||||
|
|
||||||
return rest_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_memory_chat_mode(
|
def _handle_memory_chat_mode(
|
||||||
*,
|
*,
|
||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
@ -999,9 +968,7 @@ def _handle_memory_chat_mode(
|
|||||||
memory_messages: Sequence[PromptMessage] = []
|
memory_messages: Sequence[PromptMessage] = []
|
||||||
# Get messages from memory for chat model
|
# Get messages from memory for chat model
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
|
||||||
memory_messages = memory.get_history_prompt_messages(
|
memory_messages = memory.get_history_prompt_messages(
|
||||||
max_token_limit=rest_tokens,
|
|
||||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
)
|
)
|
||||||
return memory_messages
|
return memory_messages
|
||||||
@ -1011,16 +978,13 @@ def _handle_memory_completion_mode(
|
|||||||
*,
|
*,
|
||||||
memory: TokenBufferMemory | None,
|
memory: TokenBufferMemory | None,
|
||||||
memory_config: MemoryConfig | None,
|
memory_config: MemoryConfig | None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
memory_text = ""
|
memory_text = ""
|
||||||
# Get history text from memory for completion model
|
# Get history text from memory for completion model
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
|
||||||
if not memory_config.role_prefix:
|
if not memory_config.role_prefix:
|
||||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||||
memory_text = memory.get_history_prompt_text(
|
memory_text = memory.get_history_prompt_text(
|
||||||
max_token_limit=rest_tokens,
|
|
||||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
human_prefix=memory_config.role_prefix.user,
|
human_prefix=memory_config.role_prefix.user,
|
||||||
ai_prefix=memory_config.role_prefix.assistant,
|
ai_prefix=memory_config.role_prefix.assistant,
|
||||||
|
@ -280,9 +280,11 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
|
||||||
prompt_template = self._get_function_calling_prompt_template(
|
prompt_template = self._get_function_calling_prompt_template(
|
||||||
node_data, query, variable_pool, memory, rest_token
|
node_data=node_data,
|
||||||
|
query=query,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
memory=memory,
|
||||||
)
|
)
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
@ -396,11 +398,8 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
Generate completion prompt.
|
Generate completion prompt.
|
||||||
"""
|
"""
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||||
rest_token = self._calculate_rest_token(
|
|
||||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
|
||||||
)
|
|
||||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory
|
||||||
)
|
)
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
@ -430,9 +429,6 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
Generate chat prompt.
|
Generate chat prompt.
|
||||||
"""
|
"""
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||||
rest_token = self._calculate_rest_token(
|
|
||||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
|
||||||
)
|
|
||||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||||
@ -440,7 +436,6 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
),
|
),
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
max_token_limit=rest_token,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
@ -652,11 +647,11 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
|
|
||||||
def _get_function_calling_prompt_template(
|
def _get_function_calling_prompt_template(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
node_data: ParameterExtractorNodeData,
|
node_data: ParameterExtractorNodeData,
|
||||||
query: str,
|
query: str,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
max_token_limit: int = 2000,
|
|
||||||
) -> list[ChatModelMessage]:
|
) -> list[ChatModelMessage]:
|
||||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||||
input_text = query
|
input_text = query
|
||||||
@ -664,9 +659,7 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||||
|
|
||||||
if memory and node_data.memory and node_data.memory.window:
|
if memory and node_data.memory and node_data.memory.window:
|
||||||
memory_str = memory.get_history_prompt_text(
|
memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size)
|
||||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
|
||||||
)
|
|
||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
system_prompt_messages = ChatModelMessage(
|
system_prompt_messages = ChatModelMessage(
|
||||||
role=PromptMessageRole.SYSTEM,
|
role=PromptMessageRole.SYSTEM,
|
||||||
@ -683,7 +676,6 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
query: str,
|
query: str,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
max_token_limit: int = 2000,
|
|
||||||
):
|
):
|
||||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||||
input_text = query
|
input_text = query
|
||||||
@ -691,9 +683,7 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||||
|
|
||||||
if memory and node_data.memory and node_data.memory.window:
|
if memory and node_data.memory and node_data.memory.window:
|
||||||
memory_str = memory.get_history_prompt_text(
|
memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size)
|
||||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
|
||||||
)
|
|
||||||
if model_mode == ModelMode.CHAT:
|
if model_mode == ModelMode.CHAT:
|
||||||
system_prompt_messages = ChatModelMessage(
|
system_prompt_messages = ChatModelMessage(
|
||||||
role=PromptMessageRole.SYSTEM,
|
role=PromptMessageRole.SYSTEM,
|
||||||
@ -732,9 +722,19 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
raise ModelSchemaNotFoundError("Model schema not found")
|
raise ModelSchemaNotFoundError("Model schema not found")
|
||||||
|
|
||||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
prompt_template = self._get_function_calling_prompt_template(
|
||||||
|
node_data=node_data,
|
||||||
|
query=query,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
memory=None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||||
|
node_data=node_data,
|
||||||
|
query=query,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
memory=None,
|
||||||
|
)
|
||||||
|
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
prompt_messages = prompt_transform.get_prompt(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
|
@ -2,12 +2,9 @@ import json
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_runtime.entities import LLMUsage, PromptMessageRole
|
||||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
from core.prompt.simple_prompt_transform import ModelMode
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
@ -67,17 +64,10 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
rest_token = self._calculate_rest_token(
|
|
||||||
node_data=node_data,
|
|
||||||
query=query or "",
|
|
||||||
model_config=model_config,
|
|
||||||
context="",
|
|
||||||
)
|
|
||||||
prompt_template = self._get_prompt_template(
|
prompt_template = self._get_prompt_template(
|
||||||
node_data=node_data,
|
node_data=node_data,
|
||||||
query=query or "",
|
query=query or "",
|
||||||
memory=memory,
|
memory=memory,
|
||||||
max_token_limit=rest_token,
|
|
||||||
)
|
)
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
@ -196,56 +186,11 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
"""
|
"""
|
||||||
return {"type": "question-classifier", "config": {"instructions": ""}}
|
return {"type": "question-classifier", "config": {"instructions": ""}}
|
||||||
|
|
||||||
def _calculate_rest_token(
|
|
||||||
self,
|
|
||||||
node_data: QuestionClassifierNodeData,
|
|
||||||
query: str,
|
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
|
||||||
context: Optional[str],
|
|
||||||
) -> int:
|
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
|
||||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
inputs={},
|
|
||||||
query="",
|
|
||||||
files=[],
|
|
||||||
context=context,
|
|
||||||
memory_config=node_data.memory,
|
|
||||||
memory=None,
|
|
||||||
model_config=model_config,
|
|
||||||
)
|
|
||||||
rest_tokens = 2000
|
|
||||||
|
|
||||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
||||||
if model_context_tokens:
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
|
||||||
)
|
|
||||||
|
|
||||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
|
||||||
|
|
||||||
max_tokens = 0
|
|
||||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
||||||
if parameter_rule.name == "max_tokens" or (
|
|
||||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
|
||||||
):
|
|
||||||
max_tokens = (
|
|
||||||
model_config.parameters.get(parameter_rule.name)
|
|
||||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
|
||||||
) or 0
|
|
||||||
|
|
||||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
|
||||||
rest_tokens = max(rest_tokens, 0)
|
|
||||||
|
|
||||||
return rest_tokens
|
|
||||||
|
|
||||||
def _get_prompt_template(
|
def _get_prompt_template(
|
||||||
self,
|
self,
|
||||||
node_data: QuestionClassifierNodeData,
|
node_data: QuestionClassifierNodeData,
|
||||||
query: str,
|
query: str,
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
max_token_limit: int = 2000,
|
|
||||||
):
|
):
|
||||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||||
classes = node_data.classes
|
classes = node_data.classes
|
||||||
@ -258,7 +203,6 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
memory_str = ""
|
memory_str = ""
|
||||||
if memory:
|
if memory:
|
||||||
memory_str = memory.get_history_prompt_text(
|
memory_str = memory.get_history_prompt_text(
|
||||||
max_token_limit=max_token_limit,
|
|
||||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||||
)
|
)
|
||||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||||
|
@ -271,7 +271,6 @@ class MessageService:
|
|||||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||||
|
|
||||||
histories = memory.get_history_prompt_text(
|
histories = memory.get_history_prompt_text(
|
||||||
max_token_limit=3000,
|
|
||||||
message_limit=3,
|
message_limit=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -235,7 +235,6 @@ def test__get_completion_model_prompt_messages():
|
|||||||
"#context#": context,
|
"#context#": context,
|
||||||
"#query#": query,
|
"#query#": query,
|
||||||
"#histories#": memory.get_history_prompt_text(
|
"#histories#": memory.get_history_prompt_text(
|
||||||
max_token_limit=2000,
|
|
||||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user