refactor: remove max token limit from history prompt methods and related calculations
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
c53a0db4e0
commit
4f27d5d987
@ -53,20 +53,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
|
@ -35,69 +35,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt messages without memory and context
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size."
|
||||
)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recalc_llm_max_tokens(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
|
@ -54,20 +54,6 @@ class CompletionAppRunner(AppRunner):
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
|
@ -25,9 +25,7 @@ class TokenBufferMemory:
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
def get_history_prompt_messages(self, message_limit: Optional[int] = None) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
@ -118,33 +116,23 @@ class TokenBufferMemory:
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if curr_message_tokens > max_token_limit:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
*,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
prompt_messages = self.get_history_prompt_messages(message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
|
@ -202,9 +202,6 @@ class PluginModelManager(BasePluginManager):
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
|
@ -284,14 +284,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
inputs = {"#histories#": "", **prompt_inputs}
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=role_prefix.user,
|
||||
ai_prefix=role_prefix.assistant,
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@ -16,8 +16,7 @@ class PromptTransform:
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
histories = self._get_history_messages_list_from_memory(memory=memory, memory_config=memory_config)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
@ -54,31 +53,29 @@ class PromptTransform:
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: Optional[str] = None,
|
||||
ai_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get memory messages."""
|
||||
kwargs: dict[str, Any] = {"max_token_limit": max_token_limit}
|
||||
|
||||
if human_prefix:
|
||||
kwargs["human_prefix"] = human_prefix
|
||||
|
||||
if ai_prefix:
|
||||
kwargs["ai_prefix"] = ai_prefix
|
||||
|
||||
human_prefix = human_prefix or "Human"
|
||||
ai_prefix = ai_prefix or "Assistant"
|
||||
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
||||
kwargs["message_limit"] = memory_config.window.size
|
||||
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
message_limit = memory_config.window.size
|
||||
else:
|
||||
message_limit = None
|
||||
return memory.get_history_prompt_text(
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=memory_config.window.size
|
||||
if (
|
||||
memory_config.window.enabled
|
||||
|
@ -238,9 +238,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
)
|
||||
|
||||
if memory:
|
||||
tmp_human_message = UserPromptMessage(content=prompt)
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
@ -248,7 +245,6 @@ class SimplePromptTransform(PromptTransform):
|
||||
enabled=False,
|
||||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
@ -91,7 +91,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None] | NodeRunResult:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
@ -624,7 +624,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
@ -960,36 +959,6 @@ def _render_jinja2_message(
|
||||
return result_text
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
@ -999,9 +968,7 @@ def _handle_memory_chat_mode(
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
@ -1011,16 +978,13 @@ def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
|
@ -280,9 +280,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
memory=memory,
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
@ -396,11 +398,8 @@ class ParameterExtractorNode(LLMNode):
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory
|
||||
)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
@ -430,9 +429,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data,
|
||||
query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
@ -440,7 +436,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
@ -652,11 +647,11 @@ class ParameterExtractorNode(LLMNode):
|
||||
|
||||
def _get_function_calling_prompt_template(
|
||||
self,
|
||||
*,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
@ -664,9 +659,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
@ -683,7 +676,6 @@ class ParameterExtractorNode(LLMNode):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
@ -691,9 +683,7 @@ class ParameterExtractorNode(LLMNode):
|
||||
instruction = variable_pool.convert_template(node_data.instruction or "").text
|
||||
|
||||
if memory and node_data.memory and node_data.memory.window:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
|
||||
)
|
||||
memory_str = memory.get_history_prompt_text(message_limit=node_data.memory.window.size)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
@ -732,9 +722,19 @@ class ParameterExtractorNode(LLMNode):
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
memory=None,
|
||||
)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
memory=None,
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
|
@ -2,12 +2,9 @@ import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.entities import LLMUsage, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
@ -67,17 +64,10 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
|
||||
# fetch prompt messages
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
model_config=model_config,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
max_token_limit=rest_token,
|
||||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
@ -196,56 +186,11 @@ class QuestionClassifierNode(LLMNode):
|
||||
"""
|
||||
return {"type": "question-classifier", "config": {"instructions": ""}}
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str],
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _get_prompt_template(
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
classes = node_data.classes
|
||||
@ -258,7 +203,6 @@ class QuestionClassifierNode(LLMNode):
|
||||
memory_str = ""
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
|
||||
)
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
|
@ -271,7 +271,6 @@ class MessageService:
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
histories = memory.get_history_prompt_text(
|
||||
max_token_limit=3000,
|
||||
message_limit=3,
|
||||
)
|
||||
|
||||
|
@ -235,7 +235,6 @@ def test__get_completion_model_prompt_messages():
|
||||
"#context#": context,
|
||||
"#query#": query,
|
||||
"#histories#": memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
),
|
||||
|
Loading…
Reference in New Issue
Block a user