Compare commits

...

3 Commits

Author SHA1 Message Date
-LAN-
4f27d5d987
refactor: remove max token limit from history prompt methods and related calculations
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-18 18:48:29 +08:00
-LAN-
c53a0db4e0
feat: remove pre-calculation of token counts in ChatAppRunner
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-03-18 14:40:54 +08:00
-LAN-
832d1ada20
chore: mark invoke_llm as deprecated with a reference to the issue
Sgned-off-by: -LAN- <laipz8200@outlook.com>
2025-03-18 14:40:24 +08:00
15 changed files with 44 additions and 266 deletions

View File

@ -53,20 +53,6 @@ class AgentChatAppRunner(AppRunner):
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
)
memory = None memory = None
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)

View File

@ -35,69 +35,6 @@ if TYPE_CHECKING:
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:return:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query,
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens
def recalc_llm_max_tokens( def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
): ):

View File

@ -61,20 +61,6 @@ class ChatAppRunner(AppRunner):
) )
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
memory = None memory = None
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)

View File

@ -54,20 +54,6 @@ class CompletionAppRunner(AppRunner):
) )
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(

View File

@ -25,9 +25,7 @@ class TokenBufferMemory:
self.conversation = conversation self.conversation = conversation
self.model_instance = model_instance self.model_instance = model_instance
def get_history_prompt_messages( def get_history_prompt_messages(self, message_limit: Optional[int] = None) -> Sequence[PromptMessage]:
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
:param max_token_limit: max token limit :param max_token_limit: max token limit
@ -118,33 +116,23 @@ class TokenBufferMemory:
if not prompt_messages: if not prompt_messages:
return [] return []
# prune the chat message if it exceeds the max token limit
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages return prompt_messages
def get_history_prompt_text( def get_history_prompt_text(
self, self,
*,
human_prefix: str = "Human", human_prefix: str = "Human",
ai_prefix: str = "Assistant", ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None, message_limit: Optional[int] = None,
) -> str: ) -> str:
""" """
Get history prompt text. Get history prompt text.
:param human_prefix: human prefix :param human_prefix: human prefix
:param ai_prefix: ai prefix :param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit :param message_limit: message limit
:return: :return:
""" """
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) prompt_messages = self.get_history_prompt_messages(message_limit=message_limit)
string_messages = [] string_messages = []
for m in prompt_messages: for m in prompt_messages:

View File

@ -2,6 +2,8 @@ import logging
from collections.abc import Callable, Generator, Iterable, Sequence from collections.abc import Callable, Generator, Iterable, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload from typing import IO, Any, Literal, Optional, Union, cast, overload
from typing_extensions import deprecated
from configs import dify_config from configs import dify_config
from core.entities.embedding_type import EmbeddingInputType from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
@ -176,6 +178,7 @@ class ModelInstance:
), ),
) )
@deprecated("invoke_llm is deprecated, see https://github.com/langgenius/dify/issues/16090")
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int: ) -> int:

View File

@ -202,9 +202,6 @@ class PluginModelManager(BasePluginManager):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
) -> int: ) -> int:
"""
Get number of tokens for llm
"""
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
method="POST", method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",

View File

@ -284,14 +284,10 @@ class AdvancedPromptTransform(PromptTransform):
inputs = {"#histories#": "", **prompt_inputs} inputs = {"#histories#": "", **prompt_inputs}
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory( histories = self._get_history_messages_from_memory(
memory=memory, memory=memory,
memory_config=memory_config, memory_config=memory_config,
max_token_limit=rest_tokens,
human_prefix=role_prefix.user, human_prefix=role_prefix.user,
ai_prefix=role_prefix.assistant, ai_prefix=role_prefix.assistant,
) )

View File

