feat: support backwards invoke summary

This commit is contained in:
Yeuoly 2024-10-17 19:44:30 +08:00
parent 7754431a34
commit 45f8651a3d
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
3 changed files with 173 additions and 0 deletions

View File

@ -19,6 +19,7 @@ from core.plugin.entities.request import (
RequestInvokeQuestionClassifierNode,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeSummary,
RequestInvokeTextEmbedding,
RequestInvokeTool,
RequestInvokeTTS,
@ -230,6 +231,24 @@ class PluginInvokeEncryptApi(Resource):
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
class PluginInvokeSummaryApi(Resource):
@setup_required
@plugin_inner_api_only
@get_tenant
@plugin_data(payload_type=RequestInvokeSummary)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSummary):
try:
return BaseBackwardsInvocationResponse(
data=PluginModelBackwardsInvocation.invoke_summary(
user_id=user_id,
tenant=tenant_model,
payload=payload,
)
).model_dump()
except Exception as e:
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -241,3 +260,4 @@ api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extra
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")

View File

@ -4,15 +4,23 @@ from collections.abc import Generator
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
UserPromptMessage,
)
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
RequestInvokeSummary,
RequestInvokeTextEmbedding,
RequestInvokeTTS,
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm.llm_node import LLMNode
from models.account import Tenant
@ -175,3 +183,139 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return {
"result": response,
}
@classmethod
def get_system_model_max_tokens(cls, tenant_id: str) -> int:
"""
get system model max tokens
"""
return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
@classmethod
def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
"""
get prompt tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
@classmethod
def invoke_system_model(
cls,
user_id: str,
tenant: Tenant,
prompt_messages: list[PromptMessage],
) -> LLMResult:
"""
invoke system model
"""
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=tenant.id,
tool_type=ToolProviderType.PLUGIN,
tool_name="plugin",
prompt_messages=prompt_messages,
)
@classmethod
def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
"""
invoke summary
"""
max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
content = payload.text
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
Here is the extra instruction you need to follow:
<extra_instruction>
{payload.instruction}
</extra_instruction>
"""
if (
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=content)],
)
< max_tokens * 0.6
):
return content
def get_prompt_tokens(content: str) -> int:
return cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
UserPromptMessage(content=content),
],
)
def summarize(content: str) -> str:
summary = cls.invoke_system_model(
user_id=user_id,
tenant=tenant,
prompt_messages=[
SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
UserPromptMessage(content=content),
],
)
assert isinstance(summary.message.content, str)
return summary.message.content
lines = content.split("\n")
new_lines = []
# split long line into multiple lines
for i in range(len(lines)):
line = lines[i]
if not line.strip():
continue
if len(line) < max_tokens * 0.5:
new_lines.append(line)
elif get_prompt_tokens(line) > max_tokens * 0.7:
while get_prompt_tokens(line) > max_tokens * 0.7:
new_lines.append(line[: int(max_tokens * 0.5)])
line = line[int(max_tokens * 0.5) :]
new_lines.append(line)
else:
new_lines.append(line)
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
if len(messages) == 0:
messages.append(i)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
else:
messages[-1] += i
summaries = []
for i in range(len(messages)):
message = messages[i]
summary = summarize(message)
summaries.append(summary)
result = "\n".join(summaries)
if (
cls.get_prompt_tokens(
tenant_id=tenant.id,
prompt_messages=[UserPromptMessage(content=result)],
)
> max_tokens * 0.7
):
return cls.invoke_summary(
user_id=user_id,
tenant=tenant,
payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
)
return result

View File

@ -186,3 +186,12 @@ class RequestInvokeEncrypt(BaseModel):
identity: str
data: dict = Field(default_factory=dict)
config: list[BasicProviderConfig] = Field(default_factory=list)
class RequestInvokeSummary(BaseModel):
"""
Request to summary
"""
text: str
instruction: str