Merge branch 'langgenius:main' into main
This commit is contained in:
commit
617fec0dad
13
.github/workflows/api-tests.yml
vendored
13
.github/workflows/api-tests.yml
vendored
@ -8,6 +8,9 @@ on:
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
@ -37,10 +40,10 @@ jobs:
|
||||
with:
|
||||
packages: ffmpeg
|
||||
|
||||
- name: Set up Python
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
./api/requirements.txt
|
||||
@ -50,10 +53,10 @@ jobs:
|
||||
run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
||||
run: dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: pytest api/tests/integration_tests/tools/test_all_provider.py
|
||||
run: dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run Workflow
|
||||
run: pytest api/tests/integration_tests/workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -24,11 +24,14 @@ jobs:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Python dependencies
|
||||
run: pip install ruff
|
||||
run: pip install ruff dotenv-linter
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check ./api
|
||||
|
||||
- name: Dotenv check
|
||||
run: dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
12
README.md
12
README.md
@ -29,12 +29,12 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
#
|
||||
|
32
README_CN.md
32
README_CN.md
@ -44,11 +44,11 @@
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | 趋势转变" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</div>
|
||||
|
||||
Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工作流程、RAG管道、代理功能、模型管理、可观察性功能等,让您可以快速从原型到生产。以下是其核心功能列表:
|
||||
Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 工作流、RAG 管道、Agent、模型管理、可观测性功能等,让您可以快速从原型到生产。以下是其核心功能列表:
|
||||
</br> </br>
|
||||
|
||||
**1. 工作流**:
|
||||
在视觉画布上构建和测试功能强大的AI工作流程,利用以下所有功能以及更多功能。
|
||||
在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
@ -56,7 +56,7 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
|
||||
|
||||
**2. 全面的模型支持**:
|
||||
与数百种专有/开源LLMs以及数十种推理提供商和自托管解决方案无缝集成,涵盖GPT、Mistral、Llama2以及任何与OpenAI API兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。
|
||||
与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。
|
||||
|
||||

