diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 72a1717112..71328f6d1b 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -53,20 +53,6 @@ class AgentChatAppRunner(AppRunner): query = application_generate_entity.query files = application_generate_entity.files - # Pre-calculate the number of tokens of the prompt messages, - # and return the rest number of tokens by model context token size limit and max token size limit. - # If the rest number of tokens is not enough, raise exception. - # Include: prompt template, inputs, query(optional), files(optional) - # Not Include: memory, external data, dataset context - self.get_pre_calculate_rest_tokens( - app_record=app_record, - model_config=application_generate_entity.model_conf, - prompt_template_entity=app_config.prompt_template, - inputs=dict(inputs), - files=list(files), - query=query, - ) - memory = None if application_generate_entity.conversation_id: # get memory of conversation (read-only) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 8c6b29731e..092b04363b 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -35,69 +35,6 @@ if TYPE_CHECKING: class AppRunner: - def get_pre_calculate_rest_tokens( - self, - app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: Mapping[str, str], - files: Sequence["File"], - query: Optional[str] = None, - ) -> int: - """ - Get pre calculate rest tokens - :param app_record: app record - :param model_config: model config entity - :param prompt_template_entity: prompt template entity - :param inputs: inputs - :param files: files - :param query: query - :return: - """ - # Invoke model - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") - ) or 0 - - if model_context_tokens is None: - return -1 - - if max_tokens is None: - max_tokens = 0 - - # get prompt messages without memory and context - prompt_messages, stop = self.organize_prompt_messages( - app_record=app_record, - model_config=model_config, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - query=query, - ) - - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens - if rest_tokens < 0: - raise InvokeBadRequestError( - "Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size." - ) - - return rest_tokens - def recalc_llm_max_tokens( self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] ): diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 0ed06c9c98..27f163b2b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -54,20 +54,6 @@ class CompletionAppRunner(AppRunner): ) image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - # Pre-calculate the number of tokens of the prompt messages, - # and return the rest number of tokens by model context token size limit and max token size limit. - # If the rest number of tokens is not enough, raise exception. - # Include: prompt template, inputs, query(optional), files(optional) - # Not Include: memory, external data, dataset context - self.get_pre_calculate_rest_tokens( - app_record=app_record, - model_config=application_generate_entity.model_conf, - prompt_template_entity=app_config.prompt_template, - inputs=inputs, - files=files, - query=query, - ) - # organize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) prompt_messages, stop = self.organize_prompt_messages( diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 003a0c85b1..42716c810a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -25,9 +25,7 @@ class TokenBufferMemory: self.conversation = conversation self.model_instance = model_instance - def get_history_prompt_messages( - self, max_token_limit: int = 2000, message_limit: Optional[int] = None - ) -> Sequence[PromptMessage]: + def get_history_prompt_messages(self, message_limit: Optional[int] = None) -> Sequence[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -118,33 +116,23 @@ class TokenBufferMemory: if not prompt_messages: return [] - # prune the chat message if it exceeds the max token limit - curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) - - if curr_message_tokens > max_token_limit: - pruned_memory = [] - while curr_message_tokens > max_token_limit and len(prompt_messages) > 1: - pruned_memory.append(prompt_messages.pop(0)) - curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages) - return prompt_messages def get_history_prompt_text( self, + *, human_prefix: str = "Human", ai_prefix: str = "Assistant", - max_token_limit: int = 2000, message_limit: Optional[int] = None, ) -> str: """ Get history prompt text. :param human_prefix: human prefix :param ai_prefix: ai prefix - :param max_token_limit: max token limit :param message_limit: message limit :return: """ - prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit) + prompt_messages = self.get_history_prompt_messages(message_limit=message_limit) string_messages = [] for m in prompt_messages: diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/manager/model.py index 5ebc0c2320..eda18dbafb 100644 --- a/api/core/plugin/manager/model.py +++ b/api/core/plugin/manager/model.py @@ -202,9 +202,6 @@ class PluginModelManager(BasePluginManager): prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None, ) -> int: - """ - Get number of tokens for llm - """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index c7427f797e..f7f9a35dff 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -284,14 +284,10 @@ class AdvancedPromptTransform(PromptTransform): inputs = {"#histories#": "", **prompt_inputs} parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} - tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( memory=memory, memory_config=memory_config, - max_token_limit=rest_tokens, human_prefix=role_prefix.user, ai_prefix=role_prefix.assistant, ) diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 1f040599be..ae32af03a7 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Optional from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -16,8 +16,7 @@ class PromptTransform: prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity, ) -> list[PromptMessage]: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) - histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) + histories = self._get_history_messages_list_from_memory(memory=memory, memory_config=memory_config) prompt_messages.extend(histories) return prompt_messages @@ -54,31 +53,29 @@ class PromptTransform: self, memory: TokenBufferMemory, memory_config: MemoryConfig, - max_token_limit: int, human_prefix: Optional[str] = None, ai_prefix: Optional[str] = None, ) -> str: - """Get memory messages.""" - kwargs: dict[str, Any] = {"max_token_limit": max_token_limit} - - if human_prefix: - kwargs["human_prefix"] = human_prefix - - if ai_prefix: - kwargs["ai_prefix"] = ai_prefix - + human_prefix = human_prefix or "Human" + ai_prefix = ai_prefix or "Assistant" if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0: - kwargs["message_limit"] = memory_config.window.size - - return memory.get_history_prompt_text(**kwargs) + message_limit = memory_config.window.size + else: + message_limit = None + return memory.get_history_prompt_text( + human_prefix=human_prefix, + ai_prefix=ai_prefix, + message_limit=message_limit, + ) def _get_history_messages_list_from_memory( - self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int + self, + memory: TokenBufferMemory, + memory_config: MemoryConfig, ) -> list[PromptMessage]: """Get memory messages.""" return list( memory.get_history_prompt_messages( - max_token_limit=max_token_limit, message_limit=memory_config.window.size if ( memory_config.window.enabled diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 421b14e0df..a6d8ab168c 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -238,9 +238,6 @@ class SimplePromptTransform(PromptTransform): ) if memory: - tmp_human_message = UserPromptMessage(content=prompt) - - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) histories = self._get_history_messages_from_memory( memory=memory, memory_config=MemoryConfig( @@ -248,7 +245,6 @@ class SimplePromptTransform(PromptTransform): enabled=False, ) ), - max_token_limit=rest_tokens, human_prefix=prompt_rules.get("human_prefix", "Human"), ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..87e65c89c4 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -27,7 +27,7 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID @@ -91,7 +91,7 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM - def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None] | NodeRunResult: node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" @@ -624,7 +624,6 @@ class LLMNode(BaseNode[LLMNodeData]): memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, - model_config=model_config, ) # Insert histories into the prompt prompt_content = prompt_messages[0].content @@ -960,36 +959,6 @@ def _render_jinja2_message( return result_text -def _calculate_rest_token( - *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity -) -> int: - rest_tokens = 2000 - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _handle_memory_chat_mode( *, memory: TokenBufferMemory | None, @@ -999,9 +968,7 @@ def _handle_memory_chat_mode( memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: - rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, ) return memory_messages @@ -1011,16 +978,13 @@ def _handle_memory_completion_mode( *, memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, ) -> str: memory_text = "" # Get history text from memory for completion model if memory and memory_config: - rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") memory_text = memory.get_history_prompt_text( - max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, human_prefix=memory_config.role_prefix.user, ai_prefix=memory_config.role_prefix.assistant, diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 7b1b8cf483..aebdc0d21b 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -280,9 +280,11 @@ class ParameterExtractorNode(LLMNode): ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") prompt_template = self._get_function_calling_prompt_template( - node_data, query, variable_pool, memory, rest_token + node_data=node_data, + query=query, + variable_pool=variable_pool, + memory=memory, ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, @@ -396,11 +398,8 @@ class ParameterExtractorNode(LLMNode): Generate completion prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" - ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token + node_data=node_data, query=query, variable_pool=variable_pool, memory=memory ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, @@ -430,9 +429,6 @@ class ParameterExtractorNode(LLMNode): Generate chat prompt. """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" - ) prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( @@ -440,7 +436,6 @@ class ParameterExtractorNode(LLMNode): ), variable_pool=variable_pool, memory=memory, - max_token_limit=rest_token, ) prompt_messages = prompt_transform.get_prompt( @@ -652,11 +647,11 @@ class ParameterExtractorNode(LLMNode): def _get_function_calling_prompt_template( self, + *, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000, ) -> list[ChatModelMessage]: model_mode = ModelMode.value_of(node_data.model.mode) input_text = query @@ -664,9 +659,7 @@ class ParameterExtractorNode(LLMNode): instruction = variable_pool.convert_template(node_data.instruction or "").text if memory and node_data.memory and node_data.memory.window: - memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) + memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, @@ -683,7 +676,6 @@ class ParameterExtractorNode(LLMNode): query: str, variable_pool: VariablePool, memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000, ): model_mode = ModelMode.value_of(node_data.model.mode) input_text = query @@ -691,9 +683,7 @@ class ParameterExtractorNode(LLMNode): instruction = variable_pool.convert_template(node_data.instruction or "").text if memory and node_data.memory and node_data.memory.window: - memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) + memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( role=PromptMessageRole.SYSTEM, @@ -732,9 +722,19 @@ class ParameterExtractorNode(LLMNode): raise ModelSchemaNotFoundError("Model schema not found") if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) + prompt_template = self._get_function_calling_prompt_template( + node_data=node_data, + query=query, + variable_pool=variable_pool, + memory=None, + ) else: - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) + prompt_template = self._get_prompt_engineering_prompt_template( + node_data=node_data, + query=query, + variable_pool=variable_pool, + memory=None, + ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 0ec44eefac..6797324f9d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -2,12 +2,9 @@ import json from collections.abc import Mapping, Sequence from typing import Any, Optional, cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole +from core.model_runtime.entities import LLMUsage, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult @@ -67,17 +64,10 @@ class QuestionClassifierNode(LLMNode): ) # fetch prompt messages - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query or "", - model_config=model_config, - context="", - ) prompt_template = self._get_prompt_template( node_data=node_data, query=query or "", memory=memory, - max_token_limit=rest_token, ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, @@ -196,56 +186,11 @@ class QuestionClassifierNode(LLMNode): """ return {"type": "question-classifier", "config": {"instructions": ""}} - def _calculate_rest_token( - self, - node_data: QuestionClassifierNodeData, - query: str, - model_config: ModelConfigWithCredentialsEntity, - context: Optional[str], - ) -> int: - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_config=model_config, - ) - rest_tokens = 2000 - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - def _get_prompt_template( self, node_data: QuestionClassifierNodeData, query: str, memory: Optional[TokenBufferMemory], - max_token_limit: int = 2000, ): model_mode = ModelMode.value_of(node_data.model.mode) classes = node_data.classes @@ -258,7 +203,6 @@ class QuestionClassifierNode(LLMNode): memory_str = "" if memory: memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/services/message_service.py b/api/services/message_service.py index 480d038623..b88dd05293 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -271,7 +271,6 @@ class MessageService: memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) histories = memory.get_history_prompt_text( - max_token_limit=3000, message_limit=3, ) diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c32fc2bc34..8467e002c3 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -235,7 +235,6 @@ def test__get_completion_model_prompt_messages(): "#context#": context, "#query#": query, "#histories#": memory.get_history_prompt_text( - max_token_limit=2000, human_prefix=prompt_rules.get("human_prefix", "Human"), ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"), ),