feat: support plugin permission management

This commit is contained in:
Yeuoly 2024-10-28 15:54:34 +08:00
parent 685e8cdc7d
commit c657378d06
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
6 changed files with 289 additions and 153 deletions

View File

@ -5,8 +5,7 @@ from datetime import datetime, timezone
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from sqlalchemy.orm import Session
from sqlalchemy import asc, desc
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -105,8 +104,7 @@ class GetProcessRuleApi(Resource):
rules = DocumentService.DEFAULT_RULES["rules"]
if document_id:
# get the latest process rule
with Session(db.engine) as session:
document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
document = Document.query.get_or_404(document_id)
dataset = DatasetService.get_dataset(document.dataset_id)
@ -169,77 +167,66 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
with Session(db.engine) as session:
query = session.query(Document).filter_by(
dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
if sort.startswith("-"):
sort_logic = desc
sort = sort[1:]
else:
sort_logic = asc
if sort == "hit_count":
sub_query = (
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id)
.subquery()
)
if search:
search = f"%{search}%"
query = query.filter(Document.name.like(search))
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == "created_at":
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
else:
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
if sort.startswith("-"):
sort_logic = desc
sort = sort[1:]
else:
sort_logic = asc
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
for document in documents:
completed_segments = DocumentSegment.query.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
).count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {
"data": data,
"has_more": len(documents) == limit,
"limit": limit,
"total": paginated_documents.total,
"page": page,
}
if sort == "hit_count":
sub_query = (
db.select(
DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")
)
.group_by(DocumentSegment.document_id)
.subquery()
)
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == "created_at":
query = query.order_by(
sort_logic(Document.created_at),
sort_logic(Document.position),
)
else:
query = query.order_by(
desc(Document.created_at),
desc(Document.position),
)
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
if fetch:
for document in documents:
completed_segments = (
session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
)
.count()
)
total_segments = (
session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else:
data = marshal(documents, document_fields)
response = {
"data": data,
"has_more": len(documents) == limit,
"limit": limit,
"total": paginated_documents.total,
"page": page,
}
return response
return response
documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}

View File

@ -0,0 +1,56 @@
from functools import wraps
from flask_login import current_user
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
user = current_user
tenant_id = user.current_tenant_id
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.filter(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
)
if not permission:
# no permission set, allow access for everyone
return view(*args, **kwargs)
if install_required:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)
return decorated
return interceptor

View File

@ -8,9 +8,12 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@ -18,12 +21,9 @@ class PluginDebuggingKeyApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
return {
"key": PluginService.get_debugging_key(tenant_id),
@ -37,8 +37,7 @@ class PluginListApi(Resource):
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
plugins = PluginService.list(tenant_id)
return jsonable_encoder({"plugins": plugins})
@ -57,32 +56,13 @@ class PluginIconApi(Resource):
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
class PluginUploadPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
file = request.files["pkg"]
content = file.read()
return jsonable_encoder(PluginService.upload_pkg(tenant_id, content))
class PluginUploadFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
file = request.files["pkg"]
@ -100,12 +80,9 @@ class PluginUploadFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@ -124,12 +101,9 @@ class PluginInstallFromPkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -149,12 +123,9 @@ class PluginInstallFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("repo", type=str, required=True, location="json")
@ -178,12 +149,9 @@ class PluginInstallFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@ -203,15 +171,14 @@ class PluginFetchManifestApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
user = current_user
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
tenant_id = user.current_tenant_id
return jsonable_encoder(
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args["plugin_unique_identifier"]).model_dump()}
)
@ -221,12 +188,9 @@ class PluginFetchInstallTasksApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, required=True, location="args")
@ -242,12 +206,9 @@ class PluginFetchInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def get(self, task_id: str):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
@ -256,12 +217,9 @@ class PluginDeleteInstallTaskApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
@ -270,12 +228,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self, task_id: str, identifier: str):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
@ -284,12 +239,9 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -307,12 +259,9 @@ class PluginUpgradeFromGithubApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@ -338,18 +287,62 @@ class PluginUninstallApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
args = req.parse_args()
tenant_id = current_user.current_tenant_id
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
class PluginChangePermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("install_permission", type=str, required=True, location="json")
req.add_argument("debug_permission", type=str, required=True, location="json")
args = req.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = user.current_tenant_id
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
class PluginFetchPermissionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
if not permission:
return jsonable_encoder(
{
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
)
return jsonable_encoder(
{
"install_permission": permission.install_permission,
"debug_permission": permission.debug_permission,
}
)
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
@ -368,3 +361,6 @@ api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<t
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

View File

@ -0,0 +1,37 @@
"""add_tenant_plugin_permisisons
Revision ID: 08ec4f75af5e
Revises: ddcc8bbef391
Create Date: 2024-10-28 07:20:39.711124
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '08ec4f75af5e'
down_revision = 'ddcc8bbef391'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_plugin_permissions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('install_permission', sa.String(length=16), server_default='everyone', nullable=False),
sa.Column('debug_permission', sa.String(length=16), server_default='noone', nullable=False),
sa.PrimaryKeyConstraint('id', name='account_plugin_permission_pkey'),
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('account_plugin_permissions')
# ### end Alembic commands ###

View File

@ -2,6 +2,7 @@ import enum
import json
from flask_login import UserMixin
from sqlalchemy.orm import Mapped, mapped_column
from extensions.ext_database import db
from models.base import Base
@ -260,3 +261,28 @@ class InvitationCode(db.Model):
used_by_account_id = db.Column(StringUUID)
deprecated_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TenantPluginPermission(Base):
class InstallPermission(str, enum.Enum):
EVERYONE = "everyone"
ADMINS = "admins"
NOBODY = "noone"
class DebugPermission(str, enum.Enum):
EVERYONE = "everyone"
ADMINS = "admins"
NOBODY = "noone"
__tablename__ = "account_plugin_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(
db.String(16), nullable=False, server_default="everyone"
)
debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone")

View File

@ -0,0 +1,34 @@
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import TenantPluginPermission
class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
@staticmethod
def change_permission(
tenant_id: str,
install_permission: TenantPluginPermission.InstallPermission,
debug_permission: TenantPluginPermission.DebugPermission,
):
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
)
if not permission:
permission = TenantPluginPermission(
tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
)
session.add(permission)
else:
permission.install_permission = install_permission
permission.debug_permission = debug_permission
session.commit()
return True