refactor: using DeclarativeBase as parent class of models, refactored tools
This commit is contained in:
parent
c8bc3892b3
commit
e9e5c8806a
@ -6,7 +6,7 @@ from flask_restful import Resource, reqparse
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from libs.helper import StrLen, email, get_remote_ip
|
from libs.helper import StrLen, email, get_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.model import DifySetup
|
from models.model import DifySetup, db
|
||||||
from services.account_service import RegisterService, TenantService
|
from services.account_service import RegisterService, TenantService
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
@ -69,7 +69,7 @@ def setup_required(view):
|
|||||||
|
|
||||||
def get_setup_status():
|
def get_setup_status():
|
||||||
if dify_config.EDITION == "SELF_HOSTED":
|
if dify_config.EDITION == "SELF_HOSTED":
|
||||||
return DifySetup.query.first()
|
return db.session.query(DifySetup).first()
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -610,16 +610,17 @@ class ToolLabelsApi(Resource):
|
|||||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||||
|
|
||||||
# builtin tool provider
|
# builtin tool provider
|
||||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
|
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
|
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
|
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
|
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||||
)
|
)
|
||||||
api.add_resource(
|
api.add_resource(
|
||||||
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
|
ToolBuiltinProviderCredentialsSchemaApi,
|
||||||
|
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
|
||||||
)
|
)
|
||||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
|
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||||
|
|
||||||
# api tool provider
|
# api tool provider
|
||||||
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||||
|
@ -14,9 +14,9 @@ class PluginToolManager(BasePluginManager):
|
|||||||
provider follows format: plugin_id/provider_name
|
provider follows format: plugin_id/provider_name
|
||||||
"""
|
"""
|
||||||
if "/" in provider:
|
if "/" in provider:
|
||||||
parts = provider.split("/", 1)
|
parts = provider.split("/", -1)
|
||||||
if len(parts) == 2:
|
if len(parts) >= 2:
|
||||||
return parts[0], parts[1]
|
return "/".join(parts[:-1]), parts[-1]
|
||||||
raise ValueError(f"invalid provider format: {provider}")
|
raise ValueError(f"invalid provider format: {provider}")
|
||||||
|
|
||||||
raise ValueError(f"invalid provider format: {provider}")
|
raise ValueError(f"invalid provider format: {provider}")
|
||||||
@ -46,6 +46,10 @@ class PluginToolManager(BasePluginManager):
|
|||||||
for provider in response:
|
for provider in response:
|
||||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each tool to plugin_id/provider_name
|
||||||
|
for tool in provider.declaration.tools:
|
||||||
|
tool.identity.provider = provider.declaration.identity.name
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||||
@ -54,15 +58,26 @@ class PluginToolManager(BasePluginManager):
|
|||||||
"""
|
"""
|
||||||
plugin_id, provider_name = self._split_provider(provider)
|
plugin_id, provider_name = self._split_provider(provider)
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
for tool in json_response.get("data", {}).get("declaration", {}).get("tools", []):
|
||||||
|
tool["identity"]["provider"] = provider_name
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
response = self._request_with_plugin_daemon_response(
|
response = self._request_with_plugin_daemon_response(
|
||||||
"GET",
|
"GET",
|
||||||
f"plugin/{tenant_id}/management/tool",
|
f"plugin/{tenant_id}/management/tool",
|
||||||
PluginToolProviderEntity,
|
PluginToolProviderEntity,
|
||||||
params={"provider": provider_name, "plugin_id": plugin_id},
|
params={"provider": provider_name, "plugin_id": plugin_id},
|
||||||
|
transformer=transformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each tool to plugin_id/provider_name
|
||||||
|
for tool in response.declaration.tools:
|
||||||
|
tool.identity.provider = response.declaration.identity.name
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
|
@ -11,12 +11,10 @@ from core.tools.plugin_tool.tool import PluginTool
|
|||||||
class PluginToolProviderController(BuiltinToolProviderController):
|
class PluginToolProviderController(BuiltinToolProviderController):
|
||||||
entity: ToolProviderEntityWithPlugin
|
entity: ToolProviderEntityWithPlugin
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
plugin_id: str
|
|
||||||
|
|
||||||
def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_id: str) -> None:
|
def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||||
self.entity = entity
|
self.entity = entity
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.plugin_id = plugin_id
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_type(self) -> ToolProviderType:
|
def provider_type(self) -> ToolProviderType:
|
||||||
@ -35,7 +33,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
|||||||
if not manager.validate_provider_credentials(
|
if not manager.validate_provider_credentials(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
plugin_id=self.plugin_id,
|
|
||||||
provider=self.entity.identity.name,
|
provider=self.entity.identity.name,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
):
|
):
|
||||||
@ -54,7 +51,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
|||||||
entity=tool_entity,
|
entity=tool_entity,
|
||||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
plugin_id=self.plugin_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> list[PluginTool]:
|
def get_tools(self) -> list[PluginTool]:
|
||||||
@ -66,7 +62,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
|||||||
entity=tool_entity,
|
entity=tool_entity,
|
||||||
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
runtime=ToolRuntime(tenant_id=self.tenant_id),
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
plugin_id=self.plugin_id,
|
|
||||||
)
|
)
|
||||||
for tool_entity in self.entity.tools
|
for tool_entity in self.entity.tools
|
||||||
]
|
]
|
||||||
|
@ -9,12 +9,10 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
|||||||
|
|
||||||
class PluginTool(Tool):
|
class PluginTool(Tool):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
plugin_id: str
|
|
||||||
|
|
||||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_id: str) -> None:
|
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None:
|
||||||
super().__init__(entity, runtime)
|
super().__init__(entity, runtime)
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.plugin_id = plugin_id
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_provider_type(self) -> ToolProviderType:
|
def tool_provider_type(self) -> ToolProviderType:
|
||||||
@ -25,7 +23,6 @@ class PluginTool(Tool):
|
|||||||
return manager.invoke(
|
return manager.invoke(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
plugin_id=self.plugin_id,
|
|
||||||
tool_provider=self.entity.identity.provider,
|
tool_provider=self.entity.identity.provider,
|
||||||
tool_name=self.entity.identity.name,
|
tool_name=self.entity.identity.name,
|
||||||
credentials=self.runtime.credentials,
|
credentials=self.runtime.credentials,
|
||||||
@ -37,5 +34,4 @@ class PluginTool(Tool):
|
|||||||
entity=self.entity,
|
entity=self.entity,
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
plugin_id=self.plugin_id,
|
|
||||||
)
|
)
|
||||||
|
@ -86,7 +86,6 @@ class ToolManager:
|
|||||||
return PluginToolProviderController(
|
return PluginToolProviderController(
|
||||||
entity=provider_entity.declaration,
|
entity=provider_entity.declaration,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
plugin_id=provider_entity.plugin_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -158,12 +157,11 @@ class ToolManager:
|
|||||||
|
|
||||||
# decrypt the credentials
|
# decrypt the credentials
|
||||||
credentials = builtin_provider.credentials
|
credentials = builtin_provider.credentials
|
||||||
controller = cls.get_builtin_provider(provider_id, tenant_id)
|
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
config=controller.get_credentials_schema(),
|
config=provider_controller.get_credentials_schema(),
|
||||||
provider_type=controller.provider_type.value,
|
provider_type=provider_controller.provider_type.value,
|
||||||
provider_identity=controller.entity.identity.name,
|
provider_identity=provider_controller.entity.identity.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||||
@ -400,7 +398,6 @@ class ToolManager:
|
|||||||
PluginToolProviderController(
|
PluginToolProviderController(
|
||||||
entity=provider.declaration,
|
entity=provider.declaration,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
plugin_id=provider.plugin_id,
|
|
||||||
)
|
)
|
||||||
for provider in provider_entities
|
for provider in provider_entities
|
||||||
]
|
]
|
||||||
@ -525,7 +522,7 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(provider, PluginToolProviderController):
|
if isinstance(provider, PluginToolProviderController):
|
||||||
result_providers[f"plugin_provider.{user_provider.name}.{provider.plugin_id}"] = user_provider
|
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
|
||||||
else:
|
else:
|
||||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||||
|
|
||||||
|
@ -31,19 +31,16 @@ def get_engine_url():
|
|||||||
# from myapp import mymodel
|
# from myapp import mymodel
|
||||||
# target_metadata = mymodel.Base.metadata
|
# target_metadata = mymodel.Base.metadata
|
||||||
config.set_main_option('sqlalchemy.url', get_engine_url())
|
config.set_main_option('sqlalchemy.url', get_engine_url())
|
||||||
target_db = current_app.extensions['migrate'].db
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
# can be acquired:
|
# can be acquired:
|
||||||
# my_important_option = config.get_main_option("my_important_option")
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
# ... etc.
|
# ... etc.
|
||||||
|
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
def get_metadata():
|
def get_metadata():
|
||||||
if hasattr(target_db, 'metadatas'):
|
return Base.metadata
|
||||||
return target_db.metadatas[None]
|
|
||||||
return target_db.metadata
|
|
||||||
|
|
||||||
|
|
||||||
def include_object(object, name, type_, reflected, compare_to):
|
def include_object(object, name, type_, reflected, compare_to):
|
||||||
if type_ == "foreign_key_constraint":
|
if type_ == "foreign_key_constraint":
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
"""increase max length of builtin tool provider
|
||||||
|
|
||||||
|
Revision ID: ddcc8bbef391
|
||||||
|
Revises: d57ba9ebb251
|
||||||
|
Create Date: 2024-09-29 08:35:58.062698
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'ddcc8bbef391'
|
||||||
|
down_revision = 'd57ba9ebb251'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('provider',
|
||||||
|
existing_type=sa.VARCHAR(length=40),
|
||||||
|
type_=sa.String(length=256),
|
||||||
|
existing_nullable=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('provider',
|
||||||
|
existing_type=sa.String(length=256),
|
||||||
|
type_=sa.VARCHAR(length=40),
|
||||||
|
existing_nullable=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -1,5 +1,5 @@
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import declarative_base
|
||||||
|
|
||||||
|
from extensions.ext_database import metadata
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
Base = declarative_base(metadata=metadata)
|
||||||
pass
|
|
||||||
|
@ -2,12 +2,15 @@ import json
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import UserMixin
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import Float, func, text
|
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
from core.file.tool_file_parser import ToolFileParser
|
||||||
@ -20,7 +23,7 @@ from .account import Account, Tenant
|
|||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
class DifySetup(db.Model):
|
class DifySetup(Base):
|
||||||
__tablename__ = "dify_setups"
|
__tablename__ = "dify_setups"
|
||||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||||
|
|
||||||
@ -55,7 +58,7 @@ class IconType(Enum):
|
|||||||
EMOJI = "emoji"
|
EMOJI = "emoji"
|
||||||
|
|
||||||
|
|
||||||
class App(db.Model):
|
class App(Base):
|
||||||
__tablename__ = "apps"
|
__tablename__ = "apps"
|
||||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
|
||||||
|
|
||||||
@ -133,7 +136,8 @@ class App(db.Model):
|
|||||||
return False
|
return False
|
||||||
if not app_model_config.agent_mode:
|
if not app_model_config.agent_mode:
|
||||||
return False
|
return False
|
||||||
if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get(
|
|
||||||
|
if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
|
||||||
"strategy", ""
|
"strategy", ""
|
||||||
) in {"function_call", "react"}:
|
) in {"function_call", "react"}:
|
||||||
self.mode = AppMode.AGENT_CHAT.value
|
self.mode = AppMode.AGENT_CHAT.value
|
||||||
@ -250,7 +254,7 @@ class AppModelConfig(Base):
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_dict(self) -> dict:
|
def model_dict(self):
|
||||||
return json.loads(self.model) if self.model else None
|
return json.loads(self.model) if self.model else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -284,6 +288,9 @@ class AppModelConfig(Base):
|
|||||||
)
|
)
|
||||||
if annotation_setting:
|
if annotation_setting:
|
||||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||||
|
if not collection_binding_detail:
|
||||||
|
raise ValueError("Collection binding detail not found")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": annotation_setting.id,
|
"id": annotation_setting.id,
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
@ -314,7 +321,7 @@ class AppModelConfig(Base):
|
|||||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_input_form_list(self) -> dict:
|
def user_input_form_list(self):
|
||||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -458,7 +465,7 @@ class AppModelConfig(Base):
|
|||||||
return new_app_model_config
|
return new_app_model_config
|
||||||
|
|
||||||
|
|
||||||
class RecommendedApp(db.Model):
|
class RecommendedApp(Base):
|
||||||
__tablename__ = "recommended_apps"
|
__tablename__ = "recommended_apps"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||||
@ -486,7 +493,7 @@ class RecommendedApp(db.Model):
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
class InstalledApp(db.Model):
|
class InstalledApp(Base):
|
||||||
__tablename__ = "installed_apps"
|
__tablename__ = "installed_apps"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
|
||||||
@ -522,7 +529,7 @@ class Conversation(Base):
|
|||||||
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
app_id = db.Column(StringUUID, nullable=False)
|
app_id = db.Column(StringUUID, nullable=False)
|
||||||
app_model_config_id = db.Column(StringUUID, nullable=True)
|
app_model_config_id = db.Column(StringUUID, nullable=True)
|
||||||
model_provider = db.Column(db.String(255), nullable=True)
|
model_provider = db.Column(db.String(255), nullable=True)
|
||||||
@ -546,10 +553,8 @@ class Conversation(Base):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
messages: Mapped[list["Message"]] = relationship(
|
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||||
"Message", backref="conversation", lazy="select", passive_deletes="all"
|
message_annotations = db.relationship(
|
||||||
)
|
|
||||||
message_annotations: Mapped[list["MessageAnnotation"]] = relationship(
|
|
||||||
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
|
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -578,7 +583,7 @@ class Conversation(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not app_model_config:
|
if not app_model_config:
|
||||||
raise ValueError("app config not found")
|
return {}
|
||||||
|
|
||||||
model_config = app_model_config.to_dict()
|
model_config = app_model_config.to_dict()
|
||||||
|
|
||||||
@ -692,12 +697,12 @@ class Conversation(Base):
|
|||||||
class Message(Base):
|
class Message(Base):
|
||||||
__tablename__ = "messages"
|
__tablename__ = "messages"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="message_pkey"),
|
PrimaryKeyConstraint("id", name="message_pkey"),
|
||||||
db.Index("message_app_id_idx", "app_id", "created_at"),
|
Index("message_app_id_idx", "app_id", "created_at"),
|
||||||
db.Index("message_conversation_id_idx", "conversation_id"),
|
Index("message_conversation_id_idx", "conversation_id"),
|
||||||
db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"),
|
Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"),
|
||||||
db.Index("message_account_idx", "app_id", "from_source", "from_account_id"),
|
Index("message_account_idx", "app_id", "from_source", "from_account_id"),
|
||||||
db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
|
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
@ -705,10 +710,10 @@ class Message(Base):
|
|||||||
model_provider = db.Column(db.String(255), nullable=True)
|
model_provider = db.Column(db.String(255), nullable=True)
|
||||||
model_id = db.Column(db.String(255), nullable=True)
|
model_id = db.Column(db.String(255), nullable=True)
|
||||||
override_model_configs = db.Column(db.Text)
|
override_model_configs = db.Column(db.Text)
|
||||||
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
|
conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
|
||||||
inputs: Mapped[str] = mapped_column(db.JSON)
|
inputs = db.Column(db.JSON)
|
||||||
query: Mapped[str] = mapped_column(db.Text, nullable=False)
|
query = db.Column(db.Text, nullable=False)
|
||||||
message: Mapped[str] = mapped_column(db.JSON, nullable=False)
|
message = db.Column(db.JSON, nullable=False)
|
||||||
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||||
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
|
||||||
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
|
||||||
@ -974,7 +979,7 @@ class Message(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MessageFeedback(db.Model):
|
class MessageFeedback(Base):
|
||||||
__tablename__ = "message_feedbacks"
|
__tablename__ = "message_feedbacks"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
|
||||||
@ -1009,15 +1014,15 @@ class MessageFile(Base):
|
|||||||
db.Index("message_file_created_by_idx", "created_by"),
|
db.Index("message_file_created_by_idx", "created_by"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
message_id = db.Column(StringUUID, nullable=False)
|
||||||
type: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
type = db.Column(db.String(255), nullable=False)
|
||||||
transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
transfer_method = db.Column(db.String(255), nullable=False)
|
||||||
url: Mapped[str] = mapped_column(db.Text, nullable=True)
|
url = db.Column(db.Text, nullable=True)
|
||||||
belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True)
|
belongs_to = db.Column(db.String(255), nullable=True)
|
||||||
upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
upload_file_id = db.Column(StringUUID, nullable=True)
|
||||||
created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
created_by_role = db.Column(db.String(255), nullable=False)
|
||||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
@ -1032,7 +1037,7 @@ class MessageAnnotation(Base):
|
|||||||
|
|
||||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
app_id = db.Column(StringUUID, nullable=False)
|
app_id = db.Column(StringUUID, nullable=False)
|
||||||
conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=True)
|
conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True)
|
||||||
message_id = db.Column(StringUUID, nullable=True)
|
message_id = db.Column(StringUUID, nullable=True)
|
||||||
question = db.Column(db.Text, nullable=True)
|
question = db.Column(db.Text, nullable=True)
|
||||||
content = db.Column(db.Text, nullable=False)
|
content = db.Column(db.Text, nullable=False)
|
||||||
@ -1052,7 +1057,7 @@ class MessageAnnotation(Base):
|
|||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
class AppAnnotationHitHistory(db.Model):
|
class AppAnnotationHitHistory(Base):
|
||||||
__tablename__ = "app_annotation_hit_histories"
|
__tablename__ = "app_annotation_hit_histories"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
|
||||||
@ -1090,7 +1095,7 @@ class AppAnnotationHitHistory(db.Model):
|
|||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
class AppAnnotationSetting(db.Model):
|
class AppAnnotationSetting(Base):
|
||||||
__tablename__ = "app_annotation_settings"
|
__tablename__ = "app_annotation_settings"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
|
||||||
@ -1138,7 +1143,7 @@ class AppAnnotationSetting(db.Model):
|
|||||||
return collection_binding_detail
|
return collection_binding_detail
|
||||||
|
|
||||||
|
|
||||||
class OperationLog(db.Model):
|
class OperationLog(Base):
|
||||||
__tablename__ = "operation_logs"
|
__tablename__ = "operation_logs"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
|
||||||
@ -1155,7 +1160,7 @@ class OperationLog(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class EndUser(UserMixin, db.Model):
|
class EndUser(UserMixin, Base):
|
||||||
__tablename__ = "end_users"
|
__tablename__ = "end_users"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
db.PrimaryKeyConstraint("id", name="end_user_pkey"),
|
||||||
@ -1175,7 +1180,7 @@ class EndUser(UserMixin, db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class Site(db.Model):
|
class Site(Base):
|
||||||
__tablename__ = "sites"
|
__tablename__ = "sites"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="site_pkey"),
|
db.PrimaryKeyConstraint("id", name="site_pkey"),
|
||||||
@ -1222,7 +1227,7 @@ class Site(db.Model):
|
|||||||
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
|
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
class ApiToken(db.Model):
|
class ApiToken(Base):
|
||||||
__tablename__ = "api_tokens"
|
__tablename__ = "api_tokens"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
db.PrimaryKeyConstraint("id", name="api_token_pkey"),
|
||||||
@ -1249,7 +1254,7 @@ class ApiToken(db.Model):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class UploadFile(db.Model):
|
class UploadFile(Base):
|
||||||
__tablename__ = "upload_files"
|
__tablename__ = "upload_files"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
|
||||||
@ -1273,7 +1278,7 @@ class UploadFile(db.Model):
|
|||||||
hash = db.Column(db.String(255), nullable=True)
|
hash = db.Column(db.String(255), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class ApiRequest(db.Model):
|
class ApiRequest(Base):
|
||||||
__tablename__ = "api_requests"
|
__tablename__ = "api_requests"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
db.PrimaryKeyConstraint("id", name="api_request_pkey"),
|
||||||
@ -1290,7 +1295,7 @@ class ApiRequest(db.Model):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class MessageChain(db.Model):
|
class MessageChain(Base):
|
||||||
__tablename__ = "message_chains"
|
__tablename__ = "message_chains"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
|
||||||
@ -1395,7 +1400,7 @@ class MessageAgentThought(Base):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_outputs_dict(self) -> dict:
|
def tool_outputs_dict(self):
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
try:
|
try:
|
||||||
if self.observation:
|
if self.observation:
|
||||||
@ -1417,7 +1422,7 @@ class MessageAgentThought(Base):
|
|||||||
return dict.fromkeys(tools, self.observation)
|
return dict.fromkeys(tools, self.observation)
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieverResource(db.Model):
|
class DatasetRetrieverResource(Base):
|
||||||
__tablename__ = "dataset_retriever_resources"
|
__tablename__ = "dataset_retriever_resources"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
|
||||||
@ -1444,7 +1449,7 @@ class DatasetRetrieverResource(db.Model):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||||
|
|
||||||
|
|
||||||
class Tag(db.Model):
|
class Tag(Base):
|
||||||
__tablename__ = "tags"
|
__tablename__ = "tags"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tag_pkey"),
|
db.PrimaryKeyConstraint("id", name="tag_pkey"),
|
||||||
@ -1462,7 +1467,7 @@ class Tag(db.Model):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class TagBinding(db.Model):
|
class TagBinding(Base):
|
||||||
__tablename__ = "tag_bindings"
|
__tablename__ = "tag_bindings"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
|
||||||
@ -1478,7 +1483,7 @@ class TagBinding(db.Model):
|
|||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class TraceAppConfig(db.Model):
|
class TraceAppConfig(Base):
|
||||||
__tablename__ = "trace_app_config"
|
__tablename__ = "trace_app_config"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
@ -35,7 +36,7 @@ class ProviderQuotaType(Enum):
|
|||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
class Provider(db.Model):
|
class Provider(Base):
|
||||||
"""
|
"""
|
||||||
Provider model representing the API providers and their configurations.
|
Provider model representing the API providers and their configurations.
|
||||||
"""
|
"""
|
||||||
@ -88,7 +89,7 @@ class Provider(db.Model):
|
|||||||
return self.is_valid and self.token_is_set
|
return self.is_valid and self.token_is_set
|
||||||
|
|
||||||
|
|
||||||
class ProviderModel(db.Model):
|
class ProviderModel(Base):
|
||||||
"""
|
"""
|
||||||
Provider model representing the API provider_models and their configurations.
|
Provider model representing the API provider_models and their configurations.
|
||||||
"""
|
"""
|
||||||
@ -113,7 +114,7 @@ class ProviderModel(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class TenantDefaultModel(db.Model):
|
class TenantDefaultModel(Base):
|
||||||
__tablename__ = "tenant_default_models"
|
__tablename__ = "tenant_default_models"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
|
||||||
@ -129,7 +130,7 @@ class TenantDefaultModel(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class TenantPreferredModelProvider(db.Model):
|
class TenantPreferredModelProvider(Base):
|
||||||
__tablename__ = "tenant_preferred_model_providers"
|
__tablename__ = "tenant_preferred_model_providers"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
|
||||||
@ -144,7 +145,7 @@ class TenantPreferredModelProvider(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class ProviderOrder(db.Model):
|
class ProviderOrder(Base):
|
||||||
__tablename__ = "provider_orders"
|
__tablename__ = "provider_orders"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
|
||||||
@ -169,7 +170,7 @@ class ProviderOrder(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelSetting(db.Model):
|
class ProviderModelSetting(Base):
|
||||||
"""
|
"""
|
||||||
Provider model settings for record the model enabled status and load balancing status.
|
Provider model settings for record the model enabled status and load balancing status.
|
||||||
"""
|
"""
|
||||||
@ -191,7 +192,7 @@ class ProviderModelSetting(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class LoadBalancingModelConfig(db.Model):
|
class LoadBalancingModelConfig(Base):
|
||||||
"""
|
"""
|
||||||
Configurations for load balancing models.
|
Configurations for load balancing models.
|
||||||
"""
|
"""
|
||||||
|
@ -3,11 +3,12 @@ import json
|
|||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
class DataSourceOauthBinding(db.Model):
|
class DataSourceOauthBinding(Base):
|
||||||
__tablename__ = "data_source_oauth_bindings"
|
__tablename__ = "data_source_oauth_bindings"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
|
||||||
@ -25,7 +26,7 @@ class DataSourceOauthBinding(db.Model):
|
|||||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
|
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||||
|
|
||||||
|
|
||||||
class DataSourceApiKeyAuthBinding(db.Model):
|
class DataSourceApiKeyAuthBinding(Base):
|
||||||
__tablename__ = "data_source_api_key_auth_bindings"
|
__tablename__ = "data_source_api_key_auth_bindings"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),
|
||||||
|
@ -3,9 +3,10 @@ from datetime import datetime, timezone
|
|||||||
from celery import states
|
from celery import states
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
|
|
||||||
class CeleryTask(db.Model):
|
class CeleryTask(Base):
|
||||||
"""Task result/status."""
|
"""Task result/status."""
|
||||||
|
|
||||||
__tablename__ = "celery_taskmeta"
|
__tablename__ = "celery_taskmeta"
|
||||||
@ -29,7 +30,7 @@ class CeleryTask(db.Model):
|
|||||||
queue = db.Column(db.String(155), nullable=True)
|
queue = db.Column(db.String(155), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class CeleryTaskSet(db.Model):
|
class CeleryTaskSet(Base):
|
||||||
"""TaskSet result."""
|
"""TaskSet result."""
|
||||||
|
|
||||||
__tablename__ = "celery_tasksetmeta"
|
__tablename__ = "celery_tasksetmeta"
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ class ToolProviderName(Enum):
|
|||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
class ToolProvider(db.Model):
|
class ToolProvider(Base):
|
||||||
__tablename__ = "tool_providers"
|
__tablename__ = "tool_providers"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),
|
db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
|
from sqlalchemy import ForeignKey
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -31,7 +34,7 @@ class BuiltinToolProvider(Base):
|
|||||||
# who created this tool provider
|
# who created this tool provider
|
||||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
# name of the tool provider
|
# name of the tool provider
|
||||||
provider: Mapped[str] = mapped_column(db.String(40), nullable=False)
|
provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
|
||||||
# credential of the tool provider
|
# credential of the tool provider
|
||||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
@ -182,7 +185,7 @@ class WorkflowToolProvider(Base):
|
|||||||
return db.session.query(App).filter(App.id == self.app_id).first()
|
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||||
|
|
||||||
|
|
||||||
class ToolModelInvoke(db.Model):
|
class ToolModelInvoke(Base):
|
||||||
"""
|
"""
|
||||||
store the invoke logs from tool invoke
|
store the invoke logs from tool invoke
|
||||||
"""
|
"""
|
||||||
@ -219,7 +222,7 @@ class ToolModelInvoke(db.Model):
|
|||||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
|
||||||
class ToolConversationVariables(db.Model):
|
class ToolConversationVariables(Base):
|
||||||
"""
|
"""
|
||||||
store the conversation variables from tool invoke
|
store the conversation variables from tool invoke
|
||||||
"""
|
"""
|
||||||
@ -275,3 +278,46 @@ class ToolFile(Base):
|
|||||||
mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
|
||||||
# original url
|
# original url
|
||||||
original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
|
original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated
|
||||||
|
class DeprecatedPublishedAppTool(Base):
|
||||||
|
"""
|
||||||
|
The table stores the apps published as a tool for each person.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "tool_published_apps"
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
|
||||||
|
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# id of the tool provider
|
||||||
|
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
# id of the app
|
||||||
|
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
|
||||||
|
# who published this tool
|
||||||
|
user_id = db.Column(StringUUID, nullable=False)
|
||||||
|
# description of the tool, stored in i18n format, for human
|
||||||
|
description = db.Column(db.Text, nullable=False)
|
||||||
|
# llm_description of the tool, for LLM
|
||||||
|
llm_description = db.Column(db.Text, nullable=False)
|
||||||
|
# query description, query will be seem as a parameter of the tool,
|
||||||
|
# to describe this parameter to llm, we need this field
|
||||||
|
query_description = db.Column(db.Text, nullable=False)
|
||||||
|
# query name, the name of the query parameter
|
||||||
|
query_name = db.Column(db.String(40), nullable=False)
|
||||||
|
# name of the tool provider
|
||||||
|
tool_name = db.Column(db.String(40), nullable=False)
|
||||||
|
# author
|
||||||
|
author = db.Column(db.String(40), nullable=False)
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description_i18n(self) -> I18nObject:
|
||||||
|
return I18nObject(**json.loads(self.description))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app(self) -> App:
|
||||||
|
return db.session.query(App).filter(App.id == self.app_id).first()
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
from .model import Message
|
from .model import Message
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
|
|
||||||
|
|
||||||
class SavedMessage(db.Model):
|
class SavedMessage(Base):
|
||||||
__tablename__ = "saved_messages"
|
__tablename__ = "saved_messages"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
|
||||||
@ -23,7 +24,7 @@ class SavedMessage(db.Model):
|
|||||||
return db.session.query(Message).filter(Message.id == self.message_id).first()
|
return db.session.query(Message).filter(Message.id == self.message_id).first()
|
||||||
|
|
||||||
|
|
||||||
class PinnedConversation(db.Model):
|
class PinnedConversation(Base):
|
||||||
__tablename__ = "pinned_conversations"
|
__tablename__ = "pinned_conversations"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),
|
||||||
|
@ -2,10 +2,13 @@ import json
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from sqlalchemy import func
|
if TYPE_CHECKING:
|
||||||
from sqlalchemy.orm import Mapped
|
from models.model import AppMode
|
||||||
|
|
||||||
|
from sqlalchemy import Index, PrimaryKeyConstraint, func
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
@ -13,6 +16,7 @@ from core.app.segments import SecretVariable, Variable, factory
|
|||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from models.base import Base
|
||||||
|
|
||||||
from .account import Account
|
from .account import Account
|
||||||
from .types import StringUUID
|
from .types import StringUUID
|
||||||
@ -75,7 +79,7 @@ class WorkflowType(Enum):
|
|||||||
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
|
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
|
||||||
|
|
||||||
|
|
||||||
class Workflow(db.Model):
|
class Workflow(Base):
|
||||||
"""
|
"""
|
||||||
Workflow, for `Workflow App` and `Chat App workflow mode`.
|
Workflow, for `Workflow App` and `Chat App workflow mode`.
|
||||||
|
|
||||||
@ -345,7 +349,7 @@ class WorkflowRunStatus(Enum):
|
|||||||
raise ValueError(f"invalid workflow run status value {value}")
|
raise ValueError(f"invalid workflow run status value {value}")
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRun(db.Model):
|
class WorkflowRun(Base):
|
||||||
"""
|
"""
|
||||||
Workflow Run
|
Workflow Run
|
||||||
|
|
||||||
@ -436,7 +440,7 @@ class WorkflowRun(db.Model):
|
|||||||
return json.loads(self.outputs) if self.outputs else None
|
return json.loads(self.outputs) if self.outputs else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def message(self) -> Optional["Message"]:
|
def message(self):
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -542,7 +546,7 @@ class WorkflowNodeExecutionStatus(Enum):
|
|||||||
raise ValueError(f"invalid workflow node execution status value {value}")
|
raise ValueError(f"invalid workflow node execution status value {value}")
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeExecution(db.Model):
|
class WorkflowNodeExecution(Base):
|
||||||
"""
|
"""
|
||||||
Workflow Node Execution
|
Workflow Node Execution
|
||||||
|
|
||||||
@ -708,7 +712,7 @@ class WorkflowAppLogCreatedFrom(Enum):
|
|||||||
raise ValueError(f"invalid workflow app log created from value {value}")
|
raise ValueError(f"invalid workflow app log created from value {value}")
|
||||||
|
|
||||||
|
|
||||||
class WorkflowAppLog(db.Model):
|
class WorkflowAppLog(Base):
|
||||||
"""
|
"""
|
||||||
Workflow App execution log, excluding workflow debugging records.
|
Workflow App execution log, excluding workflow debugging records.
|
||||||
|
|
||||||
@ -770,15 +774,20 @@ class WorkflowAppLog(db.Model):
|
|||||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
|
||||||
|
|
||||||
|
|
||||||
class ConversationVariable(db.Model):
|
class ConversationVariable(Base):
|
||||||
__tablename__ = "workflow_conversation_variables"
|
__tablename__ = "workflow_conversation_variables"
|
||||||
|
__table_args__ = (
|
||||||
|
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
|
||||||
|
Index("workflow__conversation_variables_app_id_idx", "app_id"),
|
||||||
|
Index("workflow__conversation_variables_created_at_idx", "created_at"),
|
||||||
|
)
|
||||||
|
|
||||||
id: Mapped[str] = db.Column(StringUUID, primary_key=True)
|
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||||
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
|
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
|
||||||
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
|
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
data = db.Column(db.Text, nullable=False)
|
data = mapped_column(db.Text, nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
updated_at = db.Column(
|
updated_at = mapped_column(
|
||||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user