dify/api/services/plugin/plugin_migration.py

523 lines
20 KiB
Python
Raw Normal View History

2024-12-26 14:07:12 +08:00
import datetime
import json
2024-12-26 14:07:12 +08:00
import logging
2024-12-31 16:38:02 +08:00
import sys
import time
2024-12-31 16:40:26 +08:00
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional
2024-12-31 16:38:02 +08:00
from uuid import uuid4
2024-12-26 14:07:12 +08:00
import click
2024-12-31 16:38:02 +08:00
import tqdm
2024-12-26 22:30:22 +08:00
from flask import Flask, current_app
2024-12-26 14:07:12 +08:00
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.entities import DEFAULT_PLUGIN_ID
2024-12-31 16:38:02 +08:00
from core.helper import marketplace
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.manager.plugin import PluginInstallationManager
2024-12-26 14:07:12 +08:00
from core.tools.entities.tool_entities import ToolProviderType
from models.account import Tenant
from models.engine import db
from models.model import App, AppMode, AppModelConfig
from models.tools import BuiltinToolProvider
from models.workflow import Workflow
logger = logging.getLogger(__name__)
excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
2024-12-26 22:30:06 +08:00
def extract_plugins(cls, filepath: str, workers: int) -> None:
2024-12-26 14:07:12 +08:00
"""
Migrate plugin.
"""
2024-12-26 22:30:06 +08:00
import concurrent.futures
from threading import Lock
2024-12-26 14:07:12 +08:00
click.echo(click.style("Migrating models/tools to new plugin Mechanism", fg="white"))
ended_at = datetime.datetime.now()
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
current_time = started_at
2024-12-26 17:39:13 +08:00
with Session(db.engine) as session:
total_tenant_count = session.query(Tenant.id).count()
2024-12-26 22:30:06 +08:00
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
2024-12-26 17:39:13 +08:00
handled_tenant_count = 0
2024-12-26 22:30:06 +08:00
file_lock = Lock()
counter_lock = Lock()
thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
with flask_app.app_context():
nonlocal handled_tenant_count
try:
plugins = cls.extract_installed_plugin_ids(tenant_id)
# Use lock when writing to file
with file_lock:
with open(filepath, "a") as f:
f.write(json.dumps({"tenant_id": tenant_id, "plugins": plugins}) + "\n")
# Use lock when updating counter
with counter_lock:
nonlocal handled_tenant_count
handled_tenant_count += 1
click.echo(
click.style(
f"[{datetime.datetime.now()}] "
f"Processed {handled_tenant_count} tenants "
f"({(handled_tenant_count / total_tenant_count) * 100:.1f}%), "
f"{handled_tenant_count}/{total_tenant_count}",
fg="green",
)
)
except Exception:
logger.exception(f"Failed to process tenant {tenant_id}")
2024-12-26 17:39:13 +08:00
2024-12-26 22:58:34 +08:00
futures = []
2024-12-26 14:07:12 +08:00
while current_time < ended_at:
2024-12-26 22:30:06 +08:00
click.echo(click.style(f"Current time: {current_time}, Started at: {datetime.datetime.now()}", fg="white"))
2024-12-26 14:07:12 +08:00
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
interval = datetime.timedelta(days=1)
# Process tenants in this batch
with Session(db.engine) as session:
# Calculate tenant count in next batch with current interval
# Try different intervals until we find one with a reasonable tenant count
test_intervals = [
datetime.timedelta(days=1),
datetime.timedelta(hours=12),
datetime.timedelta(hours=6),
datetime.timedelta(hours=3),
datetime.timedelta(hours=1),
]
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
interval = test_interval
break
else:
# If all intervals have too many tenants, use minimum interval
interval = datetime.timedelta(hours=1)
# Adjust interval to target ~100 tenants per batch
if tenant_count > 0:
# Scale interval based on ratio to target count
interval = min(
datetime.timedelta(days=1), # Max 1 day
max(
datetime.timedelta(hours=1), # Min 1 hour
interval * (100 / tenant_count), # Scale to target 100
),
)
batch_end = min(current_time + interval, ended_at)
rs = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
tenants = []
for row in rs:
tenant_id = str(row.id)
try:
tenants.append(tenant_id)
except Exception:
logger.exception(f"Failed to process tenant {tenant_id}")
continue
2024-12-26 22:58:34 +08:00
futures.append(
thread_pool.submit(
process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined]
tenant_id,
)
)
2024-12-26 17:39:13 +08:00
2024-12-26 14:07:12 +08:00
current_time = batch_end
2024-12-26 22:30:06 +08:00
# wait for all threads to finish
2024-12-26 22:58:34 +08:00
for future in futures:
future.result()
2024-12-26 22:30:06 +08:00
2024-12-26 14:07:12 +08:00
@classmethod
def extract_installed_plugin_ids(cls, tenant_id: str) -> Sequence[str]:
"""
Extract installed plugin ids.
"""
tools = cls.extract_tool_tables(tenant_id)
models = cls.extract_model_tables(tenant_id)
workflows = cls.extract_workflow_tables(tenant_id)
apps = cls.extract_app_tables(tenant_id)
return list({*tools, *models, *workflows, *apps})
@classmethod
def extract_model_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract model tables.
"""
2025-01-09 16:53:30 +08:00
models: list[str] = []
2024-12-26 14:07:12 +08:00
table_pairs = [
("providers", "provider_name"),
("provider_models", "provider_name"),
("provider_orders", "provider_name"),
("tenant_default_models", "provider_name"),
("tenant_preferred_model_providers", "provider_name"),
("provider_model_settings", "provider_name"),
("load_balancing_model_configs", "provider_name"),
]
for table, column in table_pairs:
models.extend(cls.extract_model_table(tenant_id, table, column))
# duplicate models
models = list(set(models))
return models
@classmethod
def extract_model_table(cls, tenant_id: str, table: str, column: str) -> Sequence[str]:
"""
Extract model table.
"""
with Session(db.engine) as session:
rs = session.execute(
db.text(f"SELECT DISTINCT {column} FROM {table} WHERE tenant_id = :tenant_id"), {"tenant_id": tenant_id}
)
result = []
for row in rs:
provider_name = str(row[0])
if provider_name and "/" not in provider_name:
if provider_name == "google":
provider_name = "gemini"
2024-12-31 16:38:02 +08:00
result.append(DEFAULT_PLUGIN_ID + "/" + provider_name)
2024-12-26 14:07:12 +08:00
elif provider_name:
result.append(provider_name)
return result
@classmethod
def extract_tool_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract tool tables.
"""
with Session(db.engine) as session:
rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
result = []
for row in rs:
if "/" not in row.provider:
2024-12-31 16:38:02 +08:00
result.append(DEFAULT_PLUGIN_ID + "/" + row.provider)
2024-12-26 14:07:12 +08:00
else:
result.append(row.provider)
return result
@classmethod
def _handle_builtin_tool_provider(cls, provider_name: str) -> str:
"""
Handle builtin tool provider.
"""
if provider_name == "jina":
provider_name = "jina_tool"
elif provider_name == "siliconflow":
provider_name = "siliconflow_tool"
elif provider_name == "stepfun":
provider_name = "stepfun_tool"
if "/" not in provider_name:
2024-12-31 16:38:02 +08:00
return DEFAULT_PLUGIN_ID + "/" + provider_name
2024-12-26 14:07:12 +08:00
else:
return provider_name
@classmethod
def extract_workflow_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract workflow tables, only ToolNode is required.
"""
with Session(db.engine) as session:
rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
result = []
for row in rs:
graph = row.graph_dict
# get nodes
nodes = graph.get("nodes", [])
for node in nodes:
data = node.get("data", {})
if data.get("type") == "tool":
provider_name = data.get("provider_name")
provider_type = data.get("provider_type")
if provider_name not in excluded_providers and provider_type == ToolProviderType.BUILT_IN.value:
provider_name = cls._handle_builtin_tool_provider(provider_name)
result.append(provider_name)
return result
@classmethod
def extract_app_tables(cls, tenant_id: str) -> Sequence[str]:
"""
Extract app tables.
"""
with Session(db.engine) as session:
apps = session.query(App).filter(App.tenant_id == tenant_id).all()
if not apps:
return []
agent_app_model_config_ids = [
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
]
rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
result = []
for row in rs:
agent_config = row.agent_mode_dict
if "tools" in agent_config and isinstance(agent_config["tools"], list):
for tool in agent_config["tools"]:
if isinstance(tool, dict):
try:
tool_entity = AgentToolEntity(**tool)
if (
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
and tool_entity.provider_id not in excluded_providers
):
result.append(cls._handle_builtin_tool_provider(tool_entity.provider_id))
except Exception:
logger.exception(f"Failed to process tool {tool}")
continue
return result
2024-12-31 16:38:02 +08:00
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
"""
Fetch plugin unique identifier using plugin id.
"""
plugin_manifest = marketplace.batch_fetch_plugin_manifests([plugin_id])
if not plugin_manifest:
return None
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
"""
Extract unique plugins.
"""
Path(output_file).write_text(json.dumps(cls.extract_unique_plugins(extracted_plugins)))
@classmethod
def extract_unique_plugins(cls, extracted_plugins: str) -> Mapping[str, Any]:
plugins: dict[str, str] = {}
plugin_ids = []
plugin_not_exist = []
logger.info(f"Extracting unique plugins from {extracted_plugins}")
with open(extracted_plugins) as f:
for line in f:
data = json.loads(line)
new_plugin_ids = data.get("plugins", [])
for plugin_id in new_plugin_ids:
if plugin_id not in plugin_ids:
plugin_ids.append(plugin_id)
def fetch_plugin(plugin_id):
unique_identifier = cls._fetch_plugin_unique_identifier(plugin_id)
if unique_identifier:
plugins[plugin_id] = unique_identifier
else:
plugin_not_exist.append(plugin_id)
with ThreadPoolExecutor(max_workers=10) as executor:
list(tqdm.tqdm(executor.map(fetch_plugin, plugin_ids), total=len(plugin_ids)))
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str) -> None:
"""
Install plugins.
"""
manager = PluginInstallationManager()
plugins = cls.extract_unique_plugins(extracted_plugins)
not_installed = []
plugin_install_failed = []
# use a fake tenant id to install all the plugins
fake_tenant_id = uuid4().hex
logger.info(f"Installing {len(plugins['plugins'])} plugin instances for fake tenant {fake_tenant_id}")
thread_pool = ThreadPoolExecutor(max_workers=40)
response = cls.handle_plugin_instance_install(fake_tenant_id, plugins["plugins"])
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
logger.info(f"Installing {len(plugin_ids)} plugins for tenant {tenant_id}")
# at most 64 plugins one batch
for i in range(0, len(plugin_ids), 64):
batch_plugin_ids = plugin_ids[i : i + 64]
batch_plugin_identifiers = [plugins["plugins"][plugin_id] for plugin_id in batch_plugin_ids]
manager.install_from_identifiers(
tenant_id,
batch_plugin_identifiers,
PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
2024-12-31 16:40:26 +08:00
with open(extracted_plugins) as f:
2024-12-31 16:38:02 +08:00
"""
Read line by line, and install plugins for each tenant.
"""
for line in f:
data = json.loads(line)
tenant_id = data.get("tenant_id")
plugin_ids = data.get("plugins", [])
current_not_installed = {
"tenant_id": tenant_id,
"plugin_not_exist": [],
}
# get plugin unique identifier
for plugin_id in plugin_ids:
unique_identifier = plugins.get(plugin_id)
if unique_identifier:
current_not_installed["plugin_not_exist"].append(plugin_id)
if current_not_installed["plugin_not_exist"]:
not_installed.append(current_not_installed)
thread_pool.submit(install, tenant_id, plugin_ids)
thread_pool.shutdown(wait=True)
logger.info("Uninstall plugins")
sys.exit(-1)
# get installation
try:
installation = manager.list_plugins(fake_tenant_id)
while installation:
for plugin in installation:
manager.uninstall(fake_tenant_id, plugin.installation_id)
installation = manager.list_plugins(fake_tenant_id)
except Exception:
logger.exception(f"Failed to get installation for tenant {fake_tenant_id}")
Path(output_file).write_text(
json.dumps(
{
"not_installed": not_installed,
"plugin_install_failed": plugin_install_failed,
}
)
)
@classmethod
def handle_plugin_instance_install(
cls, tenant_id: str, plugin_identifiers_map: Mapping[str, str]
) -> Mapping[str, Any]:
"""
Install plugins for a tenant.
"""
manager = PluginInstallationManager()
# download all the plugins and upload
thread_pool = ThreadPoolExecutor(max_workers=10)
futures = []
for plugin_id, plugin_identifier in plugin_identifiers_map.items():
def download_and_upload(tenant_id, plugin_id, plugin_identifier):
plugin_package = marketplace.download_plugin_pkg(plugin_identifier)
if not plugin_package:
raise Exception(f"Failed to download plugin {plugin_identifier}")
# upload
manager.upload_pkg(tenant_id, plugin_package, verify_signature=True)
futures.append(thread_pool.submit(download_and_upload, tenant_id, plugin_id, plugin_identifier))
# Wait for all downloads to complete
for future in futures:
future.result() # This will raise any exceptions that occurred
thread_pool.shutdown(wait=True)
success = []
failed = []
reverse_map = {v: k for k, v in plugin_identifiers_map.items()}
# at most 64 plugins one batch
for i in range(0, len(plugin_identifiers_map), 64):
batch_plugin_ids = list(plugin_identifiers_map.keys())[i : i + 64]
batch_plugin_identifiers = [plugin_identifiers_map[plugin_id] for plugin_id in batch_plugin_ids]
try:
response = manager.install_from_identifiers(
tenant_id=tenant_id,
identifiers=batch_plugin_identifiers,
source=PluginInstallationSource.Marketplace,
metas=[
{
"plugin_unique_identifier": identifier,
}
for identifier in batch_plugin_identifiers
],
)
except Exception:
# add to failed
failed.extend(batch_plugin_identifiers)
continue
if response.all_installed:
success.extend(batch_plugin_identifiers)
continue
task_id = response.task_id
done = False
while not done:
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
for plugin in status.plugins:
if plugin.status == PluginInstallTaskStatus.Success:
success.append(reverse_map[plugin.plugin_unique_identifier])
else:
failed.append(reverse_map[plugin.plugin_unique_identifier])
logger.error(
f"Failed to install plugin {plugin.plugin_unique_identifier}, error: {plugin.message}"
)
done = True
else:
time.sleep(1)
return {"success": success, "failed": failed}