feat: validate credentials

This commit is contained in:
Yeuoly 2024-09-23 21:13:02 +08:00
parent 7a3e756020
commit 947bfdc807
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
3 changed files with 56 additions and 16 deletions

View File

@ -35,4 +35,11 @@ class PluginToolProviderEntity(BaseModel):
provider: str
plugin_unique_identifier: str
plugin_id: str
declaration: ToolProviderEntityWithPlugin
declaration: ToolProviderEntityWithPlugin
class PluginBasicBooleanResponse(BaseModel):
"""
Basic boolean response from plugin daemon.
"""
result: bool

View File

@ -1,5 +1,5 @@
import json
from collections.abc import Generator
from collections.abc import Callable, Generator
from typing import TypeVar
import requests
@ -21,7 +21,7 @@ class BasePluginManager:
method: str,
path: str,
headers: dict | None = None,
data: bytes | dict | None = None,
data: bytes | dict | str | None = None,
params: dict | None = None,
stream: bool = False,
) -> requests.Response:
@ -31,6 +31,10 @@ class BasePluginManager:
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
headers = headers or {}
headers["X-Api-Key"] = plugin_daemon_inner_api_key
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
data = json.dumps(data)
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream
)
@ -48,7 +52,11 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API
"""
response = self._request(method, path, headers, data, params, stream=True)
yield from response.iter_lines()
for line in response.iter_lines():
line = line.decode("utf-8").strip()
if line.startswith("data:"):
line = line[5:].strip()
yield line
def _stream_request_with_model(
self,
@ -88,17 +96,15 @@ class BasePluginManager:
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
transformer: Callable[[dict], dict] | None = None,
) -> T:
"""
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params)
json_response = response.json()
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response)
if rep.code != 0:
@ -128,3 +134,4 @@ class BasePluginManager:
if rep.data is None:
raise ValueError("got empty data from plugin daemon")
yield rep.data

View File

@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.manager.base import BasePluginManager
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -11,8 +11,22 @@ class PluginToolManager(BasePluginManager):
"""
Fetch tool providers for the given asset.
"""
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET", f"plugin/{tenant_id}/tools", list[PluginToolProviderEntity], params={"page": 1, "page_size": 256}
"GET",
f"plugin/{tenant_id}/management/tools",
list[PluginToolProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
return response
@ -28,7 +42,7 @@ class PluginToolManager(BasePluginManager):
) -> Generator[ToolInvokeMessage, None, None]:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/tool/invoke",
f"plugin/{tenant_id}/dispatch/tool/invoke",
ToolInvokeMessage,
data={
"plugin_unique_identifier": plugin_unique_identifier,
@ -40,6 +54,10 @@ class PluginToolManager(BasePluginManager):
"tool_parameters": tool_parameters,
},
},
headers={
"X-Plugin-Identifier": plugin_unique_identifier,
"Content-Type": "application/json",
}
)
return response
@ -49,10 +67,10 @@ class PluginToolManager(BasePluginManager):
"""
validate the credentials of the provider
"""
response = self._request_with_plugin_daemon_response(
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/tool/validate_credentials",
bool,
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
PluginBasicBooleanResponse,
data={
"plugin_unique_identifier": plugin_unique_identifier,
"user_id": user_id,
@ -61,5 +79,13 @@ class PluginToolManager(BasePluginManager):
"credentials": credentials,
},
},
headers={
"X-Plugin-Identifier": plugin_unique_identifier,
"Content-Type": "application/json",
}
)
return response
for resp in response:
return resp.result
return False