diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 3eb006b46e..15fa4d99fd 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,6 +1,6 @@ import re -from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity +from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType from core.external_data_tool.factory import ExternalDataToolFactory @@ -13,7 +13,7 @@ class BasicVariablesConfigManager: :param config: model config args """ external_data_variables = [] - variables = [] + variable_entities = [] # old external_data_tools external_data_tools = config.get('external_data_tools', []) @@ -30,50 +30,41 @@ class BasicVariablesConfigManager: ) # variables and external_data_tools - for variable in config.get('user_input_form', []): - typ = list(variable.keys())[0] - if typ == 'external_data_tool': - val = variable[typ] - if 'config' not in val: + for variables in config.get('user_input_form', []): + variable_type = list(variables.keys())[0] + if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: + variable = variables[variable_type] + if 'config' not in variable: continue external_data_variables.append( ExternalDataVariableEntity( - variable=val['variable'], - type=val['type'], - config=val['config'] + variable=variable['variable'], + type=variable['type'], + config=variable['config'] ) ) - elif typ in [ - VariableEntity.Type.TEXT_INPUT.value, - VariableEntity.Type.PARAGRAPH.value, - VariableEntity.Type.NUMBER.value, + elif variable_type in [ + VariableEntityType.TEXT_INPUT, + VariableEntityType.PARAGRAPH, + VariableEntityType.NUMBER, + VariableEntityType.SELECT, ]: - variables.append( + variable = variables[variable_type] + variable_entities.append( VariableEntity( - type=VariableEntity.Type.value_of(typ), - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - max_length=variable[typ].get('max_length'), - default=variable[typ].get('default'), - ) - ) - elif typ == VariableEntity.Type.SELECT.value: - variables.append( - VariableEntity( - type=VariableEntity.Type.SELECT, - variable=variable[typ].get('variable'), - description=variable[typ].get('description'), - label=variable[typ].get('label'), - required=variable[typ].get('required', False), - options=variable[typ].get('options'), - default=variable[typ].get('default'), + type=variable_type, + variable=variable.get('variable'), + description=variable.get('description'), + label=variable.get('label'), + required=variable.get('required', False), + max_length=variable.get('max_length'), + options=variable.get('options'), + default=variable.get('default'), ) ) - return variables, external_data_variables + return variable_entities, external_data_variables @classmethod def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: @@ -183,4 +174,4 @@ class BasicVariablesConfigManager: config=config ) - return config, ["external_data_tools"] \ No newline at end of file + return config, ["external_data_tools"] diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 05a42a898e..bbb10d3d76 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None +class VariableEntityType(str, Enum): + TEXT_INPUT = "text-input" + SELECT = "select" + PARAGRAPH = "paragraph" + NUMBER = "number" + EXTERNAL_DATA_TOOL = "external-data-tool" + + class VariableEntity(BaseModel): """ Variable Entity. """ - class Type(Enum): - TEXT_INPUT = 'text-input' - SELECT = 'select' - PARAGRAPH = 'paragraph' - NUMBER = 'number' - - @classmethod - def value_of(cls, value: str) -> 'VariableEntity.Type': - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f'invalid variable type value {value}') variable: str label: str description: Optional[str] = None - type: Type + type: VariableEntityType required: bool = False max_length: Optional[int] = None options: Optional[list[str]] = None default: Optional[str] = None hint: Optional[str] = None - @property - def name(self) -> str: - return self.variable - class ExternalDataVariableEntity(BaseModel): """ @@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig): """ Workflow UI Based App Config Entity. """ - workflow_id: str \ No newline at end of file + workflow_id: str diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 6f48aa2363..9e331dff4d 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any, Optional -from core.app.app_config.entities import AppConfig, VariableEntity +from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType class BaseAppGenerator: @@ -9,29 +9,29 @@ class BaseAppGenerator: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} + filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} return filtered_inputs def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): - user_input_value = inputs.get(var.name) + user_input_value = inputs.get(var.variable) if var.required and not user_input_value: - raise ValueError(f'{var.name} is required in input form') + raise ValueError(f'{var.variable} is required in input form') if not var.required and not user_input_value: # TODO: should we return None here if the default value is None? return var.default or '' if ( var.type in ( - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.SELECT, - VariableEntity.Type.PARAGRAPH, + VariableEntityType.TEXT_INPUT, + VariableEntityType.SELECT, + VariableEntityType.PARAGRAPH, ) and user_input_value and not isinstance(user_input_value, str) ): - raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") - if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): + raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") + if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): # may raise ValueError if user_input_value is not a valid number try: if '.' in user_input_value: @@ -39,14 +39,14 @@ class BaseAppGenerator: else: return int(user_input_value) except ValueError: - raise ValueError(f"{var.name} in input form must be a valid number") - if var.type == VariableEntity.Type.SELECT: + raise ValueError(f"{var.variable} in input form must be a valid number") + if var.type == VariableEntityType.SELECT: options = var.options or [] if user_input_value not in options: - raise ValueError(f'{var.name} in input form must be one of the following: {options}') - elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): + raise ValueError(f'{var.variable} in input form must be one of the following: {options}') + elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): if var.max_length and user_input_value and len(user_input_value) > var.max_length: - raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') + raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') return user_input_value diff --git a/api/core/tools/provider/workflow_tool_provider.py b/api/core/tools/provider/workflow_tool_provider.py index f7911fea1d..f14abac767 100644 --- a/api/core/tools/provider/workflow_tool_provider.py +++ b/api/core/tools/provider/workflow_tool_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from core.app.app_config.entities import VariableEntity +from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -18,6 +18,13 @@ from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import Workflow +VARIABLE_TO_PARAMETER_TYPE_MAPPING = { + VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, + VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, + VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, + VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, +} + class WorkflowToolProviderController(ToolProviderController): provider_id: str @@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController): if not app: raise ValueError('app not found') - + controller = WorkflowToolProviderController(**{ 'identity': { 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', @@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController): 'credentials_schema': {}, 'provider_id': db_provider.id or '', }) - + # init tools controller.tools = [controller._get_db_provider_tool(db_provider, app)] @@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController): @property def provider_type(self) -> ToolProviderType: return ToolProviderType.WORKFLOW - + def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: """ get db provider tool @@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController): if variable: parameter_type = None options = None - if variable.type in [ - VariableEntity.Type.TEXT_INPUT, - VariableEntity.Type.PARAGRAPH, - ]: - parameter_type = ToolParameter.ToolParameterType.STRING - elif variable.type in [ - VariableEntity.Type.SELECT - ]: - parameter_type = ToolParameter.ToolParameterType.SELECT - elif variable.type in [ - VariableEntity.Type.NUMBER - ]: - parameter_type = ToolParameter.ToolParameterType.NUMBER - else: + if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: raise ValueError(f'unsupported variable type {variable.type}') - - if variable.type == VariableEntity.Type.SELECT and variable.options: + parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] + + if variable.type == VariableEntityType.SELECT and variable.options: options = [ ToolParameterOption( value=option, @@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController): """ if self.tools is not None: return self.tools - + db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == self.provider_id, @@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController): if not db_providers: return [] - + self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] return self.tools - + def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: """ get tool by name @@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController): for tool in self.tools: if tool.identity.name == tool_name: return tool - + return None diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py index 0bd5f203bf..b81ce15bd7 100644 --- a/api/core/workflow/nodes/start/entities.py +++ b/api/core/workflow/nodes/start/entities.py @@ -1,3 +1,7 @@ +from collections.abc import Sequence + +from pydantic import Field + from core.app.app_config.entities import VariableEntity from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData): """ Start Node Data """ - variables: list[VariableEntity] = [] + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index f589cd2097..a45423bf39 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( ModelConfigEntity, PromptTemplateEntity, VariableEntity, + VariableEntityType, ) from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode @@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter @pytest.fixture def default_variables(): - return [ + value = [ VariableEntity( variable="text_input", label="text-input", - type=VariableEntity.Type.TEXT_INPUT + type=VariableEntityType.TEXT_INPUT, ), VariableEntity( variable="paragraph", label="paragraph", - type=VariableEntity.Type.PARAGRAPH + type=VariableEntityType.PARAGRAPH, ), VariableEntity( variable="select", label="select", - type=VariableEntity.Type.SELECT - ) + type=VariableEntityType.SELECT, + ), ] + return value def test__convert_to_start_node(default_variables):