dify/api/controllers/inner_api/plugin/wraps.py

100 lines
2.9 KiB
Python
Raw Normal View History

2024-07-08 22:37:20 +08:00
from collections.abc import Callable
from functools import wraps
from typing import Optional
2024-07-29 16:40:04 +08:00
from flask import request
2024-07-08 22:37:20 +08:00
from flask_restful import reqparse
2024-07-29 16:40:04 +08:00
from pydantic import BaseModel
2024-07-08 22:37:20 +08:00
from extensions.ext_database import db
from models.account import Account, Tenant
from models.model import EndUser
from services.account_service import AccountService
2024-07-08 22:37:20 +08:00
def get_user(user_id: str | None) -> Account | EndUser:
try:
if not user_id:
user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER":
user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
else:
user_model = AccountService.load_user(user_id)
if not user_model:
user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
raise ValueError("user not found")
return user_model
def get_user_tenant(view: Optional[Callable] = None):
2024-07-08 22:37:20 +08:00
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
# fetch json body
parser = reqparse.RequestParser()
parser.add_argument("tenant_id", type=str, required=True, location="json")
parser.add_argument("user_id", type=str, required=True, location="json")
2024-07-08 22:37:20 +08:00
kwargs = parser.parse_args()
user_id = kwargs.get("user_id")
tenant_id = kwargs.get("tenant_id")
2024-07-08 22:37:20 +08:00
del kwargs["tenant_id"]
del kwargs["user_id"]
2024-07-08 22:37:20 +08:00
try:
tenant_model = (
db.session.query(Tenant)
.filter(
Tenant.id == tenant_id,
)
.first()
)
2024-07-08 22:37:20 +08:00
except Exception:
raise ValueError("tenant not found")
2024-07-08 22:37:20 +08:00
if not tenant_model:
raise ValueError("tenant not found")
2024-07-08 22:37:20 +08:00
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(user_id)
2024-07-08 22:37:20 +08:00
return view_func(*args, **kwargs)
2024-07-08 22:37:20 +08:00
return decorated_view
if view is None:
return decorator
else:
return decorator(view)
2024-07-29 16:40:04 +08:00
2024-09-14 02:47:01 +08:00
2024-07-29 16:40:04 +08:00
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
def decorator(view_func):
def decorated_view(*args, **kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")
2024-07-29 16:40:04 +08:00
try:
payload = payload_type(**data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
2024-07-29 16:40:04 +08:00
return view_func(*args, **kwargs)
2024-07-29 16:40:04 +08:00
return decorated_view
2024-07-29 16:40:04 +08:00
if view is None:
return decorator
else:
return decorator(view)