diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index df40aec154..c5bf35edb6 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Union +from typing import Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -32,6 +32,28 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[True] = True, + call_depth: int = 0, + ) -> Generator[dict, None, None]: ... + + @overload + def generate( + self, app_model: App, + workflow: Workflow, + user: Union[Account, EndUser], + args: dict, + invoke_from: InvokeFrom, + stream: Literal[False] = False, + call_depth: int = 0, + ) -> dict: ... + def generate( self, app_model: App, workflow: Workflow, diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 52513c13f9..68db0d5b2e 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Generator, Mapping +from collections.abc import Generator from copy import deepcopy from enum import Enum from typing import TYPE_CHECKING, Any, Optional, Union diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 4724c992d2..7a669a4fe4 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -48,6 +48,8 @@ class WorkflowTool(Tool): from core.app.apps.workflow.app_generator import WorkflowAppGenerator generator = WorkflowAppGenerator() + assert self.runtime and self.runtime.invoke_from + result = generator.generate( app_model=app, workflow=workflow, @@ -154,7 +156,7 @@ class WorkflowTool(Tool): try: file_var_list = [FileVar(**f) for f in file] for file_var in file_var_list: - file_dict = { + file_dict: dict[str, Any] = { 'transfer_method': file_var.transfer_method.value, 'type': file_var.type.value, }