diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 6b58df617d..9b7012c3fb 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -114,6 +114,10 @@ class VariableEntity(BaseModel): default: Optional[str] = None hint: Optional[str] = None + @property + def name(self) -> str: + return self.variable + class ExternalDataVariableEntity(BaseModel): """ diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20ae6ff676..6f48aa2363 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,52 +1,56 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.app.app_config.entities import AppConfig, VariableEntity class BaseAppGenerator: - def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig): - if user_inputs is None: - user_inputs = {} - - filtered_inputs = {} - + def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]: + user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - for variable_config in variables: - variable = variable_config.variable - - if (variable not in user_inputs - or user_inputs[variable] is None - or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')): - if variable_config.required: - raise ValueError(f"{variable} is required in input form") - else: - filtered_inputs[variable] = variable_config.default if variable_config.default is not None else "" - continue - - value = user_inputs[variable] - - if value is not None: - if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str): - raise ValueError(f"{variable} in input form must be a string") - elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str): - if '.' in value: - value = float(value) - else: - value = int(value) - - if variable_config.type == VariableEntity.Type.SELECT: - options = variable_config.options if variable_config.options is not None else [] - if value not in options: - raise ValueError(f"{variable} in input form must be one of the following: {options}") - elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]: - if variable_config.max_length is not None: - max_length = variable_config.max_length - if len(value) > max_length: - raise ValueError(f'{variable} in input form must be less than {max_length} characters') - - if value and isinstance(value, str): - filtered_inputs[variable] = value.replace('\x00', '') - else: - filtered_inputs[variable] = value if value is not None else None - + filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} + filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} return filtered_inputs + def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): + user_input_value = inputs.get(var.name) + if var.required and not user_input_value: + raise ValueError(f'{var.name} is required in input form') + if not var.required and not user_input_value: + # TODO: should we return None here if the default value is None? + return var.default or '' + if ( + var.type + in ( + VariableEntity.Type.TEXT_INPUT, + VariableEntity.Type.SELECT, + VariableEntity.Type.PARAGRAPH, + ) + and user_input_value + and not isinstance(user_input_value, str) + ): + raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") + if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): + # may raise ValueError if user_input_value is not a valid number + try: + if '.' in user_input_value: + return float(user_input_value) + else: + return int(user_input_value) + except ValueError: + raise ValueError(f"{var.name} in input form must be a valid number") + if var.type == VariableEntity.Type.SELECT: + options = var.options or [] + if user_input_value not in options: + raise ValueError(f'{var.name} in input form must be one of the following: {options}') + elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): + if var.max_length and user_input_value and len(user_input_value) > var.max_length: + raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') + + return user_input_value + + def _sanitize_value(self, value: Any) -> Any: + if isinstance(value, str): + return value.replace('\x00', '') + return value