fix: workflow loads tool provider icon

This commit is contained in:
Yeuoly 2024-12-02 21:08:36 +08:00
parent b10d6051ba
commit ad899844a1
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
24 changed files with 217 additions and 55 deletions

View File

@ -1,9 +1,14 @@
from contextvars import ContextVar from contextvars import ContextVar
from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id") tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")

View File

@ -15,6 +15,7 @@ class AgentToolEntity(BaseModel):
provider_id: str provider_id: str
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = {} tool_parameters: dict[str, Any] = {}
plugin_unique_identifier: str | None = None
class AgentPromptEntity(BaseModel): class AgentPromptEntity(BaseModel):

View File

@ -152,6 +152,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -201,6 +203,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
), ),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,

View File

@ -1,3 +1,4 @@
import contextvars
import logging import logging
import threading import threading
import uuid import uuid
@ -170,7 +171,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
worker_thread = threading.Thread( worker_thread = threading.Thread(
target=self._generate_worker, target=self._generate_worker,
kwargs={ kwargs={
"flask_app": current_app._get_current_object(), "flask_app": current_app._get_current_object(), # type: ignore
"contexts": contextvars.copy_context(),
"application_generate_entity": application_generate_entity, "application_generate_entity": application_generate_entity,
"queue_manager": queue_manager, "queue_manager": queue_manager,
"conversation_id": conversation.id, "conversation_id": conversation.id,
@ -195,6 +197,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker( def _generate_worker(
self, self,
flask_app: Flask, flask_app: Flask,
context: contextvars.Context,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
@ -209,6 +212,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID :param message_id: message ID
:return: :return:
""" """
for var, val in context.items():
var.set(val)
with flask_app.app_context(): with flask_app.app_context():
try: try:
# get conversation and message # get conversation and message

View File

@ -119,7 +119,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
trace_manager=trace_manager, trace_manager=trace_manager,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -223,6 +226,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
), ),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,

View File

@ -1,7 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from pydantic import BaseModel
from core.plugin.entities.bundle import PluginBundleDependency from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import ( from core.plugin.entities.plugin import (
GenericProviderID, GenericProviderID,
@ -152,15 +150,12 @@ class PluginInstallationManager(BasePluginManager):
Fetch a plugin manifest. Fetch a plugin manifest.
""" """
class PluginDeclarationResponse(BaseModel):
declaration: PluginDeclaration
return self._request_with_plugin_daemon_response( return self._request_with_plugin_daemon_response(
"GET", "GET",
f"plugin/{tenant_id}/management/fetch/manifest", f"plugin/{tenant_id}/management/fetch/manifest",
PluginDeclarationResponse, PluginDeclaration,
params={"plugin_unique_identifier": plugin_unique_identifier}, params={"plugin_unique_identifier": plugin_unique_identifier},
).declaration )
def fetch_plugin_installation_by_ids( def fetch_plugin_installation_by_ids(
self, tenant_id: str, plugin_ids: Sequence[str] self, tenant_id: str, plugin_ids: Sequence[str]

View File

@ -74,6 +74,7 @@ class BuiltinToolProviderController(ToolProviderController):
tool["identity"]["provider"] = provider tool["identity"]["provider"] = provider
tools.append( tools.append(
assistant_tool_class( assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool), entity=ToolEntity(**tool),
runtime=ToolRuntime(tenant_id=""), runtime=ToolRuntime(tenant_id=""),
) )

View File

@ -1,6 +1,7 @@
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.tools.utils.model_invocation_utils import ModelInvocationUtils
@ -19,6 +20,25 @@ class BuiltinTool(Tool):
:param meta: the meta data of a tool call processing :param meta: the meta data of a tool call processing
""" """
provider: str
def __init__(self, provider: str, **kwargs):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
provider=self.provider,
)
def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
""" """
invoke model invoke model

View File

@ -109,6 +109,7 @@ class ApiToolProviderController(ToolProviderController):
""" """
return ApiTool( return ApiTool(
api_bundle=tool_bundle, api_bundle=tool_bundle,
provider_id=self.provider_id,
entity=ToolEntity( entity=ToolEntity(
identity=ToolIdentity( identity=ToolIdentity(
author=tool_bundle.author, author=tool_bundle.author,

View File

@ -22,14 +22,16 @@ API_TOOL_DEFAULT_TIMEOUT = (
class ApiTool(Tool): class ApiTool(Tool):
api_bundle: ApiToolBundle api_bundle: ApiToolBundle
provider_id: str
""" """
Api tool Api tool
""" """
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime): def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
super().__init__(entity, runtime) super().__init__(entity, runtime)
self.api_bundle = api_bundle self.api_bundle = api_bundle
self.provider_id = provider_id
def fork_tool_runtime(self, runtime: ToolRuntime): def fork_tool_runtime(self, runtime: ToolRuntime):
""" """
@ -42,6 +44,7 @@ class ApiTool(Tool):
entity=self.entity, entity=self.entity,
api_bundle=self.api_bundle.model_copy(), api_bundle=self.api_bundle.model_copy(),
runtime=runtime, runtime=runtime,
provider_id=self.provider_id,
) )
def validate_credentials( def validate_credentials(

View File

@ -34,6 +34,7 @@ class ToolProviderApiEntity(BaseModel):
is_team_authorization: bool = False is_team_authorization: bool = False
allow_delete: bool = True allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list) tools: list[ToolApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list) labels: list[str] = Field(default_factory=list)
@ -58,6 +59,7 @@ class ToolProviderApiEntity(BaseModel):
"author": self.author, "author": self.author,
"name": self.name, "name": self.name,
"plugin_id": self.plugin_id, "plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(), "description": self.description.to_dict(),
"icon": self.icon, "icon": self.icon,
"label": self.label.to_dict(), "label": self.label.to_dict(),

View File

@ -12,11 +12,15 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin entity: ToolProviderEntityWithPlugin
tenant_id: str tenant_id: str
plugin_id: str plugin_id: str
plugin_unique_identifier: str
def __init__(self, entity: ToolProviderEntityWithPlugin, plugin_id: str, tenant_id: str) -> None: def __init__(
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
self.entity = entity self.entity = entity
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.plugin_id = plugin_id self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
@ -53,6 +57,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity, entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id), runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
) )
def get_tools(self) -> list[PluginTool]: def get_tools(self) -> list[PluginTool]:
@ -64,6 +70,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity, entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id), runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
) )
for tool_entity in self.entity.tools for tool_entity in self.entity.tools
] ]

View File

@ -11,11 +11,17 @@ from models.model import File
class PluginTool(Tool): class PluginTool(Tool):
tenant_id: str tenant_id: str
icon: str
plugin_unique_identifier: str
runtime_parameters: Optional[list[ToolParameter]] runtime_parameters: Optional[list[ToolParameter]]
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None: def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
) -> None:
super().__init__(entity, runtime) super().__init__(entity, runtime)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters = None self.runtime_parameters = None
def tool_provider_type(self) -> ToolProviderType: def tool_provider_type(self) -> ToolProviderType:
@ -64,6 +70,8 @@ class PluginTool(Tool):
entity=self.entity, entity=self.entity,
runtime=runtime, runtime=runtime,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
) )
def get_runtime_parameters(self) -> list[ToolParameter]: def get_runtime_parameters(self) -> list[ToolParameter]:

View File

@ -6,6 +6,9 @@ from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL
import contexts
from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin import GenericProviderID
from core.plugin.manager.tool import PluginToolManager from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@ -97,17 +100,27 @@ class ToolManager:
""" """
get the plugin provider get the plugin provider
""" """
with contexts.plugin_tool_providers_lock.get():
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
manager = PluginToolManager() manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider) provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity: if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found") raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
return PluginToolProviderController( controller = PluginToolProviderController(
entity=provider_entity.declaration, entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id, plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id, tenant_id=tenant_id,
) )
plugin_tool_providers[provider] = controller
return controller
@classmethod @classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
""" """
@ -132,7 +145,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
) -> Union[BuiltinTool, ApiTool, WorkflowTool]: ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
""" """
get the tool runtime get the tool runtime
@ -260,6 +273,8 @@ class ToolManager:
) )
elif provider_type == ToolProviderType.APP: elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented") raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
else: else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@ -477,6 +492,7 @@ class ToolManager:
PluginToolProviderController( PluginToolProviderController(
entity=provider.declaration, entity=provider.declaration,
plugin_id=provider.plugin_id, plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id, tenant_id=tenant_id,
) )
for provider in provider_entities for provider in provider_entities
@ -758,7 +774,66 @@ class ToolManager:
) )
@classmethod @classmethod
def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]: def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
return (
dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon"
)
@classmethod
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
return str(
URL(dify_config.CONSOLE_API_URL)
/ "console"
/ "api"
/ "workspaces"
/ "current"
/ "plugin"
/ "icon"
% {"tenant_id": tenant_id, "filename": filename}
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(workflow_provider.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
return json.loads(api_provider.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def get_tool_icon(
cls,
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> Union[str, dict]:
""" """
get the tool icon get the tool icon
@ -770,36 +845,25 @@ class ToolManager:
provider_type = provider_type provider_type = provider_type
provider_id = provider_id provider_id = provider_id
if provider_type == ToolProviderType.BUILT_IN: if provider_type == ToolProviderType.BUILT_IN:
return ( provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
dify_config.CONSOLE_API_URL if isinstance(provider, PluginToolProviderController):
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon"
)
elif provider_type == ToolProviderType.API:
try: try:
api_provider: ApiToolProvider | None = ( return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
if not api_provider:
raise ValueError("api tool not found")
return json.loads(api_provider.icon)
except: except:
return {"background": "#252525", "content": "\ud83d\ude01"} return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
elif provider_type == ToolProviderType.API:
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider: WorkflowToolProvider | None = ( return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
db.session.query(WorkflowToolProvider) elif provider_type == ToolProviderType.PLUGIN:
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
.first() if isinstance(provider, PluginToolProviderController):
) try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
if workflow_provider is None: except:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
return json.loads(workflow_provider.icon)
else: else:
raise ValueError(f"provider type {provider_type} not found") raise ValueError(f"provider type {provider_type} not found")

View File

@ -148,6 +148,7 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("variable not found") raise ValueError("variable not found")
return WorkflowTool( return WorkflowTool(
workflow_as_tool_id=db_provider.id,
entity=ToolEntity( entity=ToolEntity(
identity=ToolIdentity( identity=ToolIdentity(
author=user.name if user else "", author=user.name if user else "",

View File

@ -21,6 +21,7 @@ class WorkflowTool(Tool):
workflow_entities: dict[str, Any] workflow_entities: dict[str, Any]
workflow_call_depth: int workflow_call_depth: int
thread_pool_id: Optional[str] = None thread_pool_id: Optional[str] = None
workflow_as_tool_id: str
label: str label: str
@ -31,6 +32,7 @@ class WorkflowTool(Tool):
def __init__( def __init__(
self, self,
workflow_app_id: str, workflow_app_id: str,
workflow_as_tool_id: str,
version: str, version: str,
workflow_entities: dict[str, Any], workflow_entities: dict[str, Any],
workflow_call_depth: int, workflow_call_depth: int,
@ -40,6 +42,7 @@ class WorkflowTool(Tool):
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
): ):
self.workflow_app_id = workflow_app_id self.workflow_app_id = workflow_app_id
self.workflow_as_tool_id = workflow_as_tool_id
self.version = version self.version = version
self.workflow_entities = workflow_entities self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth self.workflow_call_depth = workflow_call_depth
@ -134,6 +137,7 @@ class WorkflowTool(Tool):
entity=self.entity.model_copy(), entity=self.entity.model_copy(),
runtime=runtime, runtime=runtime,
workflow_app_id=self.workflow_app_id, workflow_app_id=self.workflow_app_id,
workflow_as_tool_id=self.workflow_as_tool_id,
workflow_entities=self.workflow_entities, workflow_entities=self.workflow_entities,
workflow_call_depth=self.workflow_call_depth, workflow_call_depth=self.workflow_call_depth,
version=self.version, version=self.version,

View File

@ -147,8 +147,8 @@ class IterationRunStartedEvent(BaseIterationEvent):
class IterationRunNextEvent(BaseIterationEvent): class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index") index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") pre_iteration_output: Optional[Any] = None
duration: Optional[float] = Field(None, description="duration") duration: Optional[float] = None
class IterationRunSucceededEvent(BaseIterationEvent): class IterationRunSucceededEvent(BaseIterationEvent):

View File

@ -1,3 +1,4 @@
import contextvars
import logging import logging
import queue import queue
import time import time
@ -434,6 +435,7 @@ class GraphEngine:
**{ **{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined] "flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q, "q": q,
"context": contextvars.copy_context(),
"parallel_id": parallel_id, "parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id, "parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id, "parent_parallel_id": in_parallel_id,
@ -476,6 +478,7 @@ class GraphEngine:
def _run_parallel_node( def _run_parallel_node(
self, self,
flask_app: Flask, flask_app: Flask,
context: contextvars.Context,
q: queue.Queue, q: queue.Queue,
parallel_id: str, parallel_id: str,
parallel_start_node_id: str, parallel_start_node_id: str,
@ -485,6 +488,9 @@ class GraphEngine:
""" """
Run parallel nodes Run parallel nodes
""" """
for var, val in context.items():
var.set(val)
with flask_app.app_context(): with flask_app.app_context():
try: try:
q.put( q.put(

View File

@ -1,3 +1,4 @@
import contextvars
import logging import logging
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
@ -166,7 +167,8 @@ class IterationNode(BaseNode[IterationNodeData]):
for index, item in enumerate(iterator_list_value): for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit( future: Future = thread_pool.submit(
self._run_single_iter_parallel, self._run_single_iter_parallel,
current_app._get_current_object(), current_app._get_current_object(), # type: ignore
contextvars.copy_context(),
q, q,
iterator_list_value, iterator_list_value,
inputs, inputs,
@ -372,7 +374,10 @@ class IterationNode(BaseNode[IterationNodeData]):
try: try:
rst = graph_engine.run() rst = graph_engine.run()
# get current iteration index # get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value variable = variable_pool.get([self.node_id, "index"])
if variable is None:
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1 next_index = int(current_index) + 1
@ -540,6 +545,7 @@ class IterationNode(BaseNode[IterationNodeData]):
def _run_single_iter_parallel( def _run_single_iter_parallel(
self, self,
flask_app: Flask, flask_app: Flask,
context: contextvars.Context,
q: Queue, q: Queue,
iterator_list_value: list[str], iterator_list_value: list[str],
inputs: dict[str, list], inputs: dict[str, list],
@ -554,6 +560,8 @@ class IterationNode(BaseNode[IterationNodeData]):
""" """
run single iteration in parallel mode run single iteration in parallel mode
""" """
for var, val in context.items():
var.set(val)
with flask_app.app_context(): with flask_app.app_context():
parallel_mode_run_id = uuid.uuid4().hex parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy() graph_engine_copy = graph_engine.create_copy()

View File

@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str tool_name: str
tool_label: str # redundancy tool_label: str # redundancy
tool_configurations: dict[str, Any] tool_configurations: dict[str, Any]
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before") @field_validator("tool_configurations", mode="before")
@classmethod @classmethod

View File

@ -50,7 +50,11 @@ class ToolNode(BaseNode[ToolNodeData]):
node_data = cast(ToolNodeData, self.node_data) node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon # fetch tool icon
tool_info = {"provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id} tool_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
# get tool runtime # get tool runtime
try: try:

View File

@ -1,6 +1,9 @@
import threading
import pytz import pytz
from flask_login import current_user from flask_login import current_user
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from extensions.ext_database import db from extensions.ext_database import db
@ -14,6 +17,9 @@ class AgentService:
""" """
Service to get agent logs Service to get agent logs
""" """
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
conversation: Conversation = ( conversation: Conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter( .filter(

View File

@ -97,6 +97,7 @@ class ToolTransformService:
if isinstance(provider_controller, PluginToolProviderController): if isinstance(provider_controller, PluginToolProviderController):
result.plugin_id = provider_controller.plugin_id result.plugin_id = provider_controller.plugin_id
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema # get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
@ -173,6 +174,7 @@ class ToolTransformService:
masked_credentials={}, masked_credentials={},
is_team_authorization=True, is_team_authorization=True,
plugin_id=None, plugin_id=None,
plugin_unique_identifier=None,
tools=[], tools=[],
labels=labels or [], labels=labels or [],
) )
@ -214,6 +216,7 @@ class ToolTransformService:
), ),
type=ToolProviderType.API, type=ToolProviderType.API,
plugin_id=None, plugin_id=None,
plugin_unique_identifier=None,
masked_credentials={}, masked_credentials={},
is_team_authorization=True, is_team_authorization=True,
tools=[], tools=[],

View File

@ -1,3 +1,6 @@
import threading
import contexts
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
@ -117,6 +120,9 @@ class WorkflowRunService:
""" """
workflow_run = self.get_workflow_run(app_model, run_id) workflow_run = self.get_workflow_run(app_model, run_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if not workflow_run: if not workflow_run:
return [] return []