|
||||
|
||||
@ -65,16 +65,16 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
用于制作提示、比较模型性能以及向基于聊天的应用程序添加其他功能(如文本转语音)的直观界面。
|
||||
|
||||
**4. RAG Pipeline**:
|
||||
广泛的RAG功能,涵盖从文档摄入到检索的所有内容,支持从PDF、PPT和其他常见文档格式中提取文本的开箱即用的支持。
|
||||
广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。
|
||||
|
||||
**5. Agent 智能体**:
|
||||
您可以基于LLM函数调用或ReAct定义代理,并为代理添加预构建或自定义工具。Dify为AI代理提供了50多种内置工具,如谷歌搜索、DELL·E、稳定扩散和WolframAlpha等。
|
||||
您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DELL·E、Stable Diffusion 和 WolframAlpha 等。
|
||||
|
||||
**6. LLMOps**:
|
||||
随时间监视和分析应用程序日志和性能。您可以根据生产数据和注释持续改进提示、数据集和模型。
|
||||
随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。
|
||||
|
||||
**7. 后端即服务**:
|
||||
所有Dify的功能都带有相应的API,因此您可以轻松地将Dify集成到自己的业务逻辑中。
|
||||
所有 Dify 的功能都带有相应的 API,因此您可以轻松地将 Dify 集成到自己的业务逻辑中。
|
||||
|
||||
|
||||
## 功能比较
|
||||
@ -84,21 +84,21 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI助理API</th>
|
||||
<th align="center">OpenAI Assistant API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">编程方法</td>
|
||||
<td align="center">API + 应用程序导向</td>
|
||||
<td align="center">Python代码</td>
|
||||
<td align="center">Python 代码</td>
|
||||
<td align="center">应用程序导向</td>
|
||||
<td align="center">API导向</td>
|
||||
<td align="center">API 导向</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">支持的LLMs</td>
|
||||
<td align="center">支持的 LLMs</td>
|
||||
<td align="center">丰富多样</td>
|
||||
<td align="center">丰富多样</td>
|
||||
<td align="center">丰富多样</td>
|
||||
<td align="center">仅限OpenAI</td>
|
||||
<td align="center">仅限 OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG引擎</td>
|
||||
@ -108,21 +108,21 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">代理</td>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">工作流程</td>
|
||||
<td align="center">工作流</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">可观察性</td>
|
||||
<td align="center">可观测性</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
@ -202,7 +202,7 @@ docker compose up -d
|
||||
## Contributing
|
||||
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持Dify的分享。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
|
||||
|
||||
> 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
|
||||
|
@ -55,3 +55,16 @@
|
||||
9. If you need to debug local async processing, please start the worker service by running
|
||||
`celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`.
|
||||
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
|
||||
|
||||
|
||||
## Testing
|
||||
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
```bash
|
||||
pip install -r requirements.txt -r requirements-dev.txt
|
||||
```
|
||||
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
```bash
|
||||
dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
37
api/app.py
37
api/app.py
@ -1,4 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
@ -17,10 +19,13 @@ import warnings
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from commands import register_commands
|
||||
from config import CloudEditionConfig, Config
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
@ -37,11 +42,8 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
@ -86,7 +88,25 @@ def create_app(test_config=None) -> Flask:
|
||||
|
||||
app.secret_key = app.config['SECRET_KEY']
|
||||
|
||||
logging.basicConfig(level=app.config.get('LOG_LEVEL', 'INFO'))
|
||||
log_handlers = None
|
||||
log_file = app.config.get('LOG_FILE')
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_handlers = [
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5
|
||||
),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
logging.basicConfig(
|
||||
level=app.config.get('LOG_LEVEL'),
|
||||
format=app.config.get('LOG_FORMAT'),
|
||||
datefmt=app.config.get('LOG_DATEFORMAT'),
|
||||
handlers=log_handlers
|
||||
)
|
||||
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
@ -115,7 +135,7 @@ def initialize_extensions(app):
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint == 'console':
|
||||
if request.blueprint in ['console', 'inner_api']:
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if not auth_header:
|
||||
@ -151,6 +171,7 @@ def unauthorized_handler():
|
||||
def register_blueprints(app):
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
@ -188,6 +209,8 @@ def register_blueprints(app):
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
|
@ -38,6 +38,9 @@ DEFAULTS = {
|
||||
'QDRANT_CLIENT_TIMEOUT': 20,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'LOG_FILE': '',
|
||||
'LOG_FORMAT': '%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
|
||||
'LOG_DATEFORMAT': '%Y-%m-%d %H:%M:%S',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
@ -69,6 +72,8 @@ DEFAULTS = {
|
||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||
'MILVUS_DATABASE': 'default',
|
||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||
'INNER_API': 'False',
|
||||
'ENTERPRISE_ENABLED': 'False',
|
||||
}
|
||||
|
||||
|
||||
@ -99,12 +104,15 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.3"
|
||||
self.CURRENT_VERSION = "0.6.4"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
self.LOG_FILE = get_env('LOG_FILE')
|
||||
self.LOG_FORMAT = get_env('LOG_FORMAT')
|
||||
self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
@ -133,6 +141,11 @@ class Config:
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# Enable or disable the inner API.
|
||||
self.INNER_API = get_bool_env('INNER_API')
|
||||
# The inner API key is used to authenticate the inner API.
|
||||
self.INNER_API_KEY = get_env('INNER_API_KEY')
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
@ -336,6 +349,8 @@ class Config:
|
||||
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
||||
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
|
||||
|
@ -1,22 +1,57 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('console', __name__, url_prefix='/console/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, setup, version, ping
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
# Import app controllers
|
||||
from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message,
|
||||
model_config, site, statistic, workflow, workflow_run, workflow_app_log, workflow_statistic, agent)
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
generator,
|
||||
message,
|
||||
model_config,
|
||||
site,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_oauth, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
|
||||
# Import enterprise controllers
|
||||
from .enterprise import enterprise_sso
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app,
|
||||
saved_message, workflow)
|
||||
from .explore import (
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
installed_app,
|
||||
message,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
workflow,
|
||||
)
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
@ -2,13 +2,15 @@ import json
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, BadRequest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
@ -16,11 +18,8 @@ from fields.app_fields import (
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_service import AppService
|
||||
from models.model import App, AppModelConfig, AppMode
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
|
||||
|
@ -26,10 +26,13 @@ class LoginApi(Resource):
|
||||
|
||||
try:
|
||||
account = AccountService.authenticate(args['email'], args['password'])
|
||||
except services.errors.account.AccountLoginError:
|
||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||
except services.errors.account.AccountLoginError as e:
|
||||
return {'code': 'unauthorized', 'message': str(e)}, 401
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
|
@ -12,7 +12,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (
|
||||
@ -45,10 +45,6 @@ class HitTestingApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# only high quality dataset can be used for hit testing
|
||||
if dataset.indexing_technique != 'high_quality':
|
||||
raise HighQualityDatasetOnlyError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
|
||||
|
0
api/controllers/console/enterprise/__init__.py
Normal file
0
api/controllers/console/enterprise/__init__.py
Normal file
59
api/controllers/console/enterprise/enterprise_sso.py
Normal file
59
api/controllers/console/enterprise/enterprise_sso.py
Normal file
@ -0,0 +1,59 @@
|
||||
from flask import current_app, redirect
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from services.enterprise.enterprise_sso_service import EnterpriseSSOService
|
||||
|
||||
|
||||
class EnterpriseSSOSamlLogin(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
return EnterpriseSSOService.get_sso_saml_login()
|
||||
|
||||
|
||||
class EnterpriseSSOSamlAcs(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('SAMLResponse', type=str, required=True, location='form')
|
||||
args = parser.parse_args()
|
||||
saml_response = args['SAMLResponse']
|
||||
|
||||
try:
|
||||
token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||
except Exception as e:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||
|
||||
|
||||
class EnterpriseSSOOidcLogin(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
return EnterpriseSSOService.get_sso_oidc_login()
|
||||
|
||||
|
||||
class EnterpriseSSOOidcCallback(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('state', type=str, required=True, location='args')
|
||||
parser.add_argument('code', type=str, required=True, location='args')
|
||||
parser.add_argument('oidc-state', type=str, required=True, location='cookies')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
token = EnterpriseSSOService.get_sso_oidc_callback(args)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||
except Exception as e:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||
|
||||
|
||||
api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
|
||||
api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
|
||||
api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
|
||||
api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')
|
@ -1,6 +1,7 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
|
||||
from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
@ -14,4 +15,10 @@ class FeatureApi(Resource):
|
||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||
|
||||
|
||||
class EnterpriseFeatureApi(Resource):
|
||||
def get(self):
|
||||
return EnterpriseFeatureService.get_enterprise_features().dict()
|
||||
|
||||
|
||||
api.add_resource(FeatureApi, '/features')
|
||||
api.add_resource(EnterpriseFeatureApi, '/enterprise-features')
|
||||
|
@ -58,6 +58,8 @@ class SetupApi(Resource):
|
||||
password=args['password']
|
||||
)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
setup()
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Tenant
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
@ -116,6 +117,16 @@ class TenantApi(Resource):
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
# if there is any tenant, switch to the first one
|
||||
if len(tenants) > 0:
|
||||
TenantService.switch_tenant(current_user, tenants[0].id)
|
||||
tenant = tenants[0]
|
||||
# else, raise Unauthorized
|
||||
else:
|
||||
raise Unauthorized('workspace is archived')
|
||||
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('files', __name__)
|
||||
|
9
api/controllers/inner_api/__init__.py
Normal file
9
api/controllers/inner_api/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .workspace import workspace
|
||||
|
0
api/controllers/inner_api/workspace/__init__.py
Normal file
0
api/controllers/inner_api/workspace/__init__.py
Normal file
37
api/controllers/inner_api/workspace/workspace.py
Normal file
37
api/controllers/inner_api/workspace/workspace.py
Normal file
@ -0,0 +1,37 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class EnterpriseWorkspace(Resource):
|
||||
|
||||
@setup_required
|
||||
@inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('owner_email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = Account.query.filter_by(email=args['owner_email']).first()
|
||||
if account is None:
|
||||
return {
|
||||
'message': 'owner account not found.'
|
||||
}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args['name'])
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return {
|
||||
'message': 'enterprise workspace created.'
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')
|
61
api/controllers/inner_api/wraps.py
Normal file
61
api/controllers/inner_api/wraps.py
Normal file
@ -0,0 +1,61 @@
|
||||
from base64 import b64encode
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
from flask import abort, current_app, request
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
authorization = request.headers.get('Authorization')
|
||||
if not authorization:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
parts = authorization.split(':')
|
||||
if len(parts) != 2:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
user_id, token = parts
|
||||
if ' ' in user_id:
|
||||
user_id = user_id.split(' ')[1]
|
||||
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
|
||||
data_to_sign = f'DIFY {user_id}'
|
||||
|
||||
signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
|
||||
signature = b64encode(signature.digest()).decode('utf-8')
|
||||
|
||||
if signature != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
|
@ -174,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
if not dataset.indexing_technique and not args.get('indexing_technique'):
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
# save file info
|
||||
|
@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound()
|
||||
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if fetch_user_arg:
|
||||
@ -137,6 +141,7 @@ def validate_dataset_token(view=None):
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
||||
.filter(Tenant.status == TenantStatus.NORMAL) \
|
||||
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
|
@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
|
@ -7,7 +7,7 @@ from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, AppMode
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
from services.app_service import AppService
|
||||
|
||||
|
@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
@ -26,7 +26,10 @@ class AppGenerateResponseConverter(ABC):
|
||||
else:
|
||||
def _generate():
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
yield f'data: {chunk}\n\n'
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
else:
|
||||
@ -35,7 +38,10 @@ class AppGenerateResponseConverter(ABC):
|
||||
else:
|
||||
def _generate():
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
yield f'data: {chunk}\n\n'
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
|
||||
|
@ -84,7 +84,7 @@ class DatasetDocumentStore:
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
|
||||
segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
if not allow_update and segment_document:
|
||||
|
@ -30,34 +30,24 @@ class CodeExecutionResponse(BaseModel):
|
||||
|
||||
class CodeExecutor:
|
||||
@classmethod
|
||||
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
:param code: code
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
template_transformer = None
|
||||
if language == 'python3':
|
||||
template_transformer = PythonTemplateTransformer
|
||||
elif language == 'jinja2':
|
||||
template_transformer = Jinja2TemplateTransformer
|
||||
elif language == 'javascript':
|
||||
template_transformer = NodeJsTemplateTransformer
|
||||
else:
|
||||
raise CodeExecutionException('Unsupported language')
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run'
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': CODE_EXECUTION_API_KEY
|
||||
}
|
||||
|
||||
data = {
|
||||
'language': 'python3' if language == 'jinja2' else
|
||||
'nodejs' if language == 'javascript' else
|
||||
'python3' if language == 'python3' else None,
|
||||
'code': runner,
|
||||
'code': code,
|
||||
'preload': preload
|
||||
}
|
||||
|
||||
@ -85,4 +75,32 @@ class CodeExecutor:
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
|
||||
return template_transformer.transform_response(response.data.stdout)
|
||||
return response.data.stdout
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
:param code: code
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
template_transformer = None
|
||||
if language == 'python3':
|
||||
template_transformer = PythonTemplateTransformer
|
||||
elif language == 'jinja2':
|
||||
template_transformer = Jinja2TemplateTransformer
|
||||
elif language == 'javascript':
|
||||
template_transformer = NodeJsTemplateTransformer
|
||||
else:
|
||||
raise CodeExecutionException('Unsupported language')
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionException as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
@ -1,10 +1,13 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """
|
||||
import jinja2
|
||||
from json import loads
|
||||
from base64 import b64decode
|
||||
|
||||
template = jinja2.Template('''{{code}}''')
|
||||
|
||||
@ -12,7 +15,8 @@ def main(**inputs):
|
||||
return template.render(**inputs)
|
||||
|
||||
# execute main function, and return the result
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**loads(inputs))
|
||||
|
||||
result = f'''<<RESULT>>{output}<<RESULT>>'''
|
||||
|
||||
@ -39,6 +43,7 @@ JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
|
||||
|
||||
JINJA2_PRELOAD = f"""
|
||||
import jinja2
|
||||
from base64 import b64decode
|
||||
|
||||
def _jinja2_preload_():
|
||||
# prepare jinja2 environment, load template and render before to avoid sandbox issue
|
||||
@ -60,9 +65,11 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
:return:
|
||||
"""
|
||||
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# transform jinja2 template to python code
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, JINJA2_PRELOAD
|
||||
|
||||
|
@ -1,17 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """# declare main function here
|
||||
{{code}}
|
||||
|
||||
from json import loads, dumps
|
||||
from base64 import b64decode
|
||||
|
||||
# execute main function, and return the result
|
||||
# inputs is a dict, and it
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**json.loads(inputs))
|
||||
|
||||
# convert output to json and print
|
||||
output = json.dumps(output, indent=4)
|
||||
output = dumps(output, indent=4)
|
||||
|
||||
result = f'''<<RESULT>>
|
||||
{output}
|
||||
@ -20,8 +25,28 @@ result = f'''<<RESULT>>
|
||||
print(result)
|
||||
"""
|
||||
|
||||
PYTHON_PRELOAD = """"""
|
||||
|
||||
PYTHON_PRELOAD = """
|
||||
# prepare general imports
|
||||
import json
|
||||
import datetime
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import binascii
|
||||
import collections
|
||||
import functools
|
||||
import operator
|
||||
import itertools
|
||||
"""
|
||||
|
||||
class PythonTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@ -34,7 +59,7 @@ class PythonTemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
|
||||
# transform inputs to json string
|
||||
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# replace code and inputs
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
|
@ -88,6 +88,14 @@ class PromptMessage(ABC, BaseModel):
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
@ -118,6 +126,16 @@ class AssistantPromptMessage(PromptMessage):
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_calls:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
@ -132,3 +150,14 @@ class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -10,3 +10,6 @@
|
||||
- cohere.command-text-v14
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
|
@ -0,0 +1,57 @@
|
||||
model: anthropic.claude-3-opus-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Opus
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.075'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
human_prompt_prefix = "<s>[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "amazon":
|
||||
human_prompt_prefix = "\n\nUser:"
|
||||
@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
|
||||
if model_parameters.get("countPenalty"):
|
||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
payload["temperature"] = model_parameters.get("temperature")
|
||||
payload["top_p"] = model_parameters.get("top_p")
|
||||
payload["max_tokens"] = model_parameters.get("max_tokens")
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["stop"] = stop[:10] if stop else []
|
||||
|
||||
elif model_prefix == "anthropic":
|
||||
payload = { **model_parameters }
|
||||
@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
output = response_body.get("generation").strip('\n')
|
||||
prompt_tokens = response_body.get("prompt_token_count")
|
||||
completion_tokens = response_body.get("generation_token_count")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
output = response_body.get("outputs")[0].get("text")
|
||||
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
|
||||
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
content_delta = payload.get("text")
|
||||
finish_reason = payload.get("finish_reason")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
content_delta = payload.get('outputs')[0].get("text")
|
||||
finish_reason = payload.get('outputs')[0].get("stop_reason")
|
||||
|
||||
elif model_prefix == "meta":
|
||||
content_delta = payload.get("generation").strip('\n')
|
||||
finish_reason = payload.get("stop_reason")
|
||||
|
@ -0,0 +1,39 @@
|
||||
model: mistral.mistral-7b-instruct-v0:2
|
||||
label:
|
||||
en_US: Mistral 7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00015'
|
||||
output: '0.0002'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
@ -0,0 +1,27 @@
|
||||
model: mistral.mistral-large-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Large
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.7
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -0,0 +1,39 @@
|
||||
model: mistral.mixtral-8x7b-instruct-v0:1
|
||||
label:
|
||||
en_US: Mixtral 8X7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00045'
|
||||
output: '0.0007'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
@ -0,0 +1,25 @@
|
||||
model: llama3-70b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-70B-8192
|
||||
en_US: Llama-3-70B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -0,0 +1,25 @@
|
||||
model: llama3-8b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-8B-8192
|
||||
en_US: Llama-3-8B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.59'
|
||||
output: '0.79'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -1,5 +1,6 @@
|
||||
- open-mistral-7b
|
||||
- open-mixtral-8x7b
|
||||
- open-mixtral-8x22b
|
||||
- mistral-small-latest
|
||||
- mistral-medium-latest
|
||||
- mistral-large-latest
|
||||
|
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
@ -0,0 +1,51 @@
|
||||
model: open-mixtral-8x22b
|
||||
label:
|
||||
zh_Hans: open-mixtral-8x22b
|
||||
en_US: open-mixtral-8x22b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.002'
|
||||
output: '0.006'
|
||||
unit: '0.001'
|
||||
currency: USD
|
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
@ -5,6 +5,9 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
@ -5,6 +5,9 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
|
@ -5,6 +5,9 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
|
@ -1,5 +1,7 @@
|
||||
- google/gemma-7b
|
||||
- google/codegemma-7b
|
||||
- meta/llama2-70b
|
||||
- meta/llama3-8b
|
||||
- meta/llama3-70b
|
||||
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||
- fuyu-8b
|
||||
|
@ -11,13 +11,19 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
@ -22,6 +22,6 @@ parameter_rules:
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
|
@ -11,13 +11,19 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
@ -7,17 +7,23 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
@ -0,0 +1,36 @@
|
||||
model: meta/llama3-70b
|
||||
label:
|
||||
zh_Hans: meta/llama3-70b
|
||||
en_US: meta/llama3-70b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
@ -0,0 +1,36 @@
|
||||
model: meta/llama3-8b
|
||||
label:
|
||||
zh_Hans: meta/llama3-8b
|
||||
en_US: meta/llama3-8b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
@ -25,7 +25,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': '',
|
||||
'google/gemma-7b': '',
|
||||
'google/codegemma-7b': '',
|
||||
'meta/llama2-70b': ''
|
||||
'meta/llama2-70b': '',
|
||||
'meta/llama3-8b': '',
|
||||
'meta/llama3-70b': ''
|
||||
|
||||
}
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
@ -131,7 +134,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60)
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@ -232,7 +235,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
@ -11,13 +11,19 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
@ -1,6 +1,9 @@
|
||||
provider: nvidia
|
||||
label:
|
||||
en_US: API Catalog
|
||||
description:
|
||||
en_US: API Catalog
|
||||
zh_Hans: API Catalog
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
|
@ -201,7 +201,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
@ -138,7 +138,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60)
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@ -154,7 +154,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
json_result['object'] = 'chat.completion'
|
||||
elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''):
|
||||
json_result['object'] = 'text_completion'
|
||||
|
||||
|
||||
if (completion_type is LLMMode.CHAT
|
||||
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
|
||||
raise CredentialsValidateFailedError(
|
||||
@ -334,7 +334,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@ -425,6 +425,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
finish_reason = 'Unknown'
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
|
@ -73,3 +73,22 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: "4096"
|
||||
type: text-input
|
||||
- variable: vision_support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: 是否支持 Vision
|
||||
en_US: Vision Support
|
||||
type: radio
|
||||
required: false
|
||||
default: 'no_support'
|
||||
options:
|
||||
- value: 'support'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: 'no_support'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
|
@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel):
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
response = xinference_client.rerank(
|
||||
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel):
|
||||
try:
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
self.invoke(
|
||||
model=model,
|
||||
@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel):
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||
|
||||
try:
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
||||
embeddings = handle.create_embedding(input=texts)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(e)
|
||||
@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
if extra_args.max_tokens:
|
||||
credentials['max_tokens'] = extra_args.max_tokens
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except InvokeAuthorizationError as e:
|
||||
@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
@ -1,6 +1,15 @@
|
||||
|
||||
from .__version__ import __version__
|
||||
from ._client import ZhipuAI
|
||||
from .core._errors import (APIAuthenticationError, APIInternalError, APIReachLimitError, APIRequestFailedError,
|
||||
APIResponseError, APIResponseValidationError, APIServerFlowExceedError, APIStatusError,
|
||||
APITimeoutError, ZhipuAIError)
|
||||
from .core._errors import (
|
||||
APIAuthenticationError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
||||
|
@ -4,6 +4,7 @@
|
||||
- searxng
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stability
|
||||
- wikipedia
|
||||
- model.openai
|
||||
- model.google
|
||||
@ -17,6 +18,7 @@
|
||||
- model.zhipuai
|
||||
- aippt
|
||||
- youtube
|
||||
- code
|
||||
- wolframalpha
|
||||
- maths
|
||||
- github
|
||||
|
1
api/core/tools/provider/builtin/code/_assets/icon.svg
Normal file
1
api/core/tools/provider/builtin/code/_assets/icon.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="currentColor"></path></g></svg>
|
After Width: | Height: | Size: 1.4 KiB |
8
api/core/tools/provider/builtin/code/code.py
Normal file
8
api/core/tools/provider/builtin/code/code.py
Normal file
@ -0,0 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
pass
|
13
api/core/tools/provider/builtin/code/code.yaml
Normal file
13
api/core/tools/provider/builtin/code/code.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: code
|
||||
label:
|
||||
en_US: Code Interpreter
|
||||
zh_Hans: 代码解释器
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
en_US: Run a piece of code and get the result back.
|
||||
zh_Hans: 运行一段代码并返回结果。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta.
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
22
api/core/tools/provider/builtin/code/tools/simple_code.py
Normal file
22
api/core/tools/provider/builtin/code/tools/simple_code.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SimpleCode(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
invoke simple code
|
||||
"""
|
||||
|
||||
language = tool_parameters.get('language', 'python3')
|
||||
code = tool_parameters.get('code', '')
|
||||
|
||||
if language not in ['python3', 'javascript']:
|
||||
raise ValueError(f'Only python3 and javascript are supported, not {language}')
|
||||
|
||||
result = CodeExecutor.execute_code(language, '', code)
|
||||
|
||||
return self.create_text_message(result)
|
51
api/core/tools/provider/builtin/code/tools/simple_code.yaml
Normal file
51
api/core/tools/provider/builtin/code/tools/simple_code.yaml
Normal file
@ -0,0 +1,51 @@
|
||||
identity:
|
||||
name: simple_code
|
||||
author: Dify
|
||||
label:
|
||||
en_US: Code Interpreter
|
||||
zh_Hans: 代码解释器
|
||||
pt_BR: Interpretador de Código
|
||||
description:
|
||||
human:
|
||||
en_US: Run code and get the result back. When you're using a lower quality model, please make sure there are some tips help LLM to understand how to write the code.
|
||||
zh_Hans: 运行一段代码并返回结果。当您使用较低质量的模型时,请确保有一些提示帮助LLM理解如何编写代码。
|
||||
pt_BR: Execute um trecho de código e obtenha o resultado de volta. quando você estiver usando um modelo de qualidade inferior, certifique-se de que existam algumas dicas para ajudar o LLM a entender como escrever o código.
|
||||
llm: A tool for running code and getting the result back. Only native packages are allowed, network/IO operations are disabled. and you must use print() or console.log() to output the result or result will be empty.
|
||||
parameters:
|
||||
- name: language
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Language
|
||||
zh_Hans: 语言
|
||||
pt_BR: Idioma
|
||||
human_description:
|
||||
en_US: The programming language of the code
|
||||
zh_Hans: 代码的编程语言
|
||||
pt_BR: A linguagem de programação do código
|
||||
llm_description: language of the code, only "python3" and "javascript" are supported
|
||||
form: llm
|
||||
options:
|
||||
- value: python3
|
||||
label:
|
||||
en_US: Python3
|
||||
zh_Hans: Python3
|
||||
pt_BR: Python3
|
||||
- value: javascript
|
||||
label:
|
||||
en_US: JavaScript
|
||||
zh_Hans: JavaScript
|
||||
pt_BR: JavaScript
|
||||
- name: code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Code
|
||||
zh_Hans: 代码
|
||||
pt_BR: Código
|
||||
human_description:
|
||||
en_US: The code to be executed
|
||||
zh_Hans: 要执行的代码
|
||||
pt_BR: O código a ser executado
|
||||
llm_description: code to be executed, only native packages are allowed, network/IO operations are disabled.
|
||||
form: llm
|
@ -20,7 +20,7 @@ class JinaReaderTool(BuiltinTool):
|
||||
url = tool_parameters['url']
|
||||
|
||||
headers = {
|
||||
'Accept': 'text/event-stream'
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
response = ssrf_proxy.get(
|
||||
|
21
api/core/tools/provider/builtin/judge0ce/_assets/icon.svg
Normal file
21
api/core/tools/provider/builtin/judge0ce/_assets/icon.svg
Normal file
@ -0,0 +1,21 @@
|
||||
<?xml version="1.0" standalone="no"?>
|
||||
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 20010904//EN"
|
||||
"http://www.w3.org/TR/2001/REC-SVG-20010904/DTD/svg10.dtd">
|
||||
<svg version="1.0" xmlns="http://www.w3.org/2000/svg"
|
||||
width="128.000000pt" height="128.000000pt" viewBox="0 0 128.000000 128.000000"
|
||||
preserveAspectRatio="xMidYMid meet">
|
||||
|
||||
<g transform="translate(0.000000,128.000000) scale(0.100000,-0.100000)"
|
||||
fill="#000000" stroke="none">
|
||||
<path d="M0 975 l0 -305 33 1 c54 0 336 35 343 41 3 4 0 57 -7 118 -10 85 -17
|
||||
113 -29 120 -47 25 -45 104 2 133 13 8 118 26 246 41 208 26 225 26 248 11 14
|
||||
-9 30 -27 36 -41 10 -22 8 -33 -10 -68 l-23 -42 40 -316 40 -315 30 -31 c17
|
||||
-17 31 -38 31 -47 0 -25 -27 -72 -46 -79 -35 -13 -450 -59 -476 -53 -52 13
|
||||
-70 85 -32 127 10 13 10 33 -1 120 -8 58 -15 111 -15 118 0 16 -31 16 -237 -5
|
||||
l-173 -17 0 -243 0 -243 640 0 640 0 0 640 0 640 -640 0 -640 0 0 -305z"/>
|
||||
<path d="M578 977 c-128 -16 -168 -24 -168 -35 0 -10 8 -12 28 -8 15 3 90 12
|
||||
167 21 167 18 188 23 180 35 -7 12 -1 12 -207 -13z"/>
|
||||
<path d="M660 326 c-100 -13 -163 -25 -160 -31 3 -5 14 -9 25 -8 104 11 305
|
||||
35 323 39 12 2 22 9 22 14 0 13 -14 12 -210 -14z"/>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 1.1 KiB |
23
api/core/tools/provider/builtin/judge0ce/judge0ce.py
Normal file
23
api/core/tools/provider/builtin/judge0ce/judge0ce.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.judge0ce.tools.submitCodeExecutionTask import SubmitCodeExecutionTaskTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class Judge0CEProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
SubmitCodeExecutionTaskTool().fork_tool_runtime(
|
||||
meta={
|
||||
"credentials": credentials,
|
||||
}
|
||||
).invoke(
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"source_code": "print('hello world')",
|
||||
"language_id": 71,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
29
api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
Normal file
29
api/core/tools/provider/builtin/judge0ce/judge0ce.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Richards Tu
|
||||
name: judge0ce
|
||||
label:
|
||||
en_US: Judge0 CE
|
||||
zh_Hans: Judge0 CE
|
||||
pt_BR: Judge0 CE
|
||||
description:
|
||||
en_US: Judge0 CE is an open-source code execution system. Support various languages, including C, C++, Java, Python, Ruby, etc.
|
||||
zh_Hans: Judge0 CE 是一个开源的代码执行系统。支持多种语言,包括 C、C++、Java、Python、Ruby 等。
|
||||
pt_BR: Judge0 CE é um sistema de execução de código de código aberto. Suporta várias linguagens, incluindo C, C++, Java, Python, Ruby, etc.
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
X-RapidAPI-Key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: RapidAPI Key
|
||||
zh_Hans: RapidAPI Key
|
||||
pt_BR: RapidAPI Key
|
||||
help:
|
||||
en_US: RapidAPI Key is required to access the Judge0 CE API.
|
||||
zh_Hans: RapidAPI Key 是访问 Judge0 CE API 所必需的。
|
||||
pt_BR: RapidAPI Key é necessário para acessar a API do Judge0 CE.
|
||||
placeholder:
|
||||
en_US: Enter your RapidAPI Key
|
||||
zh_Hans: 输入你的 RapidAPI Key
|
||||
pt_BR: Insira sua RapidAPI Key
|
||||
url: https://rapidapi.com/judge0-official/api/judge0-ce
|
@ -0,0 +1,37 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GetExecutionResultTool(BuiltinTool):
|
||||
def _invoke(self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
url = f"https://judge0-ce.p.rapidapi.com/submissions/{tool_parameters['token']}"
|
||||
headers = {
|
||||
"X-RapidAPI-Key": api_key
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return self.create_text_message(text=f"Submission details:\n"
|
||||
f"stdout: {result.get('stdout', '')}\n"
|
||||
f"stderr: {result.get('stderr', '')}\n"
|
||||
f"compile_output: {result.get('compile_output', '')}\n"
|
||||
f"message: {result.get('message', '')}\n"
|
||||
f"status: {result['status']['description']}\n"
|
||||
f"time: {result.get('time', '')} seconds\n"
|
||||
f"memory: {result.get('memory', '')} bytes")
|
||||
else:
|
||||
return self.create_text_message(text=f"Error retrieving submission details: {response.text}")
|
@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: getExecutionResult
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Get Execution Result
|
||||
zh_Hans: 获取执行结果
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
|
||||
zh_Hans: 一个用于通过 submitCodeExecutionTask 工具提供的特定令牌来检索代码提交详细信息的工具。
|
||||
llm: A tool for retrieving the details of a code submission by a specific token from submitCodeExecutionTask.
|
||||
parameters:
|
||||
- name: token
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Token
|
||||
zh_Hans: 令牌
|
||||
human_description:
|
||||
en_US: The submission's unique token.
|
||||
zh_Hans: 提交的唯一令牌。
|
||||
llm_description: The submission's unique token. MUST get from submitCodeExecution.
|
||||
form: llm
|
@ -0,0 +1,49 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from httpx import post
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class SubmitCodeExecutionTaskTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
api_key = self.runtime.credentials['X-RapidAPI-Key']
|
||||
|
||||
source_code = tool_parameters['source_code']
|
||||
language_id = tool_parameters['language_id']
|
||||
stdin = tool_parameters.get('stdin', '')
|
||||
expected_output = tool_parameters.get('expected_output', '')
|
||||
additional_files = tool_parameters.get('additional_files', '')
|
||||
|
||||
url = "https://judge0-ce.p.rapidapi.com/submissions"
|
||||
|
||||
querystring = {"base64_encoded": "false", "fields": "*"}
|
||||
|
||||
payload = {
|
||||
"language_id": language_id,
|
||||
"source_code": source_code,
|
||||
"stdin": stdin,
|
||||
"expected_output": expected_output,
|
||||
"additional_files": additional_files,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"X-RapidAPI-Key": api_key,
|
||||
"X-RapidAPI-Host": "judge0-ce.p.rapidapi.com"
|
||||
}
|
||||
|
||||
response = post(url, data=json.dumps(payload), headers=headers, params=querystring)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise Exception(response.text)
|
||||
|
||||
token = response.json()['token']
|
||||
|
||||
return self.create_text_message(text=token)
|
@ -0,0 +1,67 @@
|
||||
identity:
|
||||
name: submitCodeExecutionTask
|
||||
author: Richards Tu
|
||||
label:
|
||||
en_US: Submit Code Execution Task
|
||||
zh_Hans: 提交代码执行任务
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for submitting code execution task to Judge0 CE.
|
||||
zh_Hans: 一个用于向 Judge0 CE 提交代码执行任务的工具。
|
||||
llm: A tool for submitting a new code execution task to Judge0 CE. It takes in the source code, language ID, standard input (optional), expected output (optional), and additional files (optional) as parameters; and returns a unique token representing the submission.
|
||||
parameters:
|
||||
- name: source_code
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Source Code
|
||||
zh_Hans: 源代码
|
||||
human_description:
|
||||
en_US: The source code to be executed.
|
||||
zh_Hans: 要执行的源代码。
|
||||
llm_description: The source code to be executed.
|
||||
form: llm
|
||||
- name: language_id
|
||||
type: number
|
||||
required: true
|
||||
label:
|
||||
en_US: Language ID
|
||||
zh_Hans: 语言 ID
|
||||
human_description:
|
||||
en_US: The ID of the language in which the source code is written.
|
||||
zh_Hans: 源代码所使用的语言的 ID。
|
||||
llm_description: The ID of the language in which the source code is written. For example, 50 for C++, 71 for Python, etc.
|
||||
form: llm
|
||||
- name: stdin
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Standard Input
|
||||
zh_Hans: 标准输入
|
||||
human_description:
|
||||
en_US: The standard input to be provided to the program.
|
||||
zh_Hans: 提供给程序的标准输入。
|
||||
llm_description: The standard input to be provided to the program. Optional.
|
||||
form: llm
|
||||
- name: expected_output
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Expected Output
|
||||
zh_Hans: 期望输出
|
||||
human_description:
|
||||
en_US: The expected output of the program. Used for comparison in some scenarios.
|
||||
zh_Hans: 程序的期望输出。在某些场景下用于比较。
|
||||
llm_description: The expected output of the program. Used for comparison in some scenarios. Optional.
|
||||
form: llm
|
||||
- name: additional_files
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Additional Files
|
||||
zh_Hans: 附加文件
|
||||
human_description:
|
||||
en_US: Base64 encoded additional files for the submission.
|
||||
zh_Hans: 提交的 Base64 编码的附加文件。
|
||||
llm_description: Base64 encoded additional files for the submission. Optional.
|
||||
form: llm
|
10
api/core/tools/provider/builtin/stability/_assets/icon.svg
Normal file
10
api/core/tools/provider/builtin/stability/_assets/icon.svg
Normal file
@ -0,0 +1,10 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="40" height="40" viewBox="0 0 40 40" fill="none">
|
||||
<path d="M12.0377 35C19.1243 35 23.7343 31.3 23.7343 25.7333C23.7343 21.4167 20.931 18.6733 15.9177 17.5367L12.701 16.585C9.87768 15.96 8.22935 15.21 8.61768 13.2933C8.94102 11.6983 9.90602 10.7983 12.1543 10.7983C19.296 10.7983 21.9427 13.2933 21.9427 13.2933V7.29333C21.9427 7.29333 19.366 5 12.1543 5C5.35435 5 1.66602 8.45 1.66602 13.7883C1.66602 18.105 4.22268 20.6167 9.40768 21.8083L9.96435 21.9467C10.7527 22.1867 11.8177 22.505 13.1577 22.9C15.8077 23.525 16.4893 24.1883 16.4893 26.1767C16.4893 27.9933 14.5727 29.0267 12.0393 29.0267C4.73435 29.0267 1.66602 25.385 1.66602 25.385V32.0333C1.66602 32.0333 3.58602 35 12.0377 35Z" fill="url(#paint0_linear_17756_15767)"/>
|
||||
<path d="M33.9561 34.55C36.4645 34.55 38.3328 32.7617 38.3328 30.34C38.3328 27.8667 36.5178 26.13 33.9561 26.13C31.4478 26.13 29.6328 27.8667 29.6328 30.34C29.6328 32.8133 31.4478 34.55 33.9561 34.55Z" fill="#E80000"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_17756_15767" x1="1105.08" y1="5" x2="1105.08" y2="3005" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#9D39FF"/>
|
||||
<stop offset="1" stop-color="#A380FF"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 1.2 KiB |
15
api/core/tools/provider/builtin/stability/stability.py
Normal file
15
api/core/tools/provider/builtin/stability/stability.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stability tool.
|
||||
"""
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
self.sd_validate_credentials(credentials)
|
29
api/core/tools/provider/builtin/stability/stability.yaml
Normal file
29
api/core/tools/provider/builtin/stability/stability.yaml
Normal file
@ -0,0 +1,29 @@
|
||||
identity:
|
||||
author: Dify
|
||||
name: stability
|
||||
label:
|
||||
en_US: Stability
|
||||
zh_Hans: Stability
|
||||
pt_BR: Stability
|
||||
description:
|
||||
en_US: Activating humanity's potential through generative AI
|
||||
zh_Hans: 通过生成式 AI 激活人类的潜力
|
||||
pt_BR: Activating humanity's potential through generative AI
|
||||
icon: icon.svg
|
||||
credentials_for_provider:
|
||||
api_key:
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: API key
|
||||
zh_Hans: API key
|
||||
pt_BR: API key
|
||||
placeholder:
|
||||
en_US: Please input your API key
|
||||
zh_Hans: 请输入你的 API key
|
||||
pt_BR: Please input your API key
|
||||
help:
|
||||
en_US: Get your API key from Stability
|
||||
zh_Hans: 从 Stability 获取你的 API key
|
||||
pt_BR: Get your API key from Stability
|
||||
url: https://platform.stability.ai/account/keys
|
34
api/core/tools/provider/builtin/stability/tools/base.py
Normal file
34
api/core/tools/provider/builtin/stability/tools/base.py
Normal file
@ -0,0 +1,34 @@
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class BaseStabilityAuthorization:
|
||||
def sd_validate_credentials(self, credentials: dict):
|
||||
"""
|
||||
This method is responsible for validating the credentials.
|
||||
"""
|
||||
api_key = credentials.get('api_key', '')
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError('API key is required.')
|
||||
|
||||
response = requests.get(
|
||||
URL('https://api.stability.ai') / 'v1' / 'user' / 'account',
|
||||
headers=self.generate_authorization_headers(credentials),
|
||||
timeout=(5, 30)
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ToolProviderCredentialValidationError('Invalid API key.')
|
||||
|
||||
return True
|
||||
|
||||
def generate_authorization_headers(self, credentials: dict) -> dict[str, str]:
|
||||
"""
|
||||
This method is responsible for generating the authorization headers.
|
||||
"""
|
||||
return {
|
||||
'Authorization': f'Bearer {credentials.get("api_key", "")}'
|
||||
}
|
||||
|
@ -0,0 +1,60 @@
|
||||
from typing import Any
|
||||
|
||||
from httpx import post
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
"""
|
||||
This class is responsible for providing the stable diffusion tool.
|
||||
"""
|
||||
model_endpoint_map = {
|
||||
'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
|
||||
'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
|
||||
}
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invoke the tool.
|
||||
"""
|
||||
payload = {
|
||||
'prompt': tool_parameters.get('prompt', ''),
|
||||
'aspect_radio': tool_parameters.get('aspect_radio', '16:9'),
|
||||
'mode': 'text-to-image',
|
||||
'seed': tool_parameters.get('seed', 0),
|
||||
'output_format': 'png',
|
||||
}
|
||||
|
||||
model = tool_parameters.get('model', 'core')
|
||||
|
||||
if model in ['sd3', 'sd3-turbo']:
|
||||
payload['model'] = tool_parameters.get('model')
|
||||
|
||||
if not model == 'sd3-turbo':
|
||||
payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
|
||||
|
||||
response = post(
|
||||
self.model_endpoint_map[tool_parameters.get('model', 'core')],
|
||||
headers={
|
||||
'accept': 'image/*',
|
||||
**self.generate_authorization_headers(self.runtime.credentials),
|
||||
},
|
||||
files={
|
||||
key: (None, str(value)) for key, value in payload.items()
|
||||
},
|
||||
timeout=(5, 30)
|
||||
)
|
||||
|
||||
if not response.status_code == 200:
|
||||
raise Exception(response.text)
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=response.content, meta={
|
||||
'mime_type': 'image/png'
|
||||
},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
)
|
142
api/core/tools/provider/builtin/stability/tools/text2image.yaml
Normal file
142
api/core/tools/provider/builtin/stability/tools/text2image.yaml
Normal file
@ -0,0 +1,142 @@
|
||||
identity:
|
||||
name: stability_text2image
|
||||
author: Dify
|
||||
label:
|
||||
en_US: StableDiffusion
|
||||
zh_Hans: 稳定扩散
|
||||
pt_BR: StableDiffusion
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generate images based on the text input
|
||||
zh_Hans: 一个基于文本输入生成图像的工具
|
||||
pt_BR: A tool for generate images based on the text input
|
||||
llm: A tool for generate images based on the text input
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
pt_BR: Prompt
|
||||
human_description:
|
||||
en_US: used for generating images
|
||||
zh_Hans: 用于生成图像
|
||||
pt_BR: used for generating images
|
||||
llm_description: key words for generating images
|
||||
form: llm
|
||||
- name: model
|
||||
type: select
|
||||
default: sd3-turbo
|
||||
required: true
|
||||
label:
|
||||
en_US: Model
|
||||
zh_Hans: 模型
|
||||
pt_BR: Model
|
||||
options:
|
||||
- value: core
|
||||
label:
|
||||
en_US: Core
|
||||
zh_Hans: Core
|
||||
pt_BR: Core
|
||||
- value: sd3
|
||||
label:
|
||||
en_US: Stable Diffusion 3
|
||||
zh_Hans: Stable Diffusion 3
|
||||
pt_BR: Stable Diffusion 3
|
||||
- value: sd3-turbo
|
||||
label:
|
||||
en_US: Stable Diffusion 3 Turbo
|
||||
zh_Hans: Stable Diffusion 3 Turbo
|
||||
pt_BR: Stable Diffusion 3 Turbo
|
||||
human_description:
|
||||
en_US: Model for generating images
|
||||
zh_Hans: 用于生成图像的模型
|
||||
pt_BR: Model for generating images
|
||||
llm_description: Model for generating images
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示
|
||||
pt_BR: Negative Prompt
|
||||
human_description:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示
|
||||
pt_BR: Negative Prompt
|
||||
llm_description: Negative Prompt
|
||||
form: form
|
||||
- name: seeds
|
||||
type: number
|
||||
default: 0
|
||||
required: false
|
||||
label:
|
||||
en_US: Seeds
|
||||
zh_Hans: 种子
|
||||
pt_BR: Seeds
|
||||
human_description:
|
||||
en_US: Seeds
|
||||
zh_Hans: 种子
|
||||
pt_BR: Seeds
|
||||
llm_description: Seeds
|
||||
min: 0
|
||||
max: 4294967294
|
||||
form: form
|
||||
- name: aspect_radio
|
||||
type: select
|
||||
default: '16:9'
|
||||
options:
|
||||
- value: '16:9'
|
||||
label:
|
||||
en_US: '16:9'
|
||||
zh_Hans: '16:9'
|
||||
pt_BR: '16:9'
|
||||
- value: '1:1'
|
||||
label:
|
||||
en_US: '1:1'
|
||||
zh_Hans: '1:1'
|
||||
pt_BR: '1:1'
|
||||
- value: '21:9'
|
||||
label:
|
||||
en_US: '21:9'
|
||||
zh_Hans: '21:9'
|
||||
pt_BR: '21:9'
|
||||
- value: '2:3'
|
||||
label:
|
||||
en_US: '2:3'
|
||||
zh_Hans: '2:3'
|
||||
pt_BR: '2:3'
|
||||
- value: '4:5'
|
||||
label:
|
||||
en_US: '4:5'
|
||||
zh_Hans: '4:5'
|
||||
pt_BR: '4:5'
|
||||
- value: '5:4'
|
||||
label:
|
||||
en_US: '5:4'
|
||||
zh_Hans: '5:4'
|
||||
pt_BR: '5:4'
|
||||
- value: '9:16'
|
||||
label:
|
||||
en_US: '9:16'
|
||||
zh_Hans: '9:16'
|
||||
pt_BR: '9:16'
|
||||
- value: '9:21'
|
||||
label:
|
||||
en_US: '9:21'
|
||||
zh_Hans: '9:21'
|
||||
pt_BR: '9:21'
|
||||
required: false
|
||||
label:
|
||||
en_US: Aspect Radio
|
||||
zh_Hans: 长宽比
|
||||
pt_BR: Aspect Radio
|
||||
human_description:
|
||||
en_US: Aspect Radio
|
||||
zh_Hans: 长宽比
|
||||
pt_BR: Aspect Radio
|
||||
llm_description: Aspect Radio
|
||||
form: form
|
@ -16,6 +16,13 @@ class TavilyProvider(BuiltinToolProviderController):
|
||||
user_id='',
|
||||
tool_parameters={
|
||||
"query": "Sachin Tendulkar",
|
||||
"search_depth": "basic",
|
||||
"include_answer": True,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
"max_results": 5,
|
||||
"include_domains": "",
|
||||
"exclude_domains": ""
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
@ -24,87 +24,43 @@ class TavilySearch:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
def raw_results(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 3,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[list[str]] = [],
|
||||
exclude_domains: Optional[list[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> dict:
|
||||
def raw_results(self, params: dict[str, Any]) -> dict:
|
||||
"""
|
||||
Retrieves raw search results from the Tavily Search API.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
max_results (int, optional): The maximum number of results to retrieve. Defaults to 3.
|
||||
search_depth (str, optional): The search depth. Defaults to "advanced".
|
||||
include_domains (List[str], optional): The domains to include in the search. Defaults to [].
|
||||
exclude_domains (List[str], optional): The domains to exclude from the search. Defaults to [].
|
||||
include_answer (bool, optional): Whether to include answer in the search results. Defaults to False.
|
||||
include_raw_content (bool, optional): Whether to include raw content in the search results. Defaults to False.
|
||||
include_images (bool, optional): Whether to include images in the search results. Defaults to False.
|
||||
params (Dict[str, Any]): The search parameters.
|
||||
|
||||
Returns:
|
||||
dict: The raw search results.
|
||||
|
||||
"""
|
||||
params = {
|
||||
"api_key": self.api_key,
|
||||
"query": query,
|
||||
"max_results": max_results,
|
||||
"search_depth": search_depth,
|
||||
"include_domains": include_domains,
|
||||
"exclude_domains": exclude_domains,
|
||||
"include_answer": include_answer,
|
||||
"include_raw_content": include_raw_content,
|
||||
"include_images": include_images,
|
||||
}
|
||||
params["api_key"] = self.api_key
|
||||
if 'exclude_domains' in params and isinstance(params['exclude_domains'], str) and params['exclude_domains'] != 'None':
|
||||
params['exclude_domains'] = params['exclude_domains'].split()
|
||||
else:
|
||||
params['exclude_domains'] = []
|
||||
if 'include_domains' in params and isinstance(params['include_domains'], str) and params['include_domains'] != 'None':
|
||||
params['include_domains'] = params['include_domains'].split()
|
||||
else:
|
||||
params['include_domains'] = []
|
||||
|
||||
response = requests.post(f"{TAVILY_API_URL}/search", json=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def results(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 3,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[list[str]] = [],
|
||||
exclude_domains: Optional[list[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> list[dict]:
|
||||
def results(self, params: dict[str, Any]) -> list[dict]:
|
||||
"""
|
||||
Retrieves cleaned search results from the Tavily Search API.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
max_results (int, optional): The maximum number of results to retrieve. Defaults to 3.
|
||||
search_depth (str, optional): The search depth. Defaults to "advanced".
|
||||
include_domains (List[str], optional): The domains to include in the search. Defaults to [].
|
||||
exclude_domains (List[str], optional): The domains to exclude from the search. Defaults to [].
|
||||
include_answer (bool, optional): Whether to include answer in the search results. Defaults to False.
|
||||
include_raw_content (bool, optional): Whether to include raw content in the search results. Defaults to False.
|
||||
include_images (bool, optional): Whether to include images in the search results. Defaults to False.
|
||||
params (Dict[str, Any]): The search parameters.
|
||||
|
||||
Returns:
|
||||
list: The cleaned search results.
|
||||
|
||||
"""
|
||||
raw_search_results = self.raw_results(
|
||||
query,
|
||||
max_results=max_results,
|
||||
search_depth=search_depth,
|
||||
include_domains=include_domains,
|
||||
exclude_domains=exclude_domains,
|
||||
include_answer=include_answer,
|
||||
include_raw_content=include_raw_content,
|
||||
include_images=include_images,
|
||||
)
|
||||
raw_search_results = self.raw_results(params)
|
||||
return self.clean_results(raw_search_results["results"])
|
||||
|
||||
def clean_results(self, results: list[dict]) -> list[dict]:
|
||||
@ -149,13 +105,14 @@ class TavilySearchTool(BuiltinTool):
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily search tool invocation.
|
||||
"""
|
||||
query = tool_parameters.get("query", "")
|
||||
|
||||
api_key = self.runtime.credentials["tavily_api_key"]
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
tavily_search = TavilySearch(api_key)
|
||||
results = tavily_search.results(query)
|
||||
results = tavily_search.results(tool_parameters)
|
||||
print(results)
|
||||
if not results:
|
||||
return self.create_text_message(f"No results found for '{query}' in Tavily")
|
||||
else:
|
||||
return self.create_text_message(text=results)
|
||||
return self.create_text_message(text=results)
|
@ -25,3 +25,138 @@ parameters:
|
||||
pt_BR: used for searching
|
||||
llm_description: key words for searching
|
||||
form: llm
|
||||
- name: search_depth
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Search Depth
|
||||
zh_Hans: 搜索深度
|
||||
pt_BR: Search Depth
|
||||
human_description:
|
||||
en_US: The depth of search results
|
||||
zh_Hans: 搜索结果的深度
|
||||
pt_BR: The depth of search results
|
||||
form: form
|
||||
options:
|
||||
- value: basic
|
||||
label:
|
||||
en_US: Basic
|
||||
zh_Hans: 基本
|
||||
pt_BR: Basic
|
||||
- value: advanced
|
||||
label:
|
||||
en_US: Advanced
|
||||
zh_Hans: 高级
|
||||
pt_BR: Advanced
|
||||
default: basic
|
||||
- name: include_images
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Images
|
||||
zh_Hans: 包含图片
|
||||
pt_BR: Include Images
|
||||
human_description:
|
||||
en_US: Include images in the search results
|
||||
zh_Hans: 在搜索结果中包含图片
|
||||
pt_BR: Include images in the search results
|
||||
form: form
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
default: false
|
||||
- name: include_answer
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Answer
|
||||
zh_Hans: 包含答案
|
||||
pt_BR: Include Answer
|
||||
human_description:
|
||||
en_US: Include answers in the search results
|
||||
zh_Hans: 在搜索结果中包含答案
|
||||
pt_BR: Include answers in the search results
|
||||
form: form
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
default: false
|
||||
- name: include_raw_content
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Raw Content
|
||||
zh_Hans: 包含原始内容
|
||||
pt_BR: Include Raw Content
|
||||
human_description:
|
||||
en_US: Include raw content in the search results
|
||||
zh_Hans: 在搜索结果中包含原始内容
|
||||
pt_BR: Include raw content in the search results
|
||||
form: form
|
||||
options:
|
||||
- value: true
|
||||
label:
|
||||
en_US: Yes
|
||||
zh_Hans: 是
|
||||
pt_BR: Yes
|
||||
- value: false
|
||||
label:
|
||||
en_US: No
|
||||
zh_Hans: 否
|
||||
pt_BR: No
|
||||
default: false
|
||||
- name: max_results
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果
|
||||
pt_BR: Max Results
|
||||
human_description:
|
||||
en_US: The number of maximum search results to return
|
||||
zh_Hans: 返回的最大搜索结果数
|
||||
pt_BR: The number of maximum search results to return
|
||||
form: form
|
||||
min: 1
|
||||
max: 20
|
||||
default: 5
|
||||
- name: include_domains
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Include Domains
|
||||
zh_Hans: 包含域
|
||||
pt_BR: Include Domains
|
||||
human_description:
|
||||
en_US: A list of domains to specifically include in the search results
|
||||
zh_Hans: 在搜索结果中特别包含的域名列表
|
||||
pt_BR: A list of domains to specifically include in the search results
|
||||
form: form
|
||||
- name: exclude_domains
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Exclude Domains
|
||||
zh_Hans: 排除域
|
||||
pt_BR: Exclude Domains
|
||||
human_description:
|
||||
en_US: A list of domains to specifically exclude from the search results
|
||||
zh_Hans: 从搜索结果中特别排除的域名列表
|
||||
pt_BR: A list of domains to specifically exclude from the search results
|
||||
form: form
|
||||
|
@ -291,6 +291,16 @@ class ApiTool(Tool):
|
||||
elif property['type'] == 'null':
|
||||
if value is None:
|
||||
return None
|
||||
elif property['type'] == 'object':
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except ValueError:
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return value
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"Invalid type {property['type']} for property {property}")
|
||||
elif 'anyOf' in property and isinstance(property['anyOf'], list):
|
||||
|
@ -81,7 +81,7 @@ class ApiBasedToolSchemaParser:
|
||||
for content_type, content in request_body['content'].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if 'schema' not in content:
|
||||
content
|
||||
continue
|
||||
|
||||
if '$ref' in content['schema']:
|
||||
# get the reference
|
||||
|
@ -112,7 +112,7 @@ class CodeNode(BaseNode):
|
||||
variables[variable] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_code(
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=code_language,
|
||||
code=code,
|
||||
inputs=variables
|
||||
|
@ -438,7 +438,11 @@ class LLMNode(BaseNode):
|
||||
stop = model_config.stop
|
||||
|
||||
vision_enabled = node_data.vision.enabled
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
|
||||
if not isinstance(prompt_message.content, str):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content:
|
||||
@ -453,7 +457,13 @@ class LLMNode(BaseNode):
|
||||
and prompt_message_content[0].type == PromptMessageContentType.TEXT):
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
|
||||
return prompt_messages, stop
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
raise ValueError("No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding.")
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@ -26,6 +25,7 @@ from core.workflow.nodes.question_classifier.template_prompts import (
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2,
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3,
|
||||
)
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -64,7 +64,8 @@ class QuestionClassifierNode(LLMNode):
|
||||
)
|
||||
categories = [_class.name for _class in node_data.classes]
|
||||
try:
|
||||
result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
#result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
categories_result = result_text_json.get('categories', [])
|
||||
if categories_result:
|
||||
categories = categories_result
|
||||
|
@ -19,29 +19,33 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_1 = """
|
||||
{ "input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."],
|
||||
"categories": ["Customer Service", "Satisfaction", "Sales", "Product"],
|
||||
"classification_instructions": ["classify the text based on the feedback provided by customer"]}```JSON
|
||||
"classification_instructions": ["classify the text based on the feedback provided by customer"]}
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """
|
||||
```json
|
||||
{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],
|
||||
"categories": ["Customer Service"]}```
|
||||
"categories": ["Customer Service"]}
|
||||
```
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_2 = """
|
||||
{"input_text": ["bad service, slow to bring the food"],
|
||||
"categories": ["Food Quality", "Experience", "Price" ],
|
||||
"classification_instructions": []}```JSON
|
||||
"classification_instructions": []}
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """
|
||||
```json
|
||||
{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],
|
||||
"categories": ["Experience"]}```
|
||||
"categories": ["Experience"]}
|
||||
```
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_USER_PROMPT_3 = """
|
||||
'{{"input_text": ["{input_text}"],',
|
||||
'"categories": ["{categories}" ], ',
|
||||
'"classification_instructions": ["{classification_instructions}"]}}```JSON'
|
||||
'"classification_instructions": ["{classification_instructions}"]}}'
|
||||
"""
|
||||
|
||||
QUESTION_CLASSIFIER_COMPLETION_PROMPT = """
|
||||
|
@ -52,7 +52,7 @@ class TemplateTransformNode(BaseNode):
|
||||
variables[variable] = value
|
||||
# Run code
|
||||
try:
|
||||
result = CodeExecutor.execute_code(
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language='jinja2',
|
||||
code=node_data.template,
|
||||
inputs=variables
|
||||
|
@ -11,3 +11,6 @@ app_model_config_was_updated = signal('app-model-config-was-updated')
|
||||
|
||||
# sender: app, kwargs: published_workflow
|
||||
app_published_workflow_was_updated = signal('app-published-workflow-was-updated')
|
||||
|
||||
# sender: app, kwargs: synced_draft_workflow
|
||||
app_draft_workflow_was_synced = signal('app-draft-workflow-was-synced')
|
||||
|
@ -5,6 +5,7 @@ from .create_installed_app_when_app_created import handle
|
||||
from .create_site_record_when_app_created import handle
|
||||
from .deduct_quota_when_messaeg_created import handle
|
||||
from .delete_installed_app_when_app_deleted import handle
|
||||
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
|
||||
from .update_app_dataset_join_when_app_model_config_updated import handle
|
||||
from .update_provider_last_used_at_when_messaeg_created import handle
|
||||
from .update_app_dataset_join_when_app_published_workflow_updated import handle
|
||||
from .update_provider_last_used_at_when_messaeg_created import handle
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user