diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 85380b7330..b8354fa012 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -1,9 +1,14 @@ from contextvars import ContextVar +from threading import Lock from typing import TYPE_CHECKING if TYPE_CHECKING: + from core.tools.plugin_tool.provider import PluginToolProviderController from core.workflow.entities.variable_pool import VariablePool tenant_id: ContextVar[str] = ContextVar("tenant_id") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") + +plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") +plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5287b9a714..386dedf798 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -15,6 +15,7 @@ class AgentToolEntity(BaseModel): provider_id: str tool_name: str tool_parameters: dict[str, Any] = {} + plugin_unique_identifier: str | None = None class AgentPromptEntity(BaseModel): diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 660527487e..b4024defba 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -152,6 +152,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_run_id=workflow_run_id, ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) return self._generate( workflow=workflow, @@ -201,6 +203,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) return self._generate( workflow=workflow, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 3789166c92..2bf696cbe0 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -170,7 +171,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): worker_thread = threading.Thread( target=self._generate_worker, kwargs={ - "flask_app": current_app._get_current_object(), + "flask_app": current_app._get_current_object(), # type: ignore + "contexts": contextvars.copy_context(), "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -195,6 +197,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): def _generate_worker( self, flask_app: Flask, + context: contextvars.Context, application_generate_entity: AgentChatAppGenerateEntity, queue_manager: AppQueueManager, conversation_id: str, @@ -209,6 +212,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): :param message_id: message ID :return: """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): try: # get conversation and message diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index a652d205d3..c32bf84ac8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -119,7 +119,10 @@ class WorkflowAppGenerator(BaseAppGenerator): trace_manager=trace_manager, workflow_run_id=workflow_run_id, ) + contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) return self._generate( app_model=app_model, @@ -223,6 +226,8 @@ class WorkflowAppGenerator(BaseAppGenerator): ), ) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) return self._generate( app_model=app_model, diff --git a/api/core/plugin/manager/plugin.py b/api/core/plugin/manager/plugin.py index 14fba6c989..c96e6c621b 100644 --- a/api/core/plugin/manager/plugin.py +++ b/api/core/plugin/manager/plugin.py @@ -1,7 +1,5 @@ from collections.abc import Sequence -from pydantic import BaseModel - from core.plugin.entities.bundle import PluginBundleDependency from core.plugin.entities.plugin import ( GenericProviderID, @@ -152,15 +150,12 @@ class PluginInstallationManager(BasePluginManager): Fetch a plugin manifest. """ - class PluginDeclarationResponse(BaseModel): - declaration: PluginDeclaration - return self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/fetch/manifest", - PluginDeclarationResponse, + PluginDeclaration, params={"plugin_unique_identifier": plugin_unique_identifier}, - ).declaration + ) def fetch_plugin_installation_by_ids( self, tenant_id: str, plugin_ids: Sequence[str] diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 15635b4f25..78949f8b1a 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -74,6 +74,7 @@ class BuiltinToolProviderController(ToolProviderController): tool["identity"]["provider"] = provider tools.append( assistant_tool_class( + provider=provider, entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""), ) diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index abba542b8e..a51813ba40 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,6 +1,7 @@ from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils @@ -19,6 +20,25 @@ class BuiltinTool(Tool): :param meta: the meta data of a tool call processing """ + provider: str + + def __init__(self, provider: str, **kwargs): + super().__init__(**kwargs) + self.provider = provider + + def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool": + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + entity=self.entity.model_copy(), + runtime=runtime, + provider=self.provider, + ) + def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult: """ invoke model diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py index 9c9f26b60a..457954c961 100644 --- a/api/core/tools/custom_tool/provider.py +++ b/api/core/tools/custom_tool/provider.py @@ -109,6 +109,7 @@ class ApiToolProviderController(ToolProviderController): """ return ApiTool( api_bundle=tool_bundle, + provider_id=self.provider_id, entity=ToolEntity( identity=ToolIdentity( author=tool_bundle.author, diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index e8594b5847..7587b00b41 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -22,14 +22,16 @@ API_TOOL_DEFAULT_TIMEOUT = ( class ApiTool(Tool): api_bundle: ApiToolBundle + provider_id: str """ Api tool """ - def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime): + def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str): super().__init__(entity, runtime) self.api_bundle = api_bundle + self.provider_id = provider_id def fork_tool_runtime(self, runtime: ToolRuntime): """ @@ -42,6 +44,7 @@ class ApiTool(Tool): entity=self.entity, api_bundle=self.api_bundle.model_copy(), runtime=runtime, + provider_id=self.provider_id, ) def validate_credentials( diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2d8af8cabe..b96c994cff 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -34,6 +34,7 @@ class ToolProviderApiEntity(BaseModel): is_team_authorization: bool = False allow_delete: bool = True plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool") + plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool") tools: list[ToolApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) @@ -58,6 +59,7 @@ class ToolProviderApiEntity(BaseModel): "author": self.author, "name": self.name, "plugin_id": self.plugin_id, + "plugin_unique_identifier": self.plugin_unique_identifier, "description": self.description.to_dict(), "icon": self.icon, "label": self.label.to_dict(), diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index d285576996..6a3c701a59 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -12,11 +12,15 @@ class PluginToolProviderController(BuiltinToolProviderController): entity: ToolProviderEntityWithPlugin tenant_id: str plugin_id: str + plugin_unique_identifier: str - def __init__(self, entity: ToolProviderEntityWithPlugin, plugin_id: str, tenant_id: str) -> None: + def __init__( + self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str + ) -> None: self.entity = entity self.tenant_id = tenant_id self.plugin_id = plugin_id + self.plugin_unique_identifier = plugin_unique_identifier @property def provider_type(self) -> ToolProviderType: @@ -53,6 +57,8 @@ class PluginToolProviderController(BuiltinToolProviderController): entity=tool_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, ) def get_tools(self) -> list[PluginTool]: @@ -64,6 +70,8 @@ class PluginToolProviderController(BuiltinToolProviderController): entity=tool_entity, runtime=ToolRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, + icon=self.entity.identity.icon, + plugin_unique_identifier=self.plugin_unique_identifier, ) for tool_entity in self.entity.tools ] diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 4e5c65ab94..559829bbb0 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -11,11 +11,17 @@ from models.model import File class PluginTool(Tool): tenant_id: str + icon: str + plugin_unique_identifier: str runtime_parameters: Optional[list[ToolParameter]] - def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None: + def __init__( + self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str + ) -> None: super().__init__(entity, runtime) self.tenant_id = tenant_id + self.icon = icon + self.plugin_unique_identifier = plugin_unique_identifier self.runtime_parameters = None def tool_provider_type(self) -> ToolProviderType: @@ -64,6 +70,8 @@ class PluginTool(Tool): entity=self.entity, runtime=runtime, tenant_id=self.tenant_id, + icon=self.icon, + plugin_unique_identifier=self.plugin_unique_identifier, ) def get_runtime_parameters(self) -> list[ToolParameter]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d4a6878fcd..9e6a3dbc6d 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,6 +6,9 @@ from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Union, cast +from yarl import URL + +import contexts from core.plugin.entities.plugin import GenericProviderID from core.plugin.manager.tool import PluginToolManager from core.tools.__base.tool_runtime import ToolRuntime @@ -97,16 +100,26 @@ class ToolManager: """ get the plugin provider """ - manager = PluginToolManager() - provider_entity = manager.fetch_tool_provider(tenant_id, provider) - if not provider_entity: - raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + with contexts.plugin_tool_providers_lock.get(): + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] - return PluginToolProviderController( - entity=provider_entity.declaration, - plugin_id=provider_entity.plugin_id, - tenant_id=tenant_id, - ) + manager = PluginToolManager() + provider_entity = manager.fetch_tool_provider(tenant_id, provider) + if not provider_entity: + raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + + controller = PluginToolProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + + plugin_tool_providers[provider] = controller + + return controller @classmethod def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: @@ -132,7 +145,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, ApiTool, WorkflowTool]: + ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: """ get the tool runtime @@ -260,6 +273,8 @@ class ToolManager: ) elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") + elif provider_type == ToolProviderType.PLUGIN: + return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -477,6 +492,7 @@ class ToolManager: PluginToolProviderController( entity=provider.declaration, plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, tenant_id=tenant_id, ) for provider in provider_entities @@ -758,7 +774,66 @@ class ToolManager: ) @classmethod - def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]: + def generate_builtin_tool_icon_url(cls, provider_id: str) -> str: + return ( + dify_config.CONSOLE_API_URL + + "/console/api/workspaces/current/tool-provider/builtin/" + + provider_id + + "/icon" + ) + + @classmethod + def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str: + return str( + URL(dify_config.CONSOLE_API_URL) + / "console" + / "api" + / "workspaces" + / "current" + / "plugin" + / "icon" + % {"tenant_id": tenant_id, "filename": filename} + ) + + @classmethod + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + try: + workflow_provider: WorkflowToolProvider | None = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) + .first() + ) + + if workflow_provider is None: + raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") + + return json.loads(workflow_provider.icon) + except: + return {"background": "#252525", "content": "\ud83d\ude01"} + + @classmethod + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: + try: + api_provider: ApiToolProvider | None = ( + db.session.query(ApiToolProvider) + .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) + .first() + ) + + if api_provider is None: + raise ToolProviderNotFoundError(f"api provider {provider_id} not found") + + return json.loads(api_provider.icon) + except: + return {"background": "#252525", "content": "\ud83d\ude01"} + + @classmethod + def get_tool_icon( + cls, + tenant_id: str, + provider_type: ToolProviderType, + provider_id: str, + ) -> Union[str, dict]: """ get the tool icon @@ -770,36 +845,25 @@ class ToolManager: provider_type = provider_type provider_id = provider_id if provider_type == ToolProviderType.BUILT_IN: - return ( - dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/" - + provider_id - + "/icon" - ) + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except: + return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_builtin_tool_icon_url(provider_id) elif provider_type == ToolProviderType.API: - try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() - ) - if not api_provider: - raise ValueError("api tool not found") - - return json.loads(api_provider.icon) - except: - return {"background": "#252525", "content": "\ud83d\ude01"} + return cls.generate_api_tool_icon_url(tenant_id, provider_id) elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() - ) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - return json.loads(workflow_provider.icon) + return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) + elif provider_type == ToolProviderType.PLUGIN: + provider = ToolManager.get_builtin_provider(provider_id, tenant_id) + if isinstance(provider, PluginToolProviderController): + try: + return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) + except: + return {"background": "#252525", "content": "\ud83d\ude01"} + raise ValueError(f"plugin provider {provider_id} not found") else: raise ValueError(f"provider type {provider_type} not found") diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 57bf5b0755..e80e286e12 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -148,6 +148,7 @@ class WorkflowToolProviderController(ToolProviderController): raise ValueError("variable not found") return WorkflowTool( + workflow_as_tool_id=db_provider.id, entity=ToolEntity( identity=ToolIdentity( author=user.name if user else "", diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index b49f6faa50..2521c87372 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -21,6 +21,7 @@ class WorkflowTool(Tool): workflow_entities: dict[str, Any] workflow_call_depth: int thread_pool_id: Optional[str] = None + workflow_as_tool_id: str label: str @@ -31,6 +32,7 @@ class WorkflowTool(Tool): def __init__( self, workflow_app_id: str, + workflow_as_tool_id: str, version: str, workflow_entities: dict[str, Any], workflow_call_depth: int, @@ -40,6 +42,7 @@ class WorkflowTool(Tool): thread_pool_id: Optional[str] = None, ): self.workflow_app_id = workflow_app_id + self.workflow_as_tool_id = workflow_as_tool_id self.version = version self.workflow_entities = workflow_entities self.workflow_call_depth = workflow_call_depth @@ -134,6 +137,7 @@ class WorkflowTool(Tool): entity=self.entity.model_copy(), runtime=runtime, workflow_app_id=self.workflow_app_id, + workflow_as_tool_id=self.workflow_as_tool_id, workflow_entities=self.workflow_entities, workflow_call_depth=self.workflow_call_depth, version=self.version, diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 3736e632c3..e95522a36f 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -147,8 +147,8 @@ class IterationRunStartedEvent(BaseIterationEvent): class IterationRunNextEvent(BaseIterationEvent): index: int = Field(..., description="index") - pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") - duration: Optional[float] = Field(None, description="duration") + pre_iteration_output: Optional[Any] = None + duration: Optional[float] = None class IterationRunSucceededEvent(BaseIterationEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 60a5901b21..21ee0f22e1 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -1,3 +1,4 @@ +import contextvars import logging import queue import time @@ -434,6 +435,7 @@ class GraphEngine: **{ "flask_app": current_app._get_current_object(), # type: ignore[attr-defined] "q": q, + "context": contextvars.copy_context(), "parallel_id": parallel_id, "parallel_start_node_id": edge.target_node_id, "parent_parallel_id": in_parallel_id, @@ -476,6 +478,7 @@ class GraphEngine: def _run_parallel_node( self, flask_app: Flask, + context: contextvars.Context, q: queue.Queue, parallel_id: str, parallel_start_node_id: str, @@ -485,6 +488,9 @@ class GraphEngine: """ Run parallel nodes """ + for var, val in context.items(): + var.set(val) + with flask_app.app_context(): try: q.put( diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 22f242a42f..1a5aa49809 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,3 +1,4 @@ +import contextvars import logging import uuid from collections.abc import Generator, Mapping, Sequence @@ -166,7 +167,8 @@ class IterationNode(BaseNode[IterationNodeData]): for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( self._run_single_iter_parallel, - current_app._get_current_object(), + current_app._get_current_object(), # type: ignore + contextvars.copy_context(), q, iterator_list_value, inputs, @@ -372,7 +374,10 @@ class IterationNode(BaseNode[IterationNodeData]): try: rst = graph_engine.run() # get current iteration index - current_index = variable_pool.get([self.node_id, "index"]).value + variable = variable_pool.get([self.node_id, "index"]) + if variable is None: + raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") + current_index = variable.value iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" next_index = int(current_index) + 1 @@ -540,6 +545,7 @@ class IterationNode(BaseNode[IterationNodeData]): def _run_single_iter_parallel( self, flask_app: Flask, + context: contextvars.Context, q: Queue, iterator_list_value: list[str], inputs: dict[str, list], @@ -554,6 +560,8 @@ class IterationNode(BaseNode[IterationNodeData]): """ run single iteration in parallel mode """ + for var, val in context.items(): + var.set(val) with flask_app.app_context(): parallel_mode_run_id = uuid.uuid4().hex graph_engine_copy = graph_engine.create_copy() diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index a3eed8fa5b..21023d4ab7 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 79e0073df6..7354086b03 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -50,7 +50,11 @@ class ToolNode(BaseNode[ToolNodeData]): node_data = cast(ToolNodeData, self.node_data) # fetch tool icon - tool_info = {"provider_type": node_data.provider_type.value, "provider_id": node_data.provider_id} + tool_info = { + "provider_type": node_data.provider_type.value, + "provider_id": node_data.provider_id, + "plugin_unique_identifier": node_data.plugin_unique_identifier, + } # get tool runtime try: diff --git a/api/services/agent_service.py b/api/services/agent_service.py index c8819535f1..bc64316faa 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -1,6 +1,9 @@ +import threading + import pytz from flask_login import current_user +import contexts from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager from core.tools.tool_manager import ToolManager from extensions.ext_database import db @@ -14,6 +17,9 @@ class AgentService: """ Service to get agent logs """ + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + conversation: Conversation = ( db.session.query(Conversation) .filter( diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6df52891f1..621100d858 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -97,6 +97,7 @@ class ToolTransformService: if isinstance(provider_controller, PluginToolProviderController): result.plugin_id = provider_controller.plugin_id + result.plugin_unique_identifier = provider_controller.plugin_unique_identifier # get credentials schema schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} @@ -173,6 +174,7 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, plugin_id=None, + plugin_unique_identifier=None, tools=[], labels=labels or [], ) @@ -214,6 +216,7 @@ class ToolTransformService: ), type=ToolProviderType.API, plugin_id=None, + plugin_unique_identifier=None, masked_credentials={}, is_team_authorization=True, tools=[], diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index d8ee323908..8cfa29ee4d 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,3 +1,6 @@ +import threading + +import contexts from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom @@ -117,6 +120,9 @@ class WorkflowRunService: """ workflow_run = self.get_workflow_run(app_model, run_id) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + if not workflow_run: return []