refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444)
This commit is contained in:
parent
e2d214e030
commit
a10b207de2
@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
|
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||||
|
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
|
|||||||
:param config: model config args
|
:param config: model config args
|
||||||
"""
|
"""
|
||||||
external_data_variables = []
|
external_data_variables = []
|
||||||
variables = []
|
variable_entities = []
|
||||||
|
|
||||||
# old external_data_tools
|
# old external_data_tools
|
||||||
external_data_tools = config.get('external_data_tools', [])
|
external_data_tools = config.get('external_data_tools', [])
|
||||||
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# variables and external_data_tools
|
# variables and external_data_tools
|
||||||
for variable in config.get('user_input_form', []):
|
for variables in config.get('user_input_form', []):
|
||||||
typ = list(variable.keys())[0]
|
variable_type = list(variables.keys())[0]
|
||||||
if typ == 'external_data_tool':
|
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||||
val = variable[typ]
|
variable = variables[variable_type]
|
||||||
if 'config' not in val:
|
if 'config' not in variable:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
external_data_variables.append(
|
external_data_variables.append(
|
||||||
ExternalDataVariableEntity(
|
ExternalDataVariableEntity(
|
||||||
variable=val['variable'],
|
variable=variable['variable'],
|
||||||
type=val['type'],
|
type=variable['type'],
|
||||||
config=val['config']
|
config=variable['config']
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif typ in [
|
elif variable_type in [
|
||||||
VariableEntity.Type.TEXT_INPUT.value,
|
VariableEntityType.TEXT_INPUT,
|
||||||
VariableEntity.Type.PARAGRAPH.value,
|
VariableEntityType.PARAGRAPH,
|
||||||
VariableEntity.Type.NUMBER.value,
|
VariableEntityType.NUMBER,
|
||||||
|
VariableEntityType.SELECT,
|
||||||
]:
|
]:
|
||||||
variables.append(
|
variable = variables[variable_type]
|
||||||
|
variable_entities.append(
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
type=VariableEntity.Type.value_of(typ),
|
type=variable_type,
|
||||||
variable=variable[typ].get('variable'),
|
variable=variable.get('variable'),
|
||||||
description=variable[typ].get('description'),
|
description=variable.get('description'),
|
||||||
label=variable[typ].get('label'),
|
label=variable.get('label'),
|
||||||
required=variable[typ].get('required', False),
|
required=variable.get('required', False),
|
||||||
max_length=variable[typ].get('max_length'),
|
max_length=variable.get('max_length'),
|
||||||
default=variable[typ].get('default'),
|
options=variable.get('options'),
|
||||||
)
|
default=variable.get('default'),
|
||||||
)
|
|
||||||
elif typ == VariableEntity.Type.SELECT.value:
|
|
||||||
variables.append(
|
|
||||||
VariableEntity(
|
|
||||||
type=VariableEntity.Type.SELECT,
|
|
||||||
variable=variable[typ].get('variable'),
|
|
||||||
description=variable[typ].get('description'),
|
|
||||||
label=variable[typ].get('label'),
|
|
||||||
required=variable[typ].get('required', False),
|
|
||||||
options=variable[typ].get('options'),
|
|
||||||
default=variable[typ].get('default'),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return variables, external_data_variables
|
return variable_entities, external_data_variables
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||||
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
|
|||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
return config, ["external_data_tools"]
|
return config, ["external_data_tools"]
|
||||||
|
@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
|
|||||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VariableEntityType(str, Enum):
|
||||||
|
TEXT_INPUT = "text-input"
|
||||||
|
SELECT = "select"
|
||||||
|
PARAGRAPH = "paragraph"
|
||||||
|
NUMBER = "number"
|
||||||
|
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||||
|
|
||||||
|
|
||||||
class VariableEntity(BaseModel):
|
class VariableEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
Variable Entity.
|
Variable Entity.
|
||||||
"""
|
"""
|
||||||
class Type(Enum):
|
|
||||||
TEXT_INPUT = 'text-input'
|
|
||||||
SELECT = 'select'
|
|
||||||
PARAGRAPH = 'paragraph'
|
|
||||||
NUMBER = 'number'
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def value_of(cls, value: str) -> 'VariableEntity.Type':
|
|
||||||
"""
|
|
||||||
Get value of given mode.
|
|
||||||
|
|
||||||
:param value: mode value
|
|
||||||
:return: mode
|
|
||||||
"""
|
|
||||||
for mode in cls:
|
|
||||||
if mode.value == value:
|
|
||||||
return mode
|
|
||||||
raise ValueError(f'invalid variable type value {value}')
|
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
label: str
|
label: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
type: Type
|
type: VariableEntityType
|
||||||
required: bool = False
|
required: bool = False
|
||||||
max_length: Optional[int] = None
|
max_length: Optional[int] = None
|
||||||
options: Optional[list[str]] = None
|
options: Optional[list[str]] = None
|
||||||
default: Optional[str] = None
|
default: Optional[str] = None
|
||||||
hint: Optional[str] = None
|
hint: Optional[str] = None
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self.variable
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
"""
|
"""
|
||||||
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
|||||||
"""
|
"""
|
||||||
Workflow UI Based App Config Entity.
|
Workflow UI Based App Config Entity.
|
||||||
"""
|
"""
|
||||||
workflow_id: str
|
workflow_id: str
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
||||||
|
|
||||||
|
|
||||||
class BaseAppGenerator:
|
class BaseAppGenerator:
|
||||||
@ -9,29 +9,29 @@ class BaseAppGenerator:
|
|||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
variables = app_config.variables
|
variables = app_config.variables
|
||||||
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||||
return filtered_inputs
|
return filtered_inputs
|
||||||
|
|
||||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||||
user_input_value = inputs.get(var.name)
|
user_input_value = inputs.get(var.variable)
|
||||||
if var.required and not user_input_value:
|
if var.required and not user_input_value:
|
||||||
raise ValueError(f'{var.name} is required in input form')
|
raise ValueError(f'{var.variable} is required in input form')
|
||||||
if not var.required and not user_input_value:
|
if not var.required and not user_input_value:
|
||||||
# TODO: should we return None here if the default value is None?
|
# TODO: should we return None here if the default value is None?
|
||||||
return var.default or ''
|
return var.default or ''
|
||||||
if (
|
if (
|
||||||
var.type
|
var.type
|
||||||
in (
|
in (
|
||||||
VariableEntity.Type.TEXT_INPUT,
|
VariableEntityType.TEXT_INPUT,
|
||||||
VariableEntity.Type.SELECT,
|
VariableEntityType.SELECT,
|
||||||
VariableEntity.Type.PARAGRAPH,
|
VariableEntityType.PARAGRAPH,
|
||||||
)
|
)
|
||||||
and user_input_value
|
and user_input_value
|
||||||
and not isinstance(user_input_value, str)
|
and not isinstance(user_input_value, str)
|
||||||
):
|
):
|
||||||
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||||
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||||
# may raise ValueError if user_input_value is not a valid number
|
# may raise ValueError if user_input_value is not a valid number
|
||||||
try:
|
try:
|
||||||
if '.' in user_input_value:
|
if '.' in user_input_value:
|
||||||
@ -39,14 +39,14 @@ class BaseAppGenerator:
|
|||||||
else:
|
else:
|
||||||
return int(user_input_value)
|
return int(user_input_value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"{var.name} in input form must be a valid number")
|
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||||
if var.type == VariableEntity.Type.SELECT:
|
if var.type == VariableEntityType.SELECT:
|
||||||
options = var.options or []
|
options = var.options or []
|
||||||
if user_input_value not in options:
|
if user_input_value not in options:
|
||||||
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||||
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||||
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||||
|
|
||||||
return user_input_value
|
return user_input_value
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
@ -18,6 +18,13 @@ from models.model import App, AppMode
|
|||||||
from models.tools import WorkflowToolProvider
|
from models.tools import WorkflowToolProvider
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||||
|
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
||||||
|
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||||
|
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||||
|
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolProviderController(ToolProviderController):
|
class WorkflowToolProviderController(ToolProviderController):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError('app not found')
|
raise ValueError('app not found')
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(**{
|
controller = WorkflowToolProviderController(**{
|
||||||
'identity': {
|
'identity': {
|
||||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||||
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
'credentials_schema': {},
|
'credentials_schema': {},
|
||||||
'provider_id': db_provider.id or '',
|
'provider_id': db_provider.id or '',
|
||||||
})
|
})
|
||||||
|
|
||||||
# init tools
|
# init tools
|
||||||
|
|
||||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||||
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
@property
|
@property
|
||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
return ToolProviderType.WORKFLOW
|
return ToolProviderType.WORKFLOW
|
||||||
|
|
||||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||||
"""
|
"""
|
||||||
get db provider tool
|
get db provider tool
|
||||||
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
if variable:
|
if variable:
|
||||||
parameter_type = None
|
parameter_type = None
|
||||||
options = None
|
options = None
|
||||||
if variable.type in [
|
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||||
VariableEntity.Type.TEXT_INPUT,
|
|
||||||
VariableEntity.Type.PARAGRAPH,
|
|
||||||
]:
|
|
||||||
parameter_type = ToolParameter.ToolParameterType.STRING
|
|
||||||
elif variable.type in [
|
|
||||||
VariableEntity.Type.SELECT
|
|
||||||
]:
|
|
||||||
parameter_type = ToolParameter.ToolParameterType.SELECT
|
|
||||||
elif variable.type in [
|
|
||||||
VariableEntity.Type.NUMBER
|
|
||||||
]:
|
|
||||||
parameter_type = ToolParameter.ToolParameterType.NUMBER
|
|
||||||
else:
|
|
||||||
raise ValueError(f'unsupported variable type {variable.type}')
|
raise ValueError(f'unsupported variable type {variable.type}')
|
||||||
|
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||||
if variable.type == VariableEntity.Type.SELECT and variable.options:
|
|
||||||
|
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||||
options = [
|
options = [
|
||||||
ToolParameterOption(
|
ToolParameterOption(
|
||||||
value=option,
|
value=option,
|
||||||
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
"""
|
"""
|
||||||
if self.tools is not None:
|
if self.tools is not None:
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||||
WorkflowToolProvider.tenant_id == tenant_id,
|
WorkflowToolProvider.tenant_id == tenant_id,
|
||||||
WorkflowToolProvider.app_id == self.provider_id,
|
WorkflowToolProvider.app_id == self.provider_id,
|
||||||
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
|
|
||||||
if not db_providers:
|
if not db_providers:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||||
|
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||||
"""
|
"""
|
||||||
get tool by name
|
get tool by name
|
||||||
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
|
|||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
if tool.identity.name == tool_name:
|
if tool.identity.name == tool_name:
|
||||||
return tool
|
return tool
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
|
||||||
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
|
|||||||
"""
|
"""
|
||||||
Start Node Data
|
Start Node Data
|
||||||
"""
|
"""
|
||||||
variables: list[VariableEntity] = []
|
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||||
|
@ -14,6 +14,7 @@ from core.app.app_config.entities import (
|
|||||||
ModelConfigEntity,
|
ModelConfigEntity,
|
||||||
PromptTemplateEntity,
|
PromptTemplateEntity,
|
||||||
VariableEntity,
|
VariableEntity,
|
||||||
|
VariableEntityType,
|
||||||
)
|
)
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_variables():
|
def default_variables():
|
||||||
return [
|
value = [
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
variable="text_input",
|
variable="text_input",
|
||||||
label="text-input",
|
label="text-input",
|
||||||
type=VariableEntity.Type.TEXT_INPUT
|
type=VariableEntityType.TEXT_INPUT,
|
||||||
),
|
),
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
variable="paragraph",
|
variable="paragraph",
|
||||||
label="paragraph",
|
label="paragraph",
|
||||||
type=VariableEntity.Type.PARAGRAPH
|
type=VariableEntityType.PARAGRAPH,
|
||||||
),
|
),
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
variable="select",
|
variable="select",
|
||||||
label="select",
|
label="select",
|
||||||
type=VariableEntity.Type.SELECT
|
type=VariableEntityType.SELECT,
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def test__convert_to_start_node(default_variables):
|
def test__convert_to_start_node(default_variables):
|
||||||
|
Loading…
Reference in New Issue
Block a user