@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -16,8 +16,7 @@ class PromptTransform:
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config) histories = self._get_history_messages_list_from_memory(memory=memory, memory_config=memory_config)
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
prompt_messages.extend(histories) prompt_messages.extend(histories)
return prompt_messages return prompt_messages
@ -54,31 +53,29 @@ class PromptTransform:
self, self,
memory: TokenBufferMemory, memory: TokenBufferMemory,
memory_config: MemoryConfig, memory_config: MemoryConfig,
max_token_limit: int,
human_prefix: Optional[str] = None, human_prefix: Optional[str] = None,
ai_prefix: Optional[str] = None, ai_prefix: Optional[str] = None,
) -> str: ) -> str:
"""Get memory messages.""" human_prefix = human_prefix or "Human"
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} ai_prefix = ai_prefix or "Assistant"
if human_prefix:
kwargs["human_prefix"] = human_prefix
if ai_prefix:
kwargs["ai_prefix"] = ai_prefix
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
kwargs["message_limit"] = memory_config.window.size message_limit = memory_config.window.size
else:
return memory.get_history_prompt_text(**kwargs) message_limit = None
return memory.get_history_prompt_text(
human_prefix=human_prefix,
ai_prefix=ai_prefix,
message_limit=message_limit,
)
def _get_history_messages_list_from_memory( def _get_history_messages_list_from_memory(
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int self,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
) -> list[PromptMessage]: ) -> list[PromptMessage]:
"""Get memory messages.""" """Get memory messages."""
return list( return list(
memory.get_history_prompt_messages( memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=memory_config.window.size message_limit=memory_config.window.size
if ( if (
memory_config.window.enabled memory_config.window.enabled

View File

@ -238,9 +238,6 @@ class SimplePromptTransform(PromptTransform):
) )
if memory: if memory:
tmp_human_message = UserPromptMessage(content=prompt)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory( histories = self._get_history_messages_from_memory(
memory=memory, memory=memory,
memory_config=MemoryConfig( memory_config=MemoryConfig(
@ -248,7 +245,6 @@ class SimplePromptTransform(PromptTransform):
enabled=False, enabled=False,
) )
), ),
max_token_limit=rest_tokens,
human_prefix=prompt_rules.get("human_prefix", "Human"), human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
) )

View File

@ -27,7 +27,7 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage, SystemPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
@ -91,7 +91,7 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
_node_type = NodeType.LLM _node_type = NodeType.LLM
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None] | NodeRunResult:
node_inputs: Optional[dict[str, Any]] = None node_inputs: Optional[dict[str, Any]] = None
process_data = None process_data = None
result_text = "" result_text = ""
@ -624,7 +624,6 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_text = _handle_memory_completion_mode( memory_text = _handle_memory_completion_mode(
memory=memory, memory=memory,
memory_config=memory_config, memory_config=memory_config,
model_config=model_config,
) )
# Insert histories into the prompt # Insert histories into the prompt
prompt_content = prompt_messages[0].content prompt_content = prompt_messages[0].content
@ -960,36 +959,6 @@ def _render_jinja2_message(
return result_text return result_text
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode( def _handle_memory_chat_mode(
*, *,
memory: TokenBufferMemory | None, memory: TokenBufferMemory | None,
@ -999,9 +968,7 @@ def _handle_memory_chat_mode(
memory_messages: Sequence[PromptMessage] = [] memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model # Get messages from memory for chat model
if memory and memory_config: if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages( memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None, message_limit=memory_config.window.size if memory_config.window.enabled else None,
) )
return memory_messages return memory_messages
@ -1011,16 +978,13 @@ def _handle_memory_completion_mode(
*, *,
memory: TokenBufferMemory | None, memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None, memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str: ) -> str:
memory_text = "" memory_text = ""
# Get history text from memory for completion model # Get history text from memory for completion model
if memory and memory_config: if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix: if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text( memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None, message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user, human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant, ai_prefix=memory_config.role_prefix.assistant,

View File

@ -280,9 +280,11 @@ class ParameterExtractorNode(LLMNode):
) )
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
prompt_template = self._get_function_calling_prompt_template( prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token node_data=node_data,
query=query,
variable_pool=variable_pool,
memory=memory,
) )
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,
@ -396,11 +398,8 @@ class ParameterExtractorNode(LLMNode):
Generate completion prompt. Generate completion prompt.
""" """
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template( prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token node_data=node_data, query=query, variable_pool=variable_pool, memory=memory
) )
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,
@ -430,9 +429,6 @@ class ParameterExtractorNode(LLMNode):
Generate chat prompt. Generate chat prompt.
""" """
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template( prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, node_data=node_data,
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
@ -440,7 +436,6 @@ class ParameterExtractorNode(LLMNode):
), ),
variable_pool=variable_pool, variable_pool=variable_pool,
memory=memory, memory=memory,
max_token_limit=rest_token,
) )
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
@ -652,11 +647,11 @@ class ParameterExtractorNode(LLMNode):
def _get_function_calling_prompt_template( def _get_function_calling_prompt_template(
self, self,
*,
node_data: ParameterExtractorNodeData, node_data: ParameterExtractorNodeData,
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]: ) -> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode) model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query input_text = query
@ -664,9 +659,7 @@ class ParameterExtractorNode(LLMNode):
instruction = variable_pool.convert_template(node_data.instruction or "").text instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window: if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text( memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size)
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage( system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM,
@ -683,7 +676,6 @@ class ParameterExtractorNode(LLMNode):
query: str, query: str,
variable_pool: VariablePool, variable_pool: VariablePool,
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
): ):
model_mode = ModelMode.value_of(node_data.model.mode) model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query input_text = query
@ -691,9 +683,7 @@ class ParameterExtractorNode(LLMNode):
instruction = variable_pool.convert_template(node_data.instruction or "").text instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window: if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text( memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size)
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage( system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM,
@ -732,9 +722,19 @@ class ParameterExtractorNode(LLMNode):
raise ModelSchemaNotFoundError("Model schema not found") raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) prompt_template = self._get_function_calling_prompt_template(
node_data=node_data,
query=query,
variable_pool=variable_pool,
memory=None,
)
else: else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
query=query,
variable_pool=variable_pool,
memory=None,
)
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,

