refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444)

This commit is contained in:
-LAN- 2024-08-20 17:30:14 +08:00 committed by GitHub
parent e2d214e030
commit a10b207de2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 82 additions and 104 deletions

View File

@ -1,6 +1,6 @@
import re import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
:param config: model config args :param config: model config args
""" """
external_data_variables = [] external_data_variables = []
variables = [] variable_entities = []
# old external_data_tools # old external_data_tools
external_data_tools = config.get('external_data_tools', []) external_data_tools = config.get('external_data_tools', [])
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
) )
# variables and external_data_tools # variables and external_data_tools
for variable in config.get('user_input_form', []): for variables in config.get('user_input_form', []):
typ = list(variable.keys())[0] variable_type = list(variables.keys())[0]
if typ == 'external_data_tool': if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
val = variable[typ] variable = variables[variable_type]
if 'config' not in val: if 'config' not in variable:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=val['variable'], variable=variable['variable'],
type=val['type'], type=variable['type'],
config=val['config'] config=variable['config']
) )
) )
elif typ in [ elif variable_type in [
VariableEntity.Type.TEXT_INPUT.value, VariableEntityType.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH.value, VariableEntityType.PARAGRAPH,
VariableEntity.Type.NUMBER.value, VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]: ]:
variables.append( variable = variables[variable_type]
variable_entities.append(
VariableEntity( VariableEntity(
type=VariableEntity.Type.value_of(typ), type=variable_type,
variable=variable[typ].get('variable'), variable=variable.get('variable'),
description=variable[typ].get('description'), description=variable.get('description'),
label=variable[typ].get('label'), label=variable.get('label'),
required=variable[typ].get('required', False), required=variable.get('required', False),
max_length=variable[typ].get('max_length'), max_length=variable.get('max_length'),
default=variable[typ].get('default'), options=variable.get('options'),
) default=variable.get('default'),
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
) )
) )
return variables, external_data_variables return variable_entities, external_data_variables
@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
config=config config=config
) )
return config, ["external_data_tools"] return config, ["external_data_tools"]

View File

@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel): class VariableEntity(BaseModel):
""" """
Variable Entity. Variable Entity.
""" """
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str variable: str
label: str label: str
description: Optional[str] = None description: Optional[str] = None
type: Type type: VariableEntityType
required: bool = False required: bool = False
max_length: Optional[int] = None max_length: Optional[int] = None
options: Optional[list[str]] = None options: Optional[list[str]] = None
default: Optional[str] = None default: Optional[str] = None
hint: Optional[str] = None hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
""" """
Workflow UI Based App Config Entity. Workflow UI Based App Config Entity.
""" """
workflow_id: str workflow_id: str

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator: class BaseAppGenerator:
@ -9,29 +9,29 @@ class BaseAppGenerator:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name) user_input_value = inputs.get(var.variable)
if var.required and not user_input_value: if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form') raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value: if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None? # TODO: should we return None here if the default value is None?
return var.default or '' return var.default or ''
if ( if (
var.type var.type
in ( in (
VariableEntity.Type.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntity.Type.SELECT, VariableEntityType.SELECT,
VariableEntity.Type.PARAGRAPH, VariableEntityType.PARAGRAPH,
) )
and user_input_value and user_input_value
and not isinstance(user_input_value, str) and not isinstance(user_input_value, str)
): ):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
try: try:
if '.' in user_input_value: if '.' in user_input_value:
@ -39,14 +39,14 @@ class BaseAppGenerator:
else: else:
return int(user_input_value) return int(user_input_value)
except ValueError: except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number") raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT: if var.type == VariableEntityType.SELECT:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}') raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value return user_input_value

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
@ -18,6 +18,13 @@ from models.model import App, AppMode
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import Workflow from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}
class WorkflowToolProviderController(ToolProviderController): class WorkflowToolProviderController(ToolProviderController):
provider_id: str provider_id: str
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
if not app: if not app:
raise ValueError('app not found') raise ValueError('app not found')
controller = WorkflowToolProviderController(**{ controller = WorkflowToolProviderController(**{
'identity': { 'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
'credentials_schema': {}, 'credentials_schema': {},
'provider_id': db_provider.id or '', 'provider_id': db_provider.id or '',
}) })
# init tools # init tools
controller.tools = [controller._get_db_provider_tool(db_provider, app)] controller.tools = [controller._get_db_provider_tool(db_provider, app)]
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
""" """
get db provider tool get db provider tool
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
if variable: if variable:
parameter_type = None parameter_type = None
options = None options = None
if variable.type in [ if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
raise ValueError(f'unsupported variable type {variable.type}') raise ValueError(f'unsupported variable type {variable.type}')
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntity.Type.SELECT and variable.options:
if variable.type == VariableEntityType.SELECT and variable.options:
options = [ options = [
ToolParameterOption( ToolParameterOption(
value=option, value=option,
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
""" """
if self.tools is not None: if self.tools is not None:
return self.tools return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.app_id == self.provider_id,
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
if not db_providers: if not db_providers:
return [] return []
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
return self.tools return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
""" """
get tool by name get tool by name
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
for tool in self.tools: for tool in self.tools:
if tool.identity.name == tool_name: if tool.identity.name == tool_name:
return tool return tool
return None return None

View File

@ -1,3 +1,7 @@
from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
""" """
Start Node Data Start Node Data
""" """
variables: list[VariableEntity] = [] variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -14,6 +14,7 @@ from core.app.app_config.entities import (
ModelConfigEntity, ModelConfigEntity,
PromptTemplateEntity, PromptTemplateEntity,
VariableEntity, VariableEntity,
VariableEntityType,
) )
from core.helper import encrypter from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter
@pytest.fixture @pytest.fixture
def default_variables(): def default_variables():
return [ value = [
VariableEntity( VariableEntity(
variable="text_input", variable="text_input",
label="text-input", label="text-input",
type=VariableEntity.Type.TEXT_INPUT type=VariableEntityType.TEXT_INPUT,
), ),
VariableEntity( VariableEntity(
variable="paragraph", variable="paragraph",
label="paragraph", label="paragraph",
type=VariableEntity.Type.PARAGRAPH type=VariableEntityType.PARAGRAPH,
), ),
VariableEntity( VariableEntity(
variable="select", variable="select",
label="select", label="select",
type=VariableEntity.Type.SELECT type=VariableEntityType.SELECT,
) ),
] ]
return value
def test__convert_to_start_node(default_variables): def test__convert_to_start_node(default_variables):