diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 8518d34a8e..29e9907fca 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,4 +1,5 @@ import os +from typing import cast from flask_login import current_user # type: ignore from flask_restful import Resource, reqparse # type: ignore @@ -11,8 +12,11 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.wraps import account_initialization_required, setup_required +from core.auto.workflow_generator.workflow_generator import WorkflowGenerator from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.llm_generator.llm_generator import LLMGenerator +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required @@ -85,5 +89,45 @@ class RuleCodeGenerateApi(Resource): return code_result +class AutoGenerateWorkflowApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + """ + Auto generate workflow + """ + + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + tenant_id = cast(str, current_user.current_tenant_id) + args = parser.parse_args() + instruction = args.get("instruction") + if not instruction: + raise ValueError("Instruction is required") + if not args.get("model_config"): + raise ValueError("Model config is required") + model_config = cast(dict, args.get("model_config")) + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + workflow_generator = WorkflowGenerator( + model_instance=model_instance, + ) + workflow_yaml = workflow_generator.generate_workflow( + user_requirement=instruction, + ) + return workflow_yaml + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") +api.add_resource( + AutoGenerateWorkflowApi, + "/auto-generate", +) diff --git a/api/core/auto/config/custom.yaml b/api/core/auto/config/custom.yaml new file mode 100644 index 0000000000..66192a102f --- /dev/null +++ b/api/core/auto/config/custom.yaml @@ -0,0 +1,27 @@ +# 自定义配置文件 +workflow_generator: + # 用于生成工作流的模型配置 + models: + default: my-gpt-4o-mini # 默认使用的模型 + available: # 可用的模型列表 + my-gpt-4o-mini: + model_name: gpt-4o-mini + base_url: https://api.pandalla.ai/v1 + key_path: ./openai_key + max_tokens: 4096 + my-gpt-4o: + model_name: gpt-4o + base_url: https://api.pandalla.ai/v1 + key_path: ./openai_key + max_tokens: 4096 + +# 调试配置 +debug: + enabled: false # 默认不启用调试模式,可通过命令行参数 --debug 启用 + dir: debug/ # 调试信息保存目录 + save_options: # 调试信息保存选项 + prompt: true # 保存提示词 + response: true # 保存大模型响应 + json: true # 保存JSON解析过程 + workflow: true # 保存工作流生成过程 + case_id_format: "%Y%m%d_%H%M%S_%f" # 运行ID格式,使用datetime.strftime格式 diff --git a/api/core/auto/config/default.yaml b/api/core/auto/config/default.yaml new file mode 100644 index 0000000000..eb4dccb40b --- /dev/null +++ b/api/core/auto/config/default.yaml @@ -0,0 +1,33 @@ +# 默认配置文件 + +# 工作流生成器配置 +workflow_generator: + # 用于生成工作流的模型配置 + models: + default: gpt-4 # 默认使用的模型 + available: # 可用的模型列表 + gpt-4: + model_name: gpt-4 + base_url: https://api.openai.com/v1 + key_path: ./openai_key + max_tokens: 8192 + gpt-4-turbo: + model_name: gpt-4-1106-preview + base_url: https://api.openai.com/v1 + key_path: ./openai_key + max_tokens: 4096 + +# 工作流节点配置 +workflow_nodes: + # LLM节点默认配置(使用 Dify 平台配置的模型) + llm: + provider: zhipuai + model: glm-4-flash + max_tokens: 16384 + temperature: 0.7 + mode: chat + +# 输出配置 +output: + dir: output/ + filename: generated_workflow.yml \ No newline at end of file diff --git a/api/core/auto/node_types/__init__.py b/api/core/auto/node_types/__init__.py new file mode 100644 index 0000000000..4db06e766d --- /dev/null +++ b/api/core/auto/node_types/__init__.py @@ -0,0 +1,78 @@ +from .agent import AgentNodeType +from .answer import AnswerNodeType +from .assigner import AssignerNodeType +from .code import CodeLanguage, CodeNodeType, OutputVar +from .common import ( + BlockEnum, + CommonEdgeType, + CommonNodeType, + CompleteEdge, + CompleteNode, + Context, + InputVar, + InputVarType, + Memory, + ModelConfig, + PromptItem, + PromptRole, + ValueSelector, + Variable, + VarType, + VisionSetting, +) +from .end import EndNodeType +from .http import HttpNodeType +from .if_else import IfElseNodeType +from .iteration import IterationNodeType +from .iteration_start import IterationStartNodeType +from .knowledge_retrieval import KnowledgeRetrievalNodeType +from .list_operator import ListFilterNodeType +from .llm import LLMNodeType, VisionConfig +from .note_node import NoteNodeType +from .parameter_extractor import ParameterExtractorNodeType +from .question_classifier import QuestionClassifierNodeType +from .start import StartNodeType +from .template_transform import TemplateTransformNodeType +from .tool import ToolNodeType +from .variable_assigner import VariableAssignerNodeType + +__all__ = [ + "AgentNodeType", + "AnswerNodeType", + "AssignerNodeType", + "BlockEnum", + "CodeLanguage", + "CodeNodeType", + "CommonEdgeType", + "CommonNodeType", + "CompleteEdge", + "CompleteNode", + "Context", + "EndNodeType", + "HttpNodeType", + "IfElseNodeType", + "InputVar", + "InputVarType", + "IterationNodeType", + "IterationStartNodeType", + "KnowledgeRetrievalNodeType", + "LLMNodeType", + "ListFilterNodeType", + "Memory", + "ModelConfig", + "NoteNodeType", + "OutputVar", + "ParameterExtractorNodeType", + "PromptItem", + "PromptRole", + "QuestionClassifierNodeType", + "StartNodeType", + "TemplateTransformNodeType", + "ToolNodeType", + "ValueSelector", + "VarType", + "Variable", + "VariableAssignerNodeType", + "VisionConfig", + "VisionSetting", +] diff --git a/api/core/auto/node_types/agent.py b/api/core/auto/node_types/agent.py new file mode 100644 index 0000000000..cba45096d8 --- /dev/null +++ b/api/core/auto/node_types/agent.py @@ -0,0 +1,34 @@ +from typing import Any, Optional + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType + +# Introduce previously defined CommonNodeType and ToolVarInputs +# Assume they are defined in the same module + + +class ToolVarInputs(BaseModel): + variable_name: Optional[str] = None + default_value: Optional[Any] = None + + +class AgentNodeType(CommonNodeType): + agent_strategy_provider_name: Optional[str] = None + agent_strategy_name: Optional[str] = None + agent_strategy_label: Optional[str] = None + agent_parameters: Optional[ToolVarInputs] = None + output_schema: dict[str, Any] + plugin_unique_identifier: Optional[str] = None + + +# 示例用法 +if __name__ == "__main__": + example_node = AgentNodeType( + title="Example Agent", + desc="An agent node example", + type=BlockEnum.agent, + output_schema={"key": "value"}, + agent_parameters=ToolVarInputs(variable_name="example_var", default_value="default"), + ) + print(example_node) diff --git a/api/core/auto/node_types/answer.py b/api/core/auto/node_types/answer.py new file mode 100644 index 0000000000..c4e9600ad1 --- /dev/null +++ b/api/core/auto/node_types/answer.py @@ -0,0 +1,21 @@ +from .common import BlockEnum, CommonNodeType, Variable + + +class AnswerNodeType(CommonNodeType): + variables: list[Variable] + answer: str + + +# Example usage +if __name__ == "__main__": + example_node = AnswerNodeType( + title="Example Answer Node", + desc="An answer node example", + type=BlockEnum.answer, + answer="This is the answer", + variables=[ + Variable(variable="var1", value_selector=["node1", "key1"]), + Variable(variable="var2", value_selector=["node2", "key2"]), + ], + ) + print(example_node) diff --git a/api/core/auto/node_types/assigner.py b/api/core/auto/node_types/assigner.py new file mode 100644 index 0000000000..f49ea45055 --- /dev/null +++ b/api/core/auto/node_types/assigner.py @@ -0,0 +1,62 @@ +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from .common import BlockEnum, CommonNodeType + +# Import previously defined CommonNodeType and ValueSelector +# Assume they are defined in the same module + + +class WriteMode(str, Enum): + overwrite = "over-write" + clear = "clear" + append = "append" + extend = "extend" + set = "set" + increment = "+=" + decrement = "-=" + multiply = "*=" + divide = "/=" + + +class AssignerNodeInputType(str, Enum): + variable = "variable" + constant = "constant" + + +class AssignerNodeOperation(BaseModel): + variable_selector: Any # Placeholder for ValueSelector type + input_type: AssignerNodeInputType + operation: WriteMode + value: Any + + +class AssignerNodeType(CommonNodeType): + version: Optional[str] = Field(None, pattern="^[12]$") # Version is '1' or '2' + items: list[AssignerNodeOperation] + + +# Example usage +if __name__ == "__main__": + example_node = AssignerNodeType( + title="Example Assigner Node", + desc="An assigner node example", + type=BlockEnum.variable_assigner, + items=[ + AssignerNodeOperation( + variable_selector={"nodeId": "node1", "key": "value"}, # Example ValueSelector + input_type=AssignerNodeInputType.variable, + operation=WriteMode.set, + value="newValue", + ), + AssignerNodeOperation( + variable_selector={"nodeId": "node2", "key": "value"}, + input_type=AssignerNodeInputType.constant, + operation=WriteMode.increment, + value=1, + ), + ], + ) + print(example_node) diff --git a/api/core/auto/node_types/code.py b/api/core/auto/node_types/code.py new file mode 100644 index 0000000000..3f057a7796 --- /dev/null +++ b/api/core/auto/node_types/code.py @@ -0,0 +1,56 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from core.auto.node_types.common import BlockEnum, CommonNodeType, Variable, VarType + +# 引入之前定义的 CommonNodeType、VarType 和 Variable +# 假设它们在同一模块中定义 + + +class CodeLanguage(str, Enum): + python3 = "python3" + javascript = "javascript" + json = "json" + + +class OutputVar(BaseModel): + type: VarType + children: Optional[None] = None # 未来支持嵌套 + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保正确序列化""" + result = {"type": self.type.value if isinstance(self.type, Enum) else self.type} + + if self.children is not None: + result["children"] = self.children + + return result + + +class CodeNodeType(CommonNodeType): + variables: list[Variable] + code_language: CodeLanguage + code: str + outputs: dict[str, OutputVar] + + +# 示例用法 +if __name__ == "__main__": + # 创建示例节点 + example_node = CodeNodeType( + title="Example Code Node", + desc="A code node example", + type=BlockEnum.code, + code_language=CodeLanguage.python3, + code="print('Hello, World!')", + outputs={ + "output1": OutputVar(type=VarType.string), + "output2": OutputVar(type=VarType.number), + }, + variables=[ + Variable(variable="var1", value_selector=["node1", "key1"]), + ], + ) + print(example_node.get_all_required_fields()) diff --git a/api/core/auto/node_types/common.py b/api/core/auto/node_types/common.py new file mode 100644 index 0000000000..6943d63b25 --- /dev/null +++ b/api/core/auto/node_types/common.py @@ -0,0 +1,690 @@ +from enum import Enum +from typing import Any, Optional, Union + +import yaml +from pydantic import BaseModel, Field + + +# BlockEnum 枚举 +class BlockEnum(str, Enum): + start = "start" + end = "end" + answer = "answer" + llm = "llm" + knowledge_retrieval = "knowledge-retrieval" + question_classifier = "question-classifier" + if_else = "if-else" + code = "code" + template_transform = "template-transform" + http_request = "http-request" + variable_assigner = "variable-assigner" + variable_aggregator = "variable-aggregator" + tool = "tool" + parameter_extractor = "parameter-extractor" + iteration = "iteration" + document_extractor = "document-extractor" + list_operator = "list-operator" + iteration_start = "iteration-start" + assigner = "assigner" # is now named as VariableAssigner + agent = "agent" + + +# Error枚举 +class ErrorHandleMode(str, Enum): + terminated = "terminated" + continue_on_error = "continue-on-error" + remove_abnormal_output = "remove-abnormal-output" + + +class ErrorHandleTypeEnum(str, Enum): + none = ("none",) + failBranch = ("fail-branch",) + defaultValue = ("default-value",) + + +# Branch 类型 +class Branch(BaseModel): + id: str + name: str + + +# NodeRunningStatus 枚举 +class NodeRunningStatus(str, Enum): + not_start = "not-start" + waiting = "waiting" + running = "running" + succeeded = "succeeded" + failed = "failed" + exception = "exception" + retry = "retry" + + +# 创建一个基类来统一CommonNodeType和CommonEdgeType的序列化逻辑 +class BaseType(BaseModel): + """基类,用于统一CommonNodeType和CommonEdgeType的序列化逻辑""" + + def to_json(self) -> dict[str, Any]: + """ + 将对象转换为JSON格式的字典,通过循环模型字段来构建JSON数据 + """ + json_data = {} + + # 获取模型的所有字段 + for field_name, field_value in self.__dict__.items(): + if field_value is not None: + # 特殊处理Branch类型的列表 + if field_name == "_targetBranches" and field_value is not None: + json_data[field_name] = [branch.dict(exclude_none=True) for branch in field_value] + # 处理枚举类型 + elif isinstance(field_value, Enum): + json_data[field_name] = field_value.value + # 处理嵌套的Pydantic模型 + elif hasattr(field_value, "dict") and callable(field_value.dict): + json_data[field_name] = field_value.dict(exclude_none=True) + # 处理列表中的Pydantic模型 + elif isinstance(field_value, list): + processed_list = [] + for item in field_value: + if hasattr(item, "dict") and callable(item.dict): + processed_list.append(item.dict(exclude_none=True)) + else: + processed_list.append(item) + json_data[field_name] = processed_list + # 处理字典中的Pydantic模型 + elif isinstance(field_value, dict): + processed_dict = {} + for key, value in field_value.items(): + if hasattr(value, "dict") and callable(value.dict): + processed_dict[key] = value.dict(exclude_none=True) + else: + processed_dict[key] = value + json_data[field_name] = processed_dict + # 其他字段直接添加 + else: + json_data[field_name] = field_value + + return json_data + + +# CommonNodeType 类型 +class CommonNodeType(BaseType): + _connectedSourceHandleIds: Optional[list[str]] = None + _connectedTargetHandleIds: Optional[list[str]] = None + _targetBranches: Optional[list[Branch]] = None + _isSingleRun: Optional[bool] = None + _runningStatus: Optional[NodeRunningStatus] = None + _singleRunningStatus: Optional[NodeRunningStatus] = None + _isCandidate: Optional[bool] = None + _isBundled: Optional[bool] = None + _children: Optional[list[str]] = None + _isEntering: Optional[bool] = None + _showAddVariablePopup: Optional[bool] = None + _holdAddVariablePopup: Optional[bool] = None + _iterationLength: Optional[int] = None + _iterationIndex: Optional[int] = None + _inParallelHovering: Optional[bool] = None + isInIteration: Optional[bool] = None + iteration_id: Optional[str] = None + selected: Optional[bool] = None + title: str + desc: str + type: BlockEnum + width: Optional[float] = None + height: Optional[float] = None + + @classmethod + def get_all_required_fields(cls) -> dict[str, str]: + """ + 获取所有必选字段,包括从父类继承的字段 + 这是一个类方法,可以通过类直接调用 + """ + all_required_fields = {} + + # 获取所有父类(除了 object 和 BaseModel) + mro = [c for c in cls.__mro__ if c not in (object, BaseModel, BaseType)] + + # 从父类到子类的顺序处理,这样子类的字段会覆盖父类的同名字段 + for class_type in reversed(mro): + if hasattr(class_type, "__annotations__"): + for field_name, field_info in class_type.__annotations__.items(): + # 检查字段是否有默认值 + has_default = hasattr(class_type, field_name) + # 检查字段是否为可选类型 + is_optional = "Optional" in str(field_info) + + # 如果字段没有默认值且不是Optional类型,则为必选字段 + if not has_default and not is_optional: + all_required_fields[field_name] = str(field_info) + + return all_required_fields + + +# CommonEdgeType 类型 +class CommonEdgeType(BaseType): + _hovering: Optional[bool] = None + _connectedNodeIsHovering: Optional[bool] = None + _connectedNodeIsSelected: Optional[bool] = None + _run: Optional[bool] = None + _isBundled: Optional[bool] = None + isInIteration: Optional[bool] = None + iteration_id: Optional[str] = None + sourceType: BlockEnum + targetType: BlockEnum + + +class ValueSelector(BaseModel): + """Value selector for selecting values from other nodes.""" + + value: list[str] = Field(default_factory=list) + + def dict(self, *args, **kwargs): + """自定义序列化方法,直接返回 value 列表""" + return self.value + + +# Add Context class for LLM node +class Context(BaseModel): + """Context configuration for LLM node.""" + + enabled: bool = False + variable_selector: Optional[ValueSelector] = None + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保 variable_selector 字段正确序列化""" + result = {"enabled": self.enabled} + + if self.variable_selector: + result["variable_selector"] = self.variable_selector.dict() + else: + result["variable_selector"] = [] + + return result + + +# Variable 类型 +class Variable(BaseModel): + """ + 变量类型,用于定义节点的输入/输出变量 + 与Dify中的Variable类型保持一致 + """ + + variable: str # 变量名 + label: Optional[Union[str, dict[str, str]]] = None # 变量标签,可以是字符串或对象 + value_selector: list[str] # 变量值选择器,格式为[nodeId, key] + variable_type: Optional[str] = None # 变量类型,对应Dify中的VarType枚举 + value: Optional[str] = None # 变量值(常量值) + options: Optional[list[str]] = None # 选项列表(用于select类型) + required: Optional[bool] = None # 是否必填 + isParagraph: Optional[bool] = None # 是否为段落 + max_length: Optional[int] = None # 最大长度 + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保正确序列化""" + result = {"variable": self.variable} + + if self.label is not None: + result["label"] = self.label + + if self.value_selector: + result["value_selector"] = self.value_selector + + if self.variable_type is not None: + result["type"] = self.variable_type # 使用type而不是variable_type,与Dify保持一致 + + if self.value is not None: + result["value"] = self.value + + if self.options is not None: + result["options"] = self.options + + if self.required is not None: + result["required"] = self.required + + if self.isParagraph is not None: + result["isParagraph"] = self.isParagraph + + if self.max_length is not None: + result["max_length"] = self.max_length + + return result + + +# EnvironmentVariable 类型 +class EnvironmentVariable(BaseModel): + id: str + name: str + value: Any + value_type: str # Expecting to be either 'string', 'number', or 'secret' + + +# ConversationVariable 类型 +class ConversationVariable(BaseModel): + id: str + name: str + value_type: str + value: Any + description: str + + +# GlobalVariable 类型 +class GlobalVariable(BaseModel): + name: str + value_type: str # Expecting to be either 'string' or 'number' + description: str + + +# VariableWithValue 类型 +class VariableWithValue(BaseModel): + key: str + value: str + + +# InputVarType 枚举 +class InputVarType(str, Enum): + text_input = "text-input" + paragraph = "paragraph" + select = "select" + number = "number" + url = "url" + files = "files" + json = "json" + contexts = "contexts" + iterator = "iterator" + file = "file" + file_list = "file-list" + + +# InputVar 类型 +class InputVar(BaseModel): + type: InputVarType + label: Union[str, dict[str, Any]] # 可以是字符串或对象 + variable: str + max_length: Optional[int] = None + default: Optional[str] = None + required: bool + hint: Optional[str] = None + options: Optional[list[str]] = None + value_selector: Optional[list[str]] = None + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保正确序列化""" + result = { + "type": self.type.value if isinstance(self.type, Enum) else self.type, + "label": self.label, + "variable": self.variable, + "required": self.required, + } + + if self.max_length is not None: + result["max_length"] = self.max_length + + if self.default is not None: + result["default"] = self.default + + if self.hint is not None: + result["hint"] = self.hint + + if self.options is not None: + result["options"] = self.options + + if self.value_selector is not None: + result["value_selector"] = self.value_selector + + return result + + +# ModelConfig 类型 +class ModelConfig(BaseModel): + provider: str + name: str + mode: str + completion_params: dict[str, Any] + + +# PromptRole 枚举 +class PromptRole(str, Enum): + system = "system" + user = "user" + assistant = "assistant" + + +# EditionType 枚举 +class EditionType(str, Enum): + basic = "basic" + jinja2 = "jinja2" + + +# PromptItem 类型 +class PromptItem(BaseModel): + id: Optional[str] = None + role: Optional[PromptRole] = None + text: str + edition_type: Optional[EditionType] = None + jinja2_text: Optional[str] = None + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保 role 字段正确序列化""" + result = {"id": self.id, "text": self.text} + + if self.role: + result["role"] = self.role.value + + if self.edition_type: + result["edition_type"] = self.edition_type.value + + if self.jinja2_text: + result["jinja2_text"] = self.jinja2_text + + return result + + +# MemoryRole 枚举 +class MemoryRole(str, Enum): + user = "user" + assistant = "assistant" + + +# RolePrefix 类型 +class RolePrefix(BaseModel): + user: str + assistant: str + + +# Memory 类型 +class Memory(BaseModel): + role_prefix: Optional[RolePrefix] = None + window: dict[str, Any] # Expecting to have 'enabled' and 'size' + query_prompt_template: str + + +# VarType 枚举 +class VarType(str, Enum): + string = "string" + number = "number" + secret = "secret" + boolean = "boolean" + object = "object" + file = "file" + array = "array" + arrayString = "array[string]" + arrayNumber = "array[number]" + arrayObject = "array[object]" + arrayFile = "array[file]" + any = "any" + + +# Var 类型 +class Var(BaseModel): + variable: str + type: VarType + children: Optional[list["Var"]] = None # Self-reference + isParagraph: Optional[bool] = None + isSelect: Optional[bool] = None + options: Optional[list[str]] = None + required: Optional[bool] = None + des: Optional[str] = None + isException: Optional[bool] = None + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保type字段正确序列化""" + result = {"variable": self.variable, "type": self.type.value if isinstance(self.type, Enum) else self.type} + + if self.children is not None: + result["children"] = [child.dict() for child in self.children] + + if self.isParagraph is not None: + result["isParagraph"] = self.isParagraph + + if self.isSelect is not None: + result["isSelect"] = self.isSelect + + if self.options is not None: + result["options"] = self.options + + if self.required is not None: + result["required"] = self.required + + if self.des is not None: + result["des"] = self.des + + if self.isException is not None: + result["isException"] = self.isException + + return result + + +# NodeOutPutVar 类型 +class NodeOutPutVar(BaseModel): + nodeId: str + title: str + vars: list[Var] + isStartNode: Optional[bool] = None + + +# Block 类型 +class Block(BaseModel): + classification: Optional[str] = None + type: BlockEnum + title: str + description: Optional[str] = None + + +# NodeDefault 类型 +class NodeDefault(BaseModel): + defaultValue: dict[str, Any] + getAvailablePrevNodes: Any # Placeholder for function reference + getAvailableNextNodes: Any # Placeholder for function reference + checkValid: Any # Placeholder for function reference + + +# OnSelectBlock 类型 +class OnSelectBlock(BaseModel): + nodeType: BlockEnum + additional_data: Optional[dict[str, Any]] = None + + +# WorkflowRunningStatus 枚举 +class WorkflowRunningStatus(str, Enum): + waiting = "waiting" + running = "running" + succeeded = "succeeded" + failed = "failed" + stopped = "stopped" + + +# WorkflowVersion 枚举 +class WorkflowVersion(str, Enum): + draft = "draft" + latest = "latest" + + +# OnNodeAdd 类型 +class OnNodeAdd(BaseModel): + nodeType: BlockEnum + sourceHandle: Optional[str] = None + targetHandle: Optional[str] = None + toolDefaultValue: Optional[dict[str, Any]] = None + + +# CheckValidRes 类型 +class CheckValidRes(BaseModel): + isValid: bool + errorMessage: Optional[str] = None + + +# RunFile 类型 +class RunFile(BaseModel): + type: str + transfer_method: list[str] + url: Optional[str] = None + upload_file_id: Optional[str] = None + + +# WorkflowRunningData 类型 +class WorkflowRunningData(BaseModel): + task_id: Optional[str] = None + message_id: Optional[str] = None + conversation_id: Optional[str] = None + result: dict[str, Any] # Expecting a structured object + tracing: Optional[list[dict[str, Any]]] = None # Placeholder for NodeTracing + + +# HistoryWorkflowData 类型 +class HistoryWorkflowData(BaseModel): + id: str + sequence_number: int + status: str + conversation_id: Optional[str] = None + + +# ChangeType 枚举 +class ChangeType(str, Enum): + changeVarName = "changeVarName" + remove = "remove" + + +# MoreInfo 类型 +class MoreInfo(BaseModel): + type: ChangeType + payload: Optional[dict[str, Any]] = None + + +# ToolWithProvider 类型 +class ToolWithProvider(BaseModel): + tools: list[dict[str, Any]] # Placeholder for Tool type + + +# SupportUploadFileTypes 枚举 +class SupportUploadFileTypes(str, Enum): + image = "image" + document = "document" + audio = "audio" + video = "video" + custom = "custom" + + +# UploadFileSetting 类型 +class UploadFileSetting(BaseModel): + allowed_file_upload_methods: list[str] + allowed_file_types: list[SupportUploadFileTypes] + allowed_file_extensions: Optional[list[str]] = None + max_length: int + number_limits: Optional[int] = None + + +# VisionSetting 类型 +class VisionSetting(BaseModel): + variable_selector: list[str] + detail: dict[str, Any] # Placeholder for Resolution type + + +# 创建一个基类来统一序列化逻辑 +class CompleteBase(BaseModel): + """基类,用于统一CompleteNode和CompleteEdge的序列化逻辑""" + + def to_json(self): + """将对象转换为JSON格式的字典""" + json_data = {} + + # 获取模型的所有字段 + for field_name, field_value in self.__dict__.items(): + if field_value is not None: + # 处理嵌套的数据对象 + if field_name == "data" and hasattr(field_value, "to_json"): + json_data[field_name] = field_value.to_json() + # 处理枚举类型 + elif isinstance(field_value, Enum): + json_data[field_name] = field_value.value + # 处理嵌套的Pydantic模型 + elif hasattr(field_value, "dict") and callable(field_value.dict): + json_data[field_name] = field_value.dict(exclude_none=True) + # 处理列表中的Pydantic模型 + elif isinstance(field_value, list): + processed_list = [] + for item in field_value: + if hasattr(item, "dict") and callable(item.dict): + processed_list.append(item.dict(exclude_none=True)) + else: + processed_list.append(item) + json_data[field_name] = processed_list + # 处理字典中的Pydantic模型 + elif isinstance(field_value, dict): + processed_dict = {} + for key, value in field_value.items(): + if hasattr(value, "dict") and callable(value.dict): + processed_dict[key] = value.dict(exclude_none=True) + else: + processed_dict[key] = value + json_data[field_name] = processed_dict + # 其他字段直接添加 + else: + json_data[field_name] = field_value + + return json_data + + def to_yaml(self): + """将对象转换为YAML格式的字符串""" + return yaml.dump(self.to_json(), allow_unicode=True) + + +class CompleteNode(CompleteBase): + id: str + position: dict + height: int + width: float + positionAbsolute: dict + selected: bool + sourcePosition: Union[dict, str] + targetPosition: Union[dict, str] + type: str + data: Optional[Union[CommonNodeType, None]] = None # Flexible field to store CommonNodeType or None + + def add_data(self, data: Union[CommonNodeType, None]): + self.data = data + + def to_json(self): + json_data = super().to_json() + + # 特殊处理sourcePosition和targetPosition + json_data["sourcePosition"] = "right" # 直接输出为字符串"right" + json_data["targetPosition"] = "left" # 直接输出为字符串"left" + + # 确保 width 是整数而不是浮点数 + if isinstance(json_data["width"], float): + json_data["width"] = int(json_data["width"]) + + return json_data + + +class CompleteEdge(CompleteBase): + id: str + source: str + sourceHandle: str + target: str + targetHandle: str + type: str + zIndex: int + data: Optional[Union[CommonEdgeType, None]] = None # Flexible field to store CommonEdgeType or None + + def add_data(self, data: Union[CommonEdgeType, None]): + self.data = data + + +# 示例用法 +if __name__ == "__main__": + # 这里可以添加示例数据进行验证 + common_node = CompleteNode( + id="1740019130520", + position={"x": 80, "y": 282}, + height=100, + width=100, + positionAbsolute={"x": 80, "y": 282}, + selected=True, + sourcePosition={"x": 80, "y": 282}, + targetPosition={"x": 80, "y": 282}, + type="custom", + ) + common_data = CommonNodeType(title="示例节点", desc="这是一个示例节点", type="") + print(CommonNodeType.get_all_required_fields()) + common_node.add_data(common_data) + # print(common_node) diff --git a/api/core/auto/node_types/end.py b/api/core/auto/node_types/end.py new file mode 100644 index 0000000000..12401dc402 --- /dev/null +++ b/api/core/auto/node_types/end.py @@ -0,0 +1,22 @@ +from .common import BlockEnum, CommonNodeType, Variable + +# Import previously defined CommonNodeType and Variable +# Assume they are defined in the same module + + +class EndNodeType(CommonNodeType): + outputs: list[Variable] + + +# Example usage +if __name__ == "__main__": + example_node = EndNodeType( + title="Example End Node", + desc="An end node example", + type=BlockEnum.end, + outputs=[ + Variable(variable="outputVar1", value_selector=["node1", "key1"]), + Variable(variable="outputVar2", value_selector=["node2", "key2"]), + ], + ) + print(example_node) diff --git a/api/core/auto/node_types/http.py b/api/core/auto/node_types/http.py new file mode 100644 index 0000000000..f49e95b60e --- /dev/null +++ b/api/core/auto/node_types/http.py @@ -0,0 +1,127 @@ +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, ValueSelector, Variable + +# Import previously defined CommonNodeType, ValueSelector, and Variable +# Assume they are defined in the same module + + +class Method(str, Enum): + """HTTP request methods.""" + + get = "get" + post = "post" + head = "head" + patch = "patch" + put = "put" + delete = "delete" + + +class BodyType(str, Enum): + """HTTP request body types.""" + + none = "none" + formData = "form-data" + xWwwFormUrlencoded = "x-www-form-urlencoded" + rawText = "raw-text" + json = "json" + binary = "binary" + + +class BodyPayloadValueType(str, Enum): + """Types of values in body payload.""" + + text = "text" + file = "file" + + +class BodyPayload(BaseModel): + """Body payload item for HTTP requests.""" + + id: Optional[str] = None + key: Optional[str] = None + type: BodyPayloadValueType + file: Optional[ValueSelector] = None # Used when type is file + value: Optional[str] = None # Used when type is text + + +class Body(BaseModel): + """HTTP request body configuration.""" + + type: BodyType + data: Union[str, list[BodyPayload]] # string is deprecated, will convert to BodyPayload + + +class AuthorizationType(str, Enum): + """HTTP authorization types.""" + + none = "no-auth" + apiKey = "api-key" + + +class APIType(str, Enum): + """API key types.""" + + basic = "basic" + bearer = "bearer" + custom = "custom" + + +class AuthConfig(BaseModel): + """Authorization configuration.""" + + type: APIType + api_key: str + header: Optional[str] = None + + +class Authorization(BaseModel): + """HTTP authorization settings.""" + + type: AuthorizationType + config: Optional[AuthConfig] = None + + +class Timeout(BaseModel): + """HTTP request timeout settings.""" + + connect: Optional[int] = None + read: Optional[int] = None + write: Optional[int] = None + max_connect_timeout: Optional[int] = None + max_read_timeout: Optional[int] = None + max_write_timeout: Optional[int] = None + + +class HttpNodeType(CommonNodeType): + """HTTP request node type implementation.""" + + variables: list[Variable] + method: Method + url: str + headers: str + params: str + body: Body + authorization: Authorization + timeout: Timeout + + +# Example usage +if __name__ == "__main__": + example_node = HttpNodeType( + title="Example HTTP Node", + desc="An HTTP request node example", + type=BlockEnum.http_request, + variables=[Variable(variable="var1", value_selector=["node1", "key1"])], + method=Method.get, + url="https://api.example.com/data", + headers="{}", + params="{}", + body=Body(type=BodyType.none, data=[]), + authorization=Authorization(type=AuthorizationType.none), + timeout=Timeout(connect=30, read=30, write=30), + ) + print(example_node) diff --git a/api/core/auto/node_types/if_else.py b/api/core/auto/node_types/if_else.py new file mode 100644 index 0000000000..0bc5df084d --- /dev/null +++ b/api/core/auto/node_types/if_else.py @@ -0,0 +1,99 @@ +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, ValueSelector, VarType +from .tool import VarType as NumberVarType + +# Import previously defined CommonNodeType, ValueSelector, Var, and VarType +# Assume they are defined in the same module + + +class LogicalOperator(str, Enum): + and_ = "and" + or_ = "or" + + +class ComparisonOperator(str, Enum): + contains = "contains" + notContains = "not contains" + startWith = "start with" + endWith = "end with" + is_ = "is" + isNot = "is not" + empty = "empty" + notEmpty = "not empty" + equal = "=" + notEqual = "≠" + largerThan = ">" + lessThan = "<" + largerThanOrEqual = "≥" + lessThanOrEqual = "≤" + isNull = "is null" + isNotNull = "is not null" + in_ = "in" + notIn = "not in" + allOf = "all of" + exists = "exists" + notExists = "not exists" + equals = "=" # Alias for equal for compatibility + + +class Condition(BaseModel): + id: str + varType: VarType + variable_selector: Optional[ValueSelector] + key: Optional[str] = None # Sub variable key + comparison_operator: Optional[ComparisonOperator] = None + value: Union[str, list[str]] + numberVarType: Optional[NumberVarType] + sub_variable_condition: Optional["CaseItem"] = None # Recursive reference + + +class CaseItem(BaseModel): + case_id: str + logical_operator: LogicalOperator + conditions: list[Condition] + + +class IfElseNodeType(CommonNodeType): + logical_operator: Optional[LogicalOperator] = None + conditions: Optional[list[Condition]] = None + cases: list[CaseItem] + isInIteration: bool + + +# Example usage +if __name__ == "__main__": + example_node = IfElseNodeType( + title="Example IfElse Node", + desc="An if-else node example", + type=BlockEnum.if_else, + logical_operator=LogicalOperator.and_, + conditions=[ + Condition( + id="condition1", + varType=VarType.string, + variable_selector={"nodeId": "varNode", "key": "value"}, + comparison_operator=ComparisonOperator.is_, + value="exampleValue", + ) + ], + cases=[ + CaseItem( + case_id="case1", + logical_operator=LogicalOperator.or_, + conditions=[ + Condition( + id="condition2", + varType=VarType.number, + value="10", + comparison_operator=ComparisonOperator.largerThan, + ) + ], + ) + ], + isInIteration=True, + ) + print(example_node) diff --git a/api/core/auto/node_types/iteration.py b/api/core/auto/node_types/iteration.py new file mode 100644 index 0000000000..617125653c --- /dev/null +++ b/api/core/auto/node_types/iteration.py @@ -0,0 +1,45 @@ +from enum import Enum +from typing import Optional + +from .common import BlockEnum, CommonNodeType, ValueSelector, VarType + + +class ErrorHandleMode(str, Enum): + """Error handling modes for iteration.""" + + terminated = "terminated" + continue_on_error = "continue-on-error" + remove_abnormal_output = "remove-abnormal-output" + + +class IterationNodeType(CommonNodeType): + """Iteration node type implementation.""" + + startNodeType: Optional[BlockEnum] = None + start_node_id: str # Start node ID in the iteration + iteration_id: Optional[str] = None + iterator_selector: ValueSelector + output_selector: ValueSelector + output_type: VarType # Output type + is_parallel: bool # Open the parallel mode or not + parallel_nums: int # The numbers of parallel + error_handle_mode: ErrorHandleMode # How to handle error in the iteration + _isShowTips: bool # Show tips when answer node in parallel mode iteration + + +# 示例用法 +if __name__ == "__main__": + example_node = IterationNodeType( + title="Example Iteration Node", + desc="An iteration node example", + type=BlockEnum.iteration, + start_node_id="startNode1", + iterator_selector=ValueSelector(value=["iteratorNode", "value"]), + output_selector=ValueSelector(value=["outputNode", "value"]), + output_type=VarType.string, + is_parallel=True, + parallel_nums=5, + error_handle_mode=ErrorHandleMode.continue_on_error, + _isShowTips=True, + ) + print(example_node) diff --git a/api/core/auto/node_types/iteration_start.py b/api/core/auto/node_types/iteration_start.py new file mode 100644 index 0000000000..b87ca7cb80 --- /dev/null +++ b/api/core/auto/node_types/iteration_start.py @@ -0,0 +1,25 @@ +from .common import BlockEnum, CommonNodeType + +# 引入之前定义的 CommonNodeType +# 假设它们在同一模块中定义 + + +class IterationStartNodeType(CommonNodeType): + """ + Iteration Start node type implementation. + + This node type is used as the starting point within an iteration block. + It inherits all properties from CommonNodeType without adding any additional fields. + """ + + pass # 仅仅继承 CommonNodeType,无其他字段 + + +# 示例用法 +if __name__ == "__main__": + example_node = IterationStartNodeType( + title="Example Iteration Start Node", + desc="An iteration start node example", + type=BlockEnum.iteration_start, + ) + print(example_node) diff --git a/api/core/auto/node_types/knowledge_retrieval.py b/api/core/auto/node_types/knowledge_retrieval.py new file mode 100644 index 0000000000..95997e60d6 --- /dev/null +++ b/api/core/auto/node_types/knowledge_retrieval.py @@ -0,0 +1,115 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, ModelConfig, ValueSelector + + +class RetrieveType(str, Enum): + """Retrieval mode types.""" + + single = "single" + multiple = "multiple" + + +class RerankingModeEnum(str, Enum): + """Reranking mode types.""" + + simple = "simple" + advanced = "advanced" + + +class VectorSetting(BaseModel): + """Vector weight settings.""" + + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSetting(BaseModel): + """Keyword weight settings.""" + + keyword_weight: float + + +class Weights(BaseModel): + """Weight configuration for retrieval.""" + + vector_setting: VectorSetting + keyword_setting: KeywordSetting + + +class RerankingModel(BaseModel): + """Reranking model configuration.""" + + provider: str + model: str + + +class MultipleRetrievalConfig(BaseModel): + """Configuration for multiple retrieval mode.""" + + top_k: int + score_threshold: Optional[float] = None + reranking_model: Optional[RerankingModel] = None + reranking_mode: Optional[RerankingModeEnum] = None + weights: Optional[Weights] = None + reranking_enable: Optional[bool] = None + + +class SingleRetrievalConfig(BaseModel): + """Configuration for single retrieval mode.""" + + model: ModelConfig + + +class DataSet(BaseModel): + """Dataset information.""" + + id: str + name: str + description: Optional[str] = None + + +class KnowledgeRetrievalNodeType(CommonNodeType): + """Knowledge retrieval node type implementation.""" + + query_variable_selector: ValueSelector + dataset_ids: list[str] + retrieval_mode: RetrieveType + multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None + single_retrieval_config: Optional[SingleRetrievalConfig] = None + _datasets: Optional[list[DataSet]] = None + + +# Example usage +if __name__ == "__main__": + example_node = KnowledgeRetrievalNodeType( + title="Example Knowledge Retrieval Node", + desc="A knowledge retrieval node example", + type=BlockEnum.knowledge_retrieval, + query_variable_selector=ValueSelector(value=["queryNode", "query"]), + dataset_ids=["dataset1", "dataset2"], + retrieval_mode=RetrieveType.multiple, + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=10, + score_threshold=0.5, + reranking_model=RerankingModel(provider="example_provider", model="example_model"), + reranking_mode=RerankingModeEnum.simple, + weights=Weights( + vector_setting=VectorSetting( + vector_weight=0.7, embedding_provider_name="provider1", embedding_model_name="model1" + ), + keyword_setting=KeywordSetting(keyword_weight=0.3), + ), + reranking_enable=True, + ), + single_retrieval_config=SingleRetrievalConfig( + model=ModelConfig( + provider="example_provider", name="example_model", mode="chat", completion_params={"temperature": 0.7} + ) + ), + ) + print(example_node) diff --git a/api/core/auto/node_types/list_operator.py b/api/core/auto/node_types/list_operator.py new file mode 100644 index 0000000000..02b70fda69 --- /dev/null +++ b/api/core/auto/node_types/list_operator.py @@ -0,0 +1,73 @@ +from enum import Enum +from typing import Optional, Union + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, ValueSelector, VarType + +# Import ComparisonOperator from if_else.py +from .if_else import ComparisonOperator + + +class OrderBy(str, Enum): + ASC = "asc" + DESC = "desc" + + +class Limit(BaseModel): + enabled: bool + size: Optional[int] = None + + +class Condition(BaseModel): + key: str + comparison_operator: ComparisonOperator + value: Union[str, int, list[str]] + + +class FilterBy(BaseModel): + enabled: bool + conditions: list[Condition] + + +class ExtractBy(BaseModel): + enabled: bool + serial: Optional[str] = None + + +class OrderByConfig(BaseModel): + enabled: bool + key: Union[ValueSelector, str] + value: OrderBy + + +class ListFilterNodeType(CommonNodeType): + """List filter/operator node type implementation.""" + + variable: ValueSelector + var_type: VarType + item_var_type: VarType + filter_by: FilterBy + extract_by: ExtractBy + order_by: OrderByConfig + limit: Limit + + +# 示例用法 +if __name__ == "__main__": + example_node = ListFilterNodeType( + title="Example List Filter Node", + desc="A list filter node example", + type=BlockEnum.list_operator, # Fixed: use list_operator instead of list_filter + variable=ValueSelector(value=["varNode", "value"]), + var_type=VarType.string, + item_var_type=VarType.number, + filter_by=FilterBy( + enabled=True, + conditions=[Condition(key="status", comparison_operator=ComparisonOperator.equals, value="active")], + ), + extract_by=ExtractBy(enabled=True, serial="serial_1"), + order_by=OrderByConfig(enabled=True, key="created_at", value=OrderBy.DESC), + limit=Limit(enabled=True, size=100), + ) + print(example_node) diff --git a/api/core/auto/node_types/llm.py b/api/core/auto/node_types/llm.py new file mode 100644 index 0000000000..699356f20a --- /dev/null +++ b/api/core/auto/node_types/llm.py @@ -0,0 +1,66 @@ +from typing import Optional, Union + +from pydantic import BaseModel + +from .common import ( + BlockEnum, + CommonNodeType, + Context, + Memory, + ModelConfig, + PromptItem, + Variable, + VisionSetting, +) + + +class PromptConfig(BaseModel): + """Configuration for prompt template variables.""" + + jinja2_variables: Optional[list[Variable]] = None + + +class VisionConfig(BaseModel): + """Configuration for vision settings.""" + + enabled: bool = False + configs: Optional[VisionSetting] = None + + def dict(self, *args, **kwargs): + """自定义序列化方法,确保正确序列化""" + result = {"enabled": self.enabled} + + if self.configs: + result["configs"] = self.configs.dict() + + return result + + +class LLMNodeType(CommonNodeType): + """LLM node type implementation.""" + + model: ModelConfig + prompt_template: Union[list[PromptItem], PromptItem] + prompt_config: Optional[PromptConfig] = None + memory: Optional[Memory] = None + context: Optional[Context] = Context(enabled=False, variable_selector=None) + vision: Optional[VisionConfig] = VisionConfig(enabled=False) + + +# 示例用法 +if __name__ == "__main__": + example_node = LLMNodeType( + title="Example LLM Node", + desc="A LLM node example", + type=BlockEnum.llm, + model=ModelConfig(provider="zhipuai", name="glm-4-flash", mode="chat", completion_params={"temperature": 0.7}), + prompt_template=[ + PromptItem( + id="system-id", role="system", text="你是一个代码工程师,你会根据用户的需求给出用户所需要的函数" + ), + PromptItem(id="user-id", role="user", text="给出两数相加的python 函数代码,函数名 func 不要添加其他内容"), + ], + context=Context(enabled=False, variable_selector=None), + vision=VisionConfig(enabled=False), + ) + print(example_node) diff --git a/api/core/auto/node_types/note_node.py b/api/core/auto/node_types/note_node.py new file mode 100644 index 0000000000..4175fb5e40 --- /dev/null +++ b/api/core/auto/node_types/note_node.py @@ -0,0 +1,38 @@ +from enum import Enum + +from .common import BlockEnum, CommonNodeType + +# Import previously defined CommonNodeType +# Assume it is defined in the same module + + +class NoteTheme(str, Enum): + blue = "blue" + cyan = "cyan" + green = "green" + yellow = "yellow" + pink = "pink" + violet = "violet" + + +class NoteNodeType(CommonNodeType): + """Custom note node type implementation.""" + + text: str + theme: NoteTheme + author: str + showAuthor: bool + + +# Example usage +if __name__ == "__main__": + example_node = NoteNodeType( + title="Example Note Node", + desc="A note node example", + type=BlockEnum.custom_note, + text="This is a note.", + theme=NoteTheme.green, + author="John Doe", + showAuthor=True, + ) + print(example_node) diff --git a/api/core/auto/node_types/parameter_extractor.py b/api/core/auto/node_types/parameter_extractor.py new file mode 100644 index 0000000000..cb6ef3d0bd --- /dev/null +++ b/api/core/auto/node_types/parameter_extractor.py @@ -0,0 +1,85 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, Memory, ModelConfig, ValueSelector, VisionSetting + +# Import previously defined CommonNodeType, Memory, ModelConfig, ValueSelector, and VisionSetting +# Assume they are defined in the same module + + +class ParamType(str, Enum): + """Parameter types for extraction.""" + + string = "string" + number = "number" + bool = "bool" + select = "select" + arrayString = "array[string]" + arrayNumber = "array[number]" + arrayObject = "array[object]" + + +class Param(BaseModel): + """Parameter definition for extraction.""" + + name: str + type: ParamType + options: Optional[list[str]] = None + description: str + required: Optional[bool] = None + + +class ReasoningModeType(str, Enum): + """Reasoning mode types for parameter extraction.""" + + prompt = "prompt" + functionCall = "function_call" + + +class VisionConfig(BaseModel): + """Vision configuration.""" + + enabled: bool + configs: Optional[VisionSetting] = None + + +class ParameterExtractorNodeType(CommonNodeType): + """Parameter extractor node type implementation.""" + + model: ModelConfig + query: ValueSelector + reasoning_mode: ReasoningModeType + parameters: List[Param] + instruction: str + memory: Optional[Memory] = None + vision: VisionConfig + + +# Example usage +if __name__ == "__main__": + example_node = ParameterExtractorNodeType( + title="Example Parameter Extractor Node", + desc="A parameter extractor node example", + type=BlockEnum.parameter_extractor, + model=ModelConfig( + provider="example_provider", name="example_model", mode="chat", completion_params={"temperature": 0.7} + ), + query=ValueSelector(value=["queryNode", "value"]), + reasoning_mode=ReasoningModeType.prompt, + parameters=[ + Param(name="param1", type=ParamType.string, description="This is a string parameter", required=True), + Param( + name="param2", + type=ParamType.number, + options=["1", "2", "3"], + description="This is a number parameter", + required=False, + ), + ], + instruction="Please extract the parameters from the input.", + memory=Memory(window={"enabled": True, "size": 10}, query_prompt_template="Extract parameters from: {{query}}"), + vision=VisionConfig(enabled=True, configs={"setting": "example_setting"}), + ) + print(example_node) diff --git a/api/core/auto/node_types/question_classifier.py b/api/core/auto/node_types/question_classifier.py new file mode 100644 index 0000000000..af4f53cb74 --- /dev/null +++ b/api/core/auto/node_types/question_classifier.py @@ -0,0 +1,51 @@ +from typing import Optional + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, Memory, ModelConfig, ValueSelector, VisionSetting + +# Import previously defined CommonNodeType, Memory, ModelConfig, ValueSelector, and VisionSetting +# Assume they are defined in the same module + + +class Topic(BaseModel): + """Topic for classification.""" + + id: str + name: str + + +class VisionConfig(BaseModel): + """Vision configuration.""" + + enabled: bool + configs: Optional[VisionSetting] = None + + +class QuestionClassifierNodeType(CommonNodeType): + """Question classifier node type implementation.""" + + query_variable_selector: ValueSelector + model: ModelConfig + classes: list[Topic] + instruction: str + memory: Optional[Memory] = None + vision: VisionConfig + + +# Example usage +if __name__ == "__main__": + example_node = QuestionClassifierNodeType( + title="Example Question Classifier Node", + desc="A question classifier node example", + type=BlockEnum.question_classifier, + query_variable_selector=ValueSelector(value=["queryNode", "value"]), + model=ModelConfig( + provider="example_provider", name="example_model", mode="chat", completion_params={"temperature": 0.7} + ), + classes=[Topic(id="1", name="Science"), Topic(id="2", name="Mathematics"), Topic(id="3", name="Literature")], + instruction="Classify the given question into the appropriate topic.", + memory=Memory(window={"enabled": True, "size": 10}, query_prompt_template="Classify this question: {{query}}"), + vision=VisionConfig(enabled=True, configs={"setting": "example_setting"}), + ) + print(example_node) diff --git a/api/core/auto/node_types/start.py b/api/core/auto/node_types/start.py new file mode 100644 index 0000000000..3074df2251 --- /dev/null +++ b/api/core/auto/node_types/start.py @@ -0,0 +1,22 @@ +from .common import BlockEnum, CommonNodeType, InputVar + +# Import previously defined CommonNodeType and InputVar +# Assume they are defined in the same module + + +class StartNodeType(CommonNodeType): + variables: list[InputVar] + + +# Example usage +if __name__ == "__main__": + example_node = StartNodeType( + title="Example Start Node", + desc="A start node example", + type=BlockEnum.start, + variables=[ + InputVar(type="text-input", label="Input 1", variable="input1", required=True), + InputVar(type="number", label="Input 2", variable="input2", required=True), + ], + ) + print(example_node) diff --git a/api/core/auto/node_types/template_transform.py b/api/core/auto/node_types/template_transform.py new file mode 100644 index 0000000000..19339fae0a --- /dev/null +++ b/api/core/auto/node_types/template_transform.py @@ -0,0 +1,26 @@ +from .common import BlockEnum, CommonNodeType, Variable + +# 引入之前定义的 CommonNodeType 和 Variable +# 假设它们在同一模块中定义 + + +class TemplateTransformNodeType(CommonNodeType): + """Template transform node type implementation.""" + + variables: list[Variable] + template: str + + +# 示例用法 +if __name__ == "__main__": + example_node = TemplateTransformNodeType( + title="Example Template Transform Node", + desc="A template transform node example", + type=BlockEnum.template_transform, + variables=[ + Variable(variable="var1", value_selector=["node1", "key1"]), + Variable(variable="var2", value_selector=["node2", "key2"]), + ], + template="Hello, {{ var1 }}! You have {{ var2 }} new messages.", + ) + print(example_node) diff --git a/api/core/auto/node_types/tool.py b/api/core/auto/node_types/tool.py new file mode 100644 index 0000000000..d6d006b1d0 --- /dev/null +++ b/api/core/auto/node_types/tool.py @@ -0,0 +1,54 @@ +from enum import Enum +from typing import Any, Optional, Union + +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, ValueSelector + +# Import previously defined CommonNodeType and ValueSelector +# Assume they are defined in the same module + + +class VarType(str, Enum): + variable = "variable" + constant = "constant" + mixed = "mixed" + + +class ToolVarInputs(BaseModel): + type: VarType + value: Optional[Union[str, ValueSelector, Any]] = None + + +class ToolNodeType(CommonNodeType): + """Tool node type implementation.""" + + provider_id: str + provider_type: Any # Placeholder for CollectionType + provider_name: str + tool_name: str + tool_label: str + tool_parameters: dict[str, ToolVarInputs] + tool_configurations: dict[str, Any] + output_schema: dict[str, Any] + + +# Example usage +if __name__ == "__main__": + example_node = ToolNodeType( + title="Example Tool Node", + desc="A tool node example", + type=BlockEnum.tool, + provider_id="12345", + provider_type="some_collection_type", # Placeholder for CollectionType + provider_name="Example Provider", + tool_name="Example Tool", + tool_label="Example Tool Label", + tool_parameters={ + "input1": ToolVarInputs(type=VarType.variable, value="some_value"), + "input2": ToolVarInputs(type=VarType.constant, value="constant_value"), + }, + tool_configurations={"config1": "value1", "config2": {"nested": "value2"}}, + output_schema={"output1": "string", "output2": "number"}, + ) + print(example_node.json(indent=2)) # Print as JSON format for viewing diff --git a/api/core/auto/node_types/variable_assigner.py b/api/core/auto/node_types/variable_assigner.py new file mode 100644 index 0000000000..1ae271294f --- /dev/null +++ b/api/core/auto/node_types/variable_assigner.py @@ -0,0 +1,56 @@ +from pydantic import BaseModel + +from .common import BlockEnum, CommonNodeType, ValueSelector, VarType + + +class VarGroupItem(BaseModel): + """Variable group item configuration.""" + + output_type: VarType + variables: list[ValueSelector] + + +class GroupConfig(VarGroupItem): + """Group configuration for advanced settings.""" + + group_name: str + groupId: str + + +class AdvancedSettings(BaseModel): + """Advanced settings for variable assigner.""" + + group_enabled: bool + groups: list[GroupConfig] + + +class VariableAssignerNodeType(CommonNodeType, VarGroupItem): + """Variable assigner node type implementation.""" + + advanced_settings: AdvancedSettings + + class Config: + arbitrary_types_allowed = True + + +# Example usage +if __name__ == "__main__": + example_node = VariableAssignerNodeType( + title="Example Variable Assigner Node", + desc="A variable assigner node example", + type=BlockEnum.variable_assigner, + output_type=VarType.string, + variables=[ValueSelector(value=["varNode1", "value1"]), ValueSelector(value=["varNode2", "value2"])], + advanced_settings=AdvancedSettings( + group_enabled=True, + groups=[ + GroupConfig( + group_name="Group 1", + groupId="group1", + output_type=VarType.number, + variables=[ValueSelector(value=["varNode3", "value3"])], + ) + ], + ), + ) + print(example_node.json(indent=2)) # Print as JSON format for viewing diff --git a/api/core/auto/output/emotion_analysis_workflow.yml b/api/core/auto/output/emotion_analysis_workflow.yml new file mode 100644 index 0000000000..f3b1e24f31 --- /dev/null +++ b/api/core/auto/output/emotion_analysis_workflow.yml @@ -0,0 +1,239 @@ +app: + description: '' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: 情绪分析工作流 + use_icon_as_answer_icon: false +kind: app +version: 0.1.2 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - id: 1740019130520-source-1740019130521-target + source: '1740019130520' + sourceHandle: source + target: '1740019130521' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: start + targetType: llm + - id: 1740019130521-source-1740019130522-target + source: '1740019130521' + sourceHandle: source + target: '1740019130522' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: llm + targetType: code + - id: 1740019130522-source-1740019130523-target + source: '1740019130522' + sourceHandle: source + target: '1740019130523' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: code + targetType: template-transform + - id: 1740019130523-source-1740019130524-target + source: '1740019130523' + sourceHandle: source + target: '1740019130524' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: template-transform + targetType: end + nodes: + - id: '1740019130520' + position: + x: 80 + y: 282 + height: 116 + width: 244 + positionAbsolute: + x: 80 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 开始节点 + desc: 开始节点,接收用户输入的文本。 + type: start + variables: + - type: text-input + label: input_text + variable: input_text + required: true + max_length: 48 + options: [] + - id: '1740019130521' + position: + x: 380 + y: 282 + height: 98 + width: 244 + positionAbsolute: + x: 380 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: LLM节点 + desc: LLM节点分析文本情绪,识别出积极、消极或中性情绪。 + type: llm + model: + provider: zhipuai + name: glm-4-flash + mode: chat + completion_params: + temperature: 0.7 + prompt_template: + - id: 1740019130521-system + text: 请分析以下文本的情绪,并返回情绪类型(积极、消极或中性)。 + role: system + - id: 1740019130521-user + text: 分析此文本的情绪:{{input_text}} + role: user + context: + enabled: false + variable_selector: [] + vision: + enabled: false + - id: '1740019130522' + position: + x: 680 + y: 282 + height: 54 + width: 244 + positionAbsolute: + x: 680 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 代码节点 + desc: 代码节点将根据LLM分析的结果处理情绪类型。 + type: code + variables: + - variable: emotion + value_selector: + - '1740019130521' + - emotion + code_language: python3 + code: "def analyze_sentiment(emotion):\n if emotion == 'positive':\n \ + \ return '积极'\n elif emotion == 'negative':\n return '消极'\n\ + \ else:\n return '中性'\n\nemotion = '{{emotion}}'\nresult = analyze_sentiment(emotion)\n\ + return {'result': result}" + outputs: + sentiment_result: + type: string + - id: '1740019130523' + position: + x: 980 + y: 282 + height: 54 + width: 244 + positionAbsolute: + x: 980 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 模板节点 + desc: 模板节点将情绪分析结果格式化输出。 + type: template-transform + variables: + - variable: sentiment_result + value_selector: + - '1740019130522' + - sentiment_result + template: 文本的情绪分析结果为:{{sentiment_result}} + - id: '1740019130524' + position: + x: 1280 + y: 282 + height: 90 + width: 244 + positionAbsolute: + x: 1280 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 结束节点 + desc: 结束节点,返回格式化后的情绪分析结果。 + type: end + outputs: + - variable: output + value_selector: + - '1740019130523' + - output + viewport: + x: 92.96659905656679 + y: 79.13437154762897 + zoom: 0.9002006986311041 diff --git a/api/core/auto/output/test_workflow.yml b/api/core/auto/output/test_workflow.yml new file mode 100644 index 0000000000..8f80c8ffac --- /dev/null +++ b/api/core/auto/output/test_workflow.yml @@ -0,0 +1,247 @@ +app: + description: '' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: 计算两个数字之和 + use_icon_as_answer_icon: false +kind: app +version: 0.1.2 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - id: 1740019130520-source-1740019130521-target + source: '1740019130520' + sourceHandle: source + target: '1740019130521' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: start + targetType: llm + - id: 1740019130521-source-1740019130522-target + source: '1740019130521' + sourceHandle: source + target: '1740019130522' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: llm + targetType: code + - id: 1740019130522-source-1740019130523-target + source: '1740019130522' + sourceHandle: source + target: '1740019130523' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: code + targetType: template-transform + - id: 1740019130523-source-1740019130524-target + source: '1740019130523' + sourceHandle: source + target: '1740019130524' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: template-transform + targetType: end + nodes: + - id: '1740019130520' + position: + x: 80 + y: 282 + height: 116 + width: 244 + positionAbsolute: + x: 80 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 开始节点 + desc: 开始节点,接收两个数字输入参数。 + type: start + variables: + - type: number + label: num1 + variable: num1 + required: true + max_length: 48 + options: [] + - type: number + label: num2 + variable: num2 + required: true + max_length: 48 + options: [] + - id: '1740019130521' + position: + x: 380 + y: 282 + height: 98 + width: 244 + positionAbsolute: + x: 380 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: LLM节点 + desc: LLM节点,根据输入的两个数字生成计算它们之和的Python函数。 + type: llm + model: + provider: openai + name: gpt-4 + mode: chat + completion_params: + temperature: 0.7 + prompt_template: + - id: 1740019130521-system + text: 你是一个Python开发助手,请根据以下输入生成一个计算两数之和的Python函数。 + role: system + - id: 1740019130521-user + text: 请为两个数字{{num1}}和{{num2}}生成一个Python函数,计算它们的和。 + role: user + context: + enabled: false + variable_selector: [] + vision: + enabled: false + - id: '1740019130522' + position: + x: 680 + y: 282 + height: 54 + width: 244 + positionAbsolute: + x: 680 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 代码节点 + desc: 代码节点,执行LLM生成的Python函数,并计算结果。 + type: code + variables: + - variable: num1 + value_selector: + - '1740019130520' + - num1 + - variable: num2 + value_selector: + - '1740019130520' + - num2 + code_language: python3 + code: "def calculate_sum(num1, num2):\n return num1 + num2\n\nresult =\ + \ calculate_sum({{num1}}, {{num2}})\nreturn result" + outputs: + result: + type: number + - id: '1740019130523' + position: + x: 980 + y: 282 + height: 54 + width: 244 + positionAbsolute: + x: 980 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 模板节点 + desc: 模板节点,将计算结果格式化为输出字符串。 + type: template-transform + variables: + - variable: result + value_selector: + - '1740019130522' + - result + template: '计算结果为: {{result}}' + - id: '1740019130524' + position: + x: 1280 + y: 282 + height: 90 + width: 244 + positionAbsolute: + x: 1280 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 结束节点 + desc: 结束节点,输出格式化后的结果。 + type: end + outputs: + - variable: output + value_selector: + - '1740019130523' + - output + viewport: + x: 92.96659905656679 + y: 79.13437154762897 + zoom: 0.9002006986311041 diff --git a/api/core/auto/output/text_analysis_workflow.yml b/api/core/auto/output/text_analysis_workflow.yml new file mode 100644 index 0000000000..3c71437de5 --- /dev/null +++ b/api/core/auto/output/text_analysis_workflow.yml @@ -0,0 +1,262 @@ +app: + description: '' + icon: 🤖 + icon_background: '#FFEAD5' + mode: workflow + name: 文本分析工作流 + use_icon_as_answer_icon: false +kind: app +version: 0.1.2 +workflow: + conversation_variables: [] + environment_variables: [] + features: + file_upload: + allowed_file_extensions: + - .JPG + - .JPEG + - .PNG + - .GIF + - .WEBP + - .SVG + allowed_file_types: + - image + allowed_file_upload_methods: + - local_file + - remote_url + enabled: false + fileUploadConfig: + audio_file_size_limit: 50 + batch_count_limit: 5 + file_size_limit: 15 + image_file_size_limit: 10 + video_file_size_limit: 100 + image: + enabled: false + number_limits: 3 + transfer_methods: + - local_file + - remote_url + number_limits: 3 + opening_statement: '' + retriever_resource: + enabled: true + sensitive_word_avoidance: + enabled: false + speech_to_text: + enabled: false + suggested_questions: [] + suggested_questions_after_answer: + enabled: false + text_to_speech: + enabled: false + language: '' + voice: '' + graph: + edges: + - id: 1740019130520-source-1740019130521-target + source: '1740019130520' + sourceHandle: source + target: '1740019130521' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: start + targetType: llm + - id: 1740019130520-source-1740019130522-target + source: '1740019130520' + sourceHandle: source + target: '1740019130522' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: start + targetType: code + - id: 1740019130521-source-1740019130523-target + source: '1740019130521' + sourceHandle: source + target: '1740019130523' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: llm + targetType: template-transform + - id: 1740019130522-source-1740019130523-target + source: '1740019130522' + sourceHandle: source + target: '1740019130523' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: code + targetType: template-transform + - id: 1740019130523-source-1740019130524-target + source: '1740019130523' + sourceHandle: source + target: '1740019130524' + targetHandle: target + type: custom + zIndex: 0 + data: + isInIteration: false + sourceType: template-transform + targetType: end + nodes: + - id: '1740019130520' + position: + x: 80 + y: 282 + height: 116 + width: 244 + positionAbsolute: + x: 80 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 开始节点 + desc: 接收用户输入的文本参数 + type: start + variables: + - type: text-input + label: user_text + variable: user_text + required: true + max_length: 48 + options: [] + - id: '1740019130521' + position: + x: 380 + y: 282 + height: 98 + width: 244 + positionAbsolute: + x: 380 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: LLM节点 + desc: 使用大语言模型进行情感分析,返回文本的情感结果 + type: llm + model: + provider: zhipuai + name: glm-4-flash + mode: chat + completion_params: + temperature: 0.7 + prompt_template: + - id: 1740019130521-system + text: 请分析以下文本的情感,返回积极、消极或中性 + role: system + - id: 1740019130521-user + text: '{{user_text}}' + role: user + context: + enabled: false + variable_selector: [] + vision: + enabled: false + - id: '1740019130522' + position: + x: 680 + y: 282 + height: 54 + width: 244 + positionAbsolute: + x: 680 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 代码节点 + desc: 计算文本的统计信息,包括字符数、单词数和句子数 + type: code + variables: + - variable: text_for_analysis + value_selector: + - '1740019130520' + - user_text + code_language: python3 + code: "import re\n\ndef main(text):\n char_count = len(text)\n word_count\ + \ = len(text.split())\n sentence_count = len(re.findall(r'[.!?]', text))\n\ + \ return {'char_count': char_count, 'word_count': word_count, 'sentence_count':\ + \ sentence_count}" + outputs: + text_statistics: + type: object + - id: '1740019130523' + position: + x: 980 + y: 282 + height: 54 + width: 244 + positionAbsolute: + x: 980 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 模板节点 + desc: 将情感分析结果和统计信息组合成格式化报告 + type: template-transform + variables: + - variable: sentiment_result + value_selector: + - '1740019130521' + - sentiment_result + - variable: text_statistics + value_selector: + - '1740019130522' + - text_statistics + template: '情感分析结果:{{sentiment_result}} + + 文本统计信息: + + 字符数:{{text_statistics.char_count}} + + 单词数:{{text_statistics.word_count}} + + 句子数:{{text_statistics.sentence_count}}' + - id: '1740019130524' + position: + x: 1280 + y: 282 + height: 90 + width: 244 + positionAbsolute: + x: 1280 + y: 282 + selected: false + sourcePosition: right + targetPosition: left + type: custom + data: + title: 结束节点 + desc: 返回最终的分析报告 + type: end + outputs: + - variable: final_report + value_selector: + - '1740019130523' + - output + viewport: + x: 92.96659905656679 + y: 79.13437154762897 + zoom: 0.9002006986311041 diff --git a/api/core/auto/workflow_generator/__init__.py b/api/core/auto/workflow_generator/__init__.py new file mode 100644 index 0000000000..a29c886973 --- /dev/null +++ b/api/core/auto/workflow_generator/__init__.py @@ -0,0 +1,8 @@ +""" +工作流生成器包 +用于根据用户需求生成Dify工作流 +""" + +from .workflow_generator import WorkflowGenerator + +__all__ = ["WorkflowGenerator"] diff --git a/api/core/auto/workflow_generator/generators/__init__.py b/api/core/auto/workflow_generator/generators/__init__.py new file mode 100644 index 0000000000..cb4d5a82c6 --- /dev/null +++ b/api/core/auto/workflow_generator/generators/__init__.py @@ -0,0 +1,9 @@ +""" +节点和边生成器包 +""" + +from .edge_generator import EdgeGenerator +from .layout_engine import LayoutEngine +from .node_generator import NodeGenerator + +__all__ = ["EdgeGenerator", "LayoutEngine", "NodeGenerator"] diff --git a/api/core/auto/workflow_generator/generators/edge_generator.py b/api/core/auto/workflow_generator/generators/edge_generator.py new file mode 100644 index 0000000000..ee85a61c20 --- /dev/null +++ b/api/core/auto/workflow_generator/generators/edge_generator.py @@ -0,0 +1,101 @@ +""" +Edge Generator +Used to generate edges in the workflow +""" + +from core.auto.node_types.common import CommonEdgeType, CompleteEdge, CompleteNode +from core.auto.workflow_generator.models.workflow_description import ConnectionDescription + + +class EdgeGenerator: + """Edge generator for creating workflow edges""" + + @staticmethod + def create_edges(nodes: list[CompleteNode], connections: list[ConnectionDescription]) -> list[CompleteEdge]: + """ + Create edges based on nodes and connection information + + Args: + nodes: list of nodes + connections: list of connection descriptions + + Returns: + list of edges + """ + edges = [] + + # If connection information is provided, create edges based on it + if connections: + for connection in connections: + source_id = connection.source + target_id = connection.target + + if not source_id or not target_id: + continue + + # Find source and target nodes + source_node = next((node for node in nodes if node.id == source_id), None) + target_node = next((node for node in nodes if node.id == target_id), None) + + if not source_node or not target_node: + continue + + # Get node types + source_type = source_node.data.type + target_type = target_node.data.type + + # Create edge + edge_id = f"{source_id}-source-{target_id}-target" + + # Create edge data + edge_data = CommonEdgeType(isInIteration=False, sourceType=source_type, targetType=target_type) + + # Create complete edge + edge = CompleteEdge( + id=edge_id, + source=source_id, + sourceHandle="source", + target=target_id, + targetHandle="target", + type="custom", + zIndex=0, + ) + + # Add edge data + edge.add_data(edge_data) + + edges.append(edge) + # If no connection information is provided, automatically create edges + else: + # Create edges based on node order + for i in range(len(nodes) - 1): + source_node = nodes[i] + target_node = nodes[i + 1] + + # Get node types + source_type = source_node.data.type + target_type = target_node.data.type + + # Create edge + edge_id = f"{source_node.id}-source-{target_node.id}-target" + + # Create edge data + edge_data = CommonEdgeType(isInIteration=False, sourceType=source_type, targetType=target_type) + + # Create complete edge + edge = CompleteEdge( + id=edge_id, + source=source_node.id, + sourceHandle="source", + target=target_node.id, + targetHandle="target", + type="custom", + zIndex=0, + ) + + # Add edge data + edge.add_data(edge_data) + + edges.append(edge) + + return edges diff --git a/api/core/auto/workflow_generator/generators/layout_engine.py b/api/core/auto/workflow_generator/generators/layout_engine.py new file mode 100644 index 0000000000..1a48a744a6 --- /dev/null +++ b/api/core/auto/workflow_generator/generators/layout_engine.py @@ -0,0 +1,77 @@ +""" +Layout Engine +Used to arrange the positions of nodes and edges +""" + +from core.auto.node_types.common import CompleteEdge, CompleteNode + + +class LayoutEngine: + """Layout engine""" + + @staticmethod + def apply_layout(nodes: list[CompleteNode]) -> None: + """ + Apply layout, arranging nodes in a row + + Args: + nodes: list of nodes + """ + # Simple linear layout, arranging nodes from left to right + x_position = 80 + y_position = 282 + + for node in nodes: + node.position = {"x": x_position, "y": y_position} + node.positionAbsolute = {"x": x_position, "y": y_position} + + # Update position for the next node + x_position += 300 # Horizontal spacing between nodes + + @staticmethod + def apply_topological_layout(nodes: list[CompleteNode], edges: list[CompleteEdge]) -> None: + """ + Apply topological sort layout, arranging nodes based on their dependencies + + Args: + nodes: list of nodes + edges: list of edges + """ + # Create mapping from node ID to node + node_map = {node.id: node for node in nodes} + + # Create adjacency list + adjacency_list = {node.id: [] for node in nodes} + for edge in edges: + adjacency_list[edge.source].append(edge.target) + + # Create in-degree table + in_degree = {node.id: 0 for node in nodes} + for source, targets in adjacency_list.items(): + for target in targets: + in_degree[target] += 1 + + # Topological sort + queue = [node_id for node_id, degree in in_degree.items() if degree == 0] + sorted_nodes = [] + + while queue: + current = queue.pop(0) + sorted_nodes.append(current) + + for neighbor in adjacency_list[current]: + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + # Apply layout + x_position = 80 + y_position = 282 + + for node_id in sorted_nodes: + node = node_map[node_id] + node.position = {"x": x_position, "y": y_position} + node.positionAbsolute = {"x": x_position, "y": y_position} + + # Update position for the next node + x_position += 300 # Horizontal spacing between nodes diff --git a/api/core/auto/workflow_generator/generators/node_generator.py b/api/core/auto/workflow_generator/generators/node_generator.py new file mode 100644 index 0000000000..e856baaf6c --- /dev/null +++ b/api/core/auto/workflow_generator/generators/node_generator.py @@ -0,0 +1,446 @@ +""" +Node Generator +Generate nodes based on workflow description +""" + +from core.auto.node_types.code import CodeLanguage, CodeNodeType, OutputVar +from core.auto.node_types.common import ( + BlockEnum, + CompleteNode, + Context, + InputVar, + ModelConfig, + PromptItem, + PromptRole, + ValueSelector, + Variable, +) +from core.auto.node_types.end import EndNodeType +from core.auto.node_types.llm import LLMNodeType, VisionConfig +from core.auto.node_types.start import StartNodeType +from core.auto.node_types.template_transform import TemplateTransformNodeType +from core.auto.workflow_generator.models.workflow_description import NodeDescription +from core.auto.workflow_generator.utils.prompts import DEFAULT_MODEL_CONFIG, DEFAULT_SYSTEM_PROMPT +from core.auto.workflow_generator.utils.type_mapper import map_string_to_var_type, map_var_type_to_input_type + + +class NodeGenerator: + """Node generator for creating workflow nodes""" + + @staticmethod + def create_nodes(node_descriptions: list[NodeDescription]) -> list[CompleteNode]: + """ + Create nodes based on node descriptions + + Args: + node_descriptions: list of node descriptions + + Returns: + list of nodes + """ + nodes = [] + + for node_desc in node_descriptions: + node_type = node_desc.type + + if node_type == "start": + node = NodeGenerator._create_start_node(node_desc) + elif node_type == "llm": + node = NodeGenerator._create_llm_node(node_desc) + elif node_type == "code": + node = NodeGenerator._create_code_node(node_desc) + elif node_type == "template": + node = NodeGenerator._create_template_node(node_desc) + elif node_type == "end": + node = NodeGenerator._create_end_node(node_desc) + else: + raise ValueError(f"Unsupported node type: {node_type}") + + nodes.append(node) + + return nodes + + @staticmethod + def _create_start_node(node_desc: NodeDescription) -> CompleteNode: + """Create start node""" + variables = [] + + for var in node_desc.variables or []: + input_var = InputVar( + type=map_var_type_to_input_type(var.type), + label=var.name, + variable=var.name, + required=var.required, + max_length=48, + options=[], + ) + variables.append(input_var) + + start_node = StartNodeType( + title=node_desc.title, desc=node_desc.description or "", type=BlockEnum.start, variables=variables + ) + + return CompleteNode( + id=node_desc.id, + type="custom", + position={"x": 0, "y": 0}, # Temporary position, will be updated later + height=118, # Increase height to match reference file + width=244, + positionAbsolute={"x": 0, "y": 0}, + selected=False, + sourcePosition="right", + targetPosition="left", + data=start_node, + ) + + @staticmethod + def _create_llm_node(node_desc: NodeDescription) -> CompleteNode: + """Create LLM node""" + # Build prompt template + prompt_template = [] + + # Add system prompt + system_prompt = node_desc.system_prompt or DEFAULT_SYSTEM_PROMPT + prompt_template.append(PromptItem(id=f"{node_desc.id}-system", role=PromptRole.system, text=system_prompt)) + + # Add user prompt + user_prompt = node_desc.user_prompt or "Please answer these questions:" + + # Build variable list + variables = [] + for var in node_desc.variables or []: + source_node = var.source_node or "" + source_variable = var.source_variable or "" + + print( + f"DEBUG: Processing variable {var.name}, source_node={source_node}, source_variable={source_variable}" + ) + + # If source node is an LLM node, ensure source_variable is 'text' + if source_node: + # Check if the source node is an LLM node by checking connections + # This is a simple heuristic - if the source node is connected to a node with 'llm' in its ID + # or if the source node has 'llm' in its ID, assume it's an LLM node + if "llm" in source_node.lower(): + print(f"DEBUG: Found LLM node {source_node}") + if source_variable != "text": + old_var = source_variable + source_variable = "text" # LLM nodes output variable is always 'text' + print( + f"Auto-fixing: Changed source variable from '{old_var}' to 'text' for LLM node {source_node}" # noqa: E501 + ) + + # Check if the user prompt already contains correctly formatted variable references + # Variable references in LLM nodes should be in the format {{#nodeID.variableName#}} + correct_format = f"{{{{#{source_node}.{source_variable}#}}}}" + simple_format = f"{{{{{var.name}}}}}" + + # If simple format is used in the prompt, replace it with the correct format + if simple_format in user_prompt and source_node and source_variable: + user_prompt = user_prompt.replace(simple_format, correct_format) + + variable = Variable(variable=var.name, value_selector=[source_node, source_variable]) + variables.append(variable) + + # Update user prompt + prompt_template.append(PromptItem(id=f"{node_desc.id}-user", role=PromptRole.user, text=user_prompt)) + + # Use default model configuration, prioritize configuration in node description + provider = node_desc.provider or DEFAULT_MODEL_CONFIG["provider"] + model = node_desc.model or DEFAULT_MODEL_CONFIG["model"] + + llm_node = LLMNodeType( + title=node_desc.title, + desc=node_desc.description or "", + type=BlockEnum.llm, + model=ModelConfig( + provider=provider, + name=model, + mode=DEFAULT_MODEL_CONFIG["mode"], + completion_params=DEFAULT_MODEL_CONFIG["completion_params"], + ), + prompt_template=prompt_template, + variables=variables, + context=Context(enabled=False, variable_selector=ValueSelector(value=[])), + vision=VisionConfig(enabled=False), + ) + + return CompleteNode( + id=node_desc.id, + type="custom", + position={"x": 0, "y": 0}, # Temporary position, will be updated later + height=126, # Increase height to match reference file + width=244, + positionAbsolute={"x": 0, "y": 0}, + selected=False, + sourcePosition="right", + targetPosition="left", + data=llm_node, + ) + + @staticmethod + def _create_code_node(node_desc: NodeDescription) -> CompleteNode: + """Create code node""" + # Build variable list and function parameter names + variables = [] + var_names = [] + var_mapping = {} # Used to store mapping from variable names to function parameter names + + # First, identify all LLM nodes in the workflow + llm_nodes = set() + for connection in node_desc.workflow_description.connections: + for node in node_desc.workflow_description.nodes: + if node.id == connection.source and node.type.lower() == "llm": + llm_nodes.add(node.id) + + for var in node_desc.variables or []: + source_node = var.source_node or "" + source_variable = var.source_variable or "" + + # Check if source node is an LLM node and warn if source_variable is not 'text' + if source_node in llm_nodes and source_variable != "text": + print( + f"WARNING: LLM node {source_node} output variable should be 'text', but got '{source_variable}'. This may cause issues in Dify." # noqa: E501 + ) + print(" Consider changing the source_variable to 'text' in your workflow description.") + # Auto-fix: Always use 'text' as the source variable for LLM nodes + old_var = source_variable + source_variable = "text" + print(f"Auto-fixing: Changed source variable from '{old_var}' to 'text' for LLM node {source_node}") + elif source_node and "llm" in source_node.lower() and source_variable != "text": + # Fallback heuristic check based on node ID + print( + f"WARNING: Node {source_node} appears to be an LLM node based on its ID, but source_variable is not 'text'." # noqa: E501 + ) + print(" Consider changing the source_variable to 'text' in your workflow description.") + # Auto-fix: Always use 'text' as the source variable for LLM nodes + old_var = source_variable + source_variable = "text" + print(f"Auto-fixing: Changed source variable from '{old_var}' to 'text' for LLM node {source_node}") + + # Use variable name as function parameter name + variable_name = var.name # Variable name defined in this node + param_name = variable_name # Function parameter name must match variable name + + # Validate variable name format + if not variable_name.replace("_", "").isalnum(): + raise ValueError( + f"Invalid variable name: {variable_name}. Variable names must only contain letters, numbers, and underscores." # noqa: E501 + ) + if not variable_name[0].isalpha() and variable_name[0] != "_": + raise ValueError( + f"Invalid variable name: {variable_name}. Variable names must start with a letter or underscore." + ) + + var_names.append(param_name) + var_mapping[variable_name] = param_name + + variable = Variable(variable=variable_name, value_selector=[source_node, source_variable]) + variables.append(variable) + + # Build output + outputs = {} + for output in node_desc.outputs or []: + # Validate output variable name format + if not output.name.replace("_", "").isalnum(): + raise ValueError( + f"Invalid output variable name: {output.name}. Output names must only contain letters, numbers, and underscores." # noqa: E501 + ) + if not output.name[0].isalpha() and output.name[0] != "_": + raise ValueError( + f"Invalid output variable name: {output.name}. Output names must start with a letter or underscore." + ) + + outputs[output.name] = OutputVar(type=map_string_to_var_type(output.type)) + + # Generate code, ensure function parameters match variable names, return values match output names + output_names = [output.name for output in node_desc.outputs or []] + + # Build function parameter list + params_str = ", ".join(var_names) if var_names else "" + + # Build return value dictionary + return_dict = {} + for output_name in output_names: + # Use the first variable as the return value by default + return_dict[output_name] = var_names[0] if var_names else f'"{output_name}"' + + return_dict_str = ", ".join([f'"{k}": {v}' for k, v in return_dict.items()]) + + # Default code template, ensure return dictionary type matches output variable + default_code = f"""def main({params_str}): + # Write your code here + # Process input variables + + # Return a dictionary, key names must match variable names defined in outputs + return {{{return_dict_str}}}""" + + # If custom code is provided, ensure it meets the specifications + if node_desc.code: + custom_code = node_desc.code + # Check if it contains main function definition + if not custom_code.strip().startswith("def main("): + # Try to fix the code by adding main function with correct parameters + custom_code = f"def main({params_str}):\n" + custom_code.strip() + else: + # Extract function parameters from the existing main function + import re + + func_params = re.search(r"def\s+main\s*\((.*?)\)", custom_code) + if func_params: + existing_params = [p.strip() for p in func_params.group(1).split(",") if p.strip()] + # Verify that all required parameters are present + missing_params = set(var_names) - set(existing_params) + if missing_params: + # 尝试修复代码,将函数参数替换为正确的参数名 + old_params = func_params.group(1) + new_params = params_str + custom_code = custom_code.replace(f"def main({old_params})", f"def main({new_params})") + print( + f"Warning: Fixed missing parameters in code node: {', '.join(missing_params)}. Function parameters must match variable names defined in this node." # noqa: E501 + ) + + # Check if the return value is a dictionary and keys match output variables + for output_name in output_names: + if f'"{output_name}"' not in custom_code and f"'{output_name}'" not in custom_code: + # Code may not meet specifications, use default code + custom_code = default_code + break + + # Use fixed code + code = custom_code + else: + code = default_code + + code_node = CodeNodeType( + title=node_desc.title, + desc=node_desc.description or "", + type=BlockEnum.code, + code_language=CodeLanguage.python3, + code=code, + variables=variables, + outputs=outputs, + ) + + return CompleteNode( + id=node_desc.id, + type="custom", + position={"x": 0, "y": 0}, # Temporary position, will be updated later + height=82, # Increase height to match reference file + width=244, + positionAbsolute={"x": 0, "y": 0}, + selected=False, + sourcePosition="right", + targetPosition="left", + data=code_node, + ) + + @staticmethod + def _create_template_node(node_desc: NodeDescription) -> CompleteNode: + """Create template node""" + # Build variable list + variables = [] + template_text = node_desc.template or "" + + # Collect all node IDs referenced in the template + referenced_nodes = set() + for var in node_desc.variables or []: + source_node = var.source_node or "" + source_variable = var.source_variable or "" + + variable = Variable(variable=var.name, value_selector=[source_node, source_variable]) + variables.append(variable) + + if source_node: + referenced_nodes.add(source_node) + + # Modify variable reference format in the template + # Replace {{#node_id.variable#}} with {{ variable }} + if source_node and source_variable: + template_text = template_text.replace(f"{{{{#{source_node}.{source_variable}#}}}}", f"{{ {var.name} }}") + + # Check if a reference to the start node needs to be added + # If the template contains a reference to the start node but the variable list does not have a corresponding variable # noqa: E501 + import re + + start_node_refs = re.findall(r"{{#(\d+)\.([^#]+)#}}", template_text) + + for node_id, var_name in start_node_refs: + # Check if there is already a reference to this variable + if not any(v.variable == var_name for v in variables): + # Add reference to start node variable + variable = Variable(variable=var_name, value_selector=[node_id, var_name]) + variables.append(variable) + + # Modify variable reference format in the template + template_text = template_text.replace(f"{{{{#{node_id}.{var_name}#}}}}", f"{{ {var_name} }}") + + # Get all variable names + var_names = [var.variable for var in variables] + + # Simple and crude method: directly replace all possible variable reference formats + for var_name in var_names: + # Replace {var_name} with {{ var_name }} + template_text = template_text.replace("{" + var_name + "}", "{{ " + var_name + " }}") + # Replace { var_name } with {{ var_name }} + template_text = template_text.replace("{ " + var_name + " }", "{{ " + var_name + " }}") + # Replace {var_name } with {{ var_name }} + template_text = template_text.replace("{" + var_name + " }", "{{ " + var_name + " }}") + # Replace { var_name} with {{ var_name }} + template_text = template_text.replace("{ " + var_name + "}", "{{ " + var_name + " }}") + # Replace {{{ var_name }}} with {{ var_name }} + template_text = template_text.replace("{{{ " + var_name + " }}}", "{{ " + var_name + " }}") + # Replace {{{var_name}}} with {{ var_name }} + template_text = template_text.replace("{{{" + var_name + "}}}", "{{ " + var_name + " }}") + + # Use regular expression to replace all triple curly braces with double curly braces + template_text = re.sub(r"{{{([^}]+)}}}", r"{{ \1 }}", template_text) + + template_node = TemplateTransformNodeType( + title=node_desc.title, + desc=node_desc.description or "", + type=BlockEnum.template_transform, + template=template_text, + variables=variables, + ) + + return CompleteNode( + id=node_desc.id, + type="custom", + position={"x": 0, "y": 0}, # Temporary position, will be updated later + height=82, # Increase height to match reference file + width=244, + positionAbsolute={"x": 0, "y": 0}, + selected=False, + sourcePosition="right", + targetPosition="left", + data=template_node, + ) + + @staticmethod + def _create_end_node(node_desc: NodeDescription) -> CompleteNode: + """Create end node""" + # Build output variable list + outputs = [] + for output in node_desc.outputs or []: + variable = Variable( + variable=output.name, value_selector=[output.source_node or "", output.source_variable or ""] + ) + outputs.append(variable) + + end_node = EndNodeType( + title=node_desc.title, desc=node_desc.description or "", type=BlockEnum.end, outputs=outputs + ) + + return CompleteNode( + id=node_desc.id, + type="custom", + position={"x": 0, "y": 0}, # Temporary position, will be updated later + height=90, + width=244, + positionAbsolute={"x": 0, "y": 0}, + selected=False, + sourcePosition="right", + targetPosition="left", + data=end_node, + ) diff --git a/api/core/auto/workflow_generator/models/__init__.py b/api/core/auto/workflow_generator/models/__init__.py new file mode 100644 index 0000000000..cb17339ab1 --- /dev/null +++ b/api/core/auto/workflow_generator/models/__init__.py @@ -0,0 +1,7 @@ +""" +模型包 +""" + +from .workflow_description import ConnectionDescription, NodeDescription, WorkflowDescription + +__all__ = ["ConnectionDescription", "NodeDescription", "WorkflowDescription"] diff --git a/api/core/auto/workflow_generator/models/workflow_description.py b/api/core/auto/workflow_generator/models/workflow_description.py new file mode 100644 index 0000000000..fd02cd98df --- /dev/null +++ b/api/core/auto/workflow_generator/models/workflow_description.py @@ -0,0 +1,80 @@ +""" +Workflow Description Model +Used to represent the simplified workflow description generated by large language models +""" + +from typing import Optional + +from pydantic import BaseModel, Field + + +class VariableDescription(BaseModel): + """Variable description""" + + name: str + type: str + description: Optional[str] = None + required: bool = True + source_node: Optional[str] = None + source_variable: Optional[str] = None + + +class OutputDescription(BaseModel): + """Output description""" + + name: str + type: str = "string" + description: Optional[str] = None + source_node: Optional[str] = None + source_variable: Optional[str] = None + + +class NodeDescription(BaseModel): + """Node description""" + + id: str + type: str + title: str + description: Optional[str] = "" + variables: Optional[list[VariableDescription]] = Field(default_factory=list) + outputs: Optional[list[OutputDescription]] = Field(default_factory=list) + + # LLM node specific fields + system_prompt: Optional[str] = None + user_prompt: Optional[str] = None + provider: Optional[str] = "zhipuai" + model: Optional[str] = "glm-4-flash" + + # Code node specific fields + code: Optional[str] = None + + # Template node specific fields + template: Optional[str] = None + + # Reference to workflow description, used for node relationship analysis + workflow_description: Optional["WorkflowDescription"] = Field(default=None, exclude=True) + + class Config: + exclude = {"workflow_description"} + + +class ConnectionDescription(BaseModel): + """Connection description""" + + source: str + target: str + + +class WorkflowDescription(BaseModel): + """Workflow description""" + + name: str + description: Optional[str] = "" + nodes: list[NodeDescription] + connections: list[ConnectionDescription] + + def __init__(self, **data): + super().__init__(**data) + # Add workflow description reference to each node + for node in self.nodes: + node.workflow_description = self diff --git a/api/core/auto/workflow_generator/utils/__init__.py b/api/core/auto/workflow_generator/utils/__init__.py new file mode 100644 index 0000000000..c0f259f456 --- /dev/null +++ b/api/core/auto/workflow_generator/utils/__init__.py @@ -0,0 +1,16 @@ +""" +工具函数包 +""" + +from .llm_client import LLMClient +from .prompts import DEFAULT_MODEL_CONFIG, DEFAULT_SYSTEM_PROMPT, build_workflow_prompt +from .type_mapper import map_string_to_var_type, map_var_type_to_input_type + +__all__ = [ + "DEFAULT_MODEL_CONFIG", + "DEFAULT_SYSTEM_PROMPT", + "LLMClient", + "build_workflow_prompt", + "map_string_to_var_type", + "map_var_type_to_input_type", +] diff --git a/api/core/auto/workflow_generator/utils/config_manager.py b/api/core/auto/workflow_generator/utils/config_manager.py new file mode 100644 index 0000000000..a11969d441 --- /dev/null +++ b/api/core/auto/workflow_generator/utils/config_manager.py @@ -0,0 +1,142 @@ +""" +Configuration Manager +Used to manage configurations and prompts +""" + +import os +import time +from pathlib import Path +from typing import Any, Optional + +import yaml + + +class ConfigManager: + """Configuration manager for managing configurations""" + + def __init__(self, config_dir: str = "config"): + """ + Initialize configuration manager + + Args: + config_dir: Configuration directory path + """ + self.config_dir = Path(config_dir) + self.config: dict[str, Any] = {} + self.last_load_time: float = 0 + self.reload_interval: float = 60 # Check every 60 seconds + self._load_config() + + def _should_reload(self) -> bool: + """Check if configuration needs to be reloaded""" + return time.time() - self.last_load_time > self.reload_interval + + def _load_config(self) -> dict[str, Any]: + """Load configuration files""" + default_config = self._load_yaml(self.config_dir / "default.yaml") + custom_config = self._load_yaml(self.config_dir / "custom.yaml") + self.config = self._deep_merge(default_config, custom_config) + self.last_load_time = time.time() + return self.config + + @staticmethod + def _load_yaml(path: Path) -> dict[str, Any]: + """Load YAML file""" + try: + with open(path, encoding="utf-8") as f: + return yaml.safe_load(f) or {} + except FileNotFoundError: + print(f"Warning: Config file not found at {path}") + return {} + except Exception as e: + print(f"Error loading config file {path}: {e}") + return {} + + @staticmethod + def _deep_merge(dict1: dict, dict2: dict) -> dict: + """Recursively merge two dictionaries""" + result = dict1.copy() + for key, value in dict2.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = ConfigManager._deep_merge(result[key], value) + else: + result[key] = value + return result + + def get(self, *keys: str, default: Any = None) -> Any: + """ + Get configuration value + + Args: + *keys: Configuration key path + default: Default value + + Returns: + Configuration value or default value + """ + if self._should_reload(): + self._load_config() + + current = self.config + for key in keys: + if isinstance(current, dict) and key in current: + current = current[key] + else: + return default + return current + + @property + def workflow_generator(self) -> dict[str, Any]: + """Get workflow generator configuration""" + return self.get("workflow_generator", default={}) + + @property + def workflow_nodes(self) -> dict[str, Any]: + """Get workflow nodes configuration""" + return self.get("workflow_nodes", default={}) + + @property + def output(self) -> dict[str, Any]: + """Get output configuration""" + return self.get("output", default={}) + + def get_output_path(self, filename: Optional[str] = None) -> str: + """ + Get output file path + + Args: + filename: Optional filename, uses default filename from config if not specified + + Returns: + Complete output file path + """ + output_config = self.output + output_dir = output_config.get("dir", "output/") + output_filename = filename or output_config.get("filename", "generated_workflow.yml") + return os.path.join(output_dir, output_filename) + + def get_workflow_model(self, model_name: Optional[str] = None) -> dict[str, Any]: + """ + Get workflow generation model configuration + + Args: + model_name: Model name, uses default model if not specified + + Returns: + Model configuration + """ + models = self.workflow_generator.get("models", {}) + + if not model_name: + model_name = models.get("default") + + return models.get("available", {}).get(model_name, {}) + + def get_llm_node_config(self) -> dict[str, Any]: + """ + Get LLM node configuration + + Returns: + LLM node configuration + """ + return self.workflow_nodes.get("llm", {}) diff --git a/api/core/auto/workflow_generator/utils/debug_manager.py b/api/core/auto/workflow_generator/utils/debug_manager.py new file mode 100644 index 0000000000..a0b41ec992 --- /dev/null +++ b/api/core/auto/workflow_generator/utils/debug_manager.py @@ -0,0 +1,151 @@ +""" +Debug Manager +Used to manage debug information saving +""" + +import datetime +import json +import os +import uuid +from typing import Any, Optional, Union + + +class DebugManager: + """Debug manager for managing debug information saving""" + + _instance = None + + def __new__(cls, *args, **kwargs): + """Singleton pattern""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config: dict[str, Any] = {}, debug_enabled: bool = False): + """ + Initialize debug manager + + Args: + config: Debug configuration + debug_enabled: Whether to enable debug mode + """ + # Avoid repeated initialization + if self._initialized: + return + + self._initialized = True + self.config = config or {} + self.debug_enabled = debug_enabled or self.config.get("enabled", False) + self.debug_dir = self.config.get("dir", "debug/") + self.save_options = self.config.get( + "save_options", {"prompt": True, "response": True, "json": True, "workflow": True} + ) + + # Generate run ID + self.case_id = self._generate_case_id() + self.case_dir = os.path.join(self.debug_dir, self.case_id) + + # If debug is enabled, create debug directory + if self.debug_enabled: + os.makedirs(self.case_dir, exist_ok=True) + print(f"Debug mode enabled, debug information will be saved to: {self.case_dir}") + + def _generate_case_id(self) -> str: + """ + Generate run ID + + Returns: + Run ID + """ + # Use format from configuration to generate run ID + format_str = self.config.get("case_id_format", "%Y%m%d_%H%M%S_%f") + timestamp = datetime.datetime.now().strftime(format_str) + + # Add random string + random_str = str(uuid.uuid4())[:8] + + return f"{timestamp}_{random_str}" + + def save_text(self, content: str, filename: str, subdir: Optional[str] = None) -> Optional[str]: + """ + Save text content to file + + Args: + content: Text content + filename: File name + subdir: Subdirectory name + + Returns: + Saved file path, returns None if debug is not enabled + """ + if not self.debug_enabled: + return None + + try: + # Determine save path + save_dir = self.case_dir + if subdir: + save_dir = os.path.join(save_dir, subdir) + os.makedirs(save_dir, exist_ok=True) + + file_path = os.path.join(save_dir, filename) + + # Save content + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + print(f"Debug information saved to: {file_path}") + return file_path + except Exception as e: + print(f"Failed to save debug information: {e}") + return None + + def save_json(self, data: Union[dict, list], filename: str, subdir: Optional[str] = None) -> Optional[str]: + """ + Save JSON data to file + + Args: + data: JSON data + filename: File name + subdir: Subdirectory name + + Returns: + Saved file path, returns None if debug is not enabled + """ + if not self.debug_enabled: + return None + + try: + # Determine save path + save_dir = self.case_dir + if subdir: + save_dir = os.path.join(save_dir, subdir) + os.makedirs(save_dir, exist_ok=True) + + file_path = os.path.join(save_dir, filename) + + # Save content + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + print(f"Debug information saved to: {file_path}") + return file_path + except Exception as e: + print(f"Failed to save debug information: {e}") + return None + + def should_save(self, option: str) -> bool: + """ + Check if specified type of debug information should be saved + + Args: + option: Debug information type + + Returns: + Whether it should be saved + """ + if not self.debug_enabled: + return False + + return self.save_options.get(option, False) diff --git a/api/core/auto/workflow_generator/utils/llm_client.py b/api/core/auto/workflow_generator/utils/llm_client.py new file mode 100644 index 0000000000..6a72e69893 --- /dev/null +++ b/api/core/auto/workflow_generator/utils/llm_client.py @@ -0,0 +1,438 @@ +""" +LLM Client +Used to call LLM API +""" + +import json +import re +from typing import Any + +from core.auto.workflow_generator.utils.debug_manager import DebugManager +from core.auto.workflow_generator.utils.prompts import DEFAULT_SYSTEM_PROMPT +from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + + +class LLMClient: + """LLM Client""" + + def __init__(self, model_instance: ModelInstance, debug_manager: DebugManager): + """ + Initialize LLM client + + Args: + api_key: API key + model: Model name + api_base: API base URL + max_tokens: Maximum number of tokens to generate + debug_manager: Debug manager + """ + + self.debug_manager = debug_manager or DebugManager() + self.model_instance = model_instance + + def generate(self, prompt: str) -> str: + """ + Generate text + + Args: + prompt: Prompt text + + Returns: + Generated text + """ + + # Save prompt + if self.debug_manager.should_save("prompt"): + self.debug_manager.save_text(prompt, "prompt.txt", "llm") + + try: + response = self.model_instance.invoke_llm( + prompt_messages=[ + SystemPromptMessage(content=DEFAULT_SYSTEM_PROMPT), + UserPromptMessage(content=prompt), + ], + model_parameters={"temperature": 0.7, "max_tokens": 4900}, + ) + content = "" + for chunk in response: + content += chunk.delta.message.content + print(f"Generation complete, text length: {len(content)} characters") + + # Save response + if self.debug_manager.should_save("response"): + self.debug_manager.save_text(content, "response.txt", "llm") + + return content + except Exception as e: + print(f"Error generating text: {e}") + raise e + + def extract_json(self, text: str) -> dict[str, Any]: + """ + Extract JSON from text + + Args: + text: Text containing JSON + + Returns: + Extracted JSON object + """ + print("Starting JSON extraction from text...") + + # Save original text + if self.debug_manager.should_save("json"): + self.debug_manager.save_text(text, "original_text.txt", "json") + + # Use regex to extract JSON part + json_match = re.search(r"```json\n(.*?)\n```", text, re.DOTALL) + if json_match: + json_str = json_match.group(1) + print("Successfully extracted JSON from code block") + else: + # Try to match code block without language identifier + json_match = re.search(r"```\n(.*?)\n```", text, re.DOTALL) + if json_match: + json_str = json_match.group(1) + print("Successfully extracted JSON from code block without language identifier") + else: + # Try to match JSON surrounded by curly braces + json_match = re.search(r"(\{.*\})", text, re.DOTALL) + if json_match: + json_str = json_match.group(1) + print("Successfully extracted JSON from curly braces") + else: + # Try to parse entire text + json_str = text + print("No JSON code block found, attempting to parse entire text") + + # Save extracted JSON string + if self.debug_manager.should_save("json"): + self.debug_manager.save_text(json_str, "extracted_json.txt", "json") + + # Try multiple methods to parse JSON + try: + # Try direct parsing + result = json.loads(json_str) + except json.JSONDecodeError as e: + print(f"Direct JSON parsing failed: {e}, attempting basic cleaning") + try: + # Try basic cleaning + cleaned_text = self._clean_text(json_str) + if self.debug_manager.should_save("json"): + self.debug_manager.save_text(cleaned_text, "cleaned_json_1.txt", "json") + result = json.loads(cleaned_text) + except json.JSONDecodeError as e: + print(f"JSON parsing after basic cleaning failed: {e}, attempting to fix common errors") + try: + # Try fixing common errors + fixed_text = self._fix_json_errors(json_str) + if self.debug_manager.should_save("json"): + self.debug_manager.save_text(fixed_text, "cleaned_json_2.txt", "json") + result = json.loads(fixed_text) + except json.JSONDecodeError as e: + print(f"JSON parsing after fixing common errors failed: {e}, attempting aggressive cleaning") + try: + # Try aggressive cleaning + aggressive_cleaned = self._aggressive_clean(json_str) + if self.debug_manager.should_save("json"): + self.debug_manager.save_text(aggressive_cleaned, "cleaned_json_3.txt", "json") + result = json.loads(aggressive_cleaned) + except json.JSONDecodeError as e: + print(f"JSON parsing after aggressive cleaning failed: {e}, attempting manual JSON extraction") + # Try manual JSON structure extraction + result = self._manual_json_extraction(json_str) + if self.debug_manager.should_save("json"): + self.debug_manager.save_json(result, "manual_json.json", "json") + + # Check for nested workflow structure + if "workflow" in result and isinstance(result["workflow"], dict): + print("Detected nested workflow structure, extracting top-level data") + # Extract workflow name and description + name = result.get("name", "Text Analysis Workflow") + description = result.get("description", "") + + # Extract nodes and connections + nodes = result["workflow"].get("nodes", []) + connections = [] + + # If there are connections, extract them + if "connections" in result["workflow"]: + connections = result["workflow"]["connections"] + + # Build standard format workflow description + result = {"name": name, "description": description, "nodes": nodes, "connections": connections} + + # Save final parsed JSON + if self.debug_manager.should_save("json"): + self.debug_manager.save_json(result, "final_json.json", "json") + + print( + f"JSON parsing successful, contains {len(result.get('nodes', []))} nodes and {len(result.get('connections', []))} connections" # noqa: E501 + ) + return result + + def _clean_text(self, text: str) -> str: + """ + Clean text by removing non-JSON characters + + Args: + text: Text to clean + + Returns: + Cleaned text + """ + print("Starting text cleaning...") + # Remove characters that might cause JSON parsing to fail + lines = text.split("\n") + cleaned_lines = [] + + in_json = False + for line in lines: + if line.strip().startswith("{") or line.strip().startswith("["): + in_json = True + + if in_json: + cleaned_lines.append(line) + + if line.strip().endswith("}") or line.strip().endswith("]"): + in_json = False + + cleaned_text = "\n".join(cleaned_lines) + print(f"Text cleaning complete, length before: {len(text)}, length after: {len(cleaned_text)}") + return cleaned_text + + def _fix_json_errors(self, text: str) -> str: + """ + Fix common JSON errors + + Args: + text: Text to fix + + Returns: + Fixed text + """ + print("Attempting to fix JSON errors...") + + # Replace single quotes with double quotes + text = re.sub(r"'([^']*)'", r'"\1"', text) + + # Fix missing commas + text = re.sub(r"}\s*{", "},{", text) + text = re.sub(r"]\s*{", "],{", text) + text = re.sub(r"}\s*\[", r"},\[", text) + text = re.sub(r"]\s*\[", r"],\[", text) + + # Fix extra commas + text = re.sub(r",\s*}", "}", text) + text = re.sub(r",\s*]", "]", text) + + # Ensure property names have double quotes + text = re.sub(r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', text) + + return text + + def _aggressive_clean(self, text: str) -> str: + """ + More aggressive text cleaning + + Args: + text: Text to clean + + Returns: + Cleaned text + """ + print("Using aggressive cleaning method...") + + # Try to find outermost curly braces + start_idx = text.find("{") + end_idx = text.rfind("}") + + if start_idx != -1 and end_idx != -1 and start_idx < end_idx: + text = text[start_idx : end_idx + 1] + + # Remove comments + text = re.sub(r"//.*?\n", "\n", text) + text = re.sub(r"/\*.*?\*/", "", text, flags=re.DOTALL) + + # Fix JSON format + text = self._fix_json_errors(text) + + # Remove escape characters + text = text.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"') + + # Fix potential Unicode escape issues + text = re.sub(r"\\u([0-9a-fA-F]{4})", lambda m: chr(int(m.group(1), 16)), text) + + return text + + def _manual_json_extraction(self, text: str) -> dict[str, Any]: + """ + Manual JSON structure extraction + + Args: + text: Text to extract from + + Returns: + Extracted JSON object + """ + print("Attempting manual JSON structure extraction...") + + # Extract workflow name + name_match = re.search(r'"name"\s*:\s*"([^"]*)"', text) + name = name_match.group(1) if name_match else "Simple Workflow" + + # Extract workflow description + desc_match = re.search(r'"description"\s*:\s*"([^"]*)"', text) + description = desc_match.group(1) if desc_match else "Automatically generated workflow" + + # Extract nodes + nodes = [] + node_matches = re.finditer(r'\{\s*"id"\s*:\s*"([^"]*)"\s*,\s*"type"\s*:\s*"([^"]*)"', text) + + for match in node_matches: + node_id = match.group(1) + node_type = match.group(2) + + # Extract node title + title_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"title"\s*:\s*"([^"]*)"', text, re.DOTALL) + title = title_match.group(1) if title_match else f"{node_type.capitalize()} Node" + + # Extract node description + desc_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"description"\s*:\s*"([^"]*)"', text, re.DOTALL) + desc = desc_match.group(1) if desc_match else "" + + # Create basic node based on node type + if node_type == "start": + # Extract variables + variables = [] + var_section_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"variables"\s*:\s*\[(.*?)\]', text, re.DOTALL) + + if var_section_match: + var_section = var_section_match.group(1) + var_matches = re.finditer(r'\{\s*"name"\s*:\s*"([^"]*)"\s*,\s*"type"\s*:\s*"([^"]*)"', var_section) + + for var_match in var_matches: + var_name = var_match.group(1) + var_type = var_match.group(2) + + # Extract variable description + var_desc_match = re.search( + rf'"name"\s*:\s*"{var_name}".*?"description"\s*:\s*"([^"]*)"', var_section, re.DOTALL + ) + var_desc = var_desc_match.group(1) if var_desc_match else "" + + # Extract required status + var_required_match = re.search( + rf'"name"\s*:\s*"{var_name}".*?"required"\s*:\s*(true|false)', var_section, re.DOTALL + ) + var_required = var_required_match.group(1).lower() == "true" if var_required_match else True + + variables.append( + {"name": var_name, "type": var_type, "description": var_desc, "required": var_required} + ) + + # If no variables found but this is a greeting workflow, add default user_name variable + if not variables and ("greeting" in name.lower()): + variables.append( + {"name": "user_name", "type": "string", "description": "User's name", "required": True} + ) + + nodes.append({"id": node_id, "type": "start", "title": title, "desc": desc, "variables": variables}) + elif node_type == "llm": + # Extract system prompt + system_prompt_match = re.search( + rf'"id"\s*:\s*"{node_id}".*?"system_prompt"\s*:\s*"([^"]*)"', text, re.DOTALL + ) + system_prompt = system_prompt_match.group(1) if system_prompt_match else "You are a helpful assistant" + + # Extract user prompt + user_prompt_match = re.search( + rf'"id"\s*:\s*"{node_id}".*?"user_prompt"\s*:\s*"([^"]*)"', text, re.DOTALL + ) + user_prompt = user_prompt_match.group(1) if user_prompt_match else "Please answer the user's question" + + nodes.append( + { + "id": node_id, + "type": "llm", + "title": title, + "desc": desc, + "provider": "zhipuai", + "model": "glm-4-flash", + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "variables": [], + } + ) + elif node_type in ("template", "template-transform"): + # Extract template content + template_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"template"\s*:\s*"([^"]*)"', text, re.DOTALL) + template = template_match.group(1) if template_match else "" + + # Fix triple curly brace issue in template, replace {{{ with {{ and }}} with }} + template = template.replace("{{{", "{{").replace("}}}", "}}") + + nodes.append( + { + "id": node_id, + "type": "template-transform", + "title": title, + "desc": desc, + "template": template, + "variables": [], + } + ) + elif node_type == "end": + # Extract outputs + outputs = [] + output_section_match = re.search( + rf'"id"\s*:\s*"{node_id}".*?"outputs"\s*:\s*\[(.*?)\]', text, re.DOTALL + ) + + if output_section_match: + output_section = output_section_match.group(1) + output_matches = re.finditer( + r'\{\s*"name"\s*:\s*"([^"]*)"\s*,\s*"type"\s*:\s*"([^"]*)"', output_section + ) + + for output_match in output_matches: + output_name = output_match.group(1) + output_type = output_match.group(2) + + # Extract source node + source_node_match = re.search( + rf'"name"\s*:\s*"{output_name}".*?"source_node"\s*:\s*"([^"]*)"', output_section, re.DOTALL + ) + source_node = source_node_match.group(1) if source_node_match else "" + + # Extract source variable + source_var_match = re.search( + rf'"name"\s*:\s*"{output_name}".*?"source_variable"\s*:\s*"([^"]*)"', + output_section, + re.DOTALL, + ) + source_var = source_var_match.group(1) if source_var_match else "" + + outputs.append( + { + "name": output_name, + "type": output_type, + "source_node": source_node, + "source_variable": source_var, + } + ) + + nodes.append({"id": node_id, "type": "end", "title": title, "desc": desc, "outputs": outputs}) + else: + # Other node types + nodes.append({"id": node_id, "type": node_type, "title": title, "desc": desc}) + + # Extract connections + connections = [] + conn_matches = re.finditer(r'\{\s*"source"\s*:\s*"([^"]*)"\s*,\s*"target"\s*:\s*"([^"]*)"', text) + + for match in conn_matches: + connections.append({"source": match.group(1), "target": match.group(2)}) + + return {"name": name, "description": description, "nodes": nodes, "connections": connections} diff --git a/api/core/auto/workflow_generator/utils/prompts.py b/api/core/auto/workflow_generator/utils/prompts.py new file mode 100644 index 0000000000..59151a9dcc --- /dev/null +++ b/api/core/auto/workflow_generator/utils/prompts.py @@ -0,0 +1,171 @@ +""" +Prompt Template Collection +Contains all prompt templates used for generating workflows +""" + +# Default model configuration +DEFAULT_MODEL_CONFIG = { + "provider": "zhipuai", + "model": "glm-4-flash", + "mode": "chat", + "completion_params": {"temperature": 0.7}, +} + + +# Default system prompt +DEFAULT_SYSTEM_PROMPT = "You are a workflow design expert who can design Dify workflows based on user requirements." + + +# Code node template +CODE_NODE_TEMPLATE = """def main(input_var): + # Process input variable + result = input_var + + # Return a dictionary; keys must exactly match variable names defined in outputs + return {"output_var_name": result}""" + + +def build_workflow_prompt(user_requirement: str) -> str: + """ + Build workflow generation prompt + + Args: + user_requirement: User requirement description + + Returns: + Prompt string + """ + # String concatenation to avoid brace escaping + prompt_part1 = ( + """ + Please design a Dify workflow based on the following user requirement: + + User requirement: """ + + user_requirement + + """ + + The description's language should align consistently with the user's requirements. + + Generate a concise workflow description containing the following node types: + - Start: Start node, defines workflow input parameters + - LLM: Large Language Model node for text generation + - Code: Code node to execute Python code + - Template: Template node for formatting outputs + - End: End node, defines workflow output + + 【Important Guidelines】: + 1. When referencing variables in LLM nodes, use the format {{#nodeID.variable_name#}}, e.g., {{#1740019130520.user_question#}}, where 1740019130520 is the source node ID. Otherwise, in most cases, the user prompt should define a template to guide the LLM’s response. + 2. Code nodes must define a `main` function that directly receives variables from upstream nodes as parameters; do not use template syntax inside the function. + 3. Dictionary keys returned by Code nodes must exactly match the variable names defined in outputs. + 4. Variables in Template nodes must strictly use double curly braces format "{{ variable_name }}"; note exactly two curly braces, neither one nor three. For example, "User question is: {{ user_question }}, answer: {{ answer }}". Triple curly braces such as "{{{ variable_name }}}" are strictly forbidden. + 5. IMPORTANT: In Code nodes, the function parameter names MUST EXACTLY MATCH the variable names defined in that Code node. For example, if a Code node defines a variable with name "input_text" that receives data from an upstream node, the function parameter must also be named "input_text" (e.g., def main(input_text): ...). + 6. CRITICAL: LLM nodes ALWAYS output their result in a variable named "text". When a Code node receives data from an LLM node, the source_variable MUST be "text". For example, if a Code node has a variable named "llm_output" that receives data from an LLM node, the source_variable should be "text", not "input_text" or any other name. + + Return the workflow description in JSON format as follows: + ```json + { + "name": "Workflow Name", + "description": "Workflow description", + "nodes": [ + { + "id": "node1", + "type": "start", + "title": "Start Node", + "description": "Description of the start node", + "variables": [ + { + "name": "variable_name", + "type": "string|number", + "description": "Variable description", + "required": true|false + } + ] + }, + { + "id": "node2", + "type": "llm", + "title": "LLM Node", + "description": "Description of LLM node", + "system_prompt": "System prompt", + "user_prompt": "User prompt, variables referenced using {{#nodeID.variable_name#}}, e.g., {{#node1.variable_name#}}", + "provider": "zhipuai", + "model": "glm-4-flash", + "variables": [ + { + "name": "variable_name", + "type": "string|number", + "source_node": "node1", + "source_variable": "variable_name" + } + ] + }, + { + "id": "node3", + "type": "code", + "title": "Code Node", + "description": "Description of the code node", + "code": "def main(input_var):\n import re\n match = re.search(r'Result[::](.*?)(?=[.]|$)', input_var)\n result = match.group(1).strip() if match else 'Not found'\n return {'output': result}", + "variables": [ + { + "name": "input_var", + "type": "string|number", + "source_node": "node2", + "source_variable": "text" + } + ], + "outputs": [ + { + "name": "output_var_name", + "type": "string|number|object" + } + ] + }, + { + "id": "node4", + "type": "template", + "title": "Template Node", + "description": "Description of the template node", + "template": "Template content using double curly braces, e.g.: The result is: {{ result }}", + "variables": [ + { + "name": "variable_name", + "type": "string|number", + "source_node": "node3", + "source_variable": "output_var_name" + } + ] + }, + { + "id": "node5", + "type": "end", + "title": "End Node", + "description": "Description of the end node", + "outputs": [ + { + "name": "output_variable_name", + "type": "string|number", + "source_node": "node4", + "source_variable": "output" + } + ] + } + ], + "connections": [ + {"source": "node1", "target": "node2"}, + {"source": "node2", "target": "node3"}, + {"source": "node3", "target": "node4"}, + {"source": "node4", "target": "node5"} + ] + } + ``` + + Ensure the workflow logic is coherent, node connections are correct, and variable passing is logical. + Generate unique numeric IDs for each node, e.g., 1740019130520. + Generate appropriate unique names for each variable across the workflow. + Ensure all LLM nodes use provider "zhipuai" and model "glm-4-flash". + + Note: LLM nodes usually return a long text; Code nodes typically require regex to extract relevant information. + """ # noqa: E501 + ) + + return prompt_part1 diff --git a/api/core/auto/workflow_generator/utils/type_mapper.py b/api/core/auto/workflow_generator/utils/type_mapper.py new file mode 100644 index 0000000000..44236aee28 --- /dev/null +++ b/api/core/auto/workflow_generator/utils/type_mapper.py @@ -0,0 +1,50 @@ +""" +Type Mapping Utility +Used to map string types to Dify types +""" + +from core.auto.node_types.common import InputVarType, VarType + + +def map_var_type_to_input_type(var_type: str) -> InputVarType: + """ + Map variable type to input variable type + + Args: + var_type: Variable type string + + Returns: + Input variable type + """ + type_map = { + "string": InputVarType.text_input, + "number": InputVarType.number, + "boolean": InputVarType.select, + "object": InputVarType.json, + "array": InputVarType.json, + "file": InputVarType.file, + } + + return type_map.get(var_type.lower(), InputVarType.text_input) + + +def map_string_to_var_type(type_str: str) -> VarType: + """ + Map string to variable type + + Args: + type_str: Type string + + Returns: + Variable type + """ + type_map = { + "string": VarType.string, + "number": VarType.number, + "boolean": VarType.boolean, + "object": VarType.object, + "array": VarType.array, + "file": VarType.file, + } + + return type_map.get(type_str.lower(), VarType.string) diff --git a/api/core/auto/workflow_generator/workflow.py b/api/core/auto/workflow_generator/workflow.py new file mode 100644 index 0000000000..8a5e07b6b6 --- /dev/null +++ b/api/core/auto/workflow_generator/workflow.py @@ -0,0 +1,134 @@ +import json +from typing import Any + +import yaml + +from core.auto.node_types.common import CompleteEdge, CompleteNode + + +class Workflow: + """ + Workflow class + """ + + def __init__(self, name: str, nodes: list[CompleteNode], edges: list[CompleteEdge]): + """ + Initialize workflow + + Args: + name: Workflow name + nodes: List of nodes + edges: List of edges + """ + self.name = name + self.nodes = nodes + self.edges = edges + + def to_dict(self) -> dict[str, Any]: + """ + Convert workflow to dictionary + + Returns: + Workflow dictionary + """ + # Apply basic information (fixed template) + app_info = { + "description": "", + "icon": "🤖", + "icon_background": "#FFEAD5", + "mode": "workflow", + "name": self.name, + "use_icon_as_answer_icon": False, + } + + # Feature configuration (fixed template) + features = { + "file_upload": { + "allowed_file_extensions": [".JPG", ".JPEG", ".PNG", ".GIF", ".WEBP", ".SVG"], + "allowed_file_types": ["image"], + "allowed_file_upload_methods": ["local_file", "remote_url"], + "enabled": False, + "fileUploadConfig": { + "audio_file_size_limit": 50, + "batch_count_limit": 5, + "file_size_limit": 15, + "image_file_size_limit": 10, + "video_file_size_limit": 100, + }, + "image": {"enabled": False, "number_limits": 3, "transfer_methods": ["local_file", "remote_url"]}, + "number_limits": 3, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + + # View configuration (fixed template) + viewport = {"x": 92.96659905656679, "y": 79.13437154762897, "zoom": 0.9002006986311041} + + # Nodes and edges + nodes_data = [] + for node in self.nodes: + node_data = node.to_json() + nodes_data.append(node_data) + + edges_data = [] + for edge in self.edges: + edge_data = edge.to_json() + edges_data.append(edge_data) + + # Build a complete workflow dictionary + workflow_dict = { + "app": app_info, + "kind": "app", + "version": "0.1.2", + "workflow": { + "conversation_variables": [], + "environment_variables": [], + "features": features, + "graph": {"edges": edges_data, "nodes": nodes_data, "viewport": viewport}, + }, + } + + return workflow_dict + + def save_to_yaml(self, file_path: str): + """ + Save workflow to YAML file + + Args: + file_path: File path + """ + workflow_dict = self.to_dict() + + with open(file_path, "w", encoding="utf-8") as f: + yaml.dump(workflow_dict, f, allow_unicode=True, sort_keys=False) + + print(f"Workflow saved to: {file_path}") + + def save_to_json(self, file_path: str): + """ + Save workflow to JSON file + + Args: + file_path: File path + """ + workflow_dict = self.to_dict() + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(workflow_dict, f, indent=2, ensure_ascii=False) + + print(f"Workflow saved to: {file_path}") + + def to_yaml(self) -> str: + """ + Convert workflow to YAML string + + Returns: + YAML string + """ + return yaml.dump(self.to_dict(), allow_unicode=True, sort_keys=False) diff --git a/api/core/auto/workflow_generator/workflow_generator.py b/api/core/auto/workflow_generator/workflow_generator.py new file mode 100644 index 0000000000..accf4ddacc --- /dev/null +++ b/api/core/auto/workflow_generator/workflow_generator.py @@ -0,0 +1,159 @@ +""" +Workflow Generator +Used to generate Dify workflows based on user requirements +""" + +from pydantic import ValidationError + +from core.auto.workflow_generator.generators.edge_generator import EdgeGenerator +from core.auto.workflow_generator.generators.layout_engine import LayoutEngine +from core.auto.workflow_generator.generators.node_generator import NodeGenerator +from core.auto.workflow_generator.models.workflow_description import WorkflowDescription +from core.auto.workflow_generator.utils.config_manager import ConfigManager +from core.auto.workflow_generator.utils.debug_manager import DebugManager +from core.auto.workflow_generator.utils.llm_client import LLMClient +from core.auto.workflow_generator.utils.prompts import build_workflow_prompt +from core.auto.workflow_generator.workflow import Workflow +from core.model_manager import ModelInstance + + +class WorkflowGenerator: + """Workflow generator for creating Dify workflows based on user requirements""" + + def __init__(self, model_instance: ModelInstance, config_dir: str = "config", debug_enabled: bool = False): + """ + Initialize workflow generator + + Args: + api_key: LLM API key + config_dir: Configuration directory path + model_name: Specified model name, uses default model if not specified + debug_enabled: Whether to enable debug mode + """ + # Load configuration + self.config = ConfigManager(config_dir) + + # Initialize debug manager + self.debug_manager = DebugManager(config=self.config.get("debug", default={}), debug_enabled=debug_enabled) + + # Get model configuration + + # Initialize LLM client + self.llm_client = LLMClient(model_instance=model_instance, debug_manager=self.debug_manager) + + def generate_workflow(self, user_requirement: str) -> str: + """ + Generate workflow based on user requirements + + Args: + user_requirement: User requirement description + output_path: Output file path, uses default path from config if None + + Returns: + Generated workflow YAML file path + """ + print("\n===== Starting Workflow Generation =====") + print(f"User requirement: {user_requirement}") + + # Save user requirement + if self.debug_manager.should_save("workflow"): + self.debug_manager.save_text(user_requirement, "user_requirement.txt", "workflow") + + # Use default path from config if output path not specified + + # Step 1: Generate simple workflow description + print("\n----- Step 1: Generating Simple Workflow Description -----") + workflow_description = self._generate_workflow_description(user_requirement) + print(f"Workflow name: {workflow_description.name}") + print(f"Workflow description: {workflow_description.description}") + print(f"Number of nodes: {len(workflow_description.nodes)}") + print(f"Number of connections: {len(workflow_description.connections)}") + + # Save workflow description + if self.debug_manager.should_save("workflow"): + self.debug_manager.save_json(workflow_description.dict(), "workflow_description.json", "workflow") + + # Step 2: Parse description and generate nodes + print("\n----- Step 2: Parsing Description, Generating Nodes -----") + nodes = NodeGenerator.create_nodes(workflow_description.nodes) + print(f"Generated nodes: {len(nodes)}") + for i, node in enumerate(nodes): + print(f"Node {i + 1}: ID={node.id}, Type={node.data.type.value}, Title={node.data.title}") + + # Save node information + if self.debug_manager.should_save("workflow"): + nodes_data = [node.dict() for node in nodes] + self.debug_manager.save_json(nodes_data, "nodes.json", "workflow") + + # Step 3: Generate edges + print("\n----- Step 3: Generating Edges -----") + edges = EdgeGenerator.create_edges(nodes, workflow_description.connections) + print(f"Generated edges: {len(edges)}") + for i, edge in enumerate(edges): + print(f"Edge {i + 1}: ID={edge.id}, Source={edge.source}, Target={edge.target}") + + # Save edge information + if self.debug_manager.should_save("workflow"): + edges_data = [edge.dict() for edge in edges] + self.debug_manager.save_json(edges_data, "edges.json", "workflow") + + # Step 4: Apply layout + print("\n----- Step 4: Applying Layout -----") + LayoutEngine.apply_topological_layout(nodes, edges) + print("Applied topological sort layout") + + # Save nodes with layout + if self.debug_manager.should_save("workflow"): + nodes_with_layout = [node.dict() for node in nodes] + self.debug_manager.save_json(nodes_with_layout, "nodes_with_layout.json", "workflow") + + # Step 5: Generate YAML + print("\n----- Step 5: Generating YAML -----") + workflow = Workflow(name=workflow_description.name, nodes=nodes, edges=edges) + + # Ensure output directory exists + + # Save as YAML + + # Save final YAML + print("\n===== Workflow Generation Complete =====") + return workflow.to_yaml() + + def _generate_workflow_description(self, user_requirement: str) -> WorkflowDescription: + """ + Generate simple workflow description using LLM + + Args: + user_requirement: User requirement description + + Returns: + Simple workflow description + """ + # Build prompt + print("Building prompt...") + prompt = build_workflow_prompt(user_requirement) + + # Call LLM + print("Calling LLM to generate workflow description...") + response_text = self.llm_client.generate(prompt) + + # Parse LLM response + print("Parsing LLM response...") + workflow_description_dict = self.llm_client.extract_json(response_text) + + try: + # Parse into WorkflowDescription object + print("Converting JSON to WorkflowDescription object...") + workflow_description = WorkflowDescription.parse_obj(workflow_description_dict) + return workflow_description + except ValidationError as e: + # If parsing fails, print error and raise exception + error_msg = f"Failed to parse workflow description: {e}" + print(error_msg) + + # Save error information + if self.debug_manager.should_save("workflow"): + self.debug_manager.save_text(str(e), "validation_error.txt", "workflow") + self.debug_manager.save_json(workflow_description_dict, "invalid_workflow_description.json", "workflow") + + raise ValueError(error_msg) diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/(commonLayout)/apps/NewAppCard.tsx index a90af4ea85..7093ff4aad 100644 --- a/web/app/(commonLayout)/apps/NewAppCard.tsx +++ b/web/app/(commonLayout)/apps/NewAppCard.tsx @@ -12,6 +12,8 @@ import CreateFromDSLModal, { CreateFromDSLModalTab } from '@/app/components/app/ import { useProviderContext } from '@/context/provider-context' import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files' import cn from '@/utils/classnames' +import AutoGenerateModal from '@/app/components/app/auto-generate-modal' +import { Agent } from '@/app/components/base/icons/src/vender/workflow' export type CreateAppCardProps = { className?: string @@ -28,7 +30,7 @@ const CreateAppCard = forwardRef(({ classNam const [showNewAppTemplateDialog, setShowNewAppTemplateDialog] = useState(false) const [showNewAppModal, setShowNewAppModal] = useState(false) const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(!!dslUrl) - + const [showAutoGenerateModal, setShowAutoGenerateModal] = useState(false) const activeTab = useMemo(() => { if (dslUrl) return CreateFromDSLModalTab.FROM_URL @@ -39,7 +41,7 @@ const CreateAppCard = forwardRef(({ classNam return (
{t('app.createApp')}
@@ -57,6 +59,12 @@ const CreateAppCard = forwardRef(({ classNam {t('app.importDSL')} +
(({ classNam onSuccess() }} /> + setShowAutoGenerateModal(false)} + onSuccess={() => { + onPlanInfoChanged() + if (onSuccess) + onSuccess() + }} + />
) }) diff --git a/web/app/components/app/auto-generate-modal/index.tsx b/web/app/components/app/auto-generate-modal/index.tsx new file mode 100644 index 0000000000..c1de14745c --- /dev/null +++ b/web/app/components/app/auto-generate-modal/index.tsx @@ -0,0 +1,203 @@ +import type { FC } from 'react' +import React from 'react' +import cn from 'classnames' +import useBoolean from 'ahooks/lib/useBoolean' +import { useTranslation } from 'react-i18next' +import { generateWorkflow } from '@/service/debug' +import { type Model, ModelModeType } from '@/types/app' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import { useContext } from 'use-context-selector' + +import Loading from '@/app/components/base/loading' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' +import { importDSL } from '@/service/apps' +import { DSLImportMode, DSLImportStatus } from '@/models/app' +import { NEED_REFRESH_APP_LIST_KEY } from '@/config' +import { getRedirection } from '@/utils/app-redirection' +import { useAppContext } from '@/context/app-context' +import { useRouter } from 'next/navigation' +import { ToastContext } from '../../base/toast' +import Generator from '../../base/icons/src/vender/other/Generator' +export type IGetCodeGeneratorResProps = { + isShow: boolean + onClose: () => void + onSuccess?: () => void +} + +export const AutoGenerateModal: FC = ( + { + isShow, + onClose, + onSuccess, + }, +) => { + const { notify } = useContext(ToastContext) + + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const { t } = useTranslation() + const { push } = useRouter() + + const [instruction, setInstruction] = React.useState('') + const [isLoading, { setTrue: setLoadingTrue, setFalse: setLoadingFalse }] = useBoolean(false) + const { isCurrentWorkspaceEditor } = useAppContext() + const [res, setRes] = React.useState(null) + const isValid = () => { + if (instruction.trim() === '') { + notify({ + type: 'error', + message: t('common.errorMsg.fieldRequired', { + field: t('appDebug.code.instruction'), + }), + }) + return false + } + return true + } + const model: Model = { + provider: currentProvider?.provider || '', + name: currentModel?.model || '', + mode: ModelModeType.chat, + // This is a fixed parameter + completion_params: { + temperature: 0.7, + max_tokens: 0, + top_p: 0, + echo: false, + stop: [], + presence_penalty: 0, + frequency_penalty: 0, + }, + } + const isInLLMNode = true + const onGenerate = async () => { + if (!isValid()) + return + if (isLoading) + return + setLoadingTrue() + try { + const res = await generateWorkflow({ + instruction, + model_config: model, + }) + setRes(res) + } + finally { + setLoadingFalse() + } + } + + const renderLoading = ( +
+ +
{t('appDebug.autoGenerate.loading')}
+
+ ) + const renderNoData = ( +
+ +
+
{t('appDebug.autoGenerate.noDataLine1')}
+
{t('appDebug.autoGenerate.noDataLine2')}
+
+
+ ) + + return ( + +
+
+
+
{t('appDebug.autoGenerate.title')}
+
{t('appDebug.autoGenerate.description')}
+
+
+ + +
+
+
+
{t('appDebug.autoGenerate.instruction')}
+