View File

@ -2,12 +2,9 @@ import json
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, PromptMessageRole
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
@ -67,17 +64,10 @@ class QuestionClassifierNode(LLMNode):
) )
# fetch prompt messages # fetch prompt messages
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_config=model_config,
context="",
)
prompt_template = self._get_prompt_template( prompt_template = self._get_prompt_template(
node_data=node_data, node_data=node_data,
query=query or "", query=query or "",
memory=memory, memory=memory,
max_token_limit=rest_token,
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
@ -196,56 +186,11 @@ class QuestionClassifierNode(LLMNode):
""" """
return {"type": "question-classifier", "config": {"instructions": ""}} return {"type": "question-classifier", "config": {"instructions": ""}}
def _calculate_rest_token(
self,
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str],
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _get_prompt_template( def _get_prompt_template(
self, self,
node_data: QuestionClassifierNodeData, node_data: QuestionClassifierNodeData,
query: str, query: str,
memory: Optional[TokenBufferMemory], memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
): ):
model_mode = ModelMode.value_of(node_data.model.mode) model_mode = ModelMode.value_of(node_data.model.mode)
classes = node_data.classes classes = node_data.classes
@ -258,7 +203,6 @@ class QuestionClassifierNode(LLMNode):
memory_str = "" memory_str = ""
if memory: if memory:
memory_str = memory.get_history_prompt_text( memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
) )
prompt_messages: list[LLMNodeChatModelMessage] = [] prompt_messages: list[LLMNodeChatModelMessage] = []

View File

@ -271,7 +271,6 @@ class MessageService:
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
histories = memory.get_history_prompt_text( histories = memory.get_history_prompt_text(
max_token_limit=3000,
message_limit=3, message_limit=3,
) )

View File

@ -235,7 +235,6 @@ def test__get_completion_model_prompt_messages():
"#context#": context, "#context#": context,
"#query#": query, "#query#": query,
"#histories#": memory.get_history_prompt_text( "#histories#": memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"), human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
), ),