diff --git a/.devcontainer/post_start_command.sh b/.devcontainer/post_start_command.sh index e3d5a6d59d..56e87614ba 100755 --- a/.devcontainer/post_start_command.sh +++ b/.devcontainer/post_start_command.sh @@ -1,3 +1,3 @@ #!/bin/bash -poetry install -C api \ No newline at end of file +cd api && poetry install \ No newline at end of file diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index c87d5a4dd4..eb09abe77c 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -7,6 +7,7 @@ on: paths: - api/** - docker/** + - .github/workflows/api-tests.yml concurrency: group: api-tests-${{ github.head_ref || github.run_id }} @@ -27,16 +28,15 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install Poetry + uses: abatilo/actions-poetry@v3 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - cache-dependency-path: | - api/pyproject.toml - api/poetry.lock - - - name: Install Poetry - uses: abatilo/actions-poetry@v3 + cache: poetry + cache-dependency-path: api/poetry.lock - name: Check Poetry lockfile run: | @@ -67,7 +67,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@v2.0.0 + uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: | docker/docker-compose.middleware.yaml @@ -77,22 +77,3 @@ jobs: - name: Run Workflow run: poetry run -C api bash dev/pytest/pytest_workflow.sh - - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) - uses: hoverkraft-tech/compose-action@v2.0.0 - with: - compose-file: | - docker/docker-compose.yaml - services: | - weaviate - qdrant - couchbase-server - etcd - minio - milvus-standalone - pgvecto-rs - pgvector - chroma - elasticsearch - - name: Test Vector Stores - run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index 6daaaf5791..8e5279fb67 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -49,7 +49,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} @@ -114,7 +114,7 @@ jobs: merge-multiple: true - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index a33cdacd80..b8246aacb3 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -43,7 +43,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@v2.0.0 + uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 5c4ee18bc9..9377fa84f6 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -24,16 +24,16 @@ jobs: with: files: api/** + - name: Install Poetry + if: steps.changed-files.outputs.any_changed == 'true' + uses: abatilo/actions-poetry@v3 + - name: Set up Python uses: actions/setup-python@v5 if: steps.changed-files.outputs.any_changed == 'true' with: python-version: '3.10' - - name: Install Poetry - if: steps.changed-files.outputs.any_changed == 'true' - uses: abatilo/actions-poetry@v3 - - name: Python dependencies if: steps.changed-files.outputs.any_changed == 'true' run: poetry install -C api --only lint diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml new file mode 100644 index 0000000000..8ea38fde76 --- /dev/null +++ b/.github/workflows/vdb-tests.yml @@ -0,0 +1,75 @@ +name: Run VDB Tests + +on: + pull_request: + branches: + - main + paths: + - api/core/rag/datasource/** + - docker/** + - .github/workflows/vdb-tests.yml + +concurrency: + group: vdb-tests-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + name: VDB Tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.10" + - "3.11" + - "3.12" + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Poetry + uses: abatilo/actions-poetry@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: poetry + cache-dependency-path: api/poetry.lock + + - name: Check Poetry lockfile + run: | + poetry check -C api --lock + poetry show -C api + + - name: Install dependencies + run: poetry install -C api --with dev + + - name: Set up dotenvs + run: | + cp docker/.env.example docker/.env + cp docker/middleware.env.example docker/middleware.env + + - name: Expose Service Ports + run: sh .github/workflows/expose_service_ports.sh + + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) + uses: hoverkraft-tech/compose-action@v2.0.2 + with: + compose-file: | + docker/docker-compose.yaml + services: | + weaviate + qdrant + couchbase-server + etcd + minio + milvus-standalone + pgvecto-rs + pgvector + chroma + elasticsearch + + - name: Test Vector Stores + run: poetry run -C api bash dev/pytest/pytest_vdb.sh diff --git a/.gitignore b/.gitignore index cc1521c249..ddc393ee83 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,7 @@ docker/volumes/pgvector/data/* docker/volumes/pgvecto_rs/data/* docker/volumes/couchbase/* docker/volumes/oceanbase/* +!docker/volumes/oceanbase/init.d docker/nginx/conf.d/default.conf docker/nginx/ssl/* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8f57cd545e..da2928d189 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install. -Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot. +Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot. ### 5. Visit dify in your browser diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index 80e68a046e..a77239ff38 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt. -Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. +Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. ### 5. Truy cập Dify trong trình duyệt của bạn diff --git a/README.md b/README.md index cd783501e2..4779048001 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,33 @@

-Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features: -

+Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. +## Quick start +> Before installing Dify, make sure your machine meets the following minimum system requirements: +> +>- CPU >= 2 Core +>- RAM >= 4 GiB + +
+ +The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: + +```bash +cd dify +cd docker +cp .env.example .env +docker compose up -d +``` + +After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. + +#### Seeking help +Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues. + +> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) + +## Key features **1. Workflow**: Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. @@ -79,73 +103,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. -## Feature comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Features (SSO/Access control)
Local Deployment
- ## Using Dify - **Cloud
** @@ -167,28 +124,7 @@ Star Dify on GitHub and be instantly notified of new releases. ![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - -## Quick start -> Before installing Dify, make sure your machine meets the following minimum system requirements: -> ->- CPU >= 2 Core ->- RAM >= 4 GiB - -
- -The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: - -```bash -cd docker -cp .env.example .env -docker compose up -d -``` - -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. - -> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) - -## Next steps +## Advanced Setup If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). @@ -216,12 +152,6 @@ At the same time, please consider supporting Dify by sharing it on social media > We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). -**Contributors** - - - - - ## Community & contact * [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. @@ -229,6 +159,12 @@ At the same time, please consider supporting Dify by sharing it on social media * [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. * [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. +**Contributors** + + + + + ## Star history [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/api/.env.example b/api/.env.example index 79d6ffdf6a..a92490608f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -120,7 +120,8 @@ SUPABASE_URL=your-server-url WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash + +# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase VECTOR_STORE=weaviate # Weaviate configuration @@ -263,14 +264,20 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 +# Lindorm configuration +LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 +LINDORM_USERNAME=admin +LINDORM_PASSWORD=admin + # OceanBase Vector configuration OCEANBASE_VECTOR_HOST=127.0.0.1 OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test -OCEANBASE_VECTOR_PASSWORD= +OCEANBASE_VECTOR_PASSWORD=difyai123456 OCEANBASE_VECTOR_DATABASE=test OCEANBASE_MEMORY_LIMIT=6G + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 @@ -313,13 +320,21 @@ ETL_TYPE=dify UNSTRUCTURED_API_URL= UNSTRUCTURED_API_KEY= +#ssrf SSRF_PROXY_HTTP_URL= SSRF_PROXY_HTTPS_URL= SSRF_DEFAULT_MAX_RETRIES=3 +SSRF_DEFAULT_TIME_OUT= +SSRF_DEFAULT_CONNECT_TIME_OUT= +SSRF_DEFAULT_READ_TIME_OUT= +SSRF_DEFAULT_WRITE_TIME_OUT= BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 + # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 CODE_EXECUTION_API_KEY=dify-sandbox diff --git a/api/Dockerfile b/api/Dockerfile index f078181264..eb37303182 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \ + && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ diff --git a/api/README.md b/api/README.md index 92cd88a6d4..de2baee4c5 100644 --- a/api/README.md +++ b/api/README.md @@ -76,13 +76,13 @@ 1. Install dependencies for both the backend and the test environment ```bash - poetry install --with dev + poetry install -C api --with dev ``` 2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` ```bash - cd ../ poetry run -C api bash dev/pytest/pytest_all_tests.sh ``` + diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index a8a4170f67..3ac2c28c1f 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -10,7 +10,6 @@ from pydantic import ( PositiveInt, computed_field, ) -from pydantic_extra_types.timezone_name import TimeZoneName from pydantic_settings import BaseSettings from configs.feature.hosted_service import HostedServiceConfig @@ -110,7 +109,7 @@ class CodeExecutionSandboxConfig(BaseSettings): ) CODE_MAX_PRECISION: PositiveInt = Field( - description="mMaximum number of decimal places for floating-point numbers in code execution", + description="Maximum number of decimal places for floating-point numbers in code execution", default=20, ) @@ -217,6 +216,11 @@ class FileUploadConfig(BaseSettings): default=20, ) + WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a workflow upload operation", + default=10, + ) + class HttpConfig(BaseSettings): """ @@ -282,6 +286,26 @@ class HttpConfig(BaseSettings): default=None, ) + SSRF_DEFAULT_TIME_OUT: PositiveFloat = Field( + description="The default timeout period used for network requests (SSRF)", + default=5, + ) + + SSRF_DEFAULT_CONNECT_TIME_OUT: PositiveFloat = Field( + description="The default connect timeout period used for network requests (SSRF)", + default=5, + ) + + SSRF_DEFAULT_READ_TIME_OUT: PositiveFloat = Field( + description="The default read timeout period used for network requests (SSRF)", + default=5, + ) + + SSRF_DEFAULT_WRITE_TIME_OUT: PositiveFloat = Field( + description="The default write timeout period used for network requests (SSRF)", + default=5, + ) + RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field( description="Enable or disable the X-Forwarded-For Proxy Fix middleware from Werkzeug" " to respect X-* headers to redirect clients", @@ -340,9 +364,8 @@ class LoggingConfig(BaseSettings): default=None, ) - LOG_TZ: Optional[TimeZoneName] = Field( - description="Timezone for log timestamps. Allowed timezone values can be referred to IANA Time Zone Database," - " e.g., 'America/New_York')", + LOG_TZ: Optional[str] = Field( + description="Timezone for log timestamps (e.g., 'America/New_York')", default=None, ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 38bb804613..57cc805ebf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -16,9 +16,11 @@ from configs.middleware.storage.supabase_storage_config import SupabaseStorageCo from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig +from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.couchbase_config import CouchbaseConfig from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig +from configs.middleware.vdb.lindorm_config import LindormConfig from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig @@ -258,6 +260,8 @@ class MiddlewareConfig( VikingDBConfig, UpstashConfig, TidbOnQdrantConfig, + LindormConfig, OceanBaseVectorConfig, + BaiduVectorDBConfig, ): pass diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py new file mode 100644 index 0000000000..0f6c652806 --- /dev/null +++ b/api/configs/middleware/vdb/lindorm_config.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings + + +class LindormConfig(BaseSettings): + """ + Lindorm configs + """ + + LINDORM_URL: Optional[str] = Field( + description="Lindorm url", + default=None, + ) + LINDORM_USERNAME: Optional[str] = Field( + description="Lindorm user", + default=None, + ) + LINDORM_PASSWORD: Optional[str] = Field( + description="Lindorm password", + default=None, + ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 3dc87e3058..b5cb1f06d9 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.10.2", + default="0.11.0", ) COMMIT_SHA: str = Field( diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py new file mode 100644 index 0000000000..c71f1ce5a3 --- /dev/null +++ b/api/controllers/common/errors.py @@ -0,0 +1,6 @@ +from werkzeug.exceptions import HTTPException + + +class FilenameNotExistsError(HTTPException): + code = 400 + description = "The specified filename does not exist." diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py new file mode 100644 index 0000000000..79869916ed --- /dev/null +++ b/api/controllers/common/fields.py @@ -0,0 +1,24 @@ +from flask_restful import fields + +parameters__system_parameters = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, +} + +parameters_fields = { + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(parameters__system_parameters), +} diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py new file mode 100644 index 0000000000..2bae203712 --- /dev/null +++ b/api/controllers/common/helpers.py @@ -0,0 +1,97 @@ +import mimetypes +import os +import re +import urllib.parse +from collections.abc import Mapping +from typing import Any +from uuid import uuid4 + +import httpx +from pydantic import BaseModel + +from configs import dify_config + + +class FileInfo(BaseModel): + filename: str + extension: str + mimetype: str + size: int + + +def guess_file_info_from_response(response: httpx.Response): + url = str(response.url) + # Try to extract filename from URL + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + + # If filename couldn't be extracted, use Content-Disposition header + if not filename: + content_disposition = response.headers.get("Content-Disposition") + if content_disposition: + filename_match = re.search(r'filename="?(.+)"?', content_disposition) + if filename_match: + filename = filename_match.group(1) + + # If still no filename, generate a unique one + if not filename: + unique_name = str(uuid4()) + filename = f"{unique_name}" + + # Guess MIME type from filename first, then URL + mimetype, _ = mimetypes.guess_type(filename) + if mimetype is None: + mimetype, _ = mimetypes.guess_type(url) + if mimetype is None: + # If guessing fails, use Content-Type from response headers + mimetype = response.headers.get("Content-Type", "application/octet-stream") + + extension = os.path.splitext(filename)[1] + + # Ensure filename has an extension + if not extension: + extension = mimetypes.guess_extension(mimetype) or ".bin" + filename = f"{filename}{extension}" + + return FileInfo( + filename=filename, + extension=extension, + mimetype=mimetype, + size=int(response.headers.get("Content-Length", -1)), + ) + + +def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]): + return { + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + }, + } diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index c7282fcf14..8a5c2e5b8f 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,9 +2,21 @@ from flask import Blueprint from libs.external_api import ExternalApi +from .files import FileApi, FilePreviewApi, FileSupportTypeApi +from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi + bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) +# File +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") + +# Remote files +api.add_resource(RemoteFileInfoApi, "/remote-files/") +api.add_resource(RemoteFileUploadApi, "/remote-files/upload") + # Import other controllers from . import admin, apikey, extension, feature, ping, setup, version @@ -43,7 +55,6 @@ from .datasets import ( datasets_document, datasets_segments, external, - file, hit_testing, website, ) diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 35ac42a14c..9537708689 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -10,8 +10,7 @@ from models.dataset import Dataset from models.model import ApiToken, App from . import api -from .setup import setup_required -from .wraps import account_initialization_required +from .wraps import account_initialization_required, setup_required api_key_fields = { "id": fields.String, diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index e7346bdf1d..c228743fa5 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,8 +1,7 @@ from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index 51899da705..d433415894 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.helper import uuid_value from libs.login import login_required from models.model import AppMode diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 1ea1c82679..fd05cbc19b 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -6,8 +6,11 @@ from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.app.error import NoFileUploadedError from controllers.console.datasets.error import TooManyFilesError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_redis import redis_client from fields.annotation_fields import ( annotation_fields, diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 1b46a3a7d3..36338cbd8a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -6,8 +6,11 @@ from werkzeug.exceptions import BadRequest, Forbidden, abort from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.ops.ops_trace_manager import OpsTraceManager from fields.app_fields import ( app_detail_fields, diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index c1ef05a488..112446613f 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -18,8 +18,7 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.login import login_required diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index d3296d3dff..9896fcaab8 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -15,8 +15,7 @@ from controllers.console.app.error import ( ProviderQuotaExceededError, ) from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index b60a424d98..7b78f622b9 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -10,8 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index 23b234dac9..d49f433ba1 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -4,8 +4,7 @@ from sqlalchemy.orm import Session from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.conversation_variable_fields import paginated_conversation_variable_fields from libs.login import login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 7108759b0b..9c3cbe4e3e 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -10,8 +10,7 @@ from controllers.console.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.llm_generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index fe06201982..b7a4c31a15 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -14,8 +14,11 @@ from controllers.console.app.error import ( ) from controllers.console.app.wraps import get_app_model from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f5068a4cd8..8ba195f5a5 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -6,8 +6,7 @@ from flask_restful import Resource from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.agent.entities import AgentToolEntity from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 374bd2b815..47b58396a1 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.ops_service import OpsService diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 115a832da9..2f5645852f 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden, NotFound from constants.languages import supported_language from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields from libs.login import login_required diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 3ef442812d..db5e282409 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a8f601aeee..f7027fb226 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -9,8 +9,7 @@ import services from controllers.console import api from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from factories import variable_factory diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 629b7a8bf4..2940556f84 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,8 +3,7 @@ from flask_restful.inputs import int_range from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required from models import App diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 5824ead9c3..08ab61bbb9 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -3,8 +3,7 @@ from flask_restful.inputs import int_range from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.workflow_run_fields import ( advanced_chat_workflow_run_pagination_fields, workflow_run_detail_fields, diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index f46af0f1ca..6c7c73707b 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 50db6eebc1..465c44e9b6 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -7,8 +7,7 @@ from controllers.console.auth.error import ApiKeyAuthFailedError from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService -from ..setup import setup_required -from ..wraps import account_initialization_required +from ..wraps import account_initialization_required, setup_required class ApiKeyAuthDataSource(Resource): diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index fd31e5ccc3..3c3f45260a 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -11,8 +11,7 @@ from controllers.console import api from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from ..setup import setup_required -from ..wraps import account_initialization_required +from ..wraps import account_initialization_required, setup_required def get_oauth_providers(): diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 7fea610610..735edae5f6 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -13,7 +13,7 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 6c795f95b6..e2e8f84920 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -20,7 +20,7 @@ from controllers.console.error import ( NotAllowedCreateWorkspace, NotAllowedRegister, ) -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 9a1d914869..4b0c82ae6c 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -2,8 +2,7 @@ from flask_login import current_user from flask_restful import Resource, reqparse from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, only_edition_cloud +from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from libs.login import login_required from services.billing_service import BillingService diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index a2c9760782..ef1e87905a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -7,8 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 4f4d186edd..82163a32ee 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,8 +10,7 @@ from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType @@ -457,7 +456,7 @@ class DatasetIndexingEstimateApi(Resource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -621,6 +620,7 @@ class DatasetRetrievalSettingApi(Resource): case ( VectorType.MILVUS | VectorType.RELYT + | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT @@ -641,6 +641,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.ELASTICSEARCH | VectorType.PGVECTOR | VectorType.TIDB_ON_QDRANT + | VectorType.LINDORM | VectorType.COUCHBASE ): return { @@ -683,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.ELASTICSEARCH | VectorType.COUCHBASE | VectorType.PGVECTOR + | VectorType.LINDORM ): return { "retrieval_method": [ diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index cdabac491e..8e784dc70b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -24,8 +24,11 @@ from controllers.console.datasets.error import ( InvalidActionError, InvalidMetadataError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from core.errors.error import ( LLMBadRequestError, ModelCurrentlyNotSupportError, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 08ea414288..5d8d664e41 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -11,11 +11,11 @@ import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError -from controllers.console.setup import setup_required from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_resource_check, + setup_required, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 2dc054cfbd..bc6e3687c1 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -6,8 +6,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services from controllers.console import api from controllers.console.datasets.error import DatasetNameDuplicateError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields from libs.login import login_required from services.dataset_service import DatasetService diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 5c9bcef84c..495f511275 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -2,8 +2,7 @@ from flask_restful import Resource from controllers.console import api from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index e80ce17c68..9127c8af45 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required from services.website_service import WebsiteService diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index ed6a99a017..e0630ca66c 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -62,3 +62,27 @@ class EmailSendIpLimitError(BaseHTTPException): error_code = "email_send_ip_limit" description = "Too many emails have been sent from this IP address recently. Please try again later." code = 429 + + +class FileTooLargeError(BaseHTTPException): + error_code = "file_too_large" + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = "unsupported_file_type" + description = "File type not allowed." + code = 415 + + +class TooManyFilesError(BaseHTTPException): + error_code = "too_many_files" + description = "Only one file is allowed." + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = "no_file_uploaded" + description = "Please upload your file." + code = 400 diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 7c7580e3c6..fee52248a6 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource @@ -11,43 +12,14 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: @@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class ExploreAppMetaApi(InstalledAppResource): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 5d6a8bf152..4ac0aa497e 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -3,8 +3,7 @@ from flask_restful import Resource, marshal_with, reqparse from constants import HIDDEN_VALUE from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.api_based_extension_fields import api_based_extension_fields from libs.login import login_required from models.api_based_extension import APIBasedExtension diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index f0482f749d..70ab4ff865 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -5,8 +5,7 @@ from libs.login import login_required from services.feature_service import FeatureService from . import api -from .setup import setup_required -from .wraps import account_initialization_required, cloud_utm_record +from .wraps import account_initialization_required, cloud_utm_record, setup_required class FeatureApi(Resource): diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/files.py similarity index 57% rename from api/controllers/console/datasets/file.py rename to api/controllers/console/files.py index 17d2879875..946d3db37f 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/files.py @@ -1,25 +1,26 @@ -import urllib.parse - from flask import request from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse +from flask_restful import Resource, marshal_with import services from configs import dify_config from constants import DOCUMENT_EXTENSIONS -from controllers.console import api -from controllers.console.datasets.error import ( +from controllers.common.errors import FilenameNotExistsError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) +from fields.file_fields import file_fields, upload_config_fields +from libs.login import login_required +from services.file_service import FileService + +from .error import ( FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.helper import ssrf_proxy -from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields -from libs.login import login_required -from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -36,6 +37,7 @@ class FileApi(Resource): "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, }, 200 @setup_required @@ -44,21 +46,29 @@ class FileApi(Resource): @marshal_with(file_fields) @cloud_edition_billing_resource_check("documents") def post(self): - # get file from request file = request.files["file"] + source = request.form.get("source") - parser = reqparse.RequestParser() - parser.add_argument("source", type=str, required=False, location="args") - source = parser.parse_args().get("source") - - # check file if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if source not in ("datasets", None): + source = None + try: - upload_file = FileService.upload_file(file=file, user=current_user, source=source) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source=source, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: @@ -83,23 +93,3 @@ class FileSupportTypeApi(Resource): @account_initialization_required def get(self): return {"allowed_extensions": DOCUMENT_EXTENSIONS} - - -class RemoteFileInfoApi(Resource): - @marshal_with(remote_file_info_fields) - def get(self, url): - decoded_url = urllib.parse.unquote(url) - try: - response = ssrf_proxy.head(decoded_url) - return { - "file_type": response.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(response.headers.get("Content-Length", 0)), - } - except Exception as e: - return {"error": str(e)}, 400 - - -api.add_resource(FileApi, "/files/upload") -api.add_resource(FilePreviewApi, "/files//preview") -api.add_resource(FileSupportTypeApi, "/files/support-type") -api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py new file mode 100644 index 0000000000..9b899bef64 --- /dev/null +++ b/api/controllers/console/remote_files.py @@ -0,0 +1,81 @@ +import urllib.parse +from typing import cast + +import httpx +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse + +import services +from controllers.common import helpers +from core.file import helpers as file_helpers +from core.helper import ssrf_proxy +from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from models.account import Account +from services.file_service import FileService + +from .error import ( + FileTooLargeError, + UnsupportedFileTypeError, +) + + +class RemoteFileInfoApi(Resource): + @marshal_with(remote_file_info_fields) + def get(self, url): + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + # failed back to get method + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return { + "file_type": resp.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(resp.headers.get("Content-Length", 0)), + } + + +class RemoteFileUploadApi(Resource): + @marshal_with(file_fields_with_signed_url) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("url", type=str, required=True, help="URL is required") + args = parser.parse_args() + + url = args["url"] + + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3) + resp.raise_for_status() + + file_info = helpers.guess_file_info_from_response(resp) + + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError + + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + + try: + user = cast(Account, current_user) + upload_file = FileService.upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=user, + source_url=url, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 15a4af118b..e0b728d977 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,5 +1,3 @@ -from functools import wraps - from flask import request from flask_restful import Resource, reqparse @@ -10,7 +8,7 @@ from models.model import DifySetup from services.account_service import RegisterService, TenantService from . import api -from .error import AlreadySetupError, NotInitValidateError, NotSetupError +from .error import AlreadySetupError, NotInitValidateError from .init_validate import get_init_validate_status from .wraps import only_edition_self_hosted @@ -52,26 +50,10 @@ class SetupApi(Resource): return {"result": "success"}, 201 -def setup_required(view): - @wraps(view) - def decorated(*args, **kwargs): - # check setup - if not get_init_validate_status(): - raise NotInitValidateError() - - elif not get_setup_status(): - raise NotSetupError() - - return view(*args, **kwargs) - - return decorated - - def get_setup_status(): if dify_config.EDITION == "SELF_HOSTED": return DifySetup.query.first() - else: - return True + return True api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index de30547e93..ccd3293a62 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -4,8 +4,7 @@ from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from fields.tag_fields import tag_fields from libs.login import login_required from models.model import Tag diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index deda1a0d02..7dea8e554e 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -3,6 +3,7 @@ import logging import requests from flask_restful import Resource, reqparse +from packaging import version from configs import dify_config @@ -47,43 +48,15 @@ class VersionApi(Resource): def _has_new_version(*, latest_version: str, current_version: str) -> bool: - def parse_version(version: str) -> tuple: - # Split version into parts and pre-release suffix if any - parts = version.split("-") - version_parts = parts[0].split(".") - pre_release = parts[1] if len(parts) > 1 else None + try: + latest = version.parse(latest_version) + current = version.parse(current_version) - # Validate version format - if len(version_parts) != 3: - raise ValueError(f"Invalid version format: {version}") - - try: - # Convert version parts to integers - major, minor, patch = map(int, version_parts) - return (major, minor, patch, pre_release) - except ValueError: - raise ValueError(f"Invalid version format: {version}") - - latest = parse_version(latest_version) - current = parse_version(current_version) - - # Compare major, minor, and patch versions - for latest_part, current_part in zip(latest[:3], current[:3]): - if latest_part > current_part: - return True - elif latest_part < current_part: - return False - - # If versions are equal, check pre-release suffixes - if latest[3] is None and current[3] is not None: - return True - elif latest[3] is not None and current[3] is None: + # Compare versions + return latest > current + except version.InvalidVersion: + logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}") return False - elif latest[3] is not None and current[3] is not None: - # Simple string comparison for pre-release versions - return latest[3] > current[3] - - return False api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 97f5625726..aabc417759 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -8,14 +8,13 @@ from flask_restful import Resource, fields, marshal_with, reqparse from configs import dify_config from constants.languages import supported_language from controllers.console import api -from controllers.console.setup import setup_required from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 771a866624..d2b2092b75 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_user, login_required diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 3e87bebf59..8f694c65e0 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -4,8 +4,11 @@ from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields from libs.login import login_required diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 9e8a53bbfb..0e54126063 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -6,8 +6,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 3138a260b3..57443cc3b3 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -5,8 +5,7 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index aaa24d501c..daadb85d84 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required +from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import login_required diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 96f866fca2..76d76f6b58 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -6,6 +6,7 @@ from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqpa from werkzeug.exceptions import Unauthorized import services +from controllers.common.errors import FilenameNotExistsError from controllers.console import api from controllers.console.admin import admin_required from controllers.console.datasets.error import ( @@ -15,8 +16,11 @@ from controllers.console.datasets.error import ( UnsupportedFileTypeError, ) from controllers.console.error import AccountNotLinkTenantError -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_resource_check, + setup_required, +) from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required @@ -193,12 +197,20 @@ class WebappLogoWorkspaceApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + if not file.filename: + raise FilenameNotExistsError + extension = file.filename.split(".")[-1] if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: - upload_file = FileService.upload_file(file=file, user=current_user) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 46223d104f..9f294cb93c 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,4 +1,5 @@ import json +import os from functools import wraps from flask import abort, request @@ -6,9 +7,12 @@ from flask_login import current_user from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError +from models.model import DifySetup from services.feature_service import FeatureService from services.operation_service import OperationService +from .error import NotInitValidateError, NotSetupError + def account_initialization_required(view): @wraps(view) @@ -124,3 +128,17 @@ def cloud_utm_record(view): return view(*args, **kwargs) return decorated + + +def setup_required(view): + @wraps(view) + def decorated(*args, **kwargs): + # check setup + if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first(): + raise NotInitValidateError() + elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first(): + raise NotSetupError() + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index fee840b30d..99d32af593 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -1,6 +1,6 @@ from flask_restful import Resource, reqparse -from controllers.console.setup import setup_required +from controllers.console.wraps import setup_required from controllers.inner_api import api from controllers.inner_api.wraps import inner_api_only from events.tenant_event import tenant_was_created diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 9a4cdc26cd..88b13faa52 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,6 +1,7 @@ -from flask_restful import Resource, fields, marshal_with +from flask_restful import Resource, marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.service_api import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -11,40 +12,8 @@ from services.app_service import AppService class AppParameterApi(Resource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - @validate_app_token - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: @@ -56,43 +25,16 @@ class AppParameterApi(Resource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMetaApi(Resource): diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index e0a772eb31..b0fd8e65ef 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -2,6 +2,7 @@ from flask import request from flask_restful import Resource, marshal_with import services +from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ( FileTooLargeError, @@ -31,8 +32,16 @@ class FileApi(Resource): if len(request.files) > 1: raise TooManyFilesError() + if not file.filename: + raise FilenameNotExistsError + try: - upload_file = FileService.upload_file(file, end_user) + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=end_user, + ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0a0a38c4c6..5c3fc7b241 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,6 +6,7 @@ from sqlalchemy import desc from werkzeug.exceptions import NotFound import services.dataset_service +from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ( @@ -55,7 +56,12 @@ class DocumentAddByTextApi(DatasetApiResource): if not dataset.indexing_technique and not args["indexing_technique"]: raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get("text"), args.get("name")) + text = args.get("text") + name = args.get("name") + if text is None or name is None: + raise ValueError("Both 'text' and 'name' must be non-null values.") + + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -104,7 +110,11 @@ class DocumentUpdateByTextApi(DatasetApiResource): raise ValueError("Dataset is not exist.") if args["text"]: - upload_file = FileService.upload_text(args.get("text"), args.get("name")) + text = args.get("text") + name = args.get("name") + if text is None or name is None: + raise ValueError("Both text and name must be strings.") + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -163,7 +173,16 @@ class DocumentAddByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file, current_user) + if not file.filename: + raise FilenameNotExistsError + + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args @@ -212,7 +231,16 @@ class DocumentUpdateByFileApi(DatasetApiResource): if len(request.files) > 1: raise TooManyFilesError() - upload_file = FileService.upload_file(file, current_user) + if not file.filename: + raise FilenameNotExistsError + + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=current_user, + source="datasets", + ) data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} args["data_source"] = data_source # validate args @@ -331,10 +359,26 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data -api.add_resource(DocumentAddByTextApi, "/datasets//document/create_by_text") -api.add_resource(DocumentAddByFileApi, "/datasets//document/create_by_file") -api.add_resource(DocumentUpdateByTextApi, "/datasets//documents//update_by_text") -api.add_resource(DocumentUpdateByFileApi, "/datasets//documents//update_by_file") +api.add_resource( + DocumentAddByTextApi, + "/datasets//document/create_by_text", + "/datasets//document/create-by-text", +) +api.add_resource( + DocumentAddByFileApi, + "/datasets//document/create_by_file", + "/datasets//document/create-by-file", +) +api.add_resource( + DocumentUpdateByTextApi, + "/datasets//documents//update_by_text", + "/datasets//documents//update-by-text", +) +api.add_resource( + DocumentUpdateByFileApi, + "/datasets//documents//update_by_file", + "/datasets//documents//update-by-file", +) api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentListApi, "/datasets//documents") api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 9c9a4302c9..465f71bf03 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -14,4 +14,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): return self.perform_hit_testing(dataset, args) -api.add_resource(HitTestingApi, "/datasets//hit-testing") +api.add_resource(HitTestingApi, "/datasets//hit-testing", "/datasets//retrieve") diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 630b9468a7..50a04a6254 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,8 +2,17 @@ from flask import Blueprint from libs.external_api import ExternalApi +from .files import FileApi +from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi + bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) +# Files +api.add_resource(FileApi, "/files/upload") -from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow +# Remote files +api.add_resource(RemoteFileInfoApi, "/remote-files/") +api.add_resource(RemoteFileUploadApi, "/remote-files/upload") + +from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 974d2cff94..cc8255ccf4 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource @@ -11,39 +12,7 @@ from services.app_service import AppService class AppParameterApi(WebApiResource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: @@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMeta(WebApiResource): diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py deleted file mode 100644 index 6eeaa0e3f0..0000000000 --- a/api/controllers/web/file.py +++ /dev/null @@ -1,56 +0,0 @@ -import urllib.parse - -from flask import request -from flask_restful import marshal_with, reqparse - -import services -from controllers.web import api -from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError -from controllers.web.wraps import WebApiResource -from core.helper import ssrf_proxy -from fields.file_fields import file_fields, remote_file_info_fields -from services.file_service import FileService - - -class FileApi(WebApiResource): - @marshal_with(file_fields) - def post(self, app_model, end_user): - # get file from request - file = request.files["file"] - - parser = reqparse.RequestParser() - parser.add_argument("source", type=str, required=False, location="args") - source = parser.parse_args().get("source") - - # check file - if "file" not in request.files: - raise NoFileUploadedError() - - if len(request.files) > 1: - raise TooManyFilesError() - try: - upload_file = FileService.upload_file(file=file, user=end_user, source=source) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - return upload_file, 201 - - -class RemoteFileInfoApi(WebApiResource): - @marshal_with(remote_file_info_fields) - def get(self, url): - decoded_url = urllib.parse.unquote(url) - try: - response = ssrf_proxy.head(decoded_url) - return { - "file_type": response.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(response.headers.get("Content-Length", -1)), - } - except Exception as e: - return {"error": str(e)}, 400 - - -api.add_resource(FileApi, "/files/upload") -api.add_resource(RemoteFileInfoApi, "/remote-files/") diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py new file mode 100644 index 0000000000..a282fc63a8 --- /dev/null +++ b/api/controllers/web/files.py @@ -0,0 +1,43 @@ +from flask import request +from flask_restful import marshal_with + +import services +from controllers.common.errors import FilenameNotExistsError +from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError +from controllers.web.wraps import WebApiResource +from fields.file_fields import file_fields +from services.file_service import FileService + + +class FileApi(WebApiResource): + @marshal_with(file_fields) + def post(self, app_model, end_user): + file = request.files["file"] + source = request.form.get("source") + + if "file" not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + if not file.filename: + raise FilenameNotExistsError + + if source not in ("datasets", None): + source = None + + try: + upload_file = FileService.upload_file( + filename=file.filename, + content=file.read(), + mimetype=file.mimetype, + user=end_user, + source=source, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py new file mode 100644 index 0000000000..d6b8eb2855 --- /dev/null +++ b/api/controllers/web/remote_files.py @@ -0,0 +1,75 @@ +import urllib.parse + +import httpx +from flask_restful import marshal_with, reqparse + +import services +from controllers.common import helpers +from controllers.web.wraps import WebApiResource +from core.file import helpers as file_helpers +from core.helper import ssrf_proxy +from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from services.file_service import FileService + +from .error import FileTooLargeError, UnsupportedFileTypeError + + +class RemoteFileInfoApi(WebApiResource): + @marshal_with(remote_file_info_fields) + def get(self, app_model, end_user, url): + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + # failed back to get method + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return { + "file_type": resp.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(resp.headers.get("Content-Length", -1)), + } + + +class RemoteFileUploadApi(WebApiResource): + @marshal_with(file_fields_with_signed_url) + def post(self, app_model, end_user): # Add app_model and end_user parameters + parser = reqparse.RequestParser() + parser.add_argument("url", type=str, required=True, help="URL is required") + args = parser.parse_args() + + url = args["url"] + + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3) + resp.raise_for_status() + + file_info = helpers.guess_file_info_from_response(resp) + + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError + + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + + try: + upload_file = FileService.upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=end_user, + source_url=url, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError + + return { + "id": upload_file.id, + "name": upload_file.name, + "size": upload_file.size, + "extension": upload_file.extension, + "url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + "mime_type": upload_file.mime_type, + "created_by": upload_file.created_by, + "created_at": upload_file.created_at, + }, 201 diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 42beec2535..d0f75d0b75 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,8 +1,7 @@ from collections.abc import Mapping from typing import Any -from core.file.models import FileExtraConfig -from models import FileUploadConfig +from core.file import FileExtraConfig class FileUploadConfigManager: @@ -43,6 +42,6 @@ class FileUploadConfigManager: if not config.get("file_upload"): config["file_upload"] = {} else: - FileUploadConfig.model_validate(config["file_upload"]) + FileExtraConfig.model_validate(config["file_upload"]) return config, ["file_upload"] diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e4cb3f8527..1fc7ffe2c7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -20,6 +20,7 @@ from core.app.entities.queue_entities import ( QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 2707ada6cb..d8e38476c7 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -22,7 +22,10 @@ class BaseAppGenerator: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} + user_inputs = { + var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) + for var in variables + } user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} # Convert files in inputs to File entity_dictionary = {item.variable: item for item in app_config.variables} @@ -74,50 +77,66 @@ class BaseAppGenerator: return user_inputs - def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"): - user_input_value = inputs.get(var.variable) - if not user_input_value: - if var.required: - raise ValueError(f"{var.variable} is required in input form") - else: - return None + def _validate_inputs( + self, + *, + variable_entity: "VariableEntity", + value: Any, + ): + if value is None: + if variable_entity.required: + raise ValueError(f"{variable_entity.variable} is required in input form") + return value - if var.type in { + if variable_entity.type in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - } and not isinstance(user_input_value, str): - raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") - if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): + } and not isinstance(value, str): + raise ValueError( + f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" + ) + + if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): # may raise ValueError if user_input_value is not a valid number try: - if "." in user_input_value: - return float(user_input_value) + if "." in value: + return float(value) else: - return int(user_input_value) + return int(value) except ValueError: - raise ValueError(f"{var.variable} in input form must be a valid number") - if var.type == VariableEntityType.SELECT: - options = var.options - if user_input_value not in options: - raise ValueError(f"{var.variable} in input form must be one of the following: {options}") - elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: - if var.max_length and len(user_input_value) > var.max_length: - raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") - elif var.type == VariableEntityType.FILE: - if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File): - raise ValueError(f"{var.variable} in input form must be a file") - elif var.type == VariableEntityType.FILE_LIST: - if not ( - isinstance(user_input_value, list) - and ( - all(isinstance(item, dict) for item in user_input_value) - or all(isinstance(item, File) for item in user_input_value) - ) - ): - raise ValueError(f"{var.variable} in input form must be a list of files") + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") - return user_input_value + match variable_entity.type: + case VariableEntityType.SELECT: + if value not in variable_entity.options: + raise ValueError( + f"{variable_entity.variable} in input form must be one of the following: " + f"{variable_entity.options}" + ) + case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH: + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} " + "characters" + ) + case VariableEntityType.FILE: + if not isinstance(value, dict) and not isinstance(value, File): + raise ValueError(f"{variable_entity.variable} in input form must be a file") + case VariableEntityType.FILE_LIST: + # if number of files exceeds the limit, raise ValueError + if not ( + isinstance(value, list) + and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value)) + ): + raise ValueError(f"{variable_entity.variable} in input form must be a list of files") + + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" + ) + + return value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 419a5da806..d119d94a61 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ca23bbdd47..9a01e8a253 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -9,6 +9,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner): node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner): error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, in_iteration_id=event.in_iteration_id, ) ) + elif isinstance(event, NodeInIterationFailedEvent): + self._publish_event( + QueueNodeInIterationFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) elif isinstance(event, NodeRunStreamChunkEvent): self._publish_event( QueueTextChunkEvent( @@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner): index=event.index, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, output=event.pre_iteration_output, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bc43baf8a5..f1542ec5d8 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" - + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" start_at: datetime + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class QueueNodeSucceededEvent(AppQueueEvent): @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): error: Optional[str] = None +class QueueNodeInIterationFailedEvent(AppQueueEvent): + """ + QueueNodeInIterationFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent): inputs: Optional[dict[str, Any]] = None process_data: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 4b5f4716ed..7e9aad54be 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse): parent_parallel_id: Optional[str] = None parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None + parallel_run_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse): extras: dict = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 2abee5bef5..b89edf9079 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData @@ -251,6 +253,12 @@ class WorkflowCycleManage: workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value workflow_node_execution.created_by_role = workflow_run.created_by_role workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) session.add(workflow_node_execution) @@ -305,7 +313,9 @@ class WorkflowCycleManage: return workflow_node_execution - def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed( + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event @@ -318,16 +328,19 @@ class WorkflowCycleManage: outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(timezone.utc).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, + WorkflowNodeExecution.execution_metadata: execution_metadata, } ) @@ -342,6 +355,7 @@ class WorkflowCycleManage: workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.execution_metadata = execution_metadata self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) @@ -448,6 +462,7 @@ class WorkflowCycleManage: parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, + parallel_run_id=event.parallel_mode_run_id, ), ) @@ -464,7 +479,7 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: @@ -608,6 +623,7 @@ class WorkflowCycleManage: extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, ), ) @@ -633,7 +649,9 @@ class WorkflowCycleManage: created_at=int(time.time()), extras={}, inputs=event.inputs or {}, - status=WorkflowNodeExecutionStatus.SUCCEEDED, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 6793e41978..df812ca83f 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -12,6 +12,10 @@ SSRF_PROXY_ALL_URL = os.getenv("SSRF_PROXY_ALL_URL", "") SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "") SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "") SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3")) +SSRF_DEFAULT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_TIME_OUT", "5")) +SSRF_DEFAULT_CONNECT_TIME_OUT = float(os.getenv("SSRF_DEFAULT_CONNECT_TIME_OUT", "5")) +SSRF_DEFAULT_READ_TIME_OUT = float(os.getenv("SSRF_DEFAULT_READ_TIME_OUT", "5")) +SSRF_DEFAULT_WRITE_TIME_OUT = float(os.getenv("SSRF_DEFAULT_WRITE_TIME_OUT", "5")) proxy_mounts = ( { @@ -32,6 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "follow_redirects" not in kwargs: kwargs["follow_redirects"] = allow_redirects + if "timeout" not in kwargs: + kwargs["timeout"] = httpx.Timeout( + SSRF_DEFAULT_TIME_OUT, + connect=SSRF_DEFAULT_CONNECT_TIME_OUT, + read=SSRF_DEFAULT_READ_TIME_OUT, + write=SSRF_DEFAULT_WRITE_TIME_OUT, + ) + retries = 0 while retries <= max_retries: try: diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 8df26172b7..e2a94073cf 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -17,6 +17,7 @@ from core.errors.error import ProviderTokenNotInitError from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting @@ -597,26 +598,9 @@ class IndexingRunner: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} + document_text = CleanProcessor.clean(text, {"rules": rules}) - if "pre_processing_rules" in rules: - pre_processing_rules = rules["pre_processing_rules"] - for pre_processing_rule in pre_processing_rules: - if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: - # Remove extra spaces - pattern = r"\n{3,}" - text = re.sub(pattern, "\n\n", text) - pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}" - text = re.sub(pattern, " ", text) - elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: - # Remove email - pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" - text = re.sub(pattern, "", text) - - # Remove URL - pattern = r"https?://[^\s]+" - text = re.sub(pattern, "", text) - - return text + return document_text @staticmethod def format_split_text(text): diff --git a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml index aca9456313..b7b28a70d4 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml @@ -1,3 +1,4 @@ +- claude-3-5-haiku-20241022 - claude-3-5-sonnet-20241022 - claude-3-5-sonnet-20240620 - claude-3-haiku-20240307 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml new file mode 100644 index 0000000000..892146f6a5 --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml @@ -0,0 +1,38 @@ +model: claude-3-5-haiku-20241022 +label: + en_US: claude-3-5-haiku-20241022 +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '1.00' + output: '5.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 24657167dd..e61a9e0474 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -37,6 +37,17 @@ def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: return rule +def _get_o1_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: + rule = ParameterRule( + name="max_completion_tokens", + **PARAMETER_RULE_TEMPLATE[DefaultParameterName.MAX_TOKENS], + ) + rule.default = default + rule.min = min_val + rule.max = max_val + return rule + + class AzureBaseModel(BaseModel): base_model_name: str entity: AIModelEntity @@ -1098,14 +1109,6 @@ LLM_BASE_MODELS = [ ModelPropertyKey.CONTEXT_SIZE: 128000, }, parameter_rules=[ - ParameterRule( - name="temperature", - **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], - ), - ParameterRule( - name="top_p", - **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], - ), ParameterRule( name="response_format", label=I18nObject(zh_Hans="回复格式", en_US="response_format"), @@ -1116,7 +1119,7 @@ LLM_BASE_MODELS = [ required=False, options=["text", "json_object"], ), - _get_max_tokens(default=512, min_val=1, max_val=32768), + _get_o1_max_tokens(default=512, min_val=1, max_val=32768), ], pricing=PriceConfig( input=15.00, @@ -1143,14 +1146,6 @@ LLM_BASE_MODELS = [ ModelPropertyKey.CONTEXT_SIZE: 128000, }, parameter_rules=[ - ParameterRule( - name="temperature", - **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE], - ), - ParameterRule( - name="top_p", - **PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P], - ), ParameterRule( name="response_format", label=I18nObject(zh_Hans="回复格式", en_US="response_format"), @@ -1161,7 +1156,7 @@ LLM_BASE_MODELS = [ required=False, options=["text", "json_object"], ), - _get_max_tokens(default=512, min_val=1, max_val=65536), + _get_o1_max_tokens(default=512, min_val=1, max_val=65536), ], pricing=PriceConfig( input=3.00, diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 0000000000..9d693dcd48 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,60 @@ +model: anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 8192 + min: 1 + max: 8192 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 0000000000..9781965555 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,60 @@ +model: us.anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py index ca67594ce4..14aa811905 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py @@ -1,6 +1,7 @@ import logging -from core.model_runtime.entities.model_entities import ModelType +import requests + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider): :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. """ try: - model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials) + api_key = credentials.get("api_key") + if not api_key: + raise CredentialsValidateFailedError("Credentials validation failed: api_key not given") + + # send a get request to validate the credentials + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300)) + + if response.status_code != 200: + raise CredentialsValidateFailedError( + f"Credentials validation failed with status code {response.status_code}" + ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png new file mode 100644 index 0000000000..dfe8e78049 Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg new file mode 100644 index 0000000000..bb23bffcf1 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png new file mode 100644 index 0000000000..b154821db9 Binary files /dev/null and b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg new file mode 100644 index 0000000000..c5c608cd7c --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/_assets/icon_s_en.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.py b/api/core/model_runtime/model_providers/gpustack/gpustack.py new file mode 100644 index 0000000000..321100167e --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class GPUStackProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/gpustack/gpustack.yaml b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml new file mode 100644 index 0000000000..ee4a3c159a --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/gpustack.yaml @@ -0,0 +1,120 @@ +provider: gpustack +label: + en_US: GPUStack +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: endpoint_url + label: + zh_Hans: 服务器地址 + en_US: Server URL + type: text-input + required: true + placeholder: + zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100 + en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100 + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 输入您的 API Key + en_US: Enter your API Key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择补全类型 + en_US: Select completion type + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: "8192" + placeholder: + zh_Hans: 输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens_to_sample + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: "8192" + type: text-input + - variable: function_calling_type + show_on: + - variable: __model_type + value: llm + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: function_call + label: + en_US: Function Call + zh_Hans: Function Call + - value: tool_call + label: + en_US: Tool Call + zh_Hans: Tool Call + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - variable: vision_support + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: Vision 支持 + en_US: Vision Support + type: select + required: false + default: no_support + options: + - value: support + label: + en_US: Support + zh_Hans: 支持 + - value: no_support + label: + en_US: Not Support + zh_Hans: 不支持 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/__init__.py b/api/core/model_runtime/model_providers/gpustack/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gpustack/llm/llm.py b/api/core/model_runtime/model_providers/gpustack/llm/llm.py new file mode 100644 index 0000000000..ce6780b6a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/llm/llm.py @@ -0,0 +1,45 @@ +from collections.abc import Generator + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import ( + OAIAPICompatLargeLanguageModel, +) + + +class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: list[PromptMessageTool] | None = None, + stop: list[str] | None = None, + stream: bool = True, + user: str | None = None, + ) -> LLMResult | Generator: + return super()._invoke( + model, + credentials, + prompt_messages, + model_parameters, + tools, + stop, + stream, + user, + ) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") + credentials["mode"] = "chat" diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py b/api/core/model_runtime/model_providers/gpustack/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py new file mode 100644 index 0000000000..5ea7532564 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py @@ -0,0 +1,146 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, +) +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GPUStackRerankModel(RerankModel): + """ + Model class for GPUStack rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + endpoint_url = credentials["endpoint_url"] + headers = { + "Authorization": f"Bearer {credentials.get('api_key')}", + "Content-Type": "application/json", + } + + data = {"model": model, "query": query, "documents": docs, "top_n": top_n} + + try: + response = post( + str(URL(endpoint_url) / "v1" / "rerank"), + headers=headers, + data=dumps(data), + timeout=10, + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"]["text"] + else: + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py new file mode 100644 index 0000000000..eb324491a2 --- /dev/null +++ b/api/core/model_runtime/model_providers/gpustack/text_embedding/text_embedding.py @@ -0,0 +1,35 @@ +from typing import Optional + +from yarl import URL + +from core.entities.embedding_type import EmbeddingInputType +from core.model_runtime.entities.text_embedding_entities import ( + TextEmbeddingResult, +) +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) + + +class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel): + """ + Model class for GPUStack text embedding model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + texts: list[str], + user: Optional[str] = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> TextEmbeddingResult: + return super()._invoke(model, credentials, texts, user, input_type) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai") diff --git a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml index 1f94a8623b..8504b90eb3 100644 --- a/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml +++ b/api/core/model_runtime/model_providers/hunyuan/llm/hunyuan-standard-256k.yaml @@ -1,7 +1,7 @@ -model: hunyuan-standard-256k +model: hunyuan-standard-256K label: - zh_Hans: hunyuan-standard-256k - en_US: hunyuan-standard-256k + zh_Hans: hunyuan-standard-256K + en_US: hunyuan-standard-256K model_type: llm features: - agent-thought diff --git a/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-haiku.yaml b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-haiku.yaml new file mode 100644 index 0000000000..de45093a72 --- /dev/null +++ b/api/core/model_runtime/model_providers/openrouter/llm/claude-3-5-haiku.yaml @@ -0,0 +1,38 @@ +model: anthropic/claude-3-5-haiku +label: + en_US: claude-3-5-haiku +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: "1" + output: "5" + unit: "0.000001" + currency: USD diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml index 235156997f..6ad2c26cc8 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-11b-vision-instruct.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - vision model_properties: mode: chat context_size: 131072 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml index 5d597f00a2..c264db0f20 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.2-90b-vision-instruct.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - vision model_properties: mode: chat context_size: 131072 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index dd53914a69..5ff00f008e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -4,10 +4,7 @@ import re from collections.abc import Generator, Iterator from typing import Any, Optional, Union, cast -# from openai.types.chat import ChatCompletion, ChatCompletionChunk import boto3 -from sagemaker import Predictor, serializers -from sagemaker.session import Session from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -212,6 +209,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): :param user: unique user id :return: full response or stream response chunk generator result """ + from sagemaker import Predictor, serializers + from sagemaker.session import Session + if not self.sagemaker_session: access_key = credentials.get("aws_access_key_id") secret_key = credentials.get("aws_secret_access_key") diff --git a/api/core/model_runtime/model_providers/vessl_ai/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png new file mode 100644 index 0000000000..18ba350fa0 Binary files /dev/null and b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg new file mode 100644 index 0000000000..242f4e82b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/_assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py b/api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py new file mode 100644 index 0000000000..034c066ab5 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/llm/llm.py @@ -0,0 +1,83 @@ +from decimal import Decimal + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + features = [] + + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties={ + ModelPropertyKey.MODE: credentials.get("mode"), + }, + parameter_rules=[ + ParameterRule( + name=DefaultParameterName.TEMPERATURE.value, + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=float(credentials.get("temperature", 0.7)), + min=0, + max=2, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_P.value, + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + default=float(credentials.get("top_p", 1)), + min=0, + max=1, + precision=2, + ), + ParameterRule( + name=DefaultParameterName.TOP_K.value, + label=I18nObject(en_US="Top K"), + type=ParameterType.INT, + default=int(credentials.get("top_k", 50)), + min=-2147483647, + max=2147483647, + precision=0, + ), + ParameterRule( + name=DefaultParameterName.MAX_TOKENS.value, + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + default=512, + min=1, + max=int(credentials.get("max_tokens_to_sample", 4096)), + ), + ], + pricing=PriceConfig( + input=Decimal(credentials.get("input_price", 0)), + output=Decimal(credentials.get("output_price", 0)), + unit=Decimal(credentials.get("unit", 0)), + currency=credentials.get("currency", "USD"), + ), + ) + + if credentials["mode"] == "chat": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value + elif credentials["mode"] == "completion": + entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value + else: + raise ValueError(f"Unknown completion type {credentials['completion_type']}") + + return entity diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py new file mode 100644 index 0000000000..7a987c6710 --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.py @@ -0,0 +1,10 @@ +import logging + +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class VesslAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + pass diff --git a/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml new file mode 100644 index 0000000000..6052756cae --- /dev/null +++ b/api/core/model_runtime/model_providers/vessl_ai/vessl_ai.yaml @@ -0,0 +1,56 @@ +provider: vessl_ai +label: + en_US: vessl_ai +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.png +background: "#F1EFED" +help: + title: + en_US: How to deploy VESSL AI LLM Model Endpoint + url: + en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment +supported_model_types: + - llm +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + placeholder: + en_US: Enter your model name + credential_form_schemas: + - variable: endpoint_url + label: + en_US: endpoint url + type: text-input + required: true + placeholder: + en_US: Enter the url of your endpoint url + - variable: api_key + required: true + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your VESSL AI secret key + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + - value: chat + label: + en_US: Chat diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index 1a4cc15371..c77a499982 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -115,6 +115,7 @@ class _CommonWenxin: "ernie-character-8k-0321": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k", "ernie-4.0-turbo-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k", "ernie-4.0-turbo-8k-preview": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview", + "ernie-4.0-turbo-128k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-128k", "yi_34b_chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat", "embedding-v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1", "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-128k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-128k.yaml new file mode 100644 index 0000000000..f8d56406d9 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-4.0-turbo-128k.yaml @@ -0,0 +1,40 @@ +model: ernie-4.0-turbo-128k +label: + en_US: Ernie-4.0-turbo-128K +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + min: 0.1 + max: 1.0 + default: 0.8 + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 2 + max: 4096 + - name: presence_penalty + use_template: presence_penalty + default: 1.0 + min: 1.0 + max: 2.0 + - name: frequency_penalty + use_template: frequency_penalty + - name: response_format + use_template: response_format + - name: disable_search + label: + zh_Hans: 禁用搜索 + en_US: Disable Search + type: boolean + help: + zh_Hans: 禁用模型自行进行外部搜索。 + en_US: Disable the model to perform external search. + required: false diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg new file mode 100644 index 0000000000..f8b745cb13 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml new file mode 100644 index 0000000000..7c305735b9 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -0,0 +1,63 @@ +model: grok-beta +label: + en_US: Grok beta +model_type: llm +features: + - multi-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py new file mode 100644 index 0000000000..3f5325a857 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -0,0 +1,37 @@ +from collections.abc import Generator +from typing import Optional, Union + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py new file mode 100644 index 0000000000..e3f2b8eeba --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="grok-beta", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml new file mode 100644 index 0000000000..90d1cbfe7e --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.yaml @@ -0,0 +1,38 @@ +provider: x +label: + en_US: xAI +description: + en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. +icon_small: + en_US: x-ai-logo.svg +icon_large: + en_US: x-ai-logo.svg +help: + title: + en_US: Get your token from xAI + zh_Hans: 从 xAI 获取 token + url: + en_US: https://x.ai/api +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + default: https://api.x.ai/v1 + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 3affbd2d0a..57af05861c 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -34,6 +34,8 @@ class RetrievalService: reranking_mode: Optional[str] = "reranking_model", weights: Optional[dict] = None, ): + if not query: + return [] dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 1d4bfef76d..eb78e8aa69 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -3,11 +3,13 @@ import time import uuid from typing import Any +import numpy as np from pydantic import BaseModel, model_validator from pymochow import MochowClient from pymochow.auth.bce_credentials import BceCredentials from pymochow.configuration import Configuration -from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState +from pymochow.exception import ServerError +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row @@ -116,6 +118,7 @@ class BaiduVector(BaseVector): self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector] anns = AnnSearch( vector_field=self.field_vector, vector_floats=query_vector, @@ -149,7 +152,13 @@ class BaiduVector(BaseVector): return docs def delete(self) -> None: - self._db.drop_table(table_name=self._collection_name) + try: + self._db.drop_table(table_name=self._collection_name) + except ServerError as e: + if e.code == ServerErrCode.TABLE_NOT_EXIST: + pass + else: + raise def _init_client(self, config) -> MochowClient: config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) @@ -166,7 +175,14 @@ class BaiduVector(BaseVector): if exists: return self._client.database(self._client_config.database) else: - return self._client.create_database(database_name=self._client_config.database) + try: + self._client.create_database(database_name=self._client_config.database) + except ServerError as e: + if e.code == ServerErrCode.DB_ALREADY_EXIST: + pass + else: + raise + return def _table_existed(self) -> bool: tables = self._db.list_table() @@ -175,7 +191,7 @@ class BaiduVector(BaseVector): def _create_table(self, dimension: int) -> None: # Try to grab distributed lock and create table lock_name = "vector_indexing_lock_{}".format(self._collection_name) - with redis_client.lock(lock_name, timeout=20): + with redis_client.lock(lock_name, timeout=60): table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) if redis_client.get(table_exist_cache_key): return @@ -238,15 +254,14 @@ class BaiduVector(BaseVector): description="Table for Dify", ) + # Wait for table created + while True: + time.sleep(1) + table = self._db.describe_table(self._collection_name) + if table.state == TableState.NORMAL: + break redis_client.set(table_exist_cache_key, 1, ex=3600) - # Wait for table created - while True: - time.sleep(1) - table = self._db.describe_table(self._collection_name) - if table.state == TableState.NORMAL: - break - class BaiduVectorFactory(AbstractVectorFactory): def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: diff --git a/api/core/rag/datasource/vdb/lindorm/__init__.py b/api/core/rag/datasource/vdb/lindorm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py new file mode 100644 index 0000000000..abd8261a69 --- /dev/null +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -0,0 +1,498 @@ +import copy +import json +import logging +from collections.abc import Iterable +from typing import Any, Optional + +from opensearchpy import OpenSearch +from opensearchpy.helpers import bulk +from pydantic import BaseModel, model_validator +from tenacity import retry, stop_after_attempt, wait_fixed + +from configs import dify_config +from core.rag.datasource.vdb.field import Field +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logging.getLogger("lindorm").setLevel(logging.WARN) + + +class LindormVectorStoreConfig(BaseModel): + hosts: str + username: Optional[str] = None + password: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["hosts"]: + raise ValueError("config URL is required") + if not values["username"]: + raise ValueError("config USERNAME is required") + if not values["password"]: + raise ValueError("config PASSWORD is required") + return values + + def to_opensearch_params(self) -> dict[str, Any]: + params = { + "hosts": self.hosts, + } + if self.username and self.password: + params["http_auth"] = (self.username, self.password) + return params + + +class LindormVectorStore(BaseVector): + def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): + super().__init__(collection_name.lower()) + self._client_config = config + self._client = OpenSearch(**config.to_opensearch_params()) + self.kwargs = kwargs + + def get_type(self) -> str: + return VectorType.LINDORM + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self.create_collection(len(embeddings[0]), **kwargs) + self.add_texts(texts, embeddings) + + def refresh(self): + self._client.indices.refresh(index=self._collection_name) + + def __filter_existed_ids( + self, + texts: list[str], + metadatas: list[dict], + ids: list[str], + bulk_size: int = 1024, + ) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: + @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) + def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: + try: + existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) + return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} + except Exception as e: + logger.error(f"Error fetching batch {batch_ids}: {e}") + return set() + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) + def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: + try: + existing_docs = self._client.mget( + body={ + "docs": [ + {"_index": self._collection_name, "_id": id, "routing": routing} + for id, routing in zip(batch_ids, route_ids) + ] + }, + _source=False, + ) + return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} + except Exception as e: + logger.error(f"Error fetching batch {batch_ids}: {e}") + return set() + + if ids is None: + return texts, metadatas, ids + + if len(texts) != len(ids): + raise RuntimeError(f"texts {len(texts)} != {ids}") + + filtered_texts = [] + filtered_metadatas = [] + filtered_ids = [] + + def batch(iterable, n): + length = len(iterable) + for idx in range(0, length, n): + yield iterable[idx : min(idx + n, length)] + + for ids_batch, texts_batch, metadatas_batch in zip( + batch(ids, bulk_size), + batch(texts, bulk_size), + batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), + ): + existing_ids_set = __fetch_existing_ids(ids_batch) + for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): + if doc_id not in existing_ids_set: + filtered_texts.append(text) + filtered_ids.append(doc_id) + if metadatas is not None: + filtered_metadatas.append(metadata) + + return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + actions = [] + uuids = self._get_uuids(documents) + for i in range(len(documents)): + action = { + "_op_type": "index", + "_index": self._collection_name.lower(), + "_id": uuids[i], + "_source": { + Field.CONTENT_KEY.value: documents[i].page_content, + Field.VECTOR.value: embeddings[i], # Make sure you pass an array here + Field.METADATA_KEY.value: documents[i].metadata, + }, + } + actions.append(action) + bulk(self._client, actions) + self.refresh() + + def get_ids_by_metadata_field(self, key: str, value: str): + query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} + response = self._client.search(index=self._collection_name, body=query) + if response["hits"]["hits"]: + return [hit["_id"] for hit in response["hits"]["hits"]] + else: + return None + + def delete_by_metadata_field(self, key: str, value: str): + query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} + results = self._client.search(index=self._collection_name, body=query_str) + ids = [hit["_id"] for hit in results["hits"]["hits"]] + if ids: + self.delete_by_ids(ids) + + def delete_by_ids(self, ids: list[str]) -> None: + for id in ids: + if self._client.exists(index=self._collection_name, id=id): + self._client.delete(index=self._collection_name, id=id) + else: + logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") + + def delete(self) -> None: + try: + if self._client.indices.exists(index=self._collection_name): + self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) + logger.info("Delete index success") + else: + logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") + except Exception as e: + logger.error(f"Error occurred while deleting the index: {e}") + raise e + + def text_exists(self, id: str) -> bool: + try: + self._client.get(index=self._collection_name, id=id) + return True + except: + return False + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + # Make sure query_vector is a list + if not isinstance(query_vector, list): + raise ValueError("query_vector should be a list of floats") + + # Check whether query_vector is a floating-point number list + if not all(isinstance(x, float) for x in query_vector): + raise ValueError("All elements in query_vector should be floats") + + top_k = kwargs.get("top_k", 10) + query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) + try: + response = self._client.search(index=self._collection_name, body=query) + except Exception as e: + logger.error(f"Error executing search: {e}") + raise + + docs_and_scores = [] + for hit in response["hits"]["hits"]: + docs_and_scores.append( + ( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ), + hit["_score"], + ) + ) + docs = [] + for doc, score in docs_and_scores: + score_threshold = kwargs.get("score_threshold", 0.0) or 0.0 + if score > score_threshold: + doc.metadata["score"] = score + docs.append(doc) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + must = kwargs.get("must") + must_not = kwargs.get("must_not") + should = kwargs.get("should") + minimum_should_match = kwargs.get("minimum_should_match", 0) + top_k = kwargs.get("top_k", 10) + filters = kwargs.get("filter") + routing = kwargs.get("routing") + full_text_query = default_text_search_query( + query_text=query, + k=top_k, + text_field=Field.CONTENT_KEY.value, + must=must, + must_not=must_not, + should=should, + minimum_should_match=minimum_should_match, + filters=filters, + routing=routing, + ) + response = self._client.search(index=self._collection_name, body=full_text_query) + docs = [] + for hit in response["hits"]["hits"]: + docs.append( + Document( + page_content=hit["_source"][Field.CONTENT_KEY.value], + vector=hit["_source"][Field.VECTOR.value], + metadata=hit["_source"][Field.METADATA_KEY.value], + ) + ) + + return docs + + def create_collection(self, dimension: int, **kwargs): + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + logger.info(f"Collection {self._collection_name} already exists.") + return + if self._client.indices.exists(index=self._collection_name): + logger.info("{self._collection_name.lower()} already exists.") + return + if len(self.kwargs) == 0 and len(kwargs) != 0: + self.kwargs = copy.deepcopy(kwargs) + vector_field = kwargs.pop("vector_field", Field.VECTOR.value) + shards = kwargs.pop("shards", 2) + + engine = kwargs.pop("engine", "lvector") + method_name = kwargs.pop("method_name", "hnsw") + data_type = kwargs.pop("data_type", "float") + space_type = kwargs.pop("space_type", "cosinesimil") + + hnsw_m = kwargs.pop("hnsw_m", 24) + hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) + ivfpq_m = kwargs.pop("ivfpq_m", dimension) + nlist = kwargs.pop("nlist", 1000) + centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False) + centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24) + centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500) + centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100) + mapping = default_text_mapping( + dimension, + method_name, + shards=shards, + engine=engine, + data_type=data_type, + space_type=space_type, + vector_field=vector_field, + hnsw_m=hnsw_m, + hnsw_ef_construction=hnsw_ef_construction, + nlist=nlist, + ivfpq_m=ivfpq_m, + centroids_use_hnsw=centroids_use_hnsw, + centroids_hnsw_m=centroids_hnsw_m, + centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, + centroids_hnsw_ef_search=centroids_hnsw_ef_search, + **kwargs, + ) + self._client.indices.create(index=self._collection_name.lower(), body=mapping) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + # logger.info(f"create index success: {self._collection_name}") + + +def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: + routing_field = kwargs.get("routing_field") + excludes_from_source = kwargs.get("excludes_from_source") + analyzer = kwargs.get("analyzer", "ik_max_word") + text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) + engine = kwargs["engine"] + shard = kwargs["shards"] + space_type = kwargs["space_type"] + data_type = kwargs["data_type"] + vector_field = kwargs.get("vector_field", Field.VECTOR.value) + + if method_name == "ivfpq": + ivfpq_m = kwargs["ivfpq_m"] + nlist = kwargs["nlist"] + centroids_use_hnsw = True if nlist > 10000 else False + centroids_hnsw_m = 24 + centroids_hnsw_ef_construct = 500 + centroids_hnsw_ef_search = 100 + parameters = { + "m": ivfpq_m, + "nlist": nlist, + "centroids_use_hnsw": centroids_use_hnsw, + "centroids_hnsw_m": centroids_hnsw_m, + "centroids_hnsw_ef_construct": centroids_hnsw_ef_construct, + "centroids_hnsw_ef_search": centroids_hnsw_ef_search, + } + elif method_name == "hnsw": + neighbor = kwargs["hnsw_m"] + ef_construction = kwargs["hnsw_ef_construction"] + parameters = {"m": neighbor, "ef_construction": ef_construction} + elif method_name == "flat": + parameters = {} + else: + raise RuntimeError(f"unexpected method_name: {method_name}") + + mapping = { + "settings": {"index": {"number_of_shards": shard, "knn": True}}, + "mappings": { + "properties": { + vector_field: { + "type": "knn_vector", + "dimension": dimension, + "data_type": data_type, + "method": { + "engine": engine, + "name": method_name, + "space_type": space_type, + "parameters": parameters, + }, + }, + text_field: {"type": "text", "analyzer": analyzer}, + } + }, + } + + if excludes_from_source: + mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} + + if method_name == "ivfpq" and routing_field is not None: + mapping["settings"]["index"]["knn_routing"] = True + mapping["settings"]["index"]["knn.offline.construction"] = True + + if method_name == "flat" and routing_field is not None: + mapping["settings"]["index"]["knn_routing"] = True + + return mapping + + +def default_text_search_query( + query_text: str, + k: int = 4, + text_field: str = Field.CONTENT_KEY.value, + must: Optional[list[dict]] = None, + must_not: Optional[list[dict]] = None, + should: Optional[list[dict]] = None, + minimum_should_match: int = 0, + filters: Optional[list[dict]] = None, + routing: Optional[str] = None, + **kwargs, +) -> dict: + if routing is not None: + routing_field = kwargs.get("routing_field", "routing_field") + query_clause = { + "bool": { + "must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] + } + } + else: + query_clause = {"match": {text_field: query_text}} + # build the simplest search_query when only query_text is specified + if not must and not must_not and not should and not filters: + search_query = {"size": k, "query": query_clause} + return search_query + + # build complex search_query when either of must/must_not/should/filter is specified + if must: + if not isinstance(must, list): + raise RuntimeError(f"unexpected [must] clause with {type(filters)}") + if query_clause not in must: + must.append(query_clause) + else: + must = [query_clause] + + boolean_query = {"must": must} + + if must_not: + if not isinstance(must_not, list): + raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}") + boolean_query["must_not"] = must_not + + if should: + if not isinstance(should, list): + raise RuntimeError(f"unexpected [should] clause with {type(filters)}") + boolean_query["should"] = should + if minimum_should_match != 0: + boolean_query["minimum_should_match"] = minimum_should_match + + if filters: + if not isinstance(filters, list): + raise RuntimeError(f"unexpected [filter] clause with {type(filters)}") + boolean_query["filter"] = filters + + search_query = {"size": k, "query": {"bool": boolean_query}} + return search_query + + +def default_vector_search_query( + query_vector: list[float], + k: int = 4, + min_score: str = "0.0", + ef_search: Optional[str] = None, # only for hnsw + nprobe: Optional[str] = None, # "2000" + reorder_factor: Optional[str] = None, # "20" + client_refactor: Optional[str] = None, # "true" + vector_field: str = Field.VECTOR.value, + filters: Optional[list[dict]] = None, + filter_type: Optional[str] = None, + **kwargs, +) -> dict: + if filters is not None: + filter_type = "post_filter" if filter_type is None else filter_type + if not isinstance(filter, list): + raise RuntimeError(f"unexpected filter with {type(filters)}") + final_ext = {"lvector": {}} + if min_score != "0.0": + final_ext["lvector"]["min_score"] = min_score + if ef_search: + final_ext["lvector"]["ef_search"] = ef_search + if nprobe: + final_ext["lvector"]["nprobe"] = nprobe + if reorder_factor: + final_ext["lvector"]["reorder_factor"] = reorder_factor + if client_refactor: + final_ext["lvector"]["client_refactor"] = client_refactor + + search_query = { + "size": k, + "_source": True, # force return '_source' + "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}}, + } + + if filters is not None: + # when using filter, transform filter from List[Dict] to Dict as valid format + filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0] + search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict + if filter_type: + final_ext["lvector"]["filter_type"] = filter_type + + if final_ext != {"lvector": {}}: + search_query["ext"] = final_ext + return search_query + + +class LindormVectorStoreFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) + lindorm_config = LindormVectorStoreConfig( + hosts=dify_config.LINDORM_URL, + username=dify_config.LINDORM_USERNAME, + password=dify_config.LINDORM_PASSWORD, + ) + return LindormVectorStore(collection_name, lindorm_config) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 0cd2a46460..a6f3ad7fef 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -37,7 +37,7 @@ class TidbService: } spending_limit = { - "monthly": 100, + "monthly": dify_config.TIDB_SPEND_LIMIT, } password = str(uuid.uuid4()).replace("-", "")[:16] display_name = str(uuid.uuid4()).replace("-", "")[:16] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index c8cb007ae8..6d2e04fc02 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -134,6 +134,10 @@ class Vector: from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory return TidbOnQdrantVectorFactory + case VectorType.LINDORM: + from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory + + return LindormVectorStoreFactory case VectorType.OCEANBASE: from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index e3b37ece88..8e53e3ae84 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,6 +16,7 @@ class VectorType(str, Enum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + LINDORM = "lindorm" COUCHBASE = "couchbase" BAIDU = "baidu" VIKINGDB = "vikingdb" diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index ae3c25125c..d4434ea28f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -14,6 +14,7 @@ import requests from docx import Document as DocxDocument from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor): image_count += 1 if rel.is_external: url = rel.reltype - response = requests.get(url, stream=True) + response = ssrf_proxy.get(url, stream=True) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 40ebf0befd..fc82b2080b 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -27,18 +27,17 @@ class RerankModelRunner(BaseRerankRunner): :return: """ docs = [] - doc_id = [] + doc_id = set() unique_documents = [] - dify_documents = [item for item in documents if item.provider == "dify"] - external_documents = [item for item in documents if item.provider == "external"] - for document in dify_documents: - if document.metadata["doc_id"] not in doc_id: - doc_id.append(document.metadata["doc_id"]) + for document in documents: + if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: + doc_id.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) - for document in external_documents: - docs.append(document.page_content) - unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append(document.page_content) + unique_documents.append(document) documents = unique_documents diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index dd9371f70d..38123f125a 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from hmac import new as hmac_new from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any, Optional +from typing import Any from httpx import get, post from requests import get as requests_get @@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, from core.tools.tool.builtin_tool import BuiltinTool -class AIPPTGenerateTool(BuiltinTool): +class AIPPTGenerateToolAdapter: """ A tool for generating a ppt """ _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock: Optional[Lock] = None + + _api_token_cache_lock = Lock() + _style_cache_lock = Lock() _task = {} _task_type_map = { "auto": 1, "markdown": 7, } + _tool: BuiltinTool - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - self._api_token_cache_lock = Lock() - self._style_cache_lock = Lock() + def __init__(self, tool: BuiltinTool = None): + self._tool = tool def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ @@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool): """ title = tool_parameters.get("title", "") if not title: - return self.create_text_message("Please provide a title for the ppt") + return self._tool.create_text_message("Please provide a title for the ppt") model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message("Please provide a model for the ppt") + return self._tool.create_text_message("Please provide a model for the ppt") outline = tool_parameters.get("outline", "") @@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool): ) # get suit - color = tool_parameters.get("color") - style = tool_parameters.get("style") + color: str = tool_parameters.get("color") + style: str = tool_parameters.get("style") if color == "__default__": color_id = "" @@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool): # generate ppt _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message( + return self._tool.create_text_message( """the ppt has been created successfully,""" - f"""the ppt url is {ppt_url}""" + f"""the ppt url is {ppt_url} .""" """please give the ppt url to user and direct user to download it.""" ) @@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( str(self._api_base_url / "ai" / "chat" / "v2" / "task"), @@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool): headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) @@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool): headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) @@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( str(self._api_base_url / "design" / "v2" / "save"), headers=headers, data={"task_id": task_id, "template_id": suit_id}, + timeout=(10, 60), ) if response.status_code != 200: @@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool): return token - @classmethod - def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: + @staticmethod + def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 + key=secret_key.encode("utf-8"), + msg=f"GET@/api/grant/token/@{timestamp}".encode(), + digestmod=sha1, ).digest() ).decode("utf-8") @@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get( + "aippt_secret_key" + ): raise Exception("Please provide aippt credentials") - return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) + return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id) def _get_suit(self, style_id: int, colour_id: int) -> int: """ @@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"), } response = get( str(self._api_base_url / "template_component" / "suit" / "search"), @@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool): ], ), ] + + +class AIPPTGenerateTool(BuiltinTool): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) + + def get_runtime_parameters(self) -> list[ToolParameter]: + return AIPPTGenerateToolAdapter(self).get_runtime_parameters() + + @classmethod + def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: + return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id) diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index 209d6ecba4..dfa3fbea6a 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,5 +1,5 @@ import matplotlib.pyplot as plt -from matplotlib.font_manager import FontProperties +from matplotlib.font_manager import FontProperties, fontManager from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -17,9 +17,10 @@ def set_chinese_font(): ] for font in font_list: - chinese_font = FontProperties(font) - if chinese_font.get_name() == font: - return chinese_font + if font in fontManager.ttflist: + chinese_font = FontProperties(font) + if chinese_font.get_name() == font: + return chinese_font return FontProperties() diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 3a47c0cfc0..20ce5e138b 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -33,7 +33,9 @@ class BarChartTool(BuiltinTool): if axis: axis = [label[:10] + "..." if len(label) > 10 else label for label in axis] ax.set_xticklabels(axis, rotation=45, ha="right") - ax.bar(axis, data) + # ensure all labels, including duplicates, are correctly displayed + ax.bar(range(len(data)), data) + ax.set_xticks(range(len(data))) else: ax.bar(range(len(data)), data) diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py index d4bf713441..1aae7b2442 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -1,5 +1,3 @@ -import base64 -import io import json import random import uuid @@ -8,7 +6,7 @@ import httpx from websocket import WebSocket from yarl import URL -from core.file.file_manager import _get_encoded_string +from core.file.file_manager import download from core.file.models import File @@ -29,8 +27,7 @@ class ComfyUiClient: return response.content def upload_image(self, image_file: File) -> dict: - image_content = base64.b64decode(_get_encoded_string(image_file)) - file = io.BytesIO(image_content) + file = download(image_file) files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"} res = httpx.post(str(self.base_url / "upload/image"), files=files) return res.json() @@ -47,12 +44,7 @@ class ComfyUiClient: ws.connect(ws_address) return ws, client_id - def set_prompt( - self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = "" - ) -> dict: - """ - find the first KSampler, then can find the prompt node through it. - """ + def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict: prompt = origin_prompt.copy() id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] @@ -64,9 +56,20 @@ class ComfyUiClient: negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt - if image_name != "": - image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0] - prompt.get(image_loader)["inputs"]["image"] = image_name + return prompt + + def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict: + prompt = origin_prompt.copy() + for index, image_node_id in enumerate(image_ids): + prompt[image_node_id]["inputs"]["image"] = image_names[index] + return prompt + + def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict: + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"] + for load_image, image_name in zip(load_image_nodes, image_names): + prompt.get(load_image)["inputs"]["image"] = image_name return prompt def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py index 11320d5d0f..d62772cda7 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -1,28 +1,75 @@ import json from typing import Any +from core.file import FileType from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolParameterValidationError from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient from core.tools.tool.builtin_tool import BuiltinTool +def sanitize_json_string(s): + escape_dict = { + "\n": "\\n", + "\r": "\\r", + "\t": "\\t", + "\b": "\\b", + "\f": "\\f", + } + for char, escaped in escape_dict.items(): + s = s.replace(char, escaped) + + return s + + class ComfyUIWorkflowTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) - positive_prompt = tool_parameters.get("positive_prompt") - negative_prompt = tool_parameters.get("negative_prompt") + positive_prompt = tool_parameters.get("positive_prompt", "") + negative_prompt = tool_parameters.get("negative_prompt", "") + images = tool_parameters.get("images") or [] workflow = tool_parameters.get("workflow_json") - image_name = "" - if image := tool_parameters.get("image"): + image_names = [] + for image in images: + if image.type != FileType.IMAGE: + continue image_name = comfyui.upload_image(image).get("name") + image_names.append(image_name) + + set_prompt_with_ksampler = True + if "{{positive_prompt}}" in workflow: + set_prompt_with_ksampler = False + workflow = workflow.replace("{{positive_prompt}}", positive_prompt.replace('"', "'")) + workflow = workflow.replace("{{negative_prompt}}", negative_prompt.replace('"', "'")) try: - origin_prompt = json.loads(workflow) - except: - return self.create_text_message("the Workflow JSON is not correct") + prompt = json.loads(workflow) + except json.JSONDecodeError: + cleaned_string = sanitize_json_string(workflow) + try: + prompt = json.loads(cleaned_string) + except: + return self.create_text_message("the Workflow JSON is not correct") + + if set_prompt_with_ksampler: + try: + prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt) + except: + raise ToolParameterValidationError( + "Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json" + ) + + if image_names: + if image_ids := tool_parameters.get("image_ids"): + image_ids = image_ids.split(",") + try: + prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids) + except: + raise ToolParameterValidationError("the Image Node ID List not match your upload image files.") + else: + prompt = comfyui.set_prompt_images_by_default(prompt, image_names) - prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name) images = comfyui.generate_image_by_prompt(prompt) result = [] for img in images: diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml index 55fcdad825..dc4e0d77b2 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -24,12 +24,12 @@ parameters: zh_Hans: 负面提示词 llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. form: llm - - name: image - type: file + - name: images + type: files label: - en_US: Input Image + en_US: Input Images zh_Hans: 输入的图片 - llm_description: The input image, used to transfer to the comfyui workflow to generate another image. + llm_description: The input images, used to transfer to the comfyui workflow to generate another image. form: llm - name: workflow_json type: string @@ -40,3 +40,15 @@ parameters: en_US: exported from ComfyUI workflow zh_Hans: 从ComfyUI的工作流中导出 form: form + - name: image_ids + type: string + label: + en_US: Image Node ID List + zh_Hans: 图片节点ID列表 + placeholder: + en_US: Use commas to separate multiple node ID + zh_Hans: 多个节点ID时使用半角逗号分隔 + human_description: + en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list. + zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI + form: form diff --git a/api/core/tools/provider/builtin/gitee_ai/_assets/icon.svg b/api/core/tools/provider/builtin/gitee_ai/_assets/icon.svg new file mode 100644 index 0000000000..6dd75d1a6b --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/_assets/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/api/core/tools/provider/builtin/gitee_ai/gitee_ai.py b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.py new file mode 100644 index 0000000000..151cafec14 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.py @@ -0,0 +1,17 @@ +import requests + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class GiteeAIProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + url = "https://ai.gitee.com/api/base/account/me" + headers = { + "accept": "application/json", + "authorization": f"Bearer {credentials.get('api_key')}", + } + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("GiteeAI API key is invalid") diff --git a/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml new file mode 100644 index 0000000000..2e18f8a7fc --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml @@ -0,0 +1,22 @@ +identity: + author: Gitee AI + name: gitee_ai + label: + en_US: Gitee AI + zh_Hans: Gitee AI + description: + en_US: 快速体验大模型,领先探索 AI 开源世界 + zh_Hans: 快速体验大模型,领先探索 AI 开源世界 + icon: icon.svg + tags: + - image +credentials_for_provider: + api_key: + type: secret-input + required: true + label: + en_US: API Key + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + url: https://ai.gitee.com/dashboard/settings/tokens diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py new file mode 100644 index 0000000000..14291d1729 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py @@ -0,0 +1,33 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GiteeAITool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['api_key']}", + } + + payload = { + "inputs": tool_parameters.get("inputs"), + "width": tool_parameters.get("width", "720"), + "height": tool_parameters.get("height", "720"), + } + model = tool_parameters.get("model", "Kolors") + url = f"https://ai.gitee.com/api/serverless/{model}/text-to-image" + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + # The returned image is base64 and needs to be mark as an image + result = [self.create_blob_message(blob=response.content, meta={"mime_type": "image/jpeg"})] + + return result diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.yaml b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.yaml new file mode 100644 index 0000000000..5e03f9abe9 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.yaml @@ -0,0 +1,72 @@ +identity: + name: text to image + author: gitee_ai + label: + en_US: text to image + icon: icon.svg +description: + human: + en_US: generate images using a variety of popular models + llm: This tool is used to generate image from text. +parameters: + - name: model + type: select + required: true + options: + - value: flux-1-schnell + label: + en_US: flux-1-schnell + - value: Kolors + label: + en_US: Kolors + - value: stable-diffusion-3-medium + label: + en_US: stable-diffusion-3-medium + - value: stable-diffusion-xl-base-1.0 + label: + en_US: stable-diffusion-xl-base-1.0 + - value: stable-diffusion-v1-4 + label: + en_US: stable-diffusion-v1-4 + default: Kolors + label: + en_US: Choose Image Model + zh_Hans: 选择生成图片的模型 + form: form + - name: inputs + type: string + required: true + label: + en_US: Input Text + zh_Hans: 输入文本 + human_description: + en_US: The text input used to generate the image. + zh_Hans: 用于生成图片的输入文本。 + llm_description: This text input will be used to generate image. + form: llm + - name: width + type: number + required: true + default: 720 + min: 1 + max: 1024 + label: + en_US: Image Width + zh_Hans: 图片宽度 + human_description: + en_US: The width of the generated image. + zh_Hans: 生成图片的宽度。 + form: form + - name: height + type: number + required: true + default: 720 + min: 1 + max: 1024 + label: + en_US: Image Height + zh_Hans: 图片高度 + human_description: + en_US: The height of the generated image. + zh_Hans: 生成图片的高度。 + form: form diff --git a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py index 8c8dd9bf68..476e2d01e1 100644 --- a/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py +++ b/api/core/tools/provider/builtin/podcast_generator/tools/podcast_audio_generator.py @@ -1,15 +1,19 @@ import concurrent.futures import io import random +import warnings from typing import Any, Literal, Optional, Union import openai -from pydub import AudioSegment from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.builtin_tool import BuiltinTool +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from pydub import AudioSegment + class PodcastAudioGeneratorTool(BuiltinTool): @staticmethod diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index 2443991d57..1c7cb39c92 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -35,7 +35,8 @@ class VannaTool(BuiltinTool): password = tool_parameters.get("password", "") port = tool_parameters.get("port", 0) - vn = VannaDefault(model=model, api_key=api_key) + base_url = self.runtime.credentials.get("base_url", None) + vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url}) db_type = tool_parameters.get("db_type", "") if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: diff --git a/api/core/tools/provider/builtin/vanna/vanna.py b/api/core/tools/provider/builtin/vanna/vanna.py index 84724e921a..1d71414bf3 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.py +++ b/api/core/tools/provider/builtin/vanna/vanna.py @@ -1,4 +1,6 @@ +import re from typing import Any +from urllib.parse import urlparse from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vanna.tools.vanna import VannaTool @@ -6,7 +8,26 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class VannaProvider(BuiltinToolProviderController): + def _get_protocol_and_main_domain(self, url): + parsed_url = urlparse(url) + protocol = parsed_url.scheme + hostname = parsed_url.hostname + port = f":{parsed_url.port}" if parsed_url.port else "" + + # Check if the hostname is an IP address + is_ip = re.match(r"^\d{1,3}(\.\d{1,3}){3}$", hostname) is not None + + # Return the full hostname (with port if present) for IP addresses, otherwise return the main domain + main_domain = f"{hostname}{port}" if is_ip else ".".join(hostname.split(".")[-2:]) + port + return f"{protocol}://{main_domain}" + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + base_url = credentials.get("base_url") + if not base_url: + base_url = "https://ask.vanna.ai/rpc" + else: + base_url = base_url.removesuffix("/") + credentials["base_url"] = base_url try: VannaTool().fork_tool_runtime( runtime={ @@ -17,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController): tool_parameters={ "model": "chinook", "db_type": "SQLite", - "url": "https://vanna.ai/Chinook.sqlite", + "url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite', "query": "What are the top 10 customers by sales?", }, ) diff --git a/api/core/tools/provider/builtin/vanna/vanna.yaml b/api/core/tools/provider/builtin/vanna/vanna.yaml index 7f953be172..cf3fdca562 100644 --- a/api/core/tools/provider/builtin/vanna/vanna.yaml +++ b/api/core/tools/provider/builtin/vanna/vanna.yaml @@ -26,3 +26,10 @@ credentials_for_provider: en_US: Get your API key from Vanna.AI zh_Hans: 从 Vanna.AI 获取你的 API key url: https://vanna.ai/account/profile + base_url: + type: text-input + required: false + label: + en_US: Vanna.AI Endpoint Base URL + placeholder: + en_US: https://ask.vanna.ai/rpc diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 63f7775164..6abe0a9cba 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -3,7 +3,7 @@ import logging import mimetypes from collections.abc import Generator from os import listdir, path -from threading import Lock +from threading import Lock, Thread from typing import Any, Optional, Union from configs import dify_config @@ -647,4 +647,5 @@ class ToolManager: raise ValueError(f"provider type {provider_type} not found") -ToolManager.load_builtin_providers_cache() +# preload builtin tool providers +Thread(target=ToolManager.load_builtin_providers_cache, name="pre_load_builtin_providers_cache", daemon=True).start() diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 0131bb342b..7e10cddc71 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum): PARALLEL_START_NODE_ID = "parallel_start_node_id" PARENT_PARALLEL_ID = "parent_parallel_id" PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" + PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 86d89e0a32..bacea191dd 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent): class NodeRunStartedEvent(BaseNodeEvent): predecessor_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None """predecessor node id""" @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeInIterationFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + ########################################### # Parallel Branch Events ########################################### @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class IterationRunStartedEvent(BaseIterationEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8f58af00ef..f07ad4de11 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait +from copy import copy, deepcopy from typing import Any, Optional from flask import Flask, current_app @@ -724,6 +725,16 @@ class GraphEngine: """ return time.perf_counter() - start_at > max_execution_time + def create_copy(self): + """ + create a graph engine copy + :return: with a new variable pool instance of graph engine + """ + new_instance = copy(self) + new_instance.graph_runtime_state = copy(self.graph_runtime_state) + new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) + return new_instance + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9d7d9027c3..de70af58dd 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus +from .exc import ( + CodeNodeError, + DepthLimitError, + OutputValidationError, +) + class CodeNode(BaseNode[CodeNodeData]): _node_data_cls = CodeNodeData @@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]): # Transform result result = self._transform_result(result, self.node_data.outputs) - except (CodeExecutionError, ValueError) as e: + except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) @@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{variable}` must be" f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" ) @@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError( + raise OutputValidationError( f"Output variable `{variable}` is out of range," f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." ) @@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(value, float): # raise error if precision is too high if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError( + raise OutputValidationError( f"Output variable `{variable}` has too high precision," f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." ) @@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]): :return: """ if depth > dify_config.CODE_MAX_DEPTH: - raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") + raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]): depth=depth + 1, ) else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}.{output_name} is not a valid array." f" make sure all elements are of the same type." ) elif output_value is None: pass else: - raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") return result @@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]): for output_name, output_config in output_schema.items(): dot = "." if prefix else "" if output_name not in result: - raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") if output_config.type == "object": # check if output is object @@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an object," f" got {type(result.get(output_name))} instead." ) @@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) @@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) @@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) @@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: pass else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name}[{i}] is not an object," f" got {type(value)} instead at index {i}." ) @@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]): for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f"Output type {output_config.type} is not supported.") + raise OutputValidationError(f"Output type {output_config.type} is not supported.") parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError("Not all output parameters are validated.") + raise CodeNodeError("Not all output parameters are validated.") return transformed_result diff --git a/api/core/workflow/nodes/code/exc.py b/api/core/workflow/nodes/code/exc.py new file mode 100644 index 0000000000..d6334fd554 --- /dev/null +++ b/api/core/workflow/nodes/code/exc.py @@ -0,0 +1,16 @@ +class CodeNodeError(ValueError): + """Base class for code node errors.""" + + pass + + +class OutputValidationError(CodeNodeError): + """Raised when there is an output validation error.""" + + pass + + +class DepthLimitError(CodeNodeError): + """Raised when the depth limit is reached.""" + + pass diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/core/workflow/nodes/document_extractor/exc.py index c9d4bb8ef6..5caf00ebc5 100644 --- a/api/core/workflow/nodes/document_extractor/exc.py +++ b/api/core/workflow/nodes/document_extractor/exc.py @@ -1,4 +1,4 @@ -class DocumentExtractorError(Exception): +class DocumentExtractorError(ValueError): """Base exception for errors related to the DocumentExtractorNode.""" diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 9e09b6d29a..c90017d5e1 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,12 +5,15 @@ import json import docx import pandas as pd import pypdfium2 +import yaml +from unstructured.partition.api import partition_via_api from unstructured.partition.email import partition_email from unstructured.partition.epub import partition_epub from unstructured.partition.msg import partition_msg from unstructured.partition.ppt import partition_ppt from unstructured.partition.pptx import partition_pptx +from configs import dify_config from core.file import File, FileTransferMethod, file_manager from core.helper import ssrf_proxy from core.variables import ArrayFileSegment @@ -101,6 +104,8 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: return _extract_text_from_msg(file_content) case "application/json": return _extract_text_from_json(file_content) + case "application/x-yaml" | "text/yaml": + return _extract_text_from_yaml(file_content) case _: raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") @@ -112,6 +117,8 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) return _extract_text_from_plain_text(file_content) case ".json": return _extract_text_from_json(file_content) + case ".yaml" | ".yml": + return _extract_text_from_yaml(file_content) case ".pdf": return _extract_text_from_pdf(file_content) case ".doc" | ".docx": @@ -149,6 +156,15 @@ def _extract_text_from_json(file_content: bytes) -> str: raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e +def _extract_text_from_yaml(file_content: bytes) -> str: + """Extract the content from yaml file""" + try: + yaml_data = yaml.safe_load_all(file_content.decode("utf-8")) + return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) + except (UnicodeDecodeError, yaml.YAMLError) as e: + raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e + + def _extract_text_from_pdf(file_content: bytes) -> str: try: pdf_file = io.BytesIO(file_content) @@ -182,10 +198,8 @@ def _download_file_content(file: File) -> bytes: response = ssrf_proxy.get(file.remote_url) response.raise_for_status() return response.content - elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - return file_manager.download(file) else: - raise ValueError(f"Unsupported transfer method: {file.transfer_method}") + return file_manager.download(file) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e @@ -249,7 +263,14 @@ def _extract_text_from_ppt(file_content: bytes) -> str: def _extract_text_from_pptx(file_content: bytes) -> str: try: with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) + if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY: + elements = partition_via_api( + file=file, + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY, + ) + else: + elements = partition_pptx(file=file) return "\n".join([getattr(element, "text", "") for element in elements]) except Exception as e: raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py new file mode 100644 index 0000000000..7a5ab7dbc1 --- /dev/null +++ b/api/core/workflow/nodes/http_request/exc.py @@ -0,0 +1,18 @@ +class HttpRequestNodeError(ValueError): + """Custom error for HTTP request node.""" + + +class AuthorizationConfigError(HttpRequestNodeError): + """Raised when authorization config is missing or invalid.""" + + +class FileFetchError(HttpRequestNodeError): + """Raised when a file cannot be fetched.""" + + +class InvalidHttpMethodError(HttpRequestNodeError): + """Raised when an invalid HTTP method is used.""" + + +class ResponseSizeError(HttpRequestNodeError): + """Raised when the response size exceeds the allowed threshold.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 6872478299..d90dfcc766 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -18,6 +18,12 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) +from .exc import ( + AuthorizationConfigError, + FileFetchError, + InvalidHttpMethodError, + ResponseSizeError, +) BODY_TYPE_TO_CONTENT_TYPE = { "json": "application/json", @@ -51,7 +57,7 @@ class Executor: # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": if node_data.authorization.config is None: - raise ValueError("authorization config is required") + raise AuthorizationConfigError("authorization config is required") node_data.authorization.config.api_key = variable_pool.convert_template( node_data.authorization.config.api_key ).text @@ -82,8 +88,10 @@ class Executor: self.url = self.variable_pool.convert_template(self.node_data.url).text def _init_params(self): - params = self.variable_pool.convert_template(self.node_data.params).text - self.params = _plain_text_to_dict(params) + params = _plain_text_to_dict(self.node_data.params) + for key in params: + params[key] = self.variable_pool.convert_template(params[key]).text + self.params = params def _init_headers(self): headers = self.variable_pool.convert_template(self.node_data.headers).text @@ -116,7 +124,7 @@ class Executor: file_selector = data[0].file file_variable = self.variable_pool.get_file(file_selector) if file_variable is None: - raise ValueError(f"cannot fetch file with selector {file_selector}") + raise FileFetchError(f"cannot fetch file with selector {file_selector}") file = file_variable.value self.content = file_manager.download(file) case "x-www-form-urlencoded": @@ -155,12 +163,12 @@ class Executor: headers = deepcopy(self.headers) or {} if self.auth.type == "api-key": if self.auth.config is None: - raise ValueError("self.authorization config is required") + raise AuthorizationConfigError("self.authorization config is required") if authorization.config is None: - raise ValueError("authorization config is required") + raise AuthorizationConfigError("authorization config is required") if self.auth.config.api_key is None: - raise ValueError("api_key is required") + raise AuthorizationConfigError("api_key is required") if not authorization.config.header: authorization.config.header = "Authorization" @@ -183,7 +191,7 @@ class Executor: else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE ) if executor_response.size > threshold_size: - raise ValueError( + raise ResponseSizeError( f'{"File" if executor_response.is_file else "Text"} size is too large,' f' max size is {threshold_size / 1024 / 1024:.2f} MB,' f' but current size is {executor_response.readable_size}.' @@ -196,7 +204,7 @@ class Executor: do http request depending on api bundle """ if self.method not in {"get", "head", "post", "put", "delete", "patch"}: - raise ValueError(f"Invalid http method {self.method}") + raise InvalidHttpMethodError(f"Invalid http method {self.method}") request_args = { "url": self.url, diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a037bee665..61c661e587 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -20,6 +20,7 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) +from .exc import HttpRequestNodeError HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): "request": http_executor.to_log(), }, ) - except Exception as e: + except HttpRequestNodeError as e: logger.warning(f"http request node {self.node_id} failed to run: {e}") return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 4afc870e50..ebcb6f82fb 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Optional from pydantic import Field @@ -5,6 +6,12 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +class ErrorHandleMode(str, Enum): + TERMINATED = "terminated" + CONTINUE_ON_ERROR = "continue-on-error" + REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" + + class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData): parent_loop_id: Optional[str] = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + is_parallel: bool = False # open the parallel mode or not + parallel_nums: int = 10 # the numbers of parallel + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error class IterationStartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/iteration/exc.py b/api/core/workflow/nodes/iteration/exc.py new file mode 100644 index 0000000000..d9947e09bc --- /dev/null +++ b/api/core/workflow/nodes/iteration/exc.py @@ -0,0 +1,22 @@ +class IterationNodeError(ValueError): + """Base class for iteration node errors.""" + + +class IteratorVariableNotFoundError(IterationNodeError): + """Raised when the iterator variable is not found.""" + + +class InvalidIteratorValueError(IterationNodeError): + """Raised when the iterator value is invalid.""" + + +class StartNodeIdNotFoundError(IterationNodeError): + """Raised when the start node ID is not found.""" + + +class IterationGraphNotFoundError(IterationNodeError): + """Raised when the iteration graph is not found.""" + + +class IterationIndexNotFoundError(IterationNodeError): + """Raised when the iteration index is not found.""" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index af79da9215..e1d2b88360 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,12 +1,20 @@ import logging +import uuid from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, wait from datetime import datetime, timezone -from typing import Any, cast +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any, Optional, cast + +from flask import Flask, current_app from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.variables import IntegerSegment -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeRunResult, +) +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) @@ -24,9 +35,20 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from models.workflow import WorkflowNodeExecutionStatus +from .exc import ( + InvalidIteratorValueError, + IterationGraphNotFoundError, + IterationIndexNotFoundError, + IterationNodeError, + IteratorVariableNotFoundError, + StartNodeIdNotFoundError, +) + +if TYPE_CHECKING: + from core.workflow.graph_engine.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -38,6 +60,17 @@ class IterationNode(BaseNode[IterationNodeData]): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "type": "iteration", + "config": { + "is_parallel": False, + "parallel_nums": 10, + "error_handle_mode": ErrorHandleMode.TERMINATED.value, + }, + } + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -45,7 +78,7 @@ class IterationNode(BaseNode[IterationNodeData]): iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) if not iterator_list_segment: - raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") + raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") if len(iterator_list_segment.value) == 0: yield RunCompletedEvent( @@ -59,14 +92,14 @@ class IterationNode(BaseNode[IterationNodeData]): iterator_list_value = iterator_list_segment.to_object() if not isinstance(iterator_list_value, list): - raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") + raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") inputs = {"iterator_selector": iterator_list_value} graph_config = self.graph_config if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in iteration {self.node_id} not found") + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self.node_data.start_node_id @@ -74,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) if not iteration_graph: - raise ValueError("iteration graph not found") + raise IterationGraphNotFoundError("iteration graph not found") variable_pool = self.graph_runtime_state.variable_pool @@ -83,7 +116,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine - from core.workflow.graph_engine.graph_engine import GraphEngine + from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool graph_engine = GraphEngine( tenant_id=self.tenant_id, @@ -123,108 +156,64 @@ class IterationNode(BaseNode[IterationNodeData]): index=0, pre_iteration_output=None, ) - outputs: list[Any] = [] try: - for _ in range(len(iterator_list_value)): - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id - - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): + if self.node_data.is_parallel: + futures: list[Future] = [] + q = Queue() + thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + for index, item in enumerate(iterator_list_value): + future: Future = thread_pool.submit( + self._run_single_iter_parallel, + current_app._get_current_object(), + q, + iterator_list_value, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, + index, + item, + ) + future.add_done_callback(thread_pool.task_done_callback) + futures.append(future) + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + if isinstance(event, IterationRunNextEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + yield event + if isinstance(event, RunCompletedEvent): + q.put(None) + for f in futures: + if not f.done(): + f.cancel() + yield event + if isinstance(event, IterationRunFailedEvent): + q.put(None) + yield event + except Empty: continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} - - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(index_variable, IntegerSegment): - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Invalid index variable type: {type(index_variable)}", - ) - ) - return - metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value - event.route_node_state.node_run_result.metadata = metadata - - yield event - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=event.error, - ) - ) - return - else: - event = cast(InNodeEvent, event) - yield event - - # append to iteration output variable list - current_iteration_output_variable = variable_pool.get(self.node_data.output_selector) - if current_iteration_output_variable is None: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Iteration output variable {self.node_data.output_selector} not found", - ) + # wait all threads + wait(futures) + else: + for _ in range(len(iterator_list_value)): + yield from self._run_single_iter( + iterator_list_value, + variable_pool, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, ) - return - current_iteration_output = current_iteration_output_variable.to_object() - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove([node_id]) - - # move to next iteration - current_index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(current_index_variable, IntegerSegment): - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = current_index_variable.value + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output), - ) - yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -242,9 +231,9 @@ class IterationNode(BaseNode[IterationNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} ) ) - except Exception as e: + except IterationNodeError as e: # iteration run failed - logger.exception("Iteration run failed") + logger.warning("Iteration run failed") yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -292,7 +281,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) if not iteration_graph: - raise ValueError("iteration graph not found") + raise IterationGraphNotFoundError("iteration graph not found") for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): if sub_node_config.get("data", {}).get("iteration_id") != node_id: @@ -330,3 +319,231 @@ class IterationNode(BaseNode[IterationNodeData]): } return variable_mapping + + def _handle_event_metadata( + self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str + ) -> NodeRunStartedEvent | BaseNodeEvent: + """ + add iteration metadata to event. + """ + if not isinstance(event, BaseNodeEvent): + return event + if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + event.parallel_mode_run_id = parallel_mode_run_id + return event + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + if self.node_data.is_parallel: + metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id + else: + metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + event.route_node_state.node_run_result.metadata = metadata + return event + + def _run_single_iter( + self, + iterator_list_value: list[str], + variable_pool: VariablePool, + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + parallel_mode_run_id: Optional[str] = None, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration + """ + try: + rst = graph_engine.run() + # get current iteration index + current_index = variable_pool.get([self.node_id, "index"]).value + next_index = int(current_index) + 1 + + if current_index is None: + raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue + + if isinstance(event, NodeRunSucceededEvent): + yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + if isinstance(event, NodeRunFailedEvent): + if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + outputs.insert(current_index, None) + variable_pool.add([self.node_id, "index"], next_index) + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield metadata_event + + current_iteration_output = variable_pool.get(self.node_data.output_selector).value + outputs.insert(current_index, current_iteration_output) + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # move to next iteration + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + ) + + except IterationNodeError as e: + logger.warning(f"Iteration run failed:{str(e)}") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + + def _run_single_iter_parallel( + self, + flask_app: Flask, + q: Queue, + iterator_list_value: list[str], + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + index: int, + item: Any, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration in parallel mode + """ + with flask_app.app_context(): + parallel_mode_run_id = uuid.uuid4().hex + graph_engine_copy = graph_engine.create_copy() + variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool + variable_pool_copy.add([self.node_id, "index"], index) + variable_pool_copy.add([self.node_id, "item"], item) + for event in self._run_single_iter( + iterator_list_value=iterator_list_value, + variable_pool=variable_pool_copy, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine_copy, + iteration_graph=iteration_graph, + parallel_mode_run_id=parallel_mode_run_id, + ): + q.put(event) diff --git a/api/core/workflow/nodes/knowledge_retrieval/exc.py b/api/core/workflow/nodes/knowledge_retrieval/exc.py new file mode 100644 index 0000000000..0c3b6e86fa --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/exc.py @@ -0,0 +1,18 @@ +class KnowledgeRetrievalNodeError(ValueError): + """Base class for KnowledgeRetrievalNode errors.""" + + +class ModelNotExistError(KnowledgeRetrievalNodeError): + """Raised when the model does not exist.""" + + +class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError): + """Raised when the model credentials are not initialized.""" + + +class ModelNotSupportedError(KnowledgeRetrievalNodeError): + """Raised when the model is not supported.""" + + +class ModelQuotaExceededError(KnowledgeRetrievalNodeError): + """Raised when the model provider quota is exceeded.""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2a5795a3ed..8c5a9b5ecb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,7 +8,6 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -18,11 +17,19 @@ from core.variables import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +from .entities import KnowledgeRetrievalNodeData +from .exc import ( + KnowledgeRetrievalNodeError, + ModelCredentialsNotInitializedError, + ModelNotExistError, + ModelNotSupportedError, + ModelQuotaExceededError, +) + logger = logging.getLogger(__name__) default_retrieval_model = { @@ -61,8 +68,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) - except Exception as e: - logger.exception("Error when running knowledge retrieval node") + except KnowledgeRetrievalNodeError as e: + logger.warning("Error when running knowledge retrieval node") return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: @@ -295,14 +302,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.") elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.") elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = node_data.single_retrieval_config.model.completion_params @@ -314,12 +321,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise ModelNotExistError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, diff --git a/api/core/workflow/nodes/list_operator/exc.py b/api/core/workflow/nodes/list_operator/exc.py new file mode 100644 index 0000000000..f88aa0be29 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/exc.py @@ -0,0 +1,16 @@ +class ListOperatorError(ValueError): + """Base class for all ListOperator errors.""" + + pass + + +class InvalidFilterValueError(ListOperatorError): + pass + + +class InvalidKeyError(ListOperatorError): + pass + + +class InvalidConditionError(ListOperatorError): + pass diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index d7e4c64313..49e7ca85fd 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal +from typing import Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus from .entities import ListOperatorNodeData +from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError class ListOperatorNode(BaseNode[ListOperatorNodeData]): @@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) - if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + if not variable.value: + inputs = {"variable": []} + process_data = {"variable": []} + outputs = {"result": [], "first_record": None, "last_record": None} + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" @@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) if isinstance(variable, ArrayFileSegment): + inputs = {"variable": [item.to_dict() for item in variable.value]} process_data["variable"] = [item.to_dict() for item in variable.value] else: + inputs = {"variable": variable.value} process_data["variable"] = variable.value - # Filter - if self.node_data.filter_by.enabled: - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise ValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise ValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) + try: + # Filter + if self.node_data.filter_by.enabled: + variable = self._apply_filter(variable) - # Order - if self.node_data.order_by.enabled: + # Order + if self.node_data.order_by.enabled: + variable = self._apply_order(variable) + + # Slice + if self.node_data.limit.enabled: + variable = self._apply_slice(variable) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + except ListOperatorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + def _apply_filter( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, ) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) + return variable - # Slice - if self.node_data.limit.enabled: - result = variable.value[: self.node_data.limit.size] + def _apply_order( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + return variable - outputs = { - "result": variable.value, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) + def _apply_slice( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + result = variable.value[: self.node_data.limit.size] + return variable.model_copy(update={"value": result}) def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: @@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: case "size": return lambda x: x.size case _: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: @@ -118,14 +157,14 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: return lambda x: x.type case "extension": return lambda x: x.extension or "" - case "mimetype": + case "mime_type": return lambda x: x.mime_type or "" case "transfer_method": return lambda x: x.transfer_method case "url": return lambda x: x.remote_url or "" case _: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: @@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo case "not empty": return lambda x: x != "" case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: @@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab case "not in": return lambda x: not _in(value)(x) case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: @@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ case "≥": return _ge(value) case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: @@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str extract_func = _get_file_extract_number_func(key=key) return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) else: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _contains(value: str): @@ -256,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq extract_func = _get_file_extract_number_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") else: - raise ValueError(f"Invalid order key: {order_by}") + raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py new file mode 100644 index 0000000000..f858be2515 --- /dev/null +++ b/api/core/workflow/nodes/llm/exc.py @@ -0,0 +1,26 @@ +class LLMNodeError(ValueError): + """Base class for LLM Node errors.""" + + +class VariableNotFoundError(LLMNodeError): + """Raised when a required variable is not found.""" + + +class InvalidContextStructureError(LLMNodeError): + """Raised when the context structure is invalid.""" + + +class InvalidVariableTypeError(LLMNodeError): + """Raised when the variable type is invalid.""" + + +class ModelNotExistError(LLMNodeError): + """Raised when the specified model does not exist.""" + + +class LLMModeRequiredError(LLMNodeError): + """Raised when LLM mode is required but not provided.""" + + +class NoPromptFoundError(LLMNodeError): + """Raised when no prompt is found in the LLM configuration.""" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 472587cb03..47b0e25d9c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -56,6 +56,15 @@ from .entities import ( LLMNodeData, ModelConfig, ) +from .exc import ( + InvalidContextStructureError, + InvalidVariableTypeError, + LLMModeRequiredError, + LLMNodeError, + ModelNotExistError, + NoPromptFoundError, + VariableNotFoundError, +) if TYPE_CHECKING: from core.file.models import File @@ -103,7 +112,7 @@ class LLMNode(BaseNode[LLMNodeData]): yield event if context: - node_inputs["#context#"] = context # type: ignore + node_inputs["#context#"] = context # fetch model config model_instance, model_config = self._fetch_model_config(self.node_data.model) @@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]): if self.node_data.memory: query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) if not query: - raise ValueError("Query not found") + raise VariableNotFoundError("Query not found") query = query.text else: query = None @@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = event.usage finish_reason = event.finish_reason break - except Exception as e: + except LLMNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: """ @@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() @@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() @@ -349,15 +358,13 @@ class LLMNode(BaseNode[LLMNodeData]): variable = self.graph_runtime_state.variable_pool.get(selector) if variable is None: return [] - if isinstance(variable, FileSegment): + elif isinstance(variable, FileSegment): return [variable.value] - if isinstance(variable, ArrayFileSegment): + elif isinstance(variable, ArrayFileSegment): return variable.value - # FIXME: Temporary fix for empty array, - # all variables added to variable pool should be a Segment instance. - if isinstance(variable, ArrayAnySegment) and len(variable.value) == 0: + elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] - raise ValueError(f"Invalid variable type: {type(variable)}") + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: @@ -378,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]): context_str += item + "\n" else: if "content" not in item: - raise ValueError(f"Invalid context structure: {item}") + raise InvalidContextStructureError(f"Invalid context structure: {item}") context_str += item["content"] + "\n" @@ -443,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") @@ -462,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]): # get model mode model_mode = node_data_model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise LLMModeRequiredError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, @@ -566,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]): filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError( + raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) @@ -638,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: - raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") + raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py new file mode 100644 index 0000000000..6511aba185 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -0,0 +1,50 @@ +class ParameterExtractorNodeError(ValueError): + """Base error for ParameterExtractorNode.""" + + +class InvalidModelTypeError(ParameterExtractorNodeError): + """Raised when the model is not a Large Language Model.""" + + +class ModelSchemaNotFoundError(ParameterExtractorNodeError): + """Raised when the model schema is not found.""" + + +class InvalidInvokeResultError(ParameterExtractorNodeError): + """Raised when the invoke result is invalid.""" + + +class InvalidTextContentTypeError(ParameterExtractorNodeError): + """Raised when the text content type is invalid.""" + + +class InvalidNumberOfParametersError(ParameterExtractorNodeError): + """Raised when the number of parameters is invalid.""" + + +class RequiredParameterMissingError(ParameterExtractorNodeError): + """Raised when a required parameter is missing.""" + + +class InvalidSelectValueError(ParameterExtractorNodeError): + """Raised when a select value is invalid.""" + + +class InvalidNumberValueError(ParameterExtractorNodeError): + """Raised when a number value is invalid.""" + + +class InvalidBoolValueError(ParameterExtractorNodeError): + """Raised when a bool value is invalid.""" + + +class InvalidStringValueError(ParameterExtractorNodeError): + """Raised when a string value is invalid.""" + + +class InvalidArrayValueError(ParameterExtractorNodeError): + """Raised when an array value is invalid.""" + + +class InvalidModelModeError(ParameterExtractorNodeError): + """Raised when the model mode is invalid.""" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 49546e9356..b64bde8ac5 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -32,6 +32,21 @@ from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus from .entities import ParameterExtractorNodeData +from .exc import ( + InvalidArrayValueError, + InvalidBoolValueError, + InvalidInvokeResultError, + InvalidModelModeError, + InvalidModelTypeError, + InvalidNumberOfParametersError, + InvalidNumberValueError, + InvalidSelectValueError, + InvalidStringValueError, + InvalidTextContentTypeError, + ModelSchemaNotFoundError, + ParameterExtractorNodeError, + RequiredParameterMissingError, +) from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, @@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema( @@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode): credentials=model_config.credentials, ) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") # fetch memory memory = self._fetch_memory( @@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode): process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) process_data["llm_text"] = text - except Exception as e: + except ParameterExtractorNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, @@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode): try: result = self._validate_result(data=node_data, result=result or {}) - except Exception as e: + except ParameterExtractorNodeError as e: error = str(e) # transform result into standard format @@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode): # handle invoke result if not isinstance(invoke_result, LLMResult): - raise ValueError(f"Invalid invoke result: {invoke_result}") + raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content if not isinstance(text, str): - raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") + raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode): files=files, ) else: - raise ValueError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {model_mode}") def _generate_prompt_engineering_completion_prompt( self, @@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode): Validate result. """ if len(data.parameters) != len(result): - raise ValueError("Invalid number of parameters") + raise InvalidNumberOfParametersError("Invalid number of parameters") for parameter in data.parameters: if parameter.required and parameter.name not in result: - raise ValueError(f"Parameter {parameter.name} is required") + raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: - raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): - raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): - raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") if parameter.type == "string" and not isinstance(result.get(parameter.name), str): - raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") if parameter.type.startswith("array"): parameters = result.get(parameter.name) if not isinstance(parameters, list): - raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in parameters: if nested_type == "number" and not isinstance(item, int | float): - raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") if nested_type == "string" and not isinstance(item, str): - raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") if nested_type == "object" and not isinstance(item, dict): - raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: @@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode): user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode): .replace("}γγγ", "") ) else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _calculate_rest_token( self, @@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) diff --git a/api/core/workflow/nodes/question_classifier/exc.py b/api/core/workflow/nodes/question_classifier/exc.py new file mode 100644 index 0000000000..2c6354e2a7 --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/exc.py @@ -0,0 +1,6 @@ +class QuestionClassifierNodeError(ValueError): + """Base class for QuestionClassifierNode errors.""" + + +class InvalidModelTypeError(QuestionClassifierNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ee160e7c69..0489020e5e 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.llm_generator.output_parser.errors import OutputParserError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole @@ -24,6 +25,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.workflow import WorkflowNodeExecutionStatus from .entities import QuestionClassifierNodeData +from .exc import InvalidModelTypeError from .template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, @@ -124,7 +126,7 @@ class QuestionClassifierNode(LLMNode): category_name = classes_map[category_id_result] category_id = category_id_result - except Exception: + except OutputParserError: logging.error(f"Failed to parse result text: {result_text}") try: process_data = { @@ -309,4 +311,4 @@ class QuestionClassifierNode(LLMNode): ) else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/core/workflow/nodes/tool/exc.py b/api/core/workflow/nodes/tool/exc.py new file mode 100644 index 0000000000..7212e8bfc0 --- /dev/null +++ b/api/core/workflow/nodes/tool/exc.py @@ -0,0 +1,16 @@ +class ToolNodeError(ValueError): + """Base exception for tool node errors.""" + + pass + + +class ToolParameterError(ToolNodeError): + """Exception raised for errors in tool parameters.""" + + pass + + +class ToolFileError(ToolNodeError): + """Exception raised for errors related to tool files.""" + + pass diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index df22130d69..42e870c46c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.models import File, FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager @@ -15,12 +15,18 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus +from .entities import ToolNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) + class ToolNode(BaseNode[ToolNodeData]): """ @@ -42,7 +48,7 @@ class ToolNode(BaseNode[ToolNodeData]): tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, @@ -53,7 +59,7 @@ class ToolNode(BaseNode[ToolNodeData]): ) # get parameters - tool_parameters = tool_runtime.get_runtime_parameters() or [] + tool_parameters = tool_runtime.parameters or [] parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -75,7 +81,7 @@ class ToolNode(BaseNode[ToolNodeData]): workflow_call_depth=self.workflow_call_depth, thread_pool_id=self.thread_pool_id, ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, @@ -133,13 +139,13 @@ class ToolNode(BaseNode[ToolNodeData]): if tool_input.type == "variable": variable = variable_pool.get(tool_input.value) if variable is None: - raise ValueError(f"variable {tool_input.value} not exists") + raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value elif tool_input.type in {"mixed", "constant"}: segment_group = variable_pool.convert_template(str(tool_input.value)) parameter_value = segment_group.log if for_log else segment_group.text else: - raise ValueError(f"unknown tool input type '{tool_input.type}'") + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") result[parameter_name] = parameter_value return result @@ -181,7 +187,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") result.append( File( @@ -203,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") result.append( File( tenant_id=self.tenant_id, @@ -224,7 +230,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") if "." in url: extension = "." + url.split("/")[-1].split(".")[1] else: diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 0fa832f420..56b1d6bd28 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -1,10 +1,8 @@ import logging import os import sys -from datetime import datetime from logging.handlers import RotatingFileHandler -import pytz from flask import Flask from configs import dify_config @@ -32,10 +30,16 @@ def init_app(app: Flask): handlers=log_handlers, force=True, ) - log_tz = dify_config.LOG_TZ if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + for handler in logging.root.handlers: - handler.formatter.converter = lambda seconds: ( - datetime.fromtimestamp(seconds, tz=pytz.UTC).astimezone(log_tz).timetuple() - ) + handler.formatter.converter = time_converter diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index ead7b9a8b3..1066dc8862 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -160,7 +160,7 @@ def _build_from_local_file( tenant_id=tenant_id, type=file_type, transfer_method=transfer_method, - remote_url=None, + remote_url=row.source_url, related_id=mapping.get("upload_file_id"), _extra_config=config, size=row.size, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index a758f9981f..0191102b90 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -10,6 +10,7 @@ from core.variables import ( ArrayNumberVariable, ArrayObjectSegment, ArrayObjectVariable, + ArraySegment, ArrayStringSegment, ArrayStringVariable, FileSegment, @@ -79,7 +80,7 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1: + if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) match types.pop(): case SegmentType.STRING: @@ -90,6 +91,8 @@ def build_segment(value: Any, /) -> Segment: return ArrayObjectSegment(value=value) case SegmentType.FILE: return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) case _: raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}") diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index bf1c491a05..2eb19c2667 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -121,6 +121,7 @@ conversation_fields = { "from_account_name": fields.String, "read_at": TimestampField, "created_at": TimestampField, + "updated_at": TimestampField, "annotation": fields.Nested(annotation_fields, allow_null=True), "model_config": fields.Nested(simple_model_config_fields), "user_feedback_stats": fields.Nested(feedback_stat_fields), @@ -182,6 +183,7 @@ conversation_detail_fields = { "from_end_user_id": fields.String, "from_account_id": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, "annotated": fields.Boolean, "introduction": fields.String, "model_config": fields.Nested(model_config_fields), @@ -197,6 +199,7 @@ simple_conversation_fields = { "status": fields.String, "introduction": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, } conversation_infinite_scroll_pagination_fields = { diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 9ff1111b74..afaacc0568 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -8,6 +8,7 @@ upload_config_fields = { "image_file_size_limit": fields.Integer, "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, } file_fields = { @@ -24,3 +25,15 @@ remote_file_info_fields = { "file_type": fields.String(attribute="file_type"), "file_length": fields.Integer(attribute="file_length"), } + + +file_fields_with_signed_url = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "url": fields.String, + "mime_type": fields.String, + "created_by": fields.String, + "created_at": TimestampField, +} diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 9131408817..41c5d20c4b 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -9,6 +9,7 @@ def parse_json_markdown(json_string: str) -> dict: starts = ["```json", "```", "``", "`", "{"] ends = ["```", "``", "`", "}"] end_index = -1 + start_index = 0 for s in starts: start_index = json_string.find(s) if start_index != -1: @@ -24,7 +25,6 @@ def parse_json_markdown(json_string: str) -> dict: break if start_index != -1 and end_index != -1 and start_index < end_index: extracted_content = json_string[start_index:end_index].strip() - print("content:", extracted_content, start_index, end_index) parsed = json.loads(extracted_content) else: raise Exception("Could not find JSON block in the output.") diff --git a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py index 6a7402b16a..153861a71a 100644 --- a/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py +++ b/api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py @@ -28,16 +28,12 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False), sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') ) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ## - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - op.drop_table('tracing_app_configs') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py new file mode 100644 index 0000000000..a749c8bddf --- /dev/null +++ b/api/migrations/versions/2024_11_01_0434-d3f6769a94a3_add_upload_files_source_url.py @@ -0,0 +1,31 @@ +"""Add upload_files.source_url + +Revision ID: d3f6769a94a3 +Revises: 43fa78bc3b7d +Create Date: 2024-11-01 04:34:23.816198 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd3f6769a94a3' +down_revision = '43fa78bc3b7d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.add_column(sa.Column('source_url', sa.String(length=255), server_default='', nullable=False)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_column('source_url') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py new file mode 100644 index 0000000000..81a7978f73 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0449-93ad8c19c40b_rename_conversation_variables_index_name.py @@ -0,0 +1,52 @@ +"""rename conversation variables index name + +Revision ID: 93ad8c19c40b +Revises: d3f6769a94a3 +Create Date: 2024-11-01 04:49:53.100250 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '93ad8c19c40b' +down_revision = 'd3f6769a94a3' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes for PostgreSQL + op.execute('ALTER INDEX workflow__conversation_variables_app_id_idx RENAME TO workflow_conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow__conversation_variables_created_at_idx RENAME TO workflow_conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index('workflow__conversation_variables_app_id_idx') + batch_op.drop_index('workflow__conversation_variables_created_at_idx') + batch_op.create_index(batch_op.f('workflow_conversation_variables_app_id_idx'), ['app_id'], unique=False) + batch_op.create_index(batch_op.f('workflow_conversation_variables_created_at_idx'), ['created_at'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + conn = op.get_bind() + if conn.dialect.name == 'postgresql': + # Rename indexes back for PostgreSQL + op.execute('ALTER INDEX workflow_conversation_variables_app_id_idx RENAME TO workflow__conversation_variables_app_id_idx') + op.execute('ALTER INDEX workflow_conversation_variables_created_at_idx RENAME TO workflow__conversation_variables_created_at_idx') + else: + # For other databases, use the original drop and create method + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_conversation_variables_created_at_idx')) + batch_op.drop_index(batch_op.f('workflow_conversation_variables_app_id_idx')) + batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False) + batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py new file mode 100644 index 0000000000..222379a490 --- /dev/null +++ b/api/migrations/versions/2024_11_01_0540-f4d7ce70a7ca_update_upload_files_source_url.py @@ -0,0 +1,41 @@ +"""update upload_files.source_url + +Revision ID: f4d7ce70a7ca +Revises: 93ad8c19c40b +Create Date: 2024-11-01 05:40:03.531751 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'f4d7ce70a7ca' +down_revision = '93ad8c19c40b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.alter_column('source_url', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + existing_nullable=False, + existing_server_default=sa.text("''::character varying")) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py new file mode 100644 index 0000000000..9a4ccf352d --- /dev/null +++ b/api/migrations/versions/2024_11_01_0622-d07474999927_update_type_of_custom_disclaimer_to_text.py @@ -0,0 +1,67 @@ +"""update type of custom_disclaimer to TEXT + +Revision ID: d07474999927 +Revises: f4d7ce70a7ca +Create Date: 2024-11-01 06:22:27.981398 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'd07474999927' +down_revision = 'f4d7ce70a7ca' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.execute("UPDATE recommended_apps SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL") + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.VARCHAR(length=255), + type_=sa.TEXT(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.alter_column('custom_disclaimer', + existing_type=sa.TEXT(), + type_=sa.VARCHAR(length=255), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py new file mode 100644 index 0000000000..117a7351cd --- /dev/null +++ b/api/migrations/versions/2024_11_01_0623-09a8d1878d9b_update_workflows_graph_features_and_.py @@ -0,0 +1,73 @@ +"""update workflows graph, features and updated_at + +Revision ID: 09a8d1878d9b +Revises: d07474999927 +Create Date: 2024-11-01 06:23:59.579186 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '09a8d1878d9b' +down_revision = 'd07474999927' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + + op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL") + op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL") + op.execute("UPDATE workflows SET features = '' WHERE features IS NULL") + + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=False) + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('updated_at', + existing_type=postgresql.TIMESTAMP(), + nullable=True) + batch_op.alter_column('features', + existing_type=sa.TEXT(), + nullable=True) + batch_op.alter_column('graph', + existing_type=sa.TEXT(), + nullable=True) + + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.alter_column('inputs', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py index 09ef5e186c..99b7010612 100644 --- a/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py +++ b/api/migrations/versions/2a3aebbbf4bb_add_app_tracing.py @@ -22,17 +22,11 @@ def upgrade(): with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True)) - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False) - # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - with op.batch_alter_table('apps', schema=None) as batch_op: batch_op.drop_column('tracing') diff --git a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py index 469c04338a..f87819c367 100644 --- a/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py +++ b/api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py @@ -30,30 +30,15 @@ def upgrade(): sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.PrimaryKeyConstraint('id', name='trace_app_config_pkey') ) + with op.batch_alter_table('trace_app_config', schema=None) as batch_op: batch_op.create_index('trace_app_config_app_id_idx', ['app_id'], unique=False) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('tracing_app_configs', - sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False), - sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False), - sa.Column('tracing_provider', sa.VARCHAR(length=255), autoincrement=False, nullable=True), - sa.Column('tracing_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), - sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('now()'), autoincrement=False, nullable=False), - sa.PrimaryKeyConstraint('id', name='trace_app_config_pkey') - ) - with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op: - batch_op.create_index('trace_app_config_app_id_idx', ['app_id'], unique=False) - - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('trace_app_config_app_id_idx') - op.drop_table('trace_app_config') + # ### end Alembic commands ### diff --git a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py index 271b2490de..6f76a361d9 100644 --- a/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py +++ b/api/migrations/versions/fecff1c3da27_remove_extra_tracing_app_config_table .py @@ -20,12 +20,10 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_table('tracing_app_configs') - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.drop_index('tracing_app_config_app_id_idx') - # idx_dataset_permissions_tenant_id with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.create_index('idx_dataset_permissions_tenant_id', ['tenant_id']) + # ### end Alembic commands ### @@ -46,9 +44,7 @@ def downgrade(): sa.PrimaryKeyConstraint('id', name='tracing_app_config_pkey') ) - with op.batch_alter_table('trace_app_config', schema=None) as batch_op: - batch_op.create_index('tracing_app_config_app_id_idx', ['app_id']) - with op.batch_alter_table('dataset_permissions', schema=None) as batch_op: batch_op.drop_index('idx_dataset_permissions_tenant_id') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 1d8bae6cfa..cd6c7674da 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -6,7 +6,6 @@ from .model import ( AppMode, Conversation, EndUser, - FileUploadConfig, InstalledApp, Message, MessageAnnotation, @@ -50,6 +49,5 @@ __all__ = [ "Tenant", "Conversation", "MessageAnnotation", - "FileUploadConfig", "ToolFile", ] diff --git a/api/models/model.py b/api/models/model.py index 3bd5886d75..d049cd373d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,14 +1,14 @@ import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from datetime import datetime from enum import Enum from typing import Any, Literal, Optional +import sqlalchemy as sa from flask import request from flask_login import UserMixin -from pydantic import BaseModel, Field from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -24,14 +24,6 @@ from .account import Account, Tenant from .types import StringUUID -class FileUploadConfig(BaseModel): - enabled: bool = Field(default=False) - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_extensions: Sequence[str] = Field(default_factory=list) - allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = Field(default=0, gt=0, le=10) - - class DifySetup(db.Model): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -114,7 +106,7 @@ class App(db.Model): return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self): if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() @@ -396,7 +388,7 @@ class AppModelConfig(db.Model): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: dict): + def from_model_config_dict(self, model_config: Mapping[str, Any]): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None @@ -483,7 +475,7 @@ class RecommendedApp(db.Model): description = db.Column(db.JSON, nullable=False) copyright = db.Column(db.String(255), nullable=False) privacy_policy = db.Column(db.String(255), nullable=False) - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") category = db.Column(db.String(255), nullable=False) position = db.Column(db.Integer, nullable=False, default=0) is_listed = db.Column(db.Boolean, nullable=False, default=True) @@ -1306,7 +1298,7 @@ class Site(db.Model): privacy_policy = db.Column(db.String(255)) show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - custom_disclaimer = db.Column(db.String(255), nullable=True) + _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1317,6 +1309,16 @@ class Site(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) + @property + def custom_disclaimer(self): + return self._custom_disclaimer + + @custom_disclaimer.setter + def custom_disclaimer(self, value: str): + if len(value) > 512: + raise ValueError("Custom disclaimer cannot exceed 512 characters.") + self._custom_disclaimer = value + @staticmethod def generate_code(n): while True: @@ -1384,6 +1386,7 @@ class UploadFile(db.Model): used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) + source_url: Mapped[str] = mapped_column(sa.TEXT, default="") def __init__( self, @@ -1402,7 +1405,8 @@ class UploadFile(db.Model): used_by: str | None = None, used_at: datetime | None = None, hash: str | None = None, - ) -> None: + source_url: str = "", + ): self.tenant_id = tenant_id self.storage_type = storage_type self.key = key @@ -1417,6 +1421,7 @@ class UploadFile(db.Model): self.used_by = used_by self.used_at = used_at self.hash = hash + self.source_url = source_url class ApiRequest(db.Model): diff --git a/api/models/tools.py b/api/models/tools.py index 691f3f3cb6..4040339e02 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,6 +1,7 @@ import json from typing import Optional +import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column @@ -117,7 +118,7 @@ class ApiToolProvider(db.Model): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True) # custom_disclaimer - custom_disclaimer = db.Column(db.String(255), nullable=True) + custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/models/workflow.py b/api/models/workflow.py index e5fbcaf87e..4f0e9a5e03 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,9 +1,10 @@ import json from collections.abc import Mapping, Sequence -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Any, Optional, Union +import sqlalchemy as sa from sqlalchemy import func from sqlalchemy.orm import Mapped, mapped_column @@ -99,14 +100,16 @@ class Workflow(db.Model): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(db.String(255), nullable=False) version: Mapped[str] = mapped_column(db.String(255), nullable=False) - graph: Mapped[str] = mapped_column(db.Text) - _features: Mapped[str] = mapped_column("features") + graph: Mapped[str] = mapped_column(sa.Text) + _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - updated_by: Mapped[str] = mapped_column(StringUUID) - updated_at: Mapped[datetime] = mapped_column(db.DateTime) + updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) + updated_at: Mapped[datetime] = mapped_column( + sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp() + ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" ) diff --git a/api/poetry.lock b/api/poetry.lock index 5b581b9965..6cd5e24dec 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -932,10 +932,6 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -948,14 +944,8 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -966,24 +956,8 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, - {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, - {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -993,10 +967,6 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -1008,10 +978,6 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -1024,10 +990,6 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -1040,10 +1002,6 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -2574,6 +2532,19 @@ files = [ {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, ] +[[package]] +name = "fire" +version = "0.7.0" +description = "A library for automatically generating command line interfaces." +optional = false +python-versions = "*" +files = [ + {file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"}, +] + +[package.dependencies] +termcolor = "*" + [[package]] name = "flasgger" version = "0.9.7.1" @@ -2739,6 +2710,19 @@ files = [ {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, ] +[[package]] +name = "fontmeta" +version = "1.6.1" +description = "An Utility to get ttf/otf font metadata" +optional = false +python-versions = "*" +files = [ + {file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"}, +] + +[package.dependencies] +fonttools = "*" + [[package]] name = "fonttools" version = "4.54.1" @@ -5321,6 +5305,22 @@ files = [ {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, ] +[[package]] +name = "mplfonts" +version = "0.0.8" +description = "Fonts manager for matplotlib" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"}, + {file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"}, +] + +[package.dependencies] +fire = ">=0.4.0" +fontmeta = ">=1.6.1" +matplotlib = ">=3.4" + [[package]] name = "mpmath" version = "1.3.0" @@ -8735,11 +8735,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -9342,6 +9337,20 @@ files = [ [package.dependencies] tencentcloud-sdk-python-common = "3.0.1257" +[[package]] +name = "termcolor" +version = "2.5.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.9" +files = [ + {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"}, + {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -10088,13 +10097,13 @@ files = [ [[package]] name = "vanna" -version = "0.7.3" +version = "0.7.5" description = "Generate SQL queries from natural language" optional = false python-versions = ">=3.9" files = [ - {file = "vanna-0.7.3-py3-none-any.whl", hash = "sha256:82ba39e5d6c503d1c8cca60835ed401d20ec3a3da98d487f529901dcb30061d6"}, - {file = "vanna-0.7.3.tar.gz", hash = "sha256:4590dd94d2fe180b4efc7a83c867b73144ef58794018910dc226857cfb703077"}, + {file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"}, + {file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"}, ] [package.dependencies] @@ -10115,7 +10124,7 @@ sqlparse = "*" tabulate = "*" [package.extras] -all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "zhipuai"] +all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"] anthropic = ["anthropic"] azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"] bedrock = ["boto3", "botocore"] @@ -10123,6 +10132,8 @@ bigquery = ["google-cloud-bigquery"] chromadb = ["chromadb"] clickhouse = ["clickhouse_connect"] duckdb = ["duckdb"] +faiss-cpu = ["faiss-cpu"] +faiss-gpu = ["faiss-gpu"] gemini = ["google-generativeai"] google = ["google-cloud-aiplatform", "google-generativeai"] hf = ["transformers"] @@ -10133,6 +10144,7 @@ mysql = ["PyMySQL"] ollama = ["httpx", "ollama"] openai = ["openai"] opensearch = ["opensearch-dsl", "opensearch-py"] +pgvector = ["langchain-postgres (>=0.0.12)"] pinecone = ["fastembed", "pinecone-client"] postgres = ["db-dtypes", "psycopg2-binary"] qdrant = ["fastembed", "qdrant-client"] @@ -10141,6 +10153,7 @@ snowflake = ["snowflake-connector-python"] test = ["tox"] vllm = ["vllm"] weaviate = ["weaviate-client"] +xinference-client = ["xinference-client"] zhipuai = ["zhipuai"] [[package]] @@ -10982,4 +10995,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b" +content-hash = "bb8385625eb61de086b7a7156745066b4fb171d9ca67afd1d092fa7e872f3abd" diff --git a/api/pyproject.toml b/api/pyproject.toml index ee7cf4d618..4438cf61db 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -168,7 +168,7 @@ readabilipy = "0.2.0" redis = { version = "~5.0.3", extras = ["hiredis"] } replicate = "~0.22.0" resend = "~0.7.0" -sagemaker = "2.231.0" +sagemaker = "~2.231.0" scikit-learn = "~1.5.1" sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sqlalchemy = "~2.0.29" @@ -206,13 +206,14 @@ cloudscraper = "1.2.71" duckduckgo-search = "~6.3.0" jsonpath-ng = "1.6.1" matplotlib = "~3.8.2" +mplfonts = "~0.0.8" newspaper3k = "0.2.8" nltk = "3.9.1" numexpr = "~2.9.0" pydub = "~0.25.1" qrcode = "~7.4.2" twilio = "~9.0.4" -vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } +vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] } wikipedia = "1.4.0" yfinance = "~0.2.40" diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 67d0706828..9efe120b7a 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -14,7 +14,7 @@ from models.dataset import Embedding @app.celery.task(queue="dataset") def clean_embedding_cache_task(): click.echo(click.style("Start clean embedding cache.", fg="green")) - clean_days = int(dify_config.CLEAN_DAY_SETTING) + clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING) start_at = time.perf_counter() thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: diff --git a/api/services/app_dsl_service/__init__.py b/api/services/app_dsl_service/__init__.py new file mode 100644 index 0000000000..9fc988ffb3 --- /dev/null +++ b/api/services/app_dsl_service/__init__.py @@ -0,0 +1,3 @@ +from .service import AppDslService + +__all__ = ["AppDslService"] diff --git a/api/services/app_dsl_service/exc.py b/api/services/app_dsl_service/exc.py new file mode 100644 index 0000000000..6da4b1938f --- /dev/null +++ b/api/services/app_dsl_service/exc.py @@ -0,0 +1,34 @@ +class DSLVersionNotSupportedError(ValueError): + """Raised when the imported DSL version is not supported by the current Dify version.""" + + +class InvalidYAMLFormatError(ValueError): + """Raised when the provided YAML format is invalid.""" + + +class MissingAppDataError(ValueError): + """Raised when the app data is missing in the provided DSL.""" + + +class InvalidAppModeError(ValueError): + """Raised when the app mode is invalid.""" + + +class MissingWorkflowDataError(ValueError): + """Raised when the workflow data is missing in the provided DSL.""" + + +class MissingModelConfigError(ValueError): + """Raised when the model config data is missing in the provided DSL.""" + + +class FileSizeLimitExceededError(ValueError): + """Raised when the file size exceeds the allowed limit.""" + + +class EmptyContentError(ValueError): + """Raised when the content fetched from the URL is empty.""" + + +class ContentDecodingError(ValueError): + """Raised when there is an error decoding the content.""" diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service/service.py similarity index 75% rename from api/services/app_dsl_service.py rename to api/services/app_dsl_service/service.py index 750d0a8cd2..e6b0d9a272 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service/service.py @@ -1,8 +1,11 @@ import logging +from collections.abc import Mapping +from typing import Any -import httpx -import yaml # type: ignore +import yaml +from packaging import version +from core.helper import ssrf_proxy from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_database import db from factories import variable_factory @@ -11,13 +14,20 @@ from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow_service import WorkflowService +from .exc import ( + ContentDecodingError, + EmptyContentError, + FileSizeLimitExceededError, + InvalidAppModeError, + InvalidYAMLFormatError, + MissingAppDataError, + MissingModelConfigError, + MissingWorkflowDataError, +) + logger = logging.getLogger(__name__) -current_dsl_version = "0.1.2" -dsl_to_dify_version_mapping: dict[str, str] = { - "0.1.2": "0.8.0", - "0.1.1": "0.6.0", # dsl version -> from dify version -} +current_dsl_version = "0.1.3" class AppDslService: @@ -30,32 +40,21 @@ class AppDslService: :param args: request args :param account: Account instance """ - try: - max_size = 10 * 1024 * 1024 # 10MB - timeout = httpx.Timeout(10.0) - with httpx.stream("GET", url.strip(), follow_redirects=True, timeout=timeout) as response: - response.raise_for_status() - total_size = 0 - content = b"" - for chunk in response.iter_bytes(): - total_size += len(chunk) - if total_size > max_size: - raise ValueError("File size exceeds the limit of 10MB") - content += chunk - except httpx.HTTPStatusError as http_err: - raise ValueError(f"HTTP error occurred: {http_err}") - except httpx.RequestError as req_err: - raise ValueError(f"Request error occurred: {req_err}") - except Exception as e: - raise ValueError(f"Failed to fetch DSL from URL: {e}") + max_size = 10 * 1024 * 1024 # 10MB + response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content + + if len(content) > max_size: + raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") if not content: - raise ValueError("Empty content from url") + raise EmptyContentError("Empty content from url") try: data = content.decode("utf-8") except UnicodeDecodeError as e: - raise ValueError(f"Error decoding content: {e}") + raise ContentDecodingError(f"Error decoding content: {e}") return cls.import_and_create_new_app(tenant_id, data, args, account) @@ -71,14 +70,14 @@ class AppDslService: try: import_data = yaml.safe_load(data) except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) + import_data = _check_or_fix_dsl(import_data) app_data = import_data.get("app") if not app_data: - raise ValueError("Missing app in data argument") + raise MissingAppDataError("Missing app in data argument") # get app basic info name = args.get("name") or app_data.get("name") @@ -90,11 +89,18 @@ class AppDslService: # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get("workflow"), + workflow_data=workflow_data, account=account, name=name, description=description, @@ -104,10 +110,16 @@ class AppDslService: use_icon_as_answer_icon=use_icon_as_answer_icon, ) elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + model_config = import_data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) + app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get("model_config"), + model_config_data=model_config, account=account, name=name, description=description, @@ -117,7 +129,7 @@ class AppDslService: use_icon_as_answer_icon=use_icon_as_answer_icon, ) else: - raise ValueError("Invalid app mode") + raise InvalidAppModeError("Invalid app mode") return app @@ -132,26 +144,32 @@ class AppDslService: try: import_data = yaml.safe_load(data) except yaml.YAMLError: - raise ValueError("Invalid YAML format in data argument.") + raise InvalidYAMLFormatError("Invalid YAML format in data argument.") # check or repair dsl version - import_data = cls._check_or_fix_dsl(import_data) + import_data = _check_or_fix_dsl(import_data) app_data = import_data.get("app") if not app_data: - raise ValueError("Missing app in data argument") + raise MissingAppDataError("Missing app in data argument") # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - raise ValueError("Only support import workflow in advanced-chat or workflow app.") + raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + workflow_data = import_data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) + return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get("workflow"), + workflow_data=workflow_data, account=account, ) @@ -186,35 +204,12 @@ class AppDslService: return yaml.dump(export_data, allow_unicode=True) - @classmethod - def _check_or_fix_dsl(cls, import_data: dict) -> dict: - """ - Check or fix dsl - - :param import_data: import data - """ - if not import_data.get("version"): - import_data["version"] = "0.1.0" - - if not import_data.get("kind") or import_data.get("kind") != "app": - import_data["kind"] = "app" - - if import_data.get("version") != current_dsl_version: - # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning( - f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." - ) - - return import_data - @classmethod def _import_and_create_new_workflow_based_app( cls, tenant_id: str, app_mode: AppMode, - workflow_data: dict, + workflow_data: Mapping[str, Any], account: Account, name: str, description: str, @@ -238,7 +233,9 @@ class AppDslService: :param use_icon_as_answer_icon: use app icon as answer icon """ if not workflow_data: - raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) app = cls._create_app( tenant_id=tenant_id, @@ -277,7 +274,7 @@ class AppDslService: @classmethod def _import_and_overwrite_workflow_based_app( - cls, app_model: App, workflow_data: dict, account: Account + cls, app_model: App, workflow_data: Mapping[str, Any], account: Account ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -287,7 +284,9 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument when app mode is advanced-chat or workflow") + raise MissingWorkflowDataError( + "Missing workflow in data argument when app mode is advanced-chat or workflow" + ) # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -323,7 +322,7 @@ class AppDslService: cls, tenant_id: str, app_mode: AppMode, - model_config_data: dict, + model_config_data: Mapping[str, Any], account: Account, name: str, description: str, @@ -345,7 +344,9 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument when app mode is chat, agent-chat or completion") + raise MissingModelConfigError( + "Missing model_config in data argument when app mode is chat, agent-chat or completion" + ) app = cls._create_app( tenant_id=tenant_id, @@ -448,3 +449,36 @@ class AppDslService: raise ValueError("Missing app configuration, please check.") export_data["model_config"] = app_model_config.to_dict() + + +def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]: + """ + Check or fix dsl + + :param import_data: import data + :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version + """ + if not import_data.get("version"): + import_data["version"] = "0.1.0" + + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" + + imported_version = import_data.get("version") + if imported_version != current_dsl_version: + if imported_version and version.parse(imported_version) > version.parse(current_dsl_version): + errmsg = ( + f"The imported DSL version {imported_version} is newer than " + f"the current supported version {current_dsl_version}. " + f"Please upgrade your Dify instance to import this configuration." + ) + logger.warning(errmsg) + # raise DSLVersionNotSupportedError(errmsg) + else: + logger.warning( + f"DSL version {imported_version} is older than " + f"the current version {current_dsl_version}. " + f"This may cause compatibility issues." + ) + + return import_data diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 414ef0224a..50da547fd8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -4,7 +4,7 @@ import logging import random import time import uuid -from typing import Optional +from typing import Any, Optional from flask_login import current_user from sqlalchemy import func @@ -675,7 +675,7 @@ class DocumentService: def save_document_with_dataset_id( dataset: Dataset, document_data: dict, - account: Account, + account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): @@ -736,11 +736,12 @@ class DocumentService: dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model documents = [] - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) + batch = document.batch else: + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: process_rule = document_data["process_rule"] @@ -921,7 +922,7 @@ class DocumentService: if duplicate_document_ids: duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids) - return documents, batch + return documents, batch @staticmethod def check_documents_upload_quota(count: int, features: FeatureModel): @@ -985,9 +986,6 @@ class DocumentService: raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") - # update document name - if document_data.get("name"): - document.name = document_data["name"] # save process rule if document_data.get("process_rule"): process_rule = document_data["process_rule"] @@ -1064,6 +1062,10 @@ class DocumentService: document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name + + # update document name + if document_data.get("name"): + document.name = document_data["name"] # update document to be waiting document.indexing_status = "waiting" document.completed_at = None diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index b49738c61c..98e5d9face 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -7,8 +7,6 @@ import httpx import validators from constants import HIDDEN_VALUE - -# from tasks.external_document_indexing_task import external_document_indexing_task from core.helper import ssrf_proxy from extensions.ext_database import db from models.dataset import ( diff --git a/api/services/file_service.py b/api/services/file_service.py index 6193a39669..976111502c 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,10 +1,9 @@ import datetime import hashlib import uuid -from typing import Literal, Union +from typing import Any, Literal, Union from flask_login import current_user -from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound from configs import dify_config @@ -21,7 +20,8 @@ from extensions.ext_storage import storage from models.account import Account from models.enums import CreatedByRole from models.model import EndUser, UploadFile -from services.errors.file import FileNotExistsError, FileTooLargeError, UnsupportedFileTypeError + +from .errors.file import FileTooLargeError, UnsupportedFileTypeError PREVIEW_WORDS_LIMIT = 3000 @@ -29,38 +29,28 @@ PREVIEW_WORDS_LIMIT = 3000 class FileService: @staticmethod def upload_file( - file: FileStorage, user: Union[Account, EndUser], source: Literal["datasets"] | None = None + *, + filename: str, + content: bytes, + mimetype: str, + user: Union[Account, EndUser, Any], + source: Literal["datasets"] | None = None, + source_url: str = "", ) -> UploadFile: - # get file name - filename = file.filename - if not filename: - raise FileNotExistsError - extension = filename.split(".")[-1] + # get file extension + extension = filename.split(".")[-1].lower() if len(filename) > 200: filename = filename.split(".")[0][:200] + "." + extension if source == "datasets" and extension not in DOCUMENT_EXTENSIONS: raise UnsupportedFileTypeError() - # select file size limit - if extension in IMAGE_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 - elif extension in VIDEO_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 - elif extension in AUDIO_EXTENSIONS: - file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 - else: - file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 - - # read file content - file_content = file.read() # get file size - file_size = len(file_content) + file_size = len(content) # check if the file size is exceeded - if file_size > file_size_limit: - message = f"File size exceeded. {file_size} > {file_size_limit}" - raise FileTooLargeError(message) + if not FileService.is_file_size_within_limit(extension=extension, file_size=file_size): + raise FileTooLargeError # generate file key file_uuid = str(uuid.uuid4()) @@ -74,7 +64,7 @@ class FileService: file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage - storage.save(file_key, file_content) + storage.save(file_key, content) # save file to db upload_file = UploadFile( @@ -84,12 +74,13 @@ class FileService: name=filename, size=file_size, extension=extension, - mime_type=file.mimetype, + mime_type=mimetype, created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest(), + hash=hashlib.sha3_256(content).hexdigest(), + source_url=source_url, ) db.session.add(upload_file) @@ -97,6 +88,19 @@ class FileService: return upload_file + @staticmethod + def is_file_size_within_limit(*, extension: str, file_size: int) -> bool: + if extension in IMAGE_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in VIDEO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT * 1024 * 1024 + elif extension in AUDIO_EXTENSIONS: + file_size_limit = dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT * 1024 * 1024 + else: + file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 + + return file_size <= file_size_limit + @staticmethod def upload_text(text: str, text_name: str) -> UploadFile: if len(text_name) > 200: diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 6791cd891b..6fd144c5c2 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -84,5 +84,18 @@ VOLC_EMBEDDING_ENDPOINT_ID= # 360 AI Credentials ZHINAO_API_KEY= +# VESSL AI Credentials +VESSL_AI_MODEL_NAME= +VESSL_AI_API_KEY= +VESSL_AI_ENDPOINT_URL= + +# GPUStack Credentials +GPUSTACK_SERVER_URL= +GPUSTACK_API_KEY= + # Gitee AI Credentials GITEE_AI_API_KEY= + +# xAI Credentials +XAI_API_KEY= +XAI_API_BASE= diff --git a/api/tests/integration_tests/model_runtime/gpustack/__init__.py b/api/tests/integration_tests/model_runtime/gpustack/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py new file mode 100644 index 0000000000..f56ad0dadc --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_embedding.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import ( + GPUStackTextEmbeddingModel, +) + + +def test_validate_credentials(): + model = GPUStackTextEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-m3", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="bge-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_model(): + model = GPUStackTextEmbeddingModel() + + result = model.invoke( + model="bge-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "context_size": 8192, + }, + texts=["hello", "world"], + user="abc-123", + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + assert result.usage.total_tokens == 7 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_llm.py b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py new file mode 100644 index 0000000000..326b7b16f0 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_llm.py @@ -0,0 +1,162 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, +) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel + + +def test_validate_credentials_for_chat_model(): + model = GPUStackLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + "mode": "chat", + }, + ) + + model.validate_credentials( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + ) + + +def test_invoke_completion_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "completion", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_chat_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="ping")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=[], + user="abc-123", + stream=False, + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + assert response.usage.total_tokens > 0 + + +def test_invoke_stream_chat_model(): + model = GPUStackLanguageModel() + + response = model.invoke( + model="llama-3.2-1b-instruct", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10}, + stop=["you"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + +def test_get_num_tokens(): + model = GPUStackLanguageModel() + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ) + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 80 + + num_tokens = model.get_num_tokens( + model="????", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + "mode": "chat", + }, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 10 diff --git a/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py new file mode 100644 index 0000000000..f5c2d2d21c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/gpustack/test_rerank.py @@ -0,0 +1,107 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.gpustack.rerank.rerank import ( + GPUStackRerankModel, +) + + +def test_validate_credentials_for_rerank_model(): + model = GPUStackRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": "invalid_url", + "api_key": "invalid_api_key", + }, + ) + + model.validate_credentials( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + ) + + +def test_invoke_rerank_model(): + model = GPUStackRerankModel() + + response = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=-0.75, + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 + + +def test__invoke(): + model = GPUStackRerankModel() + + # Test case 1: Empty docs + result = model._invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[], + top_n=3, + score_threshold=0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 0 + + # Test case 2: Expected docs + result = model._invoke( + model="bge-reranker-v2-m3", + credentials={ + "endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"), + "api_key": os.environ.get("GPUSTACK_API_KEY"), + }, + query="Organic skincare products for sensitive skin", + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials", + ], + top_n=3, + score_threshold=-0.75, + user="abc-123", + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 3 + assert all(isinstance(doc, RerankDocument) for doc in result.docs) diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py b/api/tests/integration_tests/model_runtime/vessl_ai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py new file mode 100644 index 0000000000..7797d0f8e4 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/vessl_ai/test_llm.py @@ -0,0 +1,131 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel + + +def test_validate_credentials(): + model = VesslAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": "invalid_key", + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": "http://invalid_url", + "mode": "chat", + }, + ) + + model.validate_credentials( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + ) + + +def test_invoke_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=False, + user="abc-123", + ) + + assert isinstance(response, LLMResult) + assert len(response.message.content) > 0 + + +def test_invoke_stream_model(): + model = VesslAILargeLanguageModel() + + response = model.invoke( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Who are you?"), + ], + model_parameters={ + "temperature": 1.0, + "top_k": 2, + "top_p": 0.5, + }, + stop=["How"], + stream=True, + user="abc-123", + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_get_num_tokens(): + model = VesslAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model=os.environ.get("VESSL_AI_MODEL_NAME"), + credentials={ + "api_key": os.environ.get("VESSL_AI_API_KEY"), + "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + ) + + assert isinstance(num_tokens, int) + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/x/__init__.py b/api/tests/integration_tests/model_runtime/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/x/test_llm.py b/api/tests/integration_tests/model_runtime/x/test_llm.py new file mode 100644 index 0000000000..647a2f6480 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/x/test_llm.py @@ -0,0 +1,204 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = XAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials( + model="gpt-3.5-turbo", + credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, + ) + + model.validate_credentials( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = XAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 diff --git a/api/tests/integration_tests/vdb/lindorm/__init__.py b/api/tests/integration_tests/vdb/lindorm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/lindorm/test_lindorm.py b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py new file mode 100644 index 0000000000..f8f43ba6ef --- /dev/null +++ b/api/tests/integration_tests/vdb/lindorm/test_lindorm.py @@ -0,0 +1,35 @@ +import environs + +from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis + +env = environs.Env() + + +class Config: + SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070") + SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") + SEARCH_PWD = env.str("SEARCH_PWD", "PWD") + + +class TestLindormVectorStore(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = LindormVectorStore( + collection_name=self.collection_name, + config=LindormVectorStoreConfig( + hosts=Config.SEARCH_ENDPOINT, + username=Config.SEARCH_USERNAME, + password=Config.SEARCH_PWD, + ), + ) + + def get_ids_by_metadata_field(self): + ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + assert ids[0] == self.example_doc_id + + +def test_lindorm_vector(setup_mock_redis): + TestLindormVectorStore().run_all_tests() diff --git a/api/tests/unit_tests/controllers/test_compare_versions.py b/api/tests/unit_tests/controllers/test_compare_versions.py index 87902b6d44..9db57a8446 100644 --- a/api/tests/unit_tests/controllers/test_compare_versions.py +++ b/api/tests/unit_tests/controllers/test_compare_versions.py @@ -22,17 +22,3 @@ from controllers.console.version import _has_new_version ) def test_has_new_version(latest_version, current_version, expected): assert _has_new_version(latest_version=latest_version, current_version=current_version) == expected - - -def test_has_new_version_invalid_input(): - with pytest.raises(ValueError): - _has_new_version(latest_version="1.0", current_version="1.0.0") - - with pytest.raises(ValueError): - _has_new_version(latest_version="1.0.0", current_version="1.0") - - with pytest.raises(ValueError): - _has_new_version(latest_version="invalid", current_version="1.0.0") - - with pytest.raises(ValueError): - _has_new_version(latest_version="1.0.0", current_version="invalid") diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py new file mode 100644 index 0000000000..a6bf43ab0c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -0,0 +1,52 @@ +import pytest + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.apps.base_app_generator import BaseAppGenerator + + +def test_validate_inputs_with_zero(): + base_app_generator = BaseAppGenerator() + + var = VariableEntity( + variable="test_var", + label="test_var", + type=VariableEntityType.NUMBER, + required=True, + ) + + # Test with input 0 + result = base_app_generator._validate_inputs( + variable_entity=var, + value=0, + ) + + assert result == 0 + + # Test with input "0" (string) + result = base_app_generator._validate_inputs( + variable_entity=var, + value="0", + ) + + assert result == 0 + + +def test_validate_input_with_none_for_required_variable(): + base_app_generator = BaseAppGenerator() + + for var_type in VariableEntityType: + var = VariableEntity( + variable="test_var", + label="test_var", + type=var_type, + required=True, + ) + + # Test with input None + with pytest.raises(ValueError) as exc_info: + base_app_generator._validate_inputs( + variable_entity=var, + value=None, + ) + + assert str(exc_info.value) == "test_var is required in input form" diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 72d277fad4..882a87239b 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -13,6 +13,7 @@ from core.variables import ( StringVariable, ) from core.variables.exc import VariableError +from core.variables.segments import ArrayAnySegment from factories import variable_factory @@ -156,3 +157,9 @@ def test_variable_cannot_large_than_200_kb(): "value": "a" * 1024 * 201, } ) + + +def test_array_none_variable(): + var = variable_factory.build_segment([None, None, None, None]) + assert isinstance(var, ArrayAnySegment) + assert var.value == [None, None, None, None] diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py new file mode 100644 index 0000000000..12c469a81a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -0,0 +1,198 @@ +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.executor import Executor + + +def test_executor_with_json_body_and_number_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "number"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Number Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"number": {{#pre_node_id.number#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"number": 42} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '{"number": 42}' in raw_request + + +def test_executor_with_json_body_and_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value="{{#pre_node_id.object#}}", + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_executor_with_json_body_and_nested_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"object": {{#pre_node_id.object#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"object": {' in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_extract_selectors_from_template_with_newline(): + variable_pool = VariablePool() + variable_pool.add(("node_id", "custom_query"), "line1\nline2") + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="test: {{#node_id.custom_query#}}", + body=HttpRequestNodeBody( + type="none", + data=[], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + assert executor.params == {"test": "line1\nline2"} diff --git a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py similarity index 52% rename from api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py rename to api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 720037d05f..741a3a1894 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -1,5 +1,3 @@ -import json - import httpx from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,8 +14,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeBody, HttpRequestNodeData, ) -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout -from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict +from core.workflow.nodes.http_request.executor import _plain_text_to_dict from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -203,167 +200,3 @@ def test_http_request_node_form_with_file(monkeypatch): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["body"] == "" - - -def test_executor_with_json_body_and_number_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "number"], 42) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Number Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value='{"number": {{#pre_node_id.number#}}}', - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"number": 42} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '{"number": 42}' in raw_request - - -def test_executor_with_json_body_and_object_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Object Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value="{{#pre_node_id.object#}}", - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '"name": "John Doe"' in raw_request - assert '"age": 30' in raw_request - assert '"email": "john@example.com"' in raw_request - - -def test_executor_with_json_body_and_nested_object_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Nested Object Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value='{"object": {{#pre_node_id.object#}}}', - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '"object": {' in raw_request - assert '"name": "John Doe"' in raw_request - assert '"age": 30' in raw_request - assert '"email": "john@example.com"' in raw_request diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index d755faee8a..29bd4d6c6c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.entities import ErrorHandleMode from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.enums import UserFrom @@ -185,8 +186,6 @@ def test_run(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() @@ -404,18 +403,458 @@ def test_run_parallel(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() count = 0 for item in result: - # print(type(item), item) count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} assert count == 32 + + +def test_iteration_run_in_parallel_mode(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + parallel_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + sequential_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + parallel_result = parallel_iteration_node._run() + sequential_result = sequential_iteration_node._run() + assert parallel_iteration_node.node_data.parallel_nums == 10 + assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + count = 0 + parallel_arr = [] + sequential_arr = [] + for item in parallel_result: + count += 1 + parallel_arr.append(item) + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 32 + + for item in sequential_result: + sequential_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 64 + + +def test_iteration_run_error_handle(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "iteration-start", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "tt", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "tt2", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt2", "output"], + "output_type": "array[string]", + "start_node_id": "if-else", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1.split(arg2) }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, + ], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + ], + }, + "id": "tt2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "1", + "variable_selector": ["iteration-1", "item"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["1", "1"]) + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + }, + ) + # execute continue on error node + result = iteration_node._run() + result_arr = [] + count = 0 + for item in result: + result_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": [None, None]} + + assert count == 14 + # execute remove abnormal output + iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + result = iteration_node._run() + count = 0 + for item in result: + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": []} + assert count == 14 diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py new file mode 100644 index 0000000000..def6c2a232 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -0,0 +1,125 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.file import File, FileTransferMethod, FileType +from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState +from core.workflow.nodes.answer import AnswerStreamGenerateRoute +from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.node import LLMNode +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class TestLLMNode: + @pytest.fixture + def llm_node(self): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + def test_fetch_files_with_file_segment(self, llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + def test_fetch_files_with_array_file_segment(self, llm_node): + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + def test_fetch_files_with_none_segment(self, llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + def test_fetch_files_with_array_any_segment(self, llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + def test_fetch_files_with_non_existent_variable(self, llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 53e3c93fcc..0f5c8bf51b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock import pytest -from core.file import File -from core.file.models import FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy -from core.workflow.nodes.list_operator.node import ListOperatorNode +from core.workflow.nodes.list_operator.exc import InvalidKeyError +from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from models.workflow import WorkflowNodeExecutionStatus @@ -109,3 +109,46 @@ def test_filter_files_by_type(list_operator_node): assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id + + +def test_get_file_extract_string_func(): + # Create a File object + file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + remote_url="https://example.com/test_file.txt", + related_id="test_related_id", + ) + + # Test each case + assert _get_file_extract_string_func(key="name")(file) == "test_file.txt" + assert _get_file_extract_string_func(key="type")(file) == "document" + assert _get_file_extract_string_func(key="extension")(file) == ".txt" + assert _get_file_extract_string_func(key="mime_type")(file) == "text/plain" + assert _get_file_extract_string_func(key="transfer_method")(file) == "local_file" + assert _get_file_extract_string_func(key="url")(file) == "https://example.com/test_file.txt" + + # Test with empty values + empty_file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename=None, + extension=None, + mime_type=None, + remote_url=None, + related_id="test_related_id", + ) + + assert _get_file_extract_string_func(key="name")(empty_file) == "" + assert _get_file_extract_string_func(key="extension")(empty_file) == "" + assert _get_file_extract_string_func(key="mime_type")(empty_file) == "" + assert _get_file_extract_string_func(key="url")(empty_file) == "" + + # Test invalid key + with pytest.raises(InvalidKeyError): + _get_file_extract_string_func(key="invalid_key") diff --git a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py new file mode 100644 index 0000000000..842e8268d1 --- /dev/null +++ b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py @@ -0,0 +1,47 @@ +import pytest +from packaging import version + +from services.app_dsl_service import AppDslService +from services.app_dsl_service.exc import DSLVersionNotSupportedError +from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version + + +class TestAppDSLService: + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_missing_version(self): + import_data = {} + result = _check_or_fix_dsl(import_data) + assert result["version"] == "0.1.0" + assert result["kind"] == "app" + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_missing_kind(self): + import_data = {"version": "0.1.0"} + result = _check_or_fix_dsl(import_data) + assert result["kind"] == "app" + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_older_version(self): + import_data = {"version": "0.0.9", "kind": "app"} + result = _check_or_fix_dsl(import_data) + assert result["version"] == "0.0.9" + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_current_version(self): + import_data = {"version": current_dsl_version, "kind": "app"} + result = _check_or_fix_dsl(import_data) + assert result["version"] == current_dsl_version + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_newer_version(self): + current_version = version.parse(current_dsl_version) + newer_version = f"{current_version.major}.{current_version.minor + 1}.0" + import_data = {"version": newer_version, "kind": "app"} + with pytest.raises(DSLVersionNotSupportedError): + _check_or_fix_dsl(import_data) + + @pytest.mark.skip(reason="Test skipped") + def test_check_or_fix_dsl_invalid_kind(self): + import_data = {"version": current_dsl_version, "kind": "invalid"} + result = _check_or_fix_dsl(import_data) + assert result["kind"] == "app" diff --git a/dev/reformat b/dev/reformat index ad83e897d9..94a7f3e6fe 100755 --- a/dev/reformat +++ b/dev/reformat @@ -9,10 +9,10 @@ if ! command -v ruff &> /dev/null || ! command -v dotenv-linter &> /dev/null; th fi # run ruff linter -ruff check --fix ./api +poetry run -C api ruff check --fix ./api # run ruff formatter -ruff format ./api +poetry run -C api ruff format ./api # run dotenv-linter linter -dotenv-linter ./api/.env.example ./web/.env.example +poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index e3f1c3b761..88650194ec 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.10.2 + image: langgenius/dify-api:0.11.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.10.2 + image: langgenius/dify-api:0.11.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -396,7 +396,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.10.2 + image: langgenius/dify-web:0.11.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index ef2f331c11..9a178dc44c 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -222,7 +222,6 @@ REDIS_PORT=6379 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false -REDIS_DB=0 # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -375,7 +374,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -531,11 +530,17 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 + +# Lindorm configuration, only available when VECTOR_STORE is `lindorm` +LINDORM_URL=http://ld-***************-proxy-search-pub.lindorm.aliyuncs.com:30070 +LINDORM_USERNAME=username +LINDORM_PASSWORD=password + # OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` -OCEANBASE_VECTOR_HOST=oceanbase-vector +OCEANBASE_VECTOR_HOST=oceanbase OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test -OCEANBASE_VECTOR_PASSWORD= +OCEANBASE_VECTOR_PASSWORD=difyai123456 OCEANBASE_VECTOR_DATABASE=test OCEANBASE_MEMORY_LIMIT=6G @@ -558,6 +563,22 @@ ETL_TYPE=dify # For example: http://unstructured:8000/general/v0/general UNSTRUCTURED_API_URL= +# ------------------------------ +# Model Configuration +# ------------------------------ + +# The maximum number of tokens allowed for prompt generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating a prompt in the prompt generation tool. +# Default: 512 tokens. +PROMPT_GENERATION_MAX_TOKENS=512 + +# The maximum number of tokens allowed for code generation. +# This setting controls the upper limit of tokens that can be used by the LLM +# when generating code in the code generation tool. +# Default: 1024 tokens. +CODE_GENERATION_MAX_TOKENS=1024 + # ------------------------------ # Multi-modal Configuration # ------------------------------ @@ -572,6 +593,12 @@ MULTIMODAL_SEND_IMAGE_FORMAT=base64 # Upload image file size limit, default 10M. UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 +# Upload video file size limit, default 100M. +UPLOAD_VIDEO_FILE_SIZE_LIMIT=100 + +# Upload audio file size limit, default 50M. +UPLOAD_AUDIO_FILE_SIZE_LIMIT=50 + # ------------------------------ # Sentry Configuration # Used for application monitoring and error log tracking. @@ -623,7 +650,6 @@ MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. RESEND_API_KEY=your-resend-api-key -RESEND_API_URL=https://api.resend.com # SMTP server configuration, used when MAIL_TYPE is `smtp` SMTP_SERVER= @@ -664,6 +690,7 @@ WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 +WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 31624285b1..2eea273e72 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -56,6 +56,7 @@ services: SANDBOX_PORT: ${SANDBOX_PORT:-8194} volumes: - ./volumes/sandbox/dependencies:/dependencies + - ./volumes/sandbox/conf:/conf healthcheck: test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] networks: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 06c99b5eab..a7cb8576fd 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,4 +1,5 @@ x-shared-env: &shared-api-worker-env + WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_FILE: ${LOG_FILE:-} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} @@ -167,6 +168,9 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_PORT: ${ELASTICSEARCH_PORT:-9200} ELASTICSEARCH_USERNAME: ${ELASTICSEARCH_USERNAME:-elastic} ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} + LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} + LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} + LINDORM_PASSWORD: ${LINDORM_USERNAME:-lindorm } KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} @@ -207,8 +211,12 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} + PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} + CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10} + UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100} + UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50} SENTRY_DSN: ${API_SENTRY_DSN:-} SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0} SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0} @@ -258,14 +266,15 @@ x-shared-env: &shared-api-worker-env OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-http://oceanbase-vector} OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881} OCEANBASE_VECTOR_USER: ${OCEANBASE_VECTOR_USER:-root@test} - OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-""} + OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} + OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} services: # API service api: - image: langgenius/dify-api:0.10.2 + image: langgenius/dify-api:0.11.0 restart: always environment: # Use the shared environment variables. @@ -285,7 +294,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.10.2 + image: langgenius/dify-api:0.11.0 restart: always environment: # Use the shared environment variables. @@ -304,7 +313,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.10.2 + image: langgenius/dify-web:0.11.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -589,16 +598,21 @@ services: IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} # OceanBase vector database - oceanbase-vector: + oceanbase: image: quay.io/oceanbase/oceanbase-ce:4.3.3.0-100000142024101215 profiles: - - oceanbase-vector + - oceanbase restart: always volumes: - ./volumes/oceanbase/data:/root/ob - ./volumes/oceanbase/conf:/root/.obd/cluster + - ./volumes/oceanbase/init.d:/root/boot/init.d environment: OB_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + OB_SYS_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_TENANT_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} + OB_SERVER_IP: '127.0.0.1' # Oracle vector database oracle: diff --git a/docker/volumes/oceanbase/init.d/vec_memory.sql b/docker/volumes/oceanbase/init.d/vec_memory.sql new file mode 100644 index 0000000000..f4c283fdf4 --- /dev/null +++ b/docker/volumes/oceanbase/init.d/vec_memory.sql @@ -0,0 +1 @@ +ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30; \ No newline at end of file diff --git a/docker/volumes/sandbox/conf/config.yaml b/docker/volumes/sandbox/conf/config.yaml new file mode 100644 index 0000000000..8c1a1deb54 --- /dev/null +++ b/docker/volumes/sandbox/conf/config.yaml @@ -0,0 +1,14 @@ +app: + port: 8194 + debug: True + key: dify-sandbox +max_workers: 4 +max_requests: 50 +worker_timeout: 5 +python_path: /usr/local/bin/python3 +enable_network: True # please make sure there is no network risk in your environment +allowed_syscalls: # please leave it empty if you have no idea how seccomp works +proxy: + socks5: '' + http: '' + https: '' diff --git a/docker/volumes/sandbox/conf/config.yaml.example b/docker/volumes/sandbox/conf/config.yaml.example new file mode 100644 index 0000000000..f92c19e51a --- /dev/null +++ b/docker/volumes/sandbox/conf/config.yaml.example @@ -0,0 +1,35 @@ +app: + port: 8194 + debug: True + key: dify-sandbox +max_workers: 4 +max_requests: 50 +worker_timeout: 5 +python_path: /usr/local/bin/python3 +python_lib_path: + - /usr/local/lib/python3.10 + - /usr/lib/python3.10 + - /usr/lib/python3 + - /usr/lib/x86_64-linux-gnu + - /etc/ssl/certs/ca-certificates.crt + - /etc/nsswitch.conf + - /etc/hosts + - /etc/resolv.conf + - /run/systemd/resolve/stub-resolv.conf + - /run/resolvconf/resolv.conf + - /etc/localtime + - /usr/share/zoneinfo + - /etc/timezone + # add more paths if needed +python_pip_mirror_url: https://pypi.tuna.tsinghua.edu.cn/simple +nodejs_path: /usr/local/bin/node +enable_network: True +allowed_syscalls: + - 1 + - 2 + - 3 + # add all the syscalls which you require +proxy: + socks5: '' + http: '' + https: '' diff --git a/web/app/(commonLayout)/datasets/DatasetFooter.tsx b/web/app/(commonLayout)/datasets/DatasetFooter.tsx index 6eac815a1a..b87098000f 100644 --- a/web/app/(commonLayout)/datasets/DatasetFooter.tsx +++ b/web/app/(commonLayout)/datasets/DatasetFooter.tsx @@ -9,8 +9,8 @@ const DatasetFooter = () => {

{t('dataset.didYouKnow')}

- {t('dataset.intro1')}{t('dataset.intro2')}{t('dataset.intro3')}
- {t('dataset.intro4')}{t('dataset.intro5')}{t('dataset.intro6')} + {t('dataset.intro1')}{t('dataset.intro2')}{t('dataset.intro3')}
+ {t('dataset.intro4')}{t('dataset.intro5')}{t('dataset.intro6')}

) diff --git a/web/app/(commonLayout)/datasets/Doc.tsx b/web/app/(commonLayout)/datasets/Doc.tsx index a6dd8c23ef..553dca5008 100644 --- a/web/app/(commonLayout)/datasets/Doc.tsx +++ b/web/app/(commonLayout)/datasets/Doc.tsx @@ -1,6 +1,6 @@ 'use client' -import type { FC } from 'react' +import { type FC, useEffect } from 'react' import { useContext } from 'use-context-selector' import TemplateEn from './template/template.en.mdx' import TemplateZh from './template/template.zh.mdx' @@ -14,6 +14,13 @@ const Doc: FC = ({ apiBaseUrl, }) => { const { locale } = useContext(I18n) + + useEffect(() => { + const hash = location.hash + if (hash) + document.querySelector(hash)?.scrollIntoView() + }, []) + return (
{ diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index e264fd707e..263230d049 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -20,17 +20,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge and creates a new document through text based on this Knowledge. + This API is based on an existing knowledge and creates a new document through text based on this knowledge. ### Params @@ -50,7 +50,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Index mode - high_quality High quality: embedding using embedding model, built as vector database index - - economy Economy: Build using inverted index of Keyword Table Index + - economy Economy: Build using inverted index of keyword table index Processing rules @@ -62,7 +62,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -72,11 +72,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -123,17 +123,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge and creates a new document through a file based on this Knowledge. + This API is based on an existing knowledge and creates a new document through a file based on this knowledge. ### Params @@ -145,17 +145,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - original_document_id Source document ID (optional) + - original_document_id Source document ID (optional) - Used to re-upload the document or modify the document cleaning and segmentation configuration. The missing information is copied from the source document - The source document cannot be an archived document - When original_document_id is passed in, the update operation is performed on behalf of the document. process_rule is a fillable item. If not filled in, the segmentation method of the source document will be used by default - When original_document_id is not passed in, the new operation is performed on behalf of the document, and process_rule is required - - indexing_technique Index mode + - indexing_technique Index mode - high_quality High quality: embedding using embedding model, built as vector database index - - economy Economy: Build using inverted index of Keyword Table Index + - economy Economy: Build using inverted index of keyword table index - - process_rule Processing rules + - process_rule Processing rules - mode (string) Cleaning, segmentation mode, automatic / custom - rules (object) Custom rules (in automatic mode, this field is empty) - pre_processing_rules (array[object]) Preprocessing rules @@ -164,7 +164,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -177,11 +177,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -221,12 +221,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -240,9 +240,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Knowledge description (optional) - Index Technique (optional) - - high_quality high_quality - - economy economy + Index technique (optional) + - high_quality High quality + - economy Economy Permission @@ -252,21 +252,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from Provider (optional, default: vendor) - - vendor vendor - - external external knowledge + - vendor Vendor + - external External knowledge - External Knowledge api id (optional) + External knowledge API ID (optional) - External Knowledge id (optional) + External knowledge ID (optional) - @@ -306,12 +306,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -327,9 +327,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - @@ -369,12 +369,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -406,17 +406,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge and updates the document through text based on this Knowledge. + This API is based on an existing knowledge and updates the document through text based on this knowledge. ### Params @@ -446,7 +446,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -456,11 +456,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -503,17 +503,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
- This api is based on an existing Knowledge, and updates documents through files based on this Knowledge + This API is based on an existing knowledge, and updates documents through files based on this knowledge ### Params @@ -543,7 +543,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - remove_extra_spaces Replace consecutive spaces, newlines, tabs - remove_urls_emails Delete URL, email address - enabled (bool) Whether to select this rule or not. If no document ID is passed in, it represents the default value. - - segmentation (object) segmentation rules + - segmentation (object) Segmentation rules - separator Custom segment identifier, currently only allows one delimiter to be set. Default is \n - max_tokens Maximum length (token) defaults to 1000 @@ -553,11 +553,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -597,12 +597,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -652,12 +652,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -694,12 +694,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -714,13 +714,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Query - Search keywords, currently only search document names(optional) + Search keywords, currently only search document names (optional) - Page number(optional) + Page number (optional) - Number of items returned, default 20, range 1-100(optional) + Number of items returned, default 20, range 1-100 (optional) @@ -769,12 +769,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -792,9 +792,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - content (text) Text content/question content, required - - answer (text) Answer content, if the mode of the Knowledge is qa mode, pass the value(optional) - - keywords (list) Keywords(optional) + - content (text) Text content / question content, required + - answer (text) Answer content, if the mode of the knowledge is Q&A mode, pass the value (optional) + - keywords (list) Keywords (optional) @@ -855,12 +855,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -878,10 +878,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Query - keyword,choosable + Keyword (optional) - Search status,completed + Search status, completed @@ -933,12 +933,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -979,12 +979,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -1005,10 +1005,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - content (text) text content/question content,required - - answer (text) Answer content, not required, passed if the Knowledge is in qa mode - - keywords (list) keyword, not required - - enabled (bool) false/true, not required + - content (text) Text content / question content, required + - answer (text) Answer content, passed if the knowledge is in Q&A mode (optional) + - keywords (list) Keyword (optional) + - enabled (bool) False / true (optional) @@ -1067,41 +1067,41 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
### Path - Dataset ID + Knowledge ID ### Request Body - retrieval keywordc + Query keyword - retrieval keyword(Optional, if not filled, it will be recalled according to the default method) + Retrieval model (optional, if not filled, it will be recalled according to the default method) - search_method (text) Search method: One of the following four keywords is required - keyword_search Keyword search - semantic_search Semantic search - full_text_search Full-text search - hybrid_search Hybrid search - - reranking_enable (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search - - reranking_mode (object) Rerank model configuration, optional, required if reranking is enabled + - reranking_enable (bool) Whether to enable reranking, required if the search mode is semantic_search or hybrid_search (optional) + - reranking_mode (object) Rerank model configuration, required if reranking is enabled - reranking_provider_name (string) Rerank model provider - reranking_model_name (string) Rerank model name - weights (double) Semantic search weight setting in hybrid search mode - - top_k (integer) Number of results to return, optional + - top_k (integer) Number of results to return (optional) - score_threshold_enabled (bool) Whether to enable score threshold - score_threshold (double) Score threshold @@ -1114,26 +1114,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1212,7 +1212,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 5d52664db4..9c25d1e7bb 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -20,13 +20,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -50,7 +50,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 索引方式 - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - - economy 经济:使用 Keyword Table Index 的倒排索引进行构建 + - economy 经济:使用 keyword table index 的倒排索引进行构建 处理规则 @@ -64,7 +64,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -72,11 +72,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_text' \ + curl --location --request --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -123,13 +123,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -145,17 +145,17 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - - original_document_id 源文档 ID (选填) + - original_document_id 源文档 ID(选填) - 用于重新上传文档或修改文档清洗、分段配置,缺失的信息从源文档复制 - 源文档不可为归档的文档 - 当传入 original_document_id 时,代表文档进行更新操作,process_rule 为可填项目,不填默认使用源文档的分段方式 - 未传入 original_document_id 时,代表文档进行新增操作,process_rule 为必填 - - indexing_technique 索引方式 + - indexing_technique 索引方式 - high_quality 高质量:使用 embedding 模型进行嵌入,构建为向量数据库索引 - - economy 经济:使用 Keyword Table Index 的倒排索引进行构建 + - economy 经济:使用 keyword table index 的倒排索引进行构建 - - process_rule 处理规则 + - process_rule 处理规则 - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 - rules (object) 自定义规则(自动模式下,该字段为空) - pre_processing_rules (array[object]) 预处理规则 @@ -166,7 +166,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 需要上传的文件。 @@ -177,11 +177,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/document/create-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -221,7 +221,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
economy 经济 - 权限(选填,默认only_me) + 权限(选填,默认 only_me) - only_me 仅自己 - all_team_members 所有团队成员 - partial_members 部分团队成员 - provider,(选填,默认 vendor) + Provider(选填,默认 vendor) - vendor 上传文件 - external 外部知识库 @@ -264,9 +264,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - @@ -306,7 +306,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
---- +
---- +
---- +
@@ -431,7 +431,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 文档名称 (选填) + 文档名称(选填) 文档内容(选填) @@ -448,7 +448,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -456,11 +456,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_text' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-text' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -503,13 +503,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -528,7 +528,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 文档名称 (选填) + 文档名称(选填) 需要上传的文件 @@ -545,7 +545,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - enabled (bool) 是否选中该规则,不传入文档 ID 时代表默认值 - segmentation (object) 分段规则 - separator 自定义分段标识符,目前仅允许设置一个分隔符。默认为 \n - - max_tokens 最大长度 (token) 默认为 1000 + - max_tokens 最大长度(token)默认为 1000 @@ -553,11 +553,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update_by_file' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/documents/{document_id}/update-by-file' \ --header 'Authorization: Bearer {api_key}' \ --form 'data="{\"name\":\"Dify\",\"indexing_technique\":\"high_quality\",\"process_rule\":{\"rules\":{\"pre_processing_rules\":[{\"id\":\"remove_extra_spaces\",\"enabled\":true},{\"id\":\"remove_urls_emails\",\"enabled\":true}],\"segmentation\":{\"separator\":\"###\",\"max_tokens\":500}},\"mode\":\"custom\"}}";type=text/plain' \ --form 'file=@"/path/to/file"' @@ -597,7 +597,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
---- +
- content (text) 文本内容/问题内容,必填 - - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 + - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 @@ -855,7 +855,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
---- +
---- +
- content (text) 文本内容/问题内容,必填 - - answer (text) 答案内容,非必填,如果知识库的模式为qa模式则传值 + - answer (text) 答案内容,非必填,如果知识库的模式为 Q&A 模式则传值 - keywords (list) 关键字,非必填 - enabled (bool) false/true,非必填 @@ -1068,13 +1068,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
@@ -1088,23 +1088,23 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### Request Body - 召回关键词 + 检索关键词 - 召回参数(选填,如不填,按照默认方式召回) + 检索参数(选填,如不填,按照默认方式召回) - search_method (text) 检索方法:以下三个关键字之一,必填 - keyword_search 关键字检索 - semantic_search 语义检索 - full_text_search 全文检索 - hybrid_search 混合检索 - - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值 + - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为 semantic_search 模式或者 hybrid_search 则传值 - reranking_mode (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 - reranking_provider_name (string) Rerank 模型提供商 - reranking_model_name (string) Rerank 模型名称 - weights (double) 混合检索模式下语意检索的权重设置 - top_k (integer) 返回结果数量,非必填 - - score_threshold_enabled (bool) 是否开启Score阈值 - - score_threshold (double) Score阈值 + - score_threshold_enabled (bool) 是否开启 score 阈值 + - score_threshold (double) Score 阈值 未启用字段 @@ -1115,26 +1115,26 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/retrieve' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1214,7 +1214,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ---- +
diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx index 4d8082b410..94984ebe4d 100644 --- a/web/app/account/avatar.tsx +++ b/web/app/account/avatar.tsx @@ -23,8 +23,9 @@ export default function AppSelector() { params: {}, }) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') + localStorage.removeItem('setup_status') + localStorage.removeItem('console_token') + localStorage.removeItem('refresh_token') router.push('/signin') } diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index 97c153b464..549421401c 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -33,6 +33,10 @@ import { LoveMessage } from '@/app/components/base/icons/src/vender/features' // type import type { AutomaticRes } from '@/service/debug' import { Generator } from '@/app/components/base/icons/src/vender/other' +import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' +import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' export interface IGetAutomaticResProps { mode: AppType @@ -68,7 +72,10 @@ const GetAutomaticRes: FC = ({ onFinished, }) => { const { t } = useTranslation() - + const { + currentProvider, + currentModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const tryList = [ { icon: RiTerminalBoxLine, @@ -191,6 +198,19 @@ const GetAutomaticRes: FC = ({
{t('appDebug.generate.title')}
{t('appDebug.generate.description')}
+
+ + +
{t('appDebug.generate.tryIt')}
diff --git a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx index b63e3e2693..85c522ca0f 100644 --- a/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx +++ b/web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx @@ -105,6 +105,15 @@ export const GetCodeGeneratorResModal: FC = (
{t('appDebug.codegen.loading')}
) + const renderNoData = ( +
+ +
+
{t('appDebug.codegen.noDataLine1')}
+
{t('appDebug.codegen.noDataLine2')}
+
+
+ ) return ( = (
{isLoading && renderLoading} + {!isLoading && !res && renderNoData} {(!isLoading && res) && (
{t('appDebug.codegen.resTitle')}
diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 2c082d8815..0d9d575c1e 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -15,6 +15,7 @@ import { AppType } from '@/types/app' import type { DataSet } from '@/models/datasets' import { getMultipleRetrievalConfig, + getSelectedDatasetsMode, } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -38,6 +39,7 @@ const DatasetConfig: FC = () => { isAgent, datasetConfigs, setDatasetConfigs, + setRerankSettingModalOpen, } = useContext(ConfigContext) const formattingChangedDispatcher = useFormattingChangedDispatcher() @@ -55,6 +57,20 @@ const DatasetConfig: FC = () => { ...(datasetConfigs as any), ...retrievalConfig, }) + const { + allExternal, + allInternal, + mixtureInternalAndExternal, + mixtureHighQualityAndEconomic, + inconsistentEmbeddingModel, + } = getSelectedDatasetsMode(filteredDataSets) + + if ( + (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) + || mixtureInternalAndExternal + || allExternal + ) + setRerankSettingModalOpen(true) formattingChangedDispatcher() } diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 6b1983f5e2..75f0c33349 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -266,7 +266,7 @@ const ConfigContent: FC = ({
{ - selectedDatasetsMode.allEconomic && ( + selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
{ let errMsg = '' if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) { - if (!tempDataSetConfigs.reranking_model?.reranking_model_name && (rerankDefaultModel && !isRerankDefaultModelValid)) + if (tempDataSetConfigs.reranking_enable + && tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel + && !isRerankDefaultModelValid + ) errMsg = t('appDebug.datasetConfig.rerankModelRequired') } if (errMsg) { @@ -62,7 +66,9 @@ const ParamsConfig = ({ if (!isValid()) return const config = { ...tempDataSetConfigs } - if (config.retrieval_model === RETRIEVE_TYPE.multiWay && !config.reranking_model) { + if (config.retrieval_model === RETRIEVE_TYPE.multiWay + && config.reranking_mode === RerankingModeEnum.RerankingModel + && !config.reranking_model) { config.reranking_model = { reranking_provider_name: rerankDefaultModel?.provider?.provider, reranking_model_name: rerankDefaultModel?.model, diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index af50fc65c3..bf6c5e79c8 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -252,12 +252,18 @@ const Configuration: FC = () => { } hideSelectDataSet() const { - allEconomic, + allExternal, + allInternal, + mixtureInternalAndExternal, mixtureHighQualityAndEconomic, inconsistentEmbeddingModel, } = getSelectedDatasetsMode(newDatasets) - if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel) + if ( + (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) + || mixtureInternalAndExternal + || allExternal + ) setRerankSettingModalOpen(true) const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index f289c5a401..b78aaffef2 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -36,6 +36,7 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import TextGeneration from '@/app/components/app/text-generate/item' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import MessageLogModal from '@/app/components/base/message-log-modal' +import PromptLogModal from '@/app/components/base/prompt-log-modal' import { useStore as useAppStore } from '@/app/components/app/store' import { useAppContext } from '@/context/app-context' import useTimestamp from '@/hooks/use-timestamp' @@ -168,11 +169,13 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { userProfile: { timezone } } = useAppContext() const { formatTime } = useTimestamp() const { onClose, appDetail } = useContext(DrawerContext) - const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ + const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ currentLogItem: state.currentLogItem, setCurrentLogItem: state.setCurrentLogItem, showMessageLogModal: state.showMessageLogModal, setShowMessageLogModal: state.setShowMessageLogModal, + showPromptLogModal: state.showPromptLogModal, + setShowPromptLogModal: state.setShowPromptLogModal, currentLogModalActiveTab: state.currentLogModalActiveTab, }))) const { t } = useTranslation() @@ -192,8 +195,8 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { conversation_id: detail.id, limit: 10, } - if (allChatItems.at(-1)?.id) - params.first_id = allChatItems.at(-1)?.id.replace('question-', '') + if (allChatItems[0]?.id) + params.first_id = allChatItems[0]?.id.replace('question-', '') const messageRes = await fetchChatMessages({ url: `/apps/${appDetail?.id}/chat-messages`, params, @@ -557,6 +560,16 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { defaultTab={currentLogModalActiveTab} /> )} + {showPromptLogModal && ( + { + setCurrentLogItem() + setShowPromptLogModal(false) + }} + /> + )}
) } diff --git a/web/app/components/base/app-icon-picker/Uploader.tsx b/web/app/components/base/app-icon-picker/Uploader.tsx index 4ddaa40447..ba0ef6b2b2 100644 --- a/web/app/components/base/app-icon-picker/Uploader.tsx +++ b/web/app/components/base/app-icon-picker/Uploader.tsx @@ -8,18 +8,22 @@ import classNames from 'classnames' import { ImagePlus } from '../icons/src/vender/line/images' import { useDraggableUploader } from './hooks' +import { checkIsAnimatedImage } from './utils' import { ALLOW_FILE_EXTENSIONS } from '@/types/app' type UploaderProps = { className?: string onImageCropped?: (tempUrl: string, croppedAreaPixels: Area, fileName: string) => void + onUpload?: (file?: File) => void } const Uploader: FC = ({ className, onImageCropped, + onUpload, }) => { const [inputImage, setInputImage] = useState<{ file: File; url: string }>() + const [isAnimatedImage, setIsAnimatedImage] = useState(false) useEffect(() => { return () => { if (inputImage) @@ -34,12 +38,19 @@ const Uploader: FC = ({ if (!inputImage) return onImageCropped?.(inputImage.url, croppedAreaPixels, inputImage.file.name) + onUpload?.(undefined) } const handleLocalFileInput = (e: ChangeEvent) => { const file = e.target.files?.[0] - if (file) + if (file) { setInputImage({ file, url: URL.createObjectURL(file) }) + checkIsAnimatedImage(file).then((isAnimatedImage) => { + setIsAnimatedImage(!!isAnimatedImage) + if (isAnimatedImage) + onUpload?.(file) + }) + } } const { @@ -52,6 +63,26 @@ const Uploader: FC = ({ const inputRef = createRef() + const handleShowImage = () => { + if (isAnimatedImage) { + return ( + + ) + } + + return ( + + ) + } + return (
= ({
Supports PNG, JPG, JPEG, WEBP and GIF
- : + : handleShowImage() }
diff --git a/web/app/components/base/app-icon-picker/index.tsx b/web/app/components/base/app-icon-picker/index.tsx index ba375abdd9..8a10d28653 100644 --- a/web/app/components/base/app-icon-picker/index.tsx +++ b/web/app/components/base/app-icon-picker/index.tsx @@ -74,6 +74,11 @@ const AppIconPicker: FC = ({ setImageCropInfo({ tempUrl, croppedAreaPixels, fileName }) } + const [uploadImageInfo, setUploadImageInfo] = useState<{ file?: File }>() + const handleUpload = async (file?: File) => { + setUploadImageInfo({ file }) + } + const handleSelect = async () => { if (activeTab === 'emoji') { if (emoji) { @@ -85,9 +90,13 @@ const AppIconPicker: FC = ({ } } else { - if (!imageCropInfo) + if (!imageCropInfo && !uploadImageInfo) return setUploading(true) + if (imageCropInfo.file) { + handleLocalFileUpload(imageCropInfo.file) + return + } const blob = await getCroppedImg(imageCropInfo.tempUrl, imageCropInfo.croppedAreaPixels, imageCropInfo.fileName) const file = new File([blob], imageCropInfo.fileName, { type: blob.type }) handleLocalFileUpload(file) @@ -121,7 +130,7 @@ const AppIconPicker: FC = ({ - +
diff --git a/web/app/components/base/app-icon-picker/utils.ts b/web/app/components/base/app-icon-picker/utils.ts index 14c9ae3f28..99154d56da 100644 --- a/web/app/components/base/app-icon-picker/utils.ts +++ b/web/app/components/base/app-icon-picker/utils.ts @@ -115,3 +115,52 @@ export default async function getCroppedImg( }, mimeType) }) } + +export function checkIsAnimatedImage(file) { + return new Promise((resolve, reject) => { + const fileReader = new FileReader() + + fileReader.onload = function (e) { + const arr = new Uint8Array(e.target.result) + + // Check file extension + const fileName = file.name.toLowerCase() + if (fileName.endsWith('.gif')) { + // If file is a GIF, assume it's animated + resolve(true) + } + // Check for WebP signature (RIFF and WEBP) + else if (isWebP(arr)) { + resolve(checkWebPAnimation(arr)) // Check if it's animated + } + else { + resolve(false) // Not a GIF or WebP + } + } + + fileReader.onerror = function (err) { + reject(err) // Reject the promise on error + } + + // Read the file as an array buffer + fileReader.readAsArrayBuffer(file) + }) +} + +// Function to check for WebP signature +function isWebP(arr) { + return ( + arr[0] === 0x52 && arr[1] === 0x49 && arr[2] === 0x46 && arr[3] === 0x46 + && arr[8] === 0x57 && arr[9] === 0x45 && arr[10] === 0x42 && arr[11] === 0x50 + ) // "WEBP" +} + +// Function to check if the WebP is animated (contains ANIM chunk) +function checkWebPAnimation(arr) { + // Search for the ANIM chunk in WebP to determine if it's animated + for (let i = 12; i < arr.length - 4; i++) { + if (arr[i] === 0x41 && arr[i + 1] === 0x4E && arr[i + 2] === 0x49 && arr[i + 3] === 0x4D) + return true // Found animation + } + return false // No animation chunk found +} diff --git a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap index 070975bfa7..7da09c4529 100644 --- a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap +++ b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap @@ -1804,6 +1804,280 @@ exports[`build chat item tree and get thread messages should get thread messages ] `; +exports[`build chat item tree and get thread messages should work with partial messages 1`] = ` +[ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105809, + "files": [], + "id": "1019cd79-d141-4f9f-880a-fc1441cfd802", + "message_id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "observation": "", + "position": 1, + "thought": "Sure! My number is 54. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105822, + "files": [], + "id": "0773bec7-b992-4a53-92b2-20ebaeae8798", + "message_id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "observation": "", + "position": 1, + "thought": "My number is 4729. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 4729. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "324bce32-c98c-435d-a66b-bac974ebb5ed", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4729. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.30", + "time": "09/11/2024 09:50 PM", + "tokens": 66, + }, + "parentMessageId": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-324bce32-c98c-435d-a66b-bac974ebb5ed", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726107812, + "files": [], + "id": "5ca650f3-982c-4399-8b95-9ea241c76707", + "message_id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "observation": "", + "position": 1, + "thought": "My number is 4821. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726111024, + "files": [], + "id": "095cacab-afad-4387-a41d-1662578b8b13", + "message_id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "observation": "", + "position": 1, + "thought": "My number is 1456. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "My number is 1456. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "19904a7b-7494-4ed8-b72c-1d18668cea8c", + "input": { + "inputs": {}, + "query": "1003", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "1003", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 1456. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.38", + "time": "09/11/2024 11:17 PM", + "tokens": 86, + }, + "parentMessageId": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "1003", + "id": "question-19904a7b-7494-4ed8-b72c-1d18668cea8c", + "isAnswer": false, + "message_files": [], + "parentMessageId": "684b5396-4e91-4043-88e9-aabe48b21acc", + }, + ], + "content": "My number is 4821. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "684b5396-4e91-4043-88e9-aabe48b21acc", + "input": { + "inputs": {}, + "query": "3306", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "3306", + }, + { + "files": [], + "role": "assistant", + "text": "My number is 4821. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.48", + "time": "09/11/2024 10:23 PM", + "tokens": 66, + }, + "parentMessageId": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "siblingIndex": 1, + "workflow_run_id": null, + }, + ], + "content": "3306", + "id": "question-684b5396-4e91-4043-88e9-aabe48b21acc", + "isAnswer": false, + "message_files": [], + "parentMessageId": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + }, + ], + "content": "Sure! My number is 54. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "input": { + "inputs": {}, + "query": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure! My number is 54. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.52", + "time": "09/11/2024 09:50 PM", + "tokens": 46, + }, + "parentMessageId": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + "id": "question-cd5affb0-7bc2-4a6f-be7e-25e74595c9dd", + "isAnswer": false, + "message_files": [], + }, +] +`; + exports[`build chat item tree and get thread messages should work with real world messages 1`] = ` [ { diff --git a/web/app/components/base/chat/__tests__/utils.spec.ts b/web/app/components/base/chat/__tests__/utils.spec.ts index c602ac8a99..1dead1c949 100644 --- a/web/app/components/base/chat/__tests__/utils.spec.ts +++ b/web/app/components/base/chat/__tests__/utils.spec.ts @@ -255,4 +255,10 @@ describe('build chat item tree and get thread messages', () => { const threadMessages6_2 = getThreadMessages(tree6, 'ff4c2b43-48a5-47ad-9dc5-08b34ddba61b') expect(threadMessages6_2).toMatchSnapshot() }) + + const partialMessages = (realWorldMessages as ChatItemInTree[]).slice(-10) + const tree7 = buildChatItemTree(partialMessages) + it('should work with partial messages', () => { + expect(tree7).toMatchSnapshot() + }) }) diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index 16357361cf..61dfaecffc 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -134,6 +134,12 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { } } + // If no messages have parentMessageId=null (indicating a root node), + // then we likely have a partial chat history. In this case, + // use the first available message as the root node. + if (rootNodes.length === 0 && allMessages.length > 0) + rootNodes.push(map[allMessages[0]!.id]!) + return rootNodes } diff --git a/web/app/components/base/file-uploader/constants.ts b/web/app/components/base/file-uploader/constants.ts index 629fe2566b..a749d73c74 100644 --- a/web/app/components/base/file-uploader/constants.ts +++ b/web/app/components/base/file-uploader/constants.ts @@ -3,5 +3,6 @@ export const IMG_SIZE_LIMIT = 10 * 1024 * 1024 export const FILE_SIZE_LIMIT = 15 * 1024 * 1024 export const AUDIO_SIZE_LIMIT = 50 * 1024 * 1024 export const VIDEO_SIZE_LIMIT = 100 * 1024 * 1024 +export const MAX_FILE_UPLOAD_LIMIT = 10 export const FILE_URL_REGEX = /^(https?|ftp):\/\// diff --git a/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx index d22d6ff4ec..2a042bab40 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-attachment/file-item.tsx @@ -1,6 +1,5 @@ import { memo, - useMemo, } from 'react' import { RiDeleteBinLine, @@ -35,17 +34,9 @@ const FileInAttachmentItem = ({ onRemove, onReUpload, }: FileInAttachmentItemProps) => { - const { id, name, type, progress, supportFileType, base64Url, url } = file - const ext = getFileExtension(name, type) + const { id, name, type, progress, supportFileType, base64Url, url, isRemote } = file + const ext = getFileExtension(name, type, isRemote) const isImageFile = supportFileType === SupportUploadFileTypes.image - const nameArr = useMemo(() => { - const nameMatch = name.match(/(.+)\.([^.]+)$/) - - if (nameMatch) - return [nameMatch[1], nameMatch[2]] - - return [name, ''] - }, [name]) return (
-
{nameArr[0]}
- { - nameArr[1] && ( - .{nameArr[1]} - ) - } +
{name}
{ @@ -93,7 +79,11 @@ const FileInAttachmentItem = ({ ) } - {formatFileSize(file.size || 0)} + { + !!file.size && ( + {formatFileSize(file.size)} + ) + }
diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx index 6597373020..a051b89ec1 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx @@ -31,8 +31,8 @@ const FileItem = ({ onRemove, onReUpload, }: FileItemProps) => { - const { id, name, type, progress, url } = file - const ext = getFileExtension(name, type) + const { id, name, type, progress, url, isRemote } = file + const ext = getFileExtension(name, type, isRemote) const uploadError = progress === -1 return ( @@ -75,7 +75,9 @@ const FileItem = ({ ) } - {formatFileSize(file.size || 0)} + { + !!file.size && formatFileSize(file.size) + }
{ showDownloadAction && ( diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 942e5d612a..c735754ffe 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -18,6 +18,7 @@ import { AUDIO_SIZE_LIMIT, FILE_SIZE_LIMIT, IMG_SIZE_LIMIT, + MAX_FILE_UPLOAD_LIMIT, VIDEO_SIZE_LIMIT, } from '@/app/components/base/file-uploader/constants' import { useToastContext } from '@/app/components/base/toast' @@ -25,7 +26,7 @@ import { TransferMethod } from '@/types/app' import { SupportUploadFileTypes } from '@/app/components/workflow/types' import type { FileUpload } from '@/app/components/base/features/types' import { formatFileSize } from '@/utils/format' -import { fetchRemoteFileInfo } from '@/service/common' +import { uploadRemoteFileInfo } from '@/service/common' import type { FileUploadConfigResponse } from '@/models/common' export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => { @@ -33,12 +34,14 @@ export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => const docSizeLimit = Number(fileUploadConfig?.file_size_limit) * 1024 * 1024 || FILE_SIZE_LIMIT const audioSizeLimit = Number(fileUploadConfig?.audio_file_size_limit) * 1024 * 1024 || AUDIO_SIZE_LIMIT const videoSizeLimit = Number(fileUploadConfig?.video_file_size_limit) * 1024 * 1024 || VIDEO_SIZE_LIMIT + const maxFileUploadLimit = Number(fileUploadConfig?.workflow_file_upload_limit) || MAX_FILE_UPLOAD_LIMIT return { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit, + maxFileUploadLimit, } } @@ -49,7 +52,7 @@ export const useFile = (fileConfig: FileUpload) => { const params = useParams() const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileConfig.fileUploadConfig) - const checkSizeLimit = (fileType: string, fileSize: number) => { + const checkSizeLimit = useCallback((fileType: string, fileSize: number) => { switch (fileType) { case SupportUploadFileTypes.image: { if (fileSize > imgSizeLimit) { @@ -120,7 +123,7 @@ export const useFile = (fileConfig: FileUpload) => { return true } } - } + }, [audioSizeLimit, docSizeLimit, imgSizeLimit, notify, t, videoSizeLimit]) const handleAddFile = useCallback((newFile: FileEntity) => { const { @@ -188,6 +191,17 @@ export const useFile = (fileConfig: FileUpload) => { } }, [fileStore, notify, t, handleUpdateFile, params]) + const startProgressTimer = useCallback((fileId: string) => { + const timer = setInterval(() => { + const files = fileStore.getState().files + const file = files.find(file => file.id === fileId) + + if (file && file.progress < 80 && file.progress >= 0) + handleUpdateFile({ ...file, progress: file.progress + 20 }) + else + clearTimeout(timer) + }, 200) + }, [fileStore, handleUpdateFile]) const handleLoadFileFromLink = useCallback((url: string) => { const allowedFileTypes = fileConfig.allowed_file_types @@ -197,19 +211,27 @@ export const useFile = (fileConfig: FileUpload) => { type: '', size: 0, progress: 0, - transferMethod: TransferMethod.remote_url, + transferMethod: TransferMethod.local_file, supportFileType: '', url, + isRemote: true, } handleAddFile(uploadingFile) + startProgressTimer(uploadingFile.id) - fetchRemoteFileInfo(url).then((res) => { + uploadRemoteFileInfo(url, !!params.token).then((res) => { const newFile = { ...uploadingFile, - type: res.file_type, - size: res.file_length, + type: res.mime_type, + size: res.size, progress: 100, - supportFileType: getSupportFileType(url, res.file_type, allowedFileTypes?.includes(SupportUploadFileTypes.custom)), + supportFileType: getSupportFileType(res.name, res.mime_type, allowedFileTypes?.includes(SupportUploadFileTypes.custom)), + uploadedId: res.id, + url: res.url, + } + if (!isAllowedFileExtension(res.name, res.mime_type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) { + notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + handleRemoveFile(uploadingFile.id) } if (!checkSizeLimit(newFile.supportFileType, newFile.size)) handleRemoveFile(uploadingFile.id) @@ -219,7 +241,7 @@ export const useFile = (fileConfig: FileUpload) => { notify({ type: 'error', message: t('common.fileUploader.pasteFileLinkInvalid') }) handleRemoveFile(uploadingFile.id) }) - }, [checkSizeLimit, handleAddFile, handleUpdateFile, notify, t, handleRemoveFile, fileConfig?.allowed_file_types]) + }, [checkSizeLimit, handleAddFile, handleUpdateFile, notify, t, handleRemoveFile, fileConfig?.allowed_file_types, fileConfig.allowed_file_extensions, startProgressTimer]) const handleLoadFileFromLinkSuccess = useCallback(() => { }, []) diff --git a/web/app/components/base/file-uploader/types.ts b/web/app/components/base/file-uploader/types.ts index ac4584bb4c..285023f0af 100644 --- a/web/app/components/base/file-uploader/types.ts +++ b/web/app/components/base/file-uploader/types.ts @@ -29,4 +29,5 @@ export type FileEntity = { uploadedId?: string base64Url?: string url?: string + isRemote?: boolean } diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index 4c7ef0d89b..eb9199d74b 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -43,10 +43,13 @@ export const fileUpload: FileUpload = ({ }) } -export const getFileExtension = (fileName: string, fileMimetype: string) => { +export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => { if (fileMimetype) return mime.getExtension(fileMimetype) || '' + if (isRemote) + return '' + if (fileName) { const fileNamePair = fileName.split('.') const fileNamePairLength = fileNamePair.length diff --git a/web/app/components/base/image-uploader/image-list.tsx b/web/app/components/base/image-uploader/image-list.tsx index 8d5d1a1af5..35f6149b13 100644 --- a/web/app/components/base/image-uploader/image-list.tsx +++ b/web/app/components/base/image-uploader/image-list.tsx @@ -133,6 +133,7 @@ const ImageList: FC = ({ setImagePreviewUrl('')} + title='' /> )}
diff --git a/web/app/components/base/search-input/index.tsx b/web/app/components/base/search-input/index.tsx index 4b3821da5a..89345fbe32 100644 --- a/web/app/components/base/search-input/index.tsx +++ b/web/app/components/base/search-input/index.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import { useState } from 'react' +import { useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { RiSearchLine } from '@remixicon/react' import cn from '@/utils/classnames' @@ -12,6 +12,7 @@ type SearchInputProps = { onChange: (v: string) => void white?: boolean } + const SearchInput: FC = ({ placeholder, className, @@ -21,6 +22,7 @@ const SearchInput: FC = ({ }) => { const { t } = useTranslation() const [focus, setFocus] = useState(false) + const isComposing = useRef(false) return (
= ({ placeholder={placeholder || t('common.operation.search')!} value={value} onChange={(e) => { - onChange(e.target.value) + if (!isComposing.current) + onChange(e.target.value) + }} + onCompositionStart={() => { + isComposing.current = true + }} + onCompositionEnd={() => { + isComposing.current = false }} onFocus={() => setFocus(true)} onBlur={() => setFocus(false)} diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index 02a642b94c..ba667955ce 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -126,7 +126,7 @@ const Select: FC = ({
- {filteredItems.length > 0 && ( + {(filteredItems.length > 0 && open) && ( {filteredItems.map((item: Item) => ( = ({ onClick={onClick} > -
{t(`billing.upgradeBtn.${isShort ? 'encourageShort' : 'encourage'}`)}
+
{t(`billing.upgradeBtn.${isShort ? 'encourageShort' : 'encourage'}`)}
diff --git a/web/app/components/develop/md.tsx b/web/app/components/develop/md.tsx index 793e294389..7cb0dd7dde 100644 --- a/web/app/components/develop/md.tsx +++ b/web/app/components/develop/md.tsx @@ -39,6 +39,7 @@ export const Heading = function H2({ } return ( <> +
{method} {/* */} diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 7d80367ce4..6642c5cedc 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -656,6 +656,11 @@ Chat applications support session persistence, allowing previous chat history to Return only pinned conversations as `true`, only non-pinned as `false` + + Sorting Field (Optional), Default: -updated_at (sorted in descending order by update time) + - Available Values: created_at, -created_at, updated_at, -updated_at + - The symbol before the field represents the order or reverse, "-" represents reverse order. + ### Response diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 690d700f05..8e64d63ac5 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -691,6 +691,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 只返回置顶 true,只返回非置顶 false + + 排序字段(选题),默认 -updated_at(按更新时间倒序排列) + - 可选值:created_at, -created_at, updated_at, -updated_at + - 字段前面的符号代表顺序或倒序,-代表倒序 + ### Response diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 907a1ab0b4..a94016ca3a 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -690,6 +690,11 @@ Chat applications support session persistence, allowing previous chat history to Return only pinned conversations as `true`, only non-pinned as `false` + + Sorting Field (Optional), Default: -updated_at (sorted in descending order by update time) + - Available Values: created_at, -created_at, updated_at, -updated_at + - The symbol before the field represents the order or reverse, "-" represents reverse order. + ### Response diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index f6dc7daa1e..92b13b2c7d 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -705,6 +705,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 只返回置顶 true,只返回非置顶 false + + 排序字段(选题),默认 -updated_at(按更新时间倒序排列) + - 可选值:created_at, -created_at, updated_at, -updated_at + - 字段前面的符号代表顺序或倒序,-代表倒序 + ### Response diff --git a/web/app/components/explore/category.tsx b/web/app/components/explore/category.tsx index cbf6cd26fe..8f67f0fd49 100644 --- a/web/app/components/explore/category.tsx +++ b/web/app/components/explore/category.tsx @@ -28,7 +28,7 @@ const Category: FC = ({ allCategoriesEn, }) => { const { t } = useTranslation() - const isAllCategories = !list.includes(value as AppCategory) + const isAllCategories = !list.includes(value as AppCategory) || value === allCategoriesEn const itemClassName = (isSelected: boolean) => cn( 'flex items-center px-3 py-[7px] h-[32px] rounded-lg border-[0.5px] border-transparent text-gray-700 font-medium leading-[18px] cursor-pointer hover:bg-gray-200', @@ -44,7 +44,7 @@ const Category: FC = ({ {t('explore.apps.allCategories')}
- {list.map(name => ( + {list.filter(name => name !== allCategoriesEn).map(name => (
= ({ updatePayload, on : '', selectedVersion: '', selectedPackage: '', - releases: [], + releases: updatePayload ? updatePayload.originalPackageInfo.releases : [], }) const [uniqueIdentifier, setUniqueIdentifier] = useState(null) const [manifest, setManifest] = useState(null) @@ -133,11 +133,6 @@ const InstallFromGitHub: React.FC = ({ updatePayload, on }) } - useEffect(() => { - if (state.step === InstallStepFromGitHub.selectPackage) - handleUrlSubmit() - }, []) - return ( = ({ repo: meta!.repo, version: meta!.version, package: meta!.package, + releases: fetchedReleases, }, }, }, diff --git a/web/app/components/plugins/types.ts b/web/app/components/plugins/types.ts index f0f80a3e57..304ebcab69 100644 --- a/web/app/components/plugins/types.ts +++ b/web/app/components/plugins/types.ts @@ -155,6 +155,7 @@ export type UpdateFromGitHubPayload = { repo: string version: string package: string + releases: GitHubRepoReleaseResponse[] } } diff --git a/web/app/components/swr-initor.tsx b/web/app/components/swr-initor.tsx index 8c5d9725d8..a2ae003139 100644 --- a/web/app/components/swr-initor.tsx +++ b/web/app/components/swr-initor.tsx @@ -4,7 +4,6 @@ import { SWRConfig } from 'swr' import { useCallback, useEffect, useState } from 'react' import type { ReactNode } from 'react' import { usePathname, useRouter, useSearchParams } from 'next/navigation' -import useRefreshToken from '@/hooks/use-refresh-token' import { fetchSetupStatus } from '@/service/common' interface SwrInitorProps { @@ -15,12 +14,11 @@ const SwrInitor = ({ }: SwrInitorProps) => { const router = useRouter() const searchParams = useSearchParams() - const pathname = usePathname() - const { getNewAccessToken } = useRefreshToken() - const consoleToken = searchParams.get('access_token') - const refreshToken = searchParams.get('refresh_token') + const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') + const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + const pathname = usePathname() const [init, setInit] = useState(false) const isSetupFinished = useCallback(async () => { @@ -41,25 +39,6 @@ const SwrInitor = ({ } }, []) - const setRefreshToken = useCallback(async () => { - try { - if (!(consoleToken || refreshToken || consoleTokenFromLocalStorage || refreshTokenFromLocalStorage)) - return Promise.reject(new Error('No token found')) - - if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage) - await getNewAccessToken() - - if (consoleToken && refreshToken) { - localStorage.setItem('console_token', consoleToken) - localStorage.setItem('refresh_token', refreshToken) - await getNewAccessToken() - } - } - catch (error) { - return Promise.reject(error) - } - }, [consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage, getNewAccessToken]) - useEffect(() => { (async () => { try { @@ -68,9 +47,15 @@ const SwrInitor = ({ router.replace('/install') return } - await setRefreshToken() - if (searchParams.has('access_token') || searchParams.has('refresh_token')) + if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) { + router.replace('/signin') + return + } + if (searchParams.has('access_token') || searchParams.has('refresh_token')) { + consoleToken && localStorage.setItem('console_token', consoleToken) + refreshToken && localStorage.setItem('refresh_token', refreshToken) router.replace(pathname) + } setInit(true) } @@ -78,7 +63,7 @@ const SwrInitor = ({ router.replace('/signin') } })() - }, [isSetupFinished, setRefreshToken, router, pathname, searchParams]) + }, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage]) return init ? ( diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 9d533c1ee9..09ac2ed8ea 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -340,7 +340,9 @@ export const NODES_INITIAL_DATA = { ...ListFilterDefault.defaultValue, }, } - +export const MAX_ITERATION_PARALLEL_NUM = 10 +export const MIN_ITERATION_PARALLEL_NUM = 1 +export const DEFAULT_ITER_TIMES = 1 export const NODE_WIDTH = 240 export const X_OFFSET = 60 export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index af2a1500ba..375a269377 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -644,6 +644,11 @@ export const useNodesInteractions = () => { newNode.data.isInIteration = true newNode.data.iteration_id = prevNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX + if (newNode.data.type === BlockEnum.Answer || newNode.data.type === BlockEnum.Tool || newNode.data.type === BlockEnum.Assigner) { + const parentIterNodeIndex = nodes.findIndex(node => node.id === prevNode.parentId) + const iterNodeData: IterationNodeType = nodes[parentIterNodeIndex].data + iterNodeData._isShowTips = true + } } const newEdge: Edge = { diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 0bbb1adab8..26654ef71e 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -14,6 +14,7 @@ import { NodeRunningStatus, WorkflowRunningStatus, } from '../types' +import { DEFAULT_ITER_TIMES } from '../constants' import { useWorkflowUpdate } from './use-workflow-interactions' import { useStore as useAppStore } from '@/app/components/app/store' import type { IOtherOptions } from '@/service/base' @@ -170,11 +171,13 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterParallelLogMap, } = workflowStore.getState() const { edges, setEdges, } = store.getState() + setIterParallelLogMap(new Map()) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.task_id = task_id draft.result = { @@ -244,6 +247,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -259,10 +264,21 @@ export const useWorkflowRun = () => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === node?.parentId) const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] - currIteration?.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) + if (!data.parallel_run_id) { + currIteration?.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + } + else { + if (!iterParallelLogMap.has(data.parallel_run_id)) + iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) + else + iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) + setIterParallelLogMap(iterParallelLogMap) + if (iterations) + iterations.details = Array.from(iterParallelLogMap.values()) + } })) } else { @@ -309,6 +325,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -317,21 +335,21 @@ export const useWorkflowRun = () => { const nodes = getNodes() const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId if (nodeParentId) { - setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const tracing = draft.tracing! - const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + if (!data.execution_metadata.parallel_mode_run_id) { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node - if (iterations && iterations.details) { - const iterationIndex = data.execution_metadata?.iteration_index || 0 - if (!iterations.details[iterationIndex]) - iterations.details[iterationIndex] = [] + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] - const currIteration = iterations.details[iterationIndex] - const nodeIndex = currIteration.findIndex(node => - node.node_id === data.node_id && ( - node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), - ) - if (data.status === NodeRunningStatus.Succeeded) { + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) if (nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], @@ -344,8 +362,40 @@ export const useWorkflowRun = () => { } as any) } } - } - })) + })) + } + else { + // open parallel mode + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + + if (iterations && iterations.details) { + const iterRunID = data.execution_metadata?.parallel_mode_run_id + + const currIteration = iterParallelLogMap.get(iterRunID) + const nodeIndex = currIteration?.findIndex(node => + node.node_id === data.node_id && ( + node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), + ) + if (currIteration) { + if (nodeIndex !== undefined && nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + setIterParallelLogMap(iterParallelLogMap) + iterations.details = Array.from(iterParallelLogMap.values()) + } + })) + } } else { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { @@ -379,6 +429,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -388,6 +439,7 @@ export const useWorkflowRun = () => { transform, } = store.getState() const nodes = getNodes() + setIterTimes(DEFAULT_ITER_TIMES) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -431,6 +483,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterTimes, + setIterTimes, } = workflowStore.getState() const { data } = params @@ -445,13 +499,14 @@ export const useWorkflowRun = () => { if (iteration.details!.length >= iteration.metadata.iterator_length!) return } - iteration?.details!.push([]) + if (!data.parallel_mode_run_id) + iteration?.details!.push([]) })) const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - - currentNode.data._iterationIndex = data.index > 0 ? data.index : 1 + currentNode.data._iterationIndex = iterTimes + setIterTimes(iterTimes + 1) }) setNodes(newNodes) @@ -464,6 +519,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -480,7 +536,7 @@ export const useWorkflowRun = () => { }) } })) - + setIterTimes(DEFAULT_ITER_TIMES) const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index 55b8e7dd3b..fa4389efdc 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -26,7 +26,7 @@ interface Props { isFocus: boolean isInNode?: boolean onGenerated?: (prompt: string) => void - codeLanguages: CodeLanguage + codeLanguages?: CodeLanguage fileList?: FileEntity[] showFileList?: boolean showCodeGenerator?: boolean @@ -78,7 +78,7 @@ const Base: FC = ({ e.stopPropagation() }}> {headerRight} - {showCodeGenerator && ( + {showCodeGenerator && codeLanguages && (
diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx b/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx index b5ca968185..1656d5e43d 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx @@ -31,6 +31,7 @@ export interface Props { noWrapper?: boolean isExpand?: boolean showFileList?: boolean + onGenerated?: (value: string) => void showCodeGenerator?: boolean } @@ -64,6 +65,7 @@ const CodeEditor: FC = ({ noWrapper, isExpand, showFileList, + onGenerated, showCodeGenerator = false, }) => { const [isFocus, setIsFocus] = React.useState(false) @@ -151,9 +153,6 @@ const CodeEditor: FC = ({ return isFocus ? 'focus-theme' : 'blur-theme' })() - const handleGenerated = (code: string) => { - handleEditorChange(code) - } const main = ( <> @@ -205,7 +204,7 @@ const CodeEditor: FC = ({ isFocus={isFocus && !readOnly} minHeight={minHeight} isInNode={isInNode} - onGenerated={handleGenerated} + onGenerated={onGenerated} codeLanguages={language} fileList={fileList} showFileList={showFileList} diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index a36dadbbef..8e83a0508a 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -12,15 +12,15 @@ import Tooltip from '@/app/components/base/tooltip' interface Props { className?: string title: JSX.Element | string | DefaultTFuncReturn + tooltip?: React.ReactNode isSubTitle?: boolean - tooltip?: string supportFold?: boolean children?: JSX.Element | string | null operations?: JSX.Element inline?: boolean } -const Filed: FC = ({ +const Field: FC = ({ className, title, isSubTitle, @@ -60,4 +60,4 @@ const Filed: FC = ({
) } -export default React.memo(Filed) +export default React.memo(Field) diff --git a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx index 82a3a906cf..42a7213f80 100644 --- a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx +++ b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx @@ -39,7 +39,13 @@ const FileUploadSetting: FC = ({ allowed_file_extensions, } = payload const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) - const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileUploadConfigResponse) + const { + imgSizeLimit, + docSizeLimit, + audioSizeLimit, + videoSizeLimit, + maxFileUploadLimit, + } = useFileSizeLimit(fileUploadConfigResponse) const handleSupportFileTypeChange = useCallback((type: SupportUploadFileTypes) => { const newPayload = produce(payload, (draft) => { @@ -156,7 +162,7 @@ const FileUploadSetting: FC = ({
diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index c6cffb7331..64d9f5fd7e 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -106,32 +106,29 @@ const useOneStepRun = ({ const availableNodesIncludeParent = getBeforeNodesInSameBranchIncludeParent(id) const allOutputVars = toNodeOutputVars(availableNodes, isChatMode, undefined, undefined, conversationVariables) const getVar = (valueSelector: ValueSelector): Var | undefined => { - let res: Var | undefined const isSystem = valueSelector[0] === 'sys' - const targetVar = isSystem ? allOutputVars.find(item => !!item.isStartNode) : allOutputVars.find(v => v.nodeId === valueSelector[0]) + const targetVar = allOutputVars.find(item => isSystem ? !!item.isStartNode : item.nodeId === valueSelector[0]) if (!targetVar) return undefined + if (isSystem) return targetVar.vars.find(item => item.variable.split('.')[1] === valueSelector[1]) let curr: any = targetVar.vars - if (!curr) - return + for (let i = 1; i < valueSelector.length; i++) { + const key = valueSelector[i] + const isLast = i === valueSelector.length - 1 - valueSelector.slice(1).forEach((key, i) => { - const isLast = i === valueSelector.length - 2 - // conversation variable is start with 'conversation.' - curr = curr?.find((v: any) => v.variable.replace('conversation.', '') === key) - if (isLast) { - res = curr - } - else { - if (curr?.type === VarType.object || curr?.type === VarType.file) - curr = curr.children - } - }) + if (Array.isArray(curr)) + curr = curr.find((v: any) => v.variable.replace('conversation.', '') === key) - return res + if (isLast) + return curr + else if (curr?.type === VarType.object || curr?.type === VarType.file) + curr = curr.children + } + + return undefined } const checkValid = checkValidFns[data.type] diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index bd5921c735..e864c419e2 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,6 +25,7 @@ import { useToolIcon, } from '../../hooks' import { useNodeIterationInteractions } from '../iteration/use-interactions' +import type { IterationNodeType } from '../iteration/types' import { NodeSourceHandle, NodeTargetHandle, @@ -34,6 +35,7 @@ import NodeControl from './components/node-control' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' +import Tooltip from '@/app/components/base/tooltip' type BaseNodeProps = { children: ReactElement @@ -166,9 +168,27 @@ const BaseNode: FC = ({ />
- {data.title} +
+ {data.title} +
+ { + data.type === BlockEnum.Iteration && (data as IterationNodeType).is_parallel && ( + +
+ {t('workflow.nodes.iteration.parallelModeEnableTitle')} +
+ {t('workflow.nodes.iteration.parallelModeEnableDesc')} +
} + > +
+ {t('workflow.nodes.iteration.parallelModeUpper')} +
+ + ) + } { data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && ( diff --git a/web/app/components/workflow/nodes/code/code-parser.spec.ts b/web/app/components/workflow/nodes/code/code-parser.spec.ts new file mode 100644 index 0000000000..b5d28dd136 --- /dev/null +++ b/web/app/components/workflow/nodes/code/code-parser.spec.ts @@ -0,0 +1,326 @@ +import { VarType } from '../../types' +import { extractFunctionParams, extractReturnType } from './code-parser' +import { CodeLanguage } from './types' + +const SAMPLE_CODES = { + python3: { + noParams: 'def main():', + singleParam: 'def main(param1):', + multipleParams: `def main(param1, param2, param3): + return {"result": param1}`, + withTypes: `def main(param1: str, param2: int, param3: List[str]): + result = process_data(param1, param2) + return {"output": result}`, + withDefaults: `def main(param1: str = "default", param2: int = 0): + return {"data": param1}`, + }, + javascript: { + noParams: 'function main() {', + singleParam: 'function main(param1) {', + multipleParams: `function main(param1, param2, param3) { + return { result: param1 } + }`, + withComments: `// Main function + function main(param1, param2) { + // Process data + return { output: process(param1, param2) } + }`, + withSpaces: 'function main( param1 , param2 ) {', + }, +} + +describe('extractFunctionParams', () => { + describe('Python3', () => { + test('handles no parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.noParams, CodeLanguage.python3) + expect(result).toEqual([]) + }) + + test('extracts single parameter', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.singleParam, CodeLanguage.python3) + expect(result).toEqual(['param1']) + }) + + test('extracts multiple parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.multipleParams, CodeLanguage.python3) + expect(result).toEqual(['param1', 'param2', 'param3']) + }) + + test('handles type hints', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.withTypes, CodeLanguage.python3) + expect(result).toEqual(['param1', 'param2', 'param3']) + }) + + test('handles default values', () => { + const result = extractFunctionParams(SAMPLE_CODES.python3.withDefaults, CodeLanguage.python3) + expect(result).toEqual(['param1', 'param2']) + }) + }) + + // JavaScriptのテストケース + describe('JavaScript', () => { + test('handles no parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.noParams, CodeLanguage.javascript) + expect(result).toEqual([]) + }) + + test('extracts single parameter', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.singleParam, CodeLanguage.javascript) + expect(result).toEqual(['param1']) + }) + + test('extracts multiple parameters', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.multipleParams, CodeLanguage.javascript) + expect(result).toEqual(['param1', 'param2', 'param3']) + }) + + test('handles comments in code', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.withComments, CodeLanguage.javascript) + expect(result).toEqual(['param1', 'param2']) + }) + + test('handles whitespace', () => { + const result = extractFunctionParams(SAMPLE_CODES.javascript.withSpaces, CodeLanguage.javascript) + expect(result).toEqual(['param1', 'param2']) + }) + }) +}) + +const RETURN_TYPE_SAMPLES = { + python3: { + singleReturn: ` +def main(param1): + return {"result": "value"}`, + + multipleReturns: ` +def main(param1, param2): + return {"result": "value", "status": "success"}`, + + noReturn: ` +def main(): + print("Hello")`, + + complexReturn: ` +def main(): + data = process() + return {"result": data, "count": 42, "messages": ["hello"]}`, + nestedObject: ` + def main(name, age, city): + return { + 'personal_info': { + 'name': name, + 'age': age, + 'city': city + }, + 'timestamp': int(time.time()), + 'status': 'active' + }`, + }, + + javascript: { + singleReturn: ` +function main(param1) { + return { result: "value" } +}`, + + multipleReturns: ` +function main(param1) { + return { result: "value", status: "success" } +}`, + + withParentheses: ` +function main() { + return ({ result: "value", status: "success" }) +}`, + + noReturn: ` +function main() { + console.log("Hello") +}`, + + withQuotes: ` +function main() { + return { "result": 'value', 'status': "success" } +}`, + nestedObject: ` +function main(name, age, city) { + return { + personal_info: { + name: name, + age: age, + city: city + }, + timestamp: Date.now(), + status: 'active' + } +}`, + withJSDoc: ` +/** + * Creates a user profile with personal information and metadata + * @param {string} name - The user's name + * @param {number} age - The user's age + * @param {string} city - The user's city of residence + * @returns {Object} An object containing the user profile + */ +function main(name, age, city) { + return { + result: { + personal_info: { + name: name, + age: age, + city: city + }, + timestamp: Date.now(), + status: 'active' + } + }; +}`, + + }, +} + +describe('extractReturnType', () => { + // Python3のテスト + describe('Python3', () => { + test('extracts single return value', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.singleReturn, CodeLanguage.python3) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + }) + }) + + test('extracts multiple return values', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.multipleReturns, CodeLanguage.python3) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + + test('returns empty object when no return statement', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.noReturn, CodeLanguage.python3) + expect(result).toEqual({}) + }) + + test('handles complex return statement', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.complexReturn, CodeLanguage.python3) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + count: { + type: VarType.string, + children: null, + }, + messages: { + type: VarType.string, + children: null, + }, + }) + }) + test('handles nested object structure', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.python3.nestedObject, CodeLanguage.python3) + expect(result).toEqual({ + personal_info: { + type: VarType.string, + children: null, + }, + timestamp: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + }) + + // JavaScriptのテスト + describe('JavaScript', () => { + test('extracts single return value', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.singleReturn, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + }) + }) + + test('extracts multiple return values', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.multipleReturns, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + + test('handles return with parentheses', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.withParentheses, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + + test('returns empty object when no return statement', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.noReturn, CodeLanguage.javascript) + expect(result).toEqual({}) + }) + + test('handles quoted keys', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.withQuotes, CodeLanguage.javascript) + expect(result).toEqual({ + result: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + test('handles nested object structure', () => { + const result = extractReturnType(RETURN_TYPE_SAMPLES.javascript.nestedObject, CodeLanguage.javascript) + expect(result).toEqual({ + personal_info: { + type: VarType.string, + children: null, + }, + timestamp: { + type: VarType.string, + children: null, + }, + status: { + type: VarType.string, + children: null, + }, + }) + }) + }) +}) diff --git a/web/app/components/workflow/nodes/code/code-parser.ts b/web/app/components/workflow/nodes/code/code-parser.ts new file mode 100644 index 0000000000..e1b0928f14 --- /dev/null +++ b/web/app/components/workflow/nodes/code/code-parser.ts @@ -0,0 +1,86 @@ +import { VarType } from '../../types' +import type { OutputVar } from './types' +import { CodeLanguage } from './types' + +export const extractFunctionParams = (code: string, language: CodeLanguage) => { + if (language === CodeLanguage.json) + return [] + + const patterns: Record, RegExp> = { + [CodeLanguage.python3]: /def\s+main\s*\((.*?)\)/, + [CodeLanguage.javascript]: /function\s+main\s*\((.*?)\)/, + } + const match = code.match(patterns[language]) + const params: string[] = [] + + if (match?.[1]) { + params.push(...match[1].split(',') + .map(p => p.trim()) + .filter(Boolean) + .map(p => p.split(':')[0].trim()), + ) + } + + return params +} +export const extractReturnType = (code: string, language: CodeLanguage): OutputVar => { + const codeWithoutComments = code.replace(/\/\*\*[\s\S]*?\*\//, '') + console.log(codeWithoutComments) + + const returnIndex = codeWithoutComments.indexOf('return') + if (returnIndex === -1) + return {} + + // returnから始まる部分文字列を取得 + const codeAfterReturn = codeWithoutComments.slice(returnIndex) + + let bracketCount = 0 + let startIndex = codeAfterReturn.indexOf('{') + + if (language === CodeLanguage.javascript && startIndex === -1) { + const parenStart = codeAfterReturn.indexOf('(') + if (parenStart !== -1) + startIndex = codeAfterReturn.indexOf('{', parenStart) + } + + if (startIndex === -1) + return {} + + let endIndex = -1 + + for (let i = startIndex; i < codeAfterReturn.length; i++) { + if (codeAfterReturn[i] === '{') + bracketCount++ + if (codeAfterReturn[i] === '}') { + bracketCount-- + if (bracketCount === 0) { + endIndex = i + 1 + break + } + } + } + + if (endIndex === -1) + return {} + + const returnContent = codeAfterReturn.slice(startIndex + 1, endIndex - 1) + console.log(returnContent) + + const result: OutputVar = {} + + const keyRegex = /['"]?(\w+)['"]?\s*:(?![^{]*})/g + const matches = returnContent.matchAll(keyRegex) + + for (const match of matches) { + console.log(`Found key: "${match[1]}" from match: "${match[0]}"`) + const key = match[1] + result[key] = { + type: VarType.string, + children: null, + } + } + + console.log(result) + + return result +} diff --git a/web/app/components/workflow/nodes/code/panel.tsx b/web/app/components/workflow/nodes/code/panel.tsx index d3e5e58634..08fc565836 100644 --- a/web/app/components/workflow/nodes/code/panel.tsx +++ b/web/app/components/workflow/nodes/code/panel.tsx @@ -5,6 +5,7 @@ import RemoveEffectVarConfirm from '../_base/components/remove-effect-var-confir import useConfig from './use-config' import type { CodeNodeType } from './types' import { CodeLanguage } from './types' +import { extractFunctionParams, extractReturnType } from './code-parser' import VarList from '@/app/components/workflow/nodes/_base/components/variable/var-list' import OutputVarList from '@/app/components/workflow/nodes/_base/components/variable/output-var-list' import AddButton from '@/app/components/base/button/add-button' @@ -12,10 +13,9 @@ import Field from '@/app/components/workflow/nodes/_base/components/field' import Split from '@/app/components/workflow/nodes/_base/components/split' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import TypeSelector from '@/app/components/workflow/nodes/_base/components/selector' -import type { NodePanelProps } from '@/app/components/workflow/types' +import { type NodePanelProps } from '@/app/components/workflow/types' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import ResultPanel from '@/app/components/workflow/run/result-panel' - const i18nPrefix = 'workflow.nodes.code' const codeLanguages = [ @@ -38,6 +38,7 @@ const Panel: FC> = ({ readOnly, inputs, outputKeyOrders, + handleCodeAndVarsChange, handleVarListChange, handleAddVariable, handleRemoveVariable, @@ -61,6 +62,18 @@ const Panel: FC> = ({ setInputVarValues, } = useConfig(id, data) + const handleGeneratedCode = (value: string) => { + const params = extractFunctionParams(value, inputs.code_language) + const codeNewInput = params.map((p) => { + return { + variable: p, + value_selector: [], + } + }) + const returnTypes = extractReturnType(value, inputs.code_language) + handleCodeAndVarsChange(value, codeNewInput, returnTypes) + } + return (
@@ -92,6 +105,7 @@ const Panel: FC> = ({ language={inputs.code_language} value={inputs.code} onChange={handleCodeChange} + onGenerated={handleGeneratedCode} showCodeGenerator={true} />
diff --git a/web/app/components/workflow/nodes/code/use-config.ts b/web/app/components/workflow/nodes/code/use-config.ts index 07fe85aa0f..c53c07a28e 100644 --- a/web/app/components/workflow/nodes/code/use-config.ts +++ b/web/app/components/workflow/nodes/code/use-config.ts @@ -3,7 +3,7 @@ import produce from 'immer' import useVarList from '../_base/hooks/use-var-list' import useOutputVarList from '../_base/hooks/use-output-var-list' import { BlockEnum, VarType } from '../../types' -import type { Var } from '../../types' +import type { Var, Variable } from '../../types' import { useStore } from '../../store' import type { CodeNodeType, OutputVar } from './types' import { CodeLanguage } from './types' @@ -136,7 +136,15 @@ const useConfig = (id: string, payload: CodeNodeType) => { const setInputVarValues = useCallback((newPayload: Record) => { setRunInputData(newPayload) }, [setRunInputData]) - + const handleCodeAndVarsChange = useCallback((code: string, inputVariables: Variable[], outputVariables: OutputVar) => { + const newInputs = produce(inputs, (draft) => { + draft.code = code + draft.variables = inputVariables + draft.outputs = outputVariables + }) + setInputs(newInputs) + syncOutputKeyOrders(outputVariables) + }, [inputs, setInputs, syncOutputKeyOrders]) return { readOnly, inputs, @@ -163,6 +171,7 @@ const useConfig = (id: string, payload: CodeNodeType) => { inputVarValues, setInputVarValues, runResult, + handleCodeAndVarsChange, } } diff --git a/web/app/components/workflow/nodes/if-else/use-config.ts b/web/app/components/workflow/nodes/if-else/use-config.ts index d1210431a0..41e41f6b8b 100644 --- a/web/app/components/workflow/nodes/if-else/use-config.ts +++ b/web/app/components/workflow/nodes/if-else/use-config.ts @@ -78,24 +78,24 @@ const useConfig = (id: string, payload: IfElseNodeType) => { }) const handleAddCase = useCallback(() => { - const newInputs = produce(inputs, () => { - if (inputs.cases) { + const newInputs = produce(inputs, (draft) => { + if (draft.cases) { const case_id = uuid4() - inputs.cases.push({ + draft.cases.push({ case_id, logical_operator: LogicalOperator.and, conditions: [], }) - if (inputs._targetBranches) { - const elseCaseIndex = inputs._targetBranches.findIndex(branch => branch.id === 'false') + if (draft._targetBranches) { + const elseCaseIndex = draft._targetBranches.findIndex(branch => branch.id === 'false') if (elseCaseIndex > -1) { - inputs._targetBranches = branchNameCorrect([ - ...inputs._targetBranches.slice(0, elseCaseIndex), + draft._targetBranches = branchNameCorrect([ + ...draft._targetBranches.slice(0, elseCaseIndex), { id: case_id, name: '', }, - ...inputs._targetBranches.slice(elseCaseIndex), + ...draft._targetBranches.slice(elseCaseIndex), ]) } } diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 3afa52d06e..cdef268adb 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -1,7 +1,10 @@ -import { BlockEnum } from '../../types' +import { BlockEnum, ErrorHandleMode } from '../../types' import type { NodeDefault } from '../../types' import type { IterationNodeType } from './types' -import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' +import { + ALL_CHAT_AVAILABLE_BLOCKS, + ALL_COMPLETION_AVAILABLE_BLOCKS, +} from '@/app/components/workflow/constants' const i18nPrefix = 'workflow' const nodeDefault: NodeDefault = { @@ -10,25 +13,45 @@ const nodeDefault: NodeDefault = { iterator_selector: [], output_selector: [], _children: [], + _isShowTips: false, + is_parallel: false, + parallel_nums: 10, + error_handle_mode: ErrorHandleMode.Terminated, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS - : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) + : ALL_COMPLETION_AVAILABLE_BLOCKS.filter( + type => type !== BlockEnum.End, + ) return nodes }, getAvailableNextNodes(isChatMode: boolean) { - const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + const nodes = isChatMode + ? ALL_CHAT_AVAILABLE_BLOCKS + : ALL_COMPLETION_AVAILABLE_BLOCKS return nodes }, checkValid(payload: IterationNodeType, t: any) { let errorMessages = '' - if (!errorMessages && (!payload.iterator_selector || payload.iterator_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.input`) }) + if ( + !errorMessages + && (!payload.iterator_selector || payload.iterator_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.input`), + }) + } - if (!errorMessages && (!payload.output_selector || payload.output_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.output`) }) + if ( + !errorMessages + && (!payload.output_selector || payload.output_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.output`), + }) + } return { isValid: !errorMessages, diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index 48a005a261..fda033b87a 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,12 +8,16 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' import cn from '@/utils/classnames' import type { NodeProps } from '@/app/components/workflow/types' +import Toast from '@/app/components/base/toast' + +const i18nPrefix = 'workflow.nodes.iteration' const Node: FC> = ({ id, @@ -22,11 +26,20 @@ const Node: FC> = ({ const { zoom } = useViewport() const nodesInitialized = useNodesInitialized() const { handleNodeIterationRerender } = useNodeIterationInteractions() + const { t } = useTranslation() useEffect(() => { if (nodesInitialized) handleNodeIterationRerender(id) - }, [nodesInitialized, id, handleNodeIterationRerender]) + if (data.is_parallel && data._isShowTips) { + Toast.notify({ + type: 'warning', + message: t(`${i18nPrefix}.answerNodeWarningDesc`), + duration: 5000, + }) + data._isShowTips = false + } + }, [nodesInitialized, id, handleNodeIterationRerender, data, t]) return (
> = ({ data, }) => { const { t } = useTranslation() - + const responseMethod = [ + { + value: ErrorHandleMode.Terminated, + name: t(`${i18nPrefix}.ErrorMethod.operationTerminated`), + }, + { + value: ErrorHandleMode.ContinueOnError, + name: t(`${i18nPrefix}.ErrorMethod.continueOnError`), + }, + { + value: ErrorHandleMode.RemoveAbnormalOutput, + name: t(`${i18nPrefix}.ErrorMethod.removeAbnormalOutput`), + }, + ] const { readOnly, inputs, @@ -47,6 +66,9 @@ const Panel: FC> = ({ setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } = useConfig(id, data) return ( @@ -87,6 +109,39 @@ const Panel: FC> = ({ />
+
+ {t(`${i18nPrefix}.parallelPanelDesc`)}
} inline> + + +
+ { + inputs.is_parallel && (
+ {t(`${i18nPrefix}.MaxParallelismDesc`)}
}> +
+ { changeParallelNums(Number(e.target.value)) }} /> + +
+ + + ) + } +
+ +
+ +
+ + + +
+ {isShowSingleRun && ( { @@ -184,6 +185,25 @@ const useConfig = (id: string, payload: IterationNodeType) => { }) }, [iteratorInputKey, runInputData, setRunInputData]) + const changeParallel = useCallback((value: boolean) => { + const newInputs = produce(inputs, (draft) => { + draft.is_parallel = value + }) + setInputs(newInputs) + }, [inputs, setInputs]) + + const changeErrorResponseMode = useCallback((item: Item) => { + const newInputs = produce(inputs, (draft) => { + draft.error_handle_mode = item.value as ErrorHandleMode + }) + setInputs(newInputs) + }, [inputs, setInputs]) + const changeParallelNums = useCallback((num: number) => { + const newInputs = produce(inputs, (draft) => { + draft.parallel_nums = num + }) + setInputs(newInputs) + }, [inputs, setInputs]) return { readOnly, inputs, @@ -210,6 +230,9 @@ const useConfig = (id: string, payload: IterationNodeType) => { setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } } diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index d280a2d63e..288a718aa2 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -240,7 +240,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { if ( (allInternal && (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)) || mixtureInternalAndExternal - || (allExternal && newDatasets.length > 1) + || allExternal ) setRerankModelOpen(true) }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index e24ca89d4c..1596bd1cd9 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -9,6 +9,8 @@ import { produce, setAutoFreeze } from 'immer' import { uniqBy } from 'lodash-es' import { useWorkflowRun } from '../../hooks' import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' +import { useWorkflowStore } from '../../store' +import { DEFAULT_ITER_TIMES } from '../../constants' import type { ChatItem, Inputs, @@ -43,6 +45,7 @@ export const useChat = ( const { notify } = useToastContext() const { handleRun } = useWorkflowRun() const hasStopResponded = useRef(false) + const workflowStore = useWorkflowStore() const conversationId = useRef('') const taskIdRef = useRef('') const [chatList, setChatList] = useState(prevChatList || []) @@ -52,6 +55,9 @@ export const useChat = ( const [suggestedQuestions, setSuggestQuestions] = useState([]) const suggestedQuestionsAbortControllerRef = useRef(null) + const { + setIterTimes, + } = workflowStore.getState() useEffect(() => { setAutoFreeze(false) return () => { @@ -102,15 +108,16 @@ export const useChat = ( handleResponding(false) if (stopChat && taskIdRef.current) stopChat(taskIdRef.current) - + setIterTimes(DEFAULT_ITER_TIMES) if (suggestedQuestionsAbortControllerRef.current) suggestedQuestionsAbortControllerRef.current.abort() - }, [handleResponding, stopChat]) + }, [handleResponding, setIterTimes, stopChat]) const handleRestart = useCallback(() => { conversationId.current = '' taskIdRef.current = '' handleStop() + setIterTimes(DEFAULT_ITER_TIMES) const newChatList = config?.opening_statement ? [{ id: `${Date.now()}`, @@ -126,6 +133,7 @@ export const useChat = ( config, handleStop, handleUpdateChatList, + setIterTimes, ]) const updateCurrentQA = useCallback(({ diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index e0fcb8040f..6e269c2714 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -60,36 +60,67 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe }, [notify, getResultCallback]) const formatNodeList = useCallback((list: NodeTracing[]) => { - const allItems = list.reverse() + const allItems = [...list].reverse() const result: NodeTracing[] = [] - allItems.forEach((item) => { - const { node_type, execution_metadata } = item - if (node_type !== BlockEnum.Iteration) { - const isInIteration = !!execution_metadata?.iteration_id + const groupMap = new Map() - if (isInIteration) { - const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) - const iterationDetails = iterationNode?.details - const currentIterationIndex = execution_metadata?.iteration_index ?? 0 - - if (Array.isArray(iterationDetails)) { - if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) - iterationDetails[currentIterationIndex] = [item] - else - iterationDetails[currentIterationIndex].push(item) - } - return - } - // not in iteration - result.push(item) - - return - } + const processIterationNode = (item: NodeTracing) => { result.push({ ...item, details: [], }) + } + const updateParallelModeGroup = (runId: string, item: NodeTracing, iterationNode: NodeTracing) => { + if (!groupMap.has(runId)) + groupMap.set(runId, [item]) + else + groupMap.get(runId)!.push(item) + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + + iterationNode.details = Array.from(groupMap.values()) + } + const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { + const { details } = iterationNode + if (details) { + if (!details[index]) + details[index] = [item] + else + details[index].push(item) + } + + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + } + const processNonIterationNode = (item: NodeTracing) => { + const { execution_metadata } = item + if (!execution_metadata?.iteration_id) { + result.push(item) + return + } + + const iterationNode = result.find(node => node.node_id === execution_metadata.iteration_id) + if (!iterationNode || !Array.isArray(iterationNode.details)) + return + + const { parallel_mode_run_id, iteration_index = 0 } = execution_metadata + + if (parallel_mode_run_id) + updateParallelModeGroup(parallel_mode_run_id, item, iterationNode) + else + updateSequentialModeGroup(iteration_index, item, iterationNode) + } + + allItems.forEach((item) => { + item.node_type === BlockEnum.Iteration + ? processIterationNode(item) + : processNonIterationNode(item) }) + return result }, []) diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index 8847c43fb9..44b8ac6b84 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightSLine, RiCloseLine, + RiErrorWarningLine, } from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' import TracingPanel from './tracing-panel' @@ -27,7 +28,7 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() - const [expandedIterations, setExpandedIterations] = useState>([]) + const [expandedIterations, setExpandedIterations] = useState>({}) const toggleIteration = useCallback((index: number) => { setExpandedIterations(prev => ({ @@ -71,10 +72,19 @@ const IterationResultPanel: FC = ({ {t(`${i18nPrefix}.iteration`)} {index + 1} - + { + iteration.some(item => item.status === 'failed') + ? ( + + ) + : (< RiArrowRightSLine className={ + cn( + 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', + expandedIterations[index] && 'transform rotate-90', + )} /> + ) + } + {expandedIterations[index] &&
= ({ return iteration_length } + const getErrorCount = (details: NodeTracing[][] | undefined) => { + if (!details || details.length === 0) + return 0 + return details.reduce((acc, iteration) => { + if (iteration.some(item => item.status === 'failed')) + acc++ + return acc + }, 0) + } useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) @@ -136,7 +145,12 @@ const NodePanel: FC = ({ onClick={handleOnShowIterationDetail} > -
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}
+
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}{getErrorCount(nodeInfo.details) > 0 && ( + <> + {t('workflow.nodes.iteration.comma')} + {t('workflow.nodes.iteration.error', { count: getErrorCount(nodeInfo.details) })} + + )}
{justShowIterationNavArrow ? ( diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index 3202f5d498..e1e9e530c1 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -21,6 +21,7 @@ import type { WorkflowRunningData, } from './types' import { WorkflowContext } from './context' +import type { NodeTracing } from '@/types/workflow' // #TODO chatVar# // const MOCK_DATA = [ @@ -166,6 +167,10 @@ interface Shape { setShowImportDSLModal: (showImportDSLModal: boolean) => void showTips: string setShowTips: (showTips: string) => void + iterTimes: number + setIterTimes: (iterTimes: number) => void + iterParallelLogMap: Map + setIterParallelLogMap: (iterParallelLogMap: Map) => void } export const createWorkflowStore = () => { @@ -281,6 +286,11 @@ export const createWorkflowStore = () => { setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), showTips: '', setShowTips: showTips => set(() => ({ showTips })), + iterTimes: 1, + setIterTimes: iterTimes => set(() => ({ iterTimes })), + iterParallelLogMap: new Map(), + setIterParallelLogMap: iterParallelLogMap => set(() => ({ iterParallelLogMap })), + })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 811ec0d70c..1a57308d28 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -37,7 +37,12 @@ export enum ControlMode { Hand = 'hand', } -export interface Branch { +export enum ErrorHandleMode { + Terminated = 'terminated', + ContinueOnError = 'continue-on-error', + RemoveAbnormalOutput = 'remove-abnormal-output', +} +export type Branch = { id: string name: string } diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 91656e3bbc..aaf333f4d7 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -19,7 +19,7 @@ import type { ToolWithProvider, ValueSelector, } from './types' -import { BlockEnum } from './types' +import { BlockEnum, ErrorHandleMode } from './types' import { CUSTOM_NODE, ITERATION_CHILDREN_Z_INDEX, @@ -267,8 +267,13 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { }) } - if (node.data.type === BlockEnum.Iteration) - node.data._children = iterationNodeMap[node.id] || [] + if (node.data.type === BlockEnum.Iteration) { + const iterationNodeData = node.data as IterationNodeType + iterationNodeData._children = iterationNodeMap[node.id] || [] + iterationNodeData.is_parallel = iterationNodeData.is_parallel || false + iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10 + iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated + } return node }) diff --git a/web/app/signin/normalForm.tsx b/web/app/signin/normalForm.tsx index c0f2d89b37..f4f46c68ba 100644 --- a/web/app/signin/normalForm.tsx +++ b/web/app/signin/normalForm.tsx @@ -12,11 +12,9 @@ import cn from '@/utils/classnames' import { getSystemFeatures, invitationCheck } from '@/service/common' import { defaultSystemFeatures } from '@/types/feature' import Toast from '@/app/components/base/toast' -import useRefreshToken from '@/hooks/use-refresh-token' import { IS_CE_EDITION } from '@/config' const NormalForm = () => { - const { getNewAccessToken } = useRefreshToken() const { t } = useTranslation() const router = useRouter() const searchParams = useSearchParams() @@ -38,7 +36,6 @@ const NormalForm = () => { if (consoleToken && refreshToken) { localStorage.setItem('console_token', consoleToken) localStorage.setItem('refresh_token', refreshToken) - getNewAccessToken() router.replace('/apps') return } @@ -71,7 +68,7 @@ const NormalForm = () => { setSystemFeatures(defaultSystemFeatures) } finally { setIsLoading(false) } - }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, getNewAccessToken]) + }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink]) useEffect(() => { init() }, [init]) diff --git a/web/hooks/use-refresh-token.ts b/web/hooks/use-refresh-token.ts deleted file mode 100644 index 53dc4faf00..0000000000 --- a/web/hooks/use-refresh-token.ts +++ /dev/null @@ -1,99 +0,0 @@ -'use client' -import { useCallback, useEffect, useRef } from 'react' -import { jwtDecode } from 'jwt-decode' -import dayjs from 'dayjs' -import utc from 'dayjs/plugin/utc' -import { useRouter } from 'next/navigation' -import type { CommonResponse } from '@/models/common' -import { fetchNewToken } from '@/service/common' -import { fetchWithRetry } from '@/utils' - -dayjs.extend(utc) - -const useRefreshToken = () => { - const router = useRouter() - const timer = useRef() - const advanceTime = useRef(5 * 60 * 1000) - - const getExpireTime = useCallback((token: string) => { - if (!token) - return 0 - const decoded = jwtDecode(token) - return (decoded.exp || 0) * 1000 - }, []) - - const getCurrentTimeStamp = useCallback(() => { - return dayjs.utc().valueOf() - }, []) - - const handleError = useCallback(() => { - localStorage?.removeItem('is_refreshing') - localStorage?.removeItem('console_token') - localStorage?.removeItem('refresh_token') - router.replace('/signin') - }, []) - - const getNewAccessToken = useCallback(async () => { - const currentAccessToken = localStorage?.getItem('console_token') - const currentRefreshToken = localStorage?.getItem('refresh_token') - if (!currentAccessToken || !currentRefreshToken) { - handleError() - return new Error('No access token or refresh token found') - } - if (localStorage?.getItem('is_refreshing') === '1') { - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, 1000) - return null - } - const currentTokenExpireTime = getExpireTime(currentAccessToken) - if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) { - localStorage?.setItem('is_refreshing', '1') - const [e, res] = await fetchWithRetry(fetchNewToken({ - body: { refresh_token: currentRefreshToken }, - }) as Promise) - if (e) { - handleError() - return e - } - const { access_token, refresh_token } = res.data - localStorage?.setItem('is_refreshing', '0') - localStorage?.setItem('console_token', access_token) - localStorage?.setItem('refresh_token', refresh_token) - const newTokenExpireTime = getExpireTime(access_token) - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) - } - else { - const newTokenExpireTime = getExpireTime(currentAccessToken) - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) - } - return null - }, [getExpireTime, getCurrentTimeStamp, handleError]) - - const handleVisibilityChange = useCallback(() => { - if (document.visibilityState === 'visible') - getNewAccessToken() - }, []) - - useEffect(() => { - window.addEventListener('visibilitychange', handleVisibilityChange) - return () => { - window.removeEventListener('visibilitychange', handleVisibilityChange) - clearTimeout(timer.current) - localStorage?.removeItem('is_refreshing') - } - }, []) - - return { - getNewAccessToken, - } -} - -export default useRefreshToken diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index bde0250fcc..d05070c308 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterationen', currentIteration: 'Aktuelle Iteration', + ErrorMethod: { + operationTerminated: 'beendet', + removeAbnormalOutput: 'remove-abnormale_ausgabe', + continueOnError: 'Fehler "Fortfahren bei"', + }, + MaxParallelismTitle: 'Maximale Parallelität', + parallelMode: 'Paralleler Modus', + errorResponseMethod: 'Methode der Fehlerantwort', + error_one: '{{Anzahl}} Fehler', + error_other: '{{Anzahl}} Irrtümer', + MaxParallelismDesc: 'Die maximale Parallelität wird verwendet, um die Anzahl der Aufgaben zu steuern, die gleichzeitig in einer einzigen Iteration ausgeführt werden.', + parallelPanelDesc: 'Im parallelen Modus unterstützen Aufgaben in der Iteration die parallele Ausführung.', + parallelModeEnableDesc: 'Im parallelen Modus unterstützen Aufgaben innerhalb von Iterationen die parallele Ausführung. Sie können dies im Eigenschaftenbereich auf der rechten Seite konfigurieren.', + answerNodeWarningDesc: 'Warnung im parallelen Modus: Antwortknoten, Zuweisungen von Konversationsvariablen und persistente Lese-/Schreibvorgänge innerhalb von Iterationen können Ausnahmen verursachen.', + parallelModeEnableTitle: 'Paralleler Modus aktiviert', + parallelModeUpper: 'PARALLELER MODUS', + comma: ',', }, note: { editor: { diff --git a/web/i18n/en-US/app-debug.ts b/web/i18n/en-US/app-debug.ts index b2144262f6..e17afc38bf 100644 --- a/web/i18n/en-US/app-debug.ts +++ b/web/i18n/en-US/app-debug.ts @@ -224,6 +224,8 @@ const translation = { description: 'The Code Generator uses configured models to generate high-quality code based on your instructions. Please provide clear and detailed instructions.', instruction: 'Instructions', instructionPlaceholder: 'Enter detailed description of the code you want to generate.', + noDataLine1: 'Describe your use case on the left,', + noDataLine2: 'the code preview will show here.', generate: 'Generate', generatedCodeTitle: 'Generated Code', loading: 'Generating code...', diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 8b5f96453c..7f237b1a49 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterations', currentIteration: 'Current Iteration', + comma: ', ', + error_one: '{{count}} Error', + error_other: '{{count}} Errors', + parallelMode: 'Parallel Mode', + parallelModeUpper: 'PARALLEL MODE', + parallelModeEnableTitle: 'Parallel Mode Enabled', + parallelModeEnableDesc: 'In parallel mode, tasks within iterations support parallel execution. You can configure this in the properties panel on the right.', + parallelPanelDesc: 'In parallel mode, tasks in the iteration support parallel execution.', + MaxParallelismTitle: 'Maximum parallelism', + MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.', + errorResponseMethod: 'Error response method', + ErrorMethod: { + operationTerminated: 'terminated', + continueOnError: 'continue-on-error', + removeAbnormalOutput: 'remove-abnormal-output', + }, + answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.', }, note: { addNote: 'Add Note', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 59a330e7f4..6c9af49c4d 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteración', iteration_other: '{{count}} Iteraciones', currentIteration: 'Iteración actual', + ErrorMethod: { + operationTerminated: 'Terminado', + continueOnError: 'Continuar en el error', + removeAbnormalOutput: 'eliminar-salida-anormal', + }, + comma: ',', + errorResponseMethod: 'Método de respuesta a errores', + error_one: '{{conteo}} Error', + parallelPanelDesc: 'En el modo paralelo, las tareas de la iteración admiten la ejecución en paralelo.', + MaxParallelismTitle: 'Máximo paralelismo', + error_other: '{{conteo}} Errores', + parallelMode: 'Modo paralelo', + parallelModeEnableDesc: 'En el modo paralelo, las tareas dentro de las iteraciones admiten la ejecución en paralelo. Puede configurar esto en el panel de propiedades a la derecha.', + parallelModeUpper: 'MODO PARALELO', + MaxParallelismDesc: 'El paralelismo máximo se utiliza para controlar el número de tareas ejecutadas simultáneamente en una sola iteración.', + answerNodeWarningDesc: 'Advertencia de modo paralelo: Los nodos de respuesta, las asignaciones de variables de conversación y las operaciones de lectura/escritura persistentes dentro de las iteraciones pueden provocar excepciones.', + parallelModeEnableTitle: 'Modo paralelo habilitado', }, note: { addNote: 'Agregar nota', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index b1f9384159..4b00390663 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} تکرار', iteration_other: '{{count}} تکرارها', currentIteration: 'تکرار فعلی', + ErrorMethod: { + continueOnError: 'ادامه در خطا', + operationTerminated: 'فسخ', + removeAbnormalOutput: 'حذف خروجی غیرطبیعی', + }, + error_one: '{{تعداد}} خطا', + error_other: '{{تعداد}} خطاهای', + parallelMode: 'حالت موازی', + errorResponseMethod: 'روش پاسخ به خطا', + parallelModeEnableTitle: 'حالت موازی فعال است', + parallelModeUpper: 'حالت موازی', + comma: ',', + parallelModeEnableDesc: 'در حالت موازی، وظایف درون تکرارها از اجرای موازی پشتیبانی می کنند. می توانید این را در پانل ویژگی ها در سمت راست پیکربندی کنید.', + MaxParallelismTitle: 'حداکثر موازی سازی', + parallelPanelDesc: 'در حالت موازی، وظایف در تکرار از اجرای موازی پشتیبانی می کنند.', + MaxParallelismDesc: 'حداکثر موازی سازی برای کنترل تعداد وظایف اجرا شده به طور همزمان در یک تکرار واحد استفاده می شود.', + answerNodeWarningDesc: 'هشدار حالت موازی: گره های پاسخ، تکالیف متغیر مکالمه و عملیات خواندن/نوشتن مداوم در تکرارها ممکن است باعث استثنائات شود.', }, note: { addNote: 'افزودن یادداشت', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index e56932455f..e736e2cb07 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Itération', iteration_other: '{{count}} Itérations', currentIteration: 'Itération actuelle', + ErrorMethod: { + operationTerminated: 'Terminé', + removeAbnormalOutput: 'remove-abnormal-output', + continueOnError: 'continuer sur l’erreur', + }, + comma: ',', + error_one: '{{compte}} Erreur', + error_other: '{{compte}} Erreurs', + parallelModeEnableDesc: 'En mode parallèle, les tâches au sein des itérations prennent en charge l’exécution parallèle. Vous pouvez le configurer dans le panneau des propriétés à droite.', + parallelModeUpper: 'MODE PARALLÈLE', + parallelPanelDesc: 'En mode parallèle, les tâches de l’itération prennent en charge l’exécution parallèle.', + MaxParallelismDesc: 'Le parallélisme maximal est utilisé pour contrôler le nombre de tâches exécutées simultanément en une seule itération.', + errorResponseMethod: 'Méthode de réponse aux erreurs', + MaxParallelismTitle: 'Parallélisme maximal', + answerNodeWarningDesc: 'Avertissement en mode parallèle : les nœuds de réponse, les affectations de variables de conversation et les opérations de lecture/écriture persistantes au sein des itérations peuvent provoquer des exceptions.', + parallelModeEnableTitle: 'Mode parallèle activé', + parallelMode: 'Mode parallèle', }, note: { addNote: 'Ajouter note', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 1473f78ccd..4112643488 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -577,6 +577,23 @@ const translation = { iteration_one: '{{count}} इटरेशन', iteration_other: '{{count}} इटरेशन्स', currentIteration: 'वर्तमान इटरेशन', + ErrorMethod: { + operationTerminated: 'समाप्त', + continueOnError: 'जारी रखें-पर-त्रुटि', + removeAbnormalOutput: 'निकालें-असामान्य-आउटपुट', + }, + comma: ',', + error_other: '{{गिनती}} त्रुटियों', + error_one: '{{गिनती}} चूक', + parallelMode: 'समानांतर मोड', + parallelModeUpper: 'समानांतर मोड', + errorResponseMethod: 'त्रुटि प्रतिक्रिया विधि', + MaxParallelismTitle: 'अधिकतम समांतरता', + parallelModeEnableTitle: 'समानांतर मोड सक्षम किया गया', + parallelModeEnableDesc: 'समानांतर मोड में, पुनरावृत्तियों के भीतर कार्य समानांतर निष्पादन का समर्थन करते हैं। आप इसे दाईं ओर गुण पैनल में कॉन्फ़िगर कर सकते हैं।', + parallelPanelDesc: 'समानांतर मोड में, पुनरावृत्ति में कार्य समानांतर निष्पादन का समर्थन करते हैं।', + MaxParallelismDesc: 'अधिकतम समांतरता का उपयोग एकल पुनरावृत्ति में एक साथ निष्पादित कार्यों की संख्या को नियंत्रित करने के लिए किया जाता है।', + answerNodeWarningDesc: 'समानांतर मोड चेतावनी: उत्तर नोड्स, वार्तालाप चर असाइनमेंट, और पुनरावृत्तियों के भीतर लगातार पढ़ने/लिखने की कार्रवाई अपवाद पैदा कर सकती है।', }, note: { addNote: 'नोट जोड़ें', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 19fa7bfbb5..756fb665af 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -584,6 +584,23 @@ const translation = { iteration_one: '{{count}} Iterazione', iteration_other: '{{count}} Iterazioni', currentIteration: 'Iterazione Corrente', + ErrorMethod: { + operationTerminated: 'Terminato', + continueOnError: 'continua sull\'errore', + removeAbnormalOutput: 'rimuovi-output-anomalo', + }, + error_one: '{{conteggio}} Errore', + parallelMode: 'Modalità parallela', + MaxParallelismTitle: 'Parallelismo massimo', + error_other: '{{conteggio}} Errori', + parallelModeEnableDesc: 'In modalità parallela, le attività all\'interno delle iterazioni supportano l\'esecuzione parallela. È possibile configurare questa opzione nel pannello delle proprietà a destra.', + MaxParallelismDesc: 'Il parallelismo massimo viene utilizzato per controllare il numero di attività eseguite contemporaneamente in una singola iterazione.', + errorResponseMethod: 'Metodo di risposta all\'errore', + parallelModeEnableTitle: 'Modalità parallela abilitata', + parallelModeUpper: 'MODALITÀ PARALLELA', + comma: ',', + parallelPanelDesc: 'In modalità parallela, le attività nell\'iterazione supportano l\'esecuzione parallela.', + answerNodeWarningDesc: 'Avviso in modalità parallela: i nodi di risposta, le assegnazioni di variabili di conversazione e le operazioni di lettura/scrittura persistenti all\'interno delle iterazioni possono causare eccezioni.', }, note: { addNote: 'Aggiungi Nota', diff --git a/web/i18n/ja-JP/app-debug.ts b/web/i18n/ja-JP/app-debug.ts index 620d9b2f55..05e81a2ae2 100644 --- a/web/i18n/ja-JP/app-debug.ts +++ b/web/i18n/ja-JP/app-debug.ts @@ -224,6 +224,8 @@ const translation = { description: 'コードジェネレーターは、設定されたモデルを使用して指示に基づいて高品質なコードを生成します。明確で詳細な指示を提供してください。', instruction: '指示', instructionPlaceholder: '生成したいコードの詳細な説明を入力してください。', + noDataLine1: '左側に使用例を記入してください,', + noDataLine2: 'コードのプレビューがこちらに表示されます。', generate: '生成', generatedCodeTitle: '生成されたコード', loading: 'コードを生成中...', diff --git a/web/i18n/ja-JP/app.ts b/web/i18n/ja-JP/app.ts index 76c7d1c4f4..48a35c61af 100644 --- a/web/i18n/ja-JP/app.ts +++ b/web/i18n/ja-JP/app.ts @@ -39,10 +39,10 @@ const translation = { workflowWarning: '現在ベータ版です', chatbotType: 'チャットボットのオーケストレーション方法', basic: '基本', - basicTip: '初心者向け。後で Chatflow に切り替えることができます', + basicTip: '初心者向け。後で「チャットフロー」に切り替えることができます', basicFor: '初心者向け', basicDescription: '基本オーケストレートは、組み込みのプロンプトを変更する機能がなく、簡単な設定を使用してチャットボット アプリをオーケストレートします。初心者向けです。', - advanced: 'Chatflow', + advanced: 'チャットフロー', advancedFor: '上級ユーザー向け', advancedDescription: 'ワークフロー オーケストレートは、ワークフロー形式でチャットボットをオーケストレートし、組み込みのプロンプトを編集する機能を含む高度なカスタマイズを提供します。経験豊富なユーザー向けです。', captionName: 'アプリのアイコンと名前', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index b6c7786081..a82ba71e48 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -558,6 +558,23 @@ const translation = { iteration_one: '{{count}} イテレーション', iteration_other: '{{count}} イテレーション', currentIteration: '現在のイテレーション', + ErrorMethod: { + operationTerminated: '終了', + continueOnError: 'エラー時に続行', + removeAbnormalOutput: 'アブノーマルアウトプットの削除', + }, + comma: ',', + error_other: '{{カウント}}エラー', + error_one: '{{カウント}}エラー', + parallelModeUpper: 'パラレルモード', + parallelMode: 'パラレルモード', + MaxParallelismTitle: '最大並列処理', + errorResponseMethod: 'エラー応答方式', + parallelPanelDesc: '並列モードでは、イテレーションのタスクは並列実行をサポートします。', + parallelModeEnableDesc: '並列モードでは、イテレーション内のタスクは並列実行をサポートします。これは、右側のプロパティパネルで構成できます。', + parallelModeEnableTitle: 'パラレルモード有効', + MaxParallelismDesc: '最大並列処理は、1 回の反復で同時に実行されるタスクの数を制御するために使用されます。', + answerNodeWarningDesc: '並列モードの警告: 応答ノード、会話変数の割り当て、およびイテレーション内の永続的な読み取り/書き込み操作により、例外が発生する可能性があります。', }, note: { addNote: 'コメントを追加', diff --git a/web/i18n/ko-KR/common.ts b/web/i18n/ko-KR/common.ts index d2035e7c71..43e7402bd4 100644 --- a/web/i18n/ko-KR/common.ts +++ b/web/i18n/ko-KR/common.ts @@ -169,7 +169,7 @@ const translation = { deleteConfirmTip: '확인하려면 등록된 이메일에서 다음 내용을 로 보내주세요 ', myAccount: '내 계정', studio: '디파이 스튜디오', - account: '계좌', + account: '계정', }, members: { team: '팀', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index b62aff2068..589831401c 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} 반복', iteration_other: '{{count}} 반복', currentIteration: '현재 반복', + ErrorMethod: { + operationTerminated: '종료', + continueOnError: '오류 발생 시 계속', + removeAbnormalOutput: '비정상 출력 제거', + }, + comma: ',', + error_one: '{{개수}} 오류', + parallelMode: '병렬 모드', + errorResponseMethod: '오류 응답 방법', + parallelModeUpper: '병렬 모드', + MaxParallelismTitle: '최대 병렬 처리', + error_other: '{{개수}} 오류', + parallelModeEnableTitle: 'Parallel Mode Enabled(병렬 모드 사용)', + parallelPanelDesc: '병렬 모드에서 반복의 작업은 병렬 실행을 지원합니다.', + parallelModeEnableDesc: '병렬 모드에서는 반복 내의 작업이 병렬 실행을 지원합니다. 오른쪽의 속성 패널에서 이를 구성할 수 있습니다.', + MaxParallelismDesc: '최대 병렬 처리는 단일 반복에서 동시에 실행되는 작업 수를 제어하는 데 사용됩니다.', + answerNodeWarningDesc: '병렬 모드 경고: 응답 노드, 대화 변수 할당 및 반복 내의 지속적인 읽기/쓰기 작업으로 인해 예외가 발생할 수 있습니다.', }, note: { editor: { diff --git a/web/i18n/pl-PL/app-debug.ts b/web/i18n/pl-PL/app-debug.ts index 7cf6c77cb4..cf7232e563 100644 --- a/web/i18n/pl-PL/app-debug.ts +++ b/web/i18n/pl-PL/app-debug.ts @@ -355,7 +355,7 @@ const translation = { openingStatement: { title: 'Wstęp do rozmowy', add: 'Dodaj', - writeOpner: 'Napisz wstęp', + writeOpener: 'Napisz wstęp', placeholder: 'Tutaj napisz swoją wiadomość wprowadzającą, możesz użyć zmiennych, spróbuj wpisać {{variable}}.', openingQuestion: 'Pytania otwierające', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index aace1b2642..f118f7945c 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteracja', iteration_other: '{{count}} Iteracje', currentIteration: 'Bieżąca iteracja', + ErrorMethod: { + continueOnError: 'kontynuacja w przypadku błędu', + operationTerminated: 'Zakończone', + removeAbnormalOutput: 'usuń-nieprawidłowe-wyjście', + }, + comma: ',', + parallelModeUpper: 'TRYB RÓWNOLEGŁY', + parallelModeEnableTitle: 'Włączony tryb równoległy', + MaxParallelismTitle: 'Maksymalna równoległość', + error_one: '{{liczba}} Błąd', + error_other: '{{liczba}} Błędy', + parallelPanelDesc: 'W trybie równoległym zadania w iteracji obsługują wykonywanie równoległe.', + parallelMode: 'Tryb równoległy', + MaxParallelismDesc: 'Maksymalna równoległość służy do kontrolowania liczby zadań wykonywanych jednocześnie w jednej iteracji.', + parallelModeEnableDesc: 'W trybie równoległym zadania w iteracjach obsługują wykonywanie równoległe. Możesz to skonfigurować w panelu właściwości po prawej stronie.', + answerNodeWarningDesc: 'Ostrzeżenie w trybie równoległym: węzły odpowiedzi, przypisania zmiennych konwersacji i trwałe operacje odczytu/zapisu w iteracjach mogą powodować wyjątki.', + errorResponseMethod: 'Metoda odpowiedzi na błąd', }, note: { editor: { diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index f0f2fec0e2..44afda5cd4 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteração', iteration_other: '{{count}} Iterações', currentIteration: 'Iteração atual', + ErrorMethod: { + continueOnError: 'continuar em erro', + removeAbnormalOutput: 'saída anormal de remoção', + operationTerminated: 'Terminada', + }, + MaxParallelismTitle: 'Paralelismo máximo', + parallelModeEnableTitle: 'Modo paralelo ativado', + errorResponseMethod: 'Método de resposta de erro', + error_other: '{{contagem}} Erros', + parallelMode: 'Modo paralelo', + parallelModeUpper: 'MODO PARALELO', + error_one: '{{contagem}} Erro', + parallelModeEnableDesc: 'No modo paralelo, as tarefas dentro das iterações dão suporte à execução paralela. Você pode configurar isso no painel de propriedades à direita.', + comma: ',', + MaxParallelismDesc: 'O paralelismo máximo é usado para controlar o número de tarefas executadas simultaneamente em uma única iteração.', + answerNodeWarningDesc: 'Aviso de modo paralelo: nós de resposta, atribuições de variáveis de conversação e operações persistentes de leitura/gravação em iterações podem causar exceções.', + parallelPanelDesc: 'No modo paralelo, as tarefas na iteração dão suporte à execução paralela.', }, note: { editor: { diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index ab0100d347..d8cd84f730 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iterație', iteration_other: '{{count}} Iterații', currentIteration: 'Iterație curentă', + ErrorMethod: { + operationTerminated: 'Încheiată', + continueOnError: 'continuare-la-eroare', + removeAbnormalOutput: 'elimină-ieșire-anormală', + }, + parallelModeEnableTitle: 'Modul paralel activat', + errorResponseMethod: 'Metoda de răspuns la eroare', + comma: ',', + parallelModeEnableDesc: 'În modul paralel, sarcinile din iterații acceptă execuția paralelă. Puteți configura acest lucru în panoul de proprietăți din dreapta.', + parallelModeUpper: 'MOD PARALEL', + MaxParallelismTitle: 'Paralelism maxim', + parallelMode: 'Mod paralel', + error_other: '{{număr}} Erori', + error_one: '{{număr}} Eroare', + parallelPanelDesc: 'În modul paralel, activitățile din iterație acceptă execuția paralelă.', + MaxParallelismDesc: 'Paralelismul maxim este utilizat pentru a controla numărul de sarcini executate simultan într-o singură iterație.', + answerNodeWarningDesc: 'Avertisment modul paralel: Nodurile de răspuns, atribuirea variabilelor de conversație și operațiunile persistente de citire/scriere în iterații pot cauza excepții.', }, note: { editor: { diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 27735fbb7d..c822f8c3e5 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Итерация', iteration_other: '{{count}} Итераций', currentIteration: 'Текущая итерация', + ErrorMethod: { + operationTerminated: 'Прекращено', + continueOnError: 'продолжить по ошибке', + removeAbnormalOutput: 'удалить аномальный вывод', + }, + comma: ',', + error_other: '{{Количество}} Ошибки', + errorResponseMethod: 'Метод реагирования на ошибку', + MaxParallelismTitle: 'Максимальный параллелизм', + parallelModeUpper: 'ПАРАЛЛЕЛЬНЫЙ РЕЖИМ', + error_one: '{{Количество}} Ошибка', + parallelModeEnableTitle: 'Параллельный режим включен', + parallelMode: 'Параллельный режим', + parallelPanelDesc: 'В параллельном режиме задачи в итерации поддерживают параллельное выполнение.', + parallelModeEnableDesc: 'В параллельном режиме задачи в итерациях поддерживают параллельное выполнение. Вы можете настроить это на панели свойств справа.', + MaxParallelismDesc: 'Максимальный параллелизм используется для управления количеством задач, выполняемых одновременно в одной итерации.', + answerNodeWarningDesc: 'Предупреждение о параллельном режиме: узлы ответов, присвоение переменных диалога и постоянные операции чтения и записи в итерациях могут вызывать исключения.', }, note: { addNote: 'Добавить заметку', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 82718ebc03..e6e25f6d0e 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -558,6 +558,23 @@ const translation = { iteration_one: '{{count}} Yineleme', iteration_other: '{{count}} Yineleme', currentIteration: 'Mevcut Yineleme', + ErrorMethod: { + operationTerminated: 'Sonlandırıldı', + continueOnError: 'Hata Üzerine Devam Et', + removeAbnormalOutput: 'anormal çıktıyı kaldır', + }, + parallelModeUpper: 'PARALEL MOD', + parallelMode: 'Paralel Mod', + MaxParallelismTitle: 'Maksimum paralellik', + error_one: '{{sayı}} Hata', + errorResponseMethod: 'Hata yanıtı yöntemi', + comma: ',', + parallelModeEnableTitle: 'Paralel Mod Etkin', + error_other: '{{sayı}} Hata', + parallelPanelDesc: 'Paralel modda, yinelemedeki görevler paralel yürütmeyi destekler.', + answerNodeWarningDesc: 'Paralel mod uyarısı: Yinelemeler içindeki yanıt düğümleri, konuşma değişkeni atamaları ve kalıcı okuma/yazma işlemleri özel durumlara neden olabilir.', + parallelModeEnableDesc: 'Paralel modda, yinelemeler içindeki görevler paralel yürütmeyi destekler. Bunu sağdaki özellikler panelinde yapılandırabilirsiniz.', + MaxParallelismDesc: 'Maksimum paralellik, tek bir yinelemede aynı anda yürütülen görevlerin sayısını kontrol etmek için kullanılır.', }, note: { addNote: 'Not Ekle', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 1828b6499f..663b5e4c13 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Ітерація', iteration_other: '{{count}} Ітерацій', currentIteration: 'Поточна ітерація', + ErrorMethod: { + operationTerminated: 'Припинено', + continueOnError: 'Продовжити після помилки', + removeAbnormalOutput: 'видалити-ненормальний-вивід', + }, + error_one: '{{count}} Помилка', + comma: ',', + MaxParallelismTitle: 'Максимальна паралельність', + parallelModeUpper: 'ПАРАЛЕЛЬНИЙ РЕЖИМ', + error_other: '{{count}} Помилки', + parallelMode: 'Паралельний режим', + parallelModeEnableTitle: 'Увімкнено паралельний режим', + errorResponseMethod: 'Метод реагування на помилку', + parallelPanelDesc: 'У паралельному режимі завдання в ітерації підтримують паралельне виконання.', + parallelModeEnableDesc: 'У паралельному режимі завдання всередині ітерацій підтримують паралельне виконання. Ви можете налаштувати це на панелі властивостей праворуч.', + MaxParallelismDesc: 'Максимальний паралелізм використовується для контролю числа завдань, що виконуються одночасно за одну ітерацію.', + answerNodeWarningDesc: 'Попередження в паралельному режимі: вузли відповідей, призначення змінних розмови та постійні операції читання/запису в межах ітерацій можуть спричинити винятки.', }, note: { editor: { diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 2866af8a2a..1176fdd2b5 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Lặp', iteration_other: '{{count}} Lặp', currentIteration: 'Lặp hiện tại', + ErrorMethod: { + operationTerminated: 'Chấm dứt', + removeAbnormalOutput: 'loại bỏ-bất thường-đầu ra', + continueOnError: 'Tiếp tục lỗi', + }, + comma: ',', + error_other: '{{đếm}} Lỗi', + error_one: '{{đếm}} Lỗi', + MaxParallelismTitle: 'Song song tối đa', + parallelPanelDesc: 'Ở chế độ song song, các tác vụ trong quá trình lặp hỗ trợ thực thi song song.', + parallelMode: 'Chế độ song song', + parallelModeEnableTitle: 'Đã bật Chế độ song song', + errorResponseMethod: 'Phương pháp phản hồi lỗi', + MaxParallelismDesc: 'Tính song song tối đa được sử dụng để kiểm soát số lượng tác vụ được thực hiện đồng thời trong một lần lặp.', + answerNodeWarningDesc: 'Cảnh báo chế độ song song: Các nút trả lời, bài tập biến hội thoại và các thao tác đọc/ghi liên tục trong các lần lặp có thể gây ra ngoại lệ.', + parallelModeEnableDesc: 'Trong chế độ song song, các tác vụ trong các lần lặp hỗ trợ thực thi song song. Bạn có thể định cấu hình điều này trong bảng thuộc tính ở bên phải.', + parallelModeUpper: 'CHẾ ĐỘ SONG SONG', }, note: { editor: { diff --git a/web/i18n/zh-Hans/app-debug.ts b/web/i18n/zh-Hans/app-debug.ts index 3e801bcf62..9e21945755 100644 --- a/web/i18n/zh-Hans/app-debug.ts +++ b/web/i18n/zh-Hans/app-debug.ts @@ -224,6 +224,8 @@ const translation = { description: '代码生成器使用配置的模型根据您的指令生成高质量的代码。请提供清晰详细的说明。', instruction: '指令', instructionPlaceholder: '请输入您想要生成的代码的详细描述。', + noDataLine1: '在左侧描述您的用例,', + noDataLine2: '代码预览将在此处显示。', generate: '生成', generatedCodeTitle: '生成的代码', loading: '正在生成代码...', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 519f25d34e..6d574bc6f5 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}}个迭代', iteration_other: '{{count}}个迭代', currentIteration: '当前迭代', + comma: ',', + error_one: '{{count}}个失败', + error_other: '{{count}}个失败', + parallelMode: '并行模式', + parallelModeUpper: '并行模式', + parallelModeEnableTitle: '并行模式启用', + parallelModeEnableDesc: '启用并行模式时迭代内的任务支持并行执行。你可以在右侧的属性面板中进行配置。', + parallelPanelDesc: '在并行模式下,迭代中的任务支持并行执行。', + MaxParallelismTitle: '最大并行度', + MaxParallelismDesc: '最大并行度用于控制单次迭代中同时执行的任务数量。', + errorResponseMethod: '错误响应方法', + ErrorMethod: { + operationTerminated: '错误时终止', + continueOnError: '忽略错误并继续', + removeAbnormalOutput: '移除错误输出', + }, + answerNodeWarningDesc: '并行模式警告:在迭代中,回答节点、会话变量赋值和工具持久读/写操作可能会导致异常。', }, note: { addNote: '添加注释', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index d65b3999d2..f3fbfdedc2 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}}個迭代', iteration_other: '{{count}}個迭代', currentIteration: '當前迭代', + ErrorMethod: { + operationTerminated: '終止', + removeAbnormalOutput: 'remove-abnormal-output', + continueOnError: '出錯時繼續', + }, + comma: ',', + parallelMode: '並行模式', + parallelModeEnableTitle: 'Parallel Mode 已啟用', + MaxParallelismTitle: '最大並行度', + parallelModeUpper: '並行模式', + parallelPanelDesc: '在並行模式下,反覆運算中的任務支援並行執行。', + error_one: '{{count}}錯誤', + errorResponseMethod: '錯誤回應方法', + parallelModeEnableDesc: '在並行模式下,反覆運算中的任務支援並行執行。您可以在右側的 properties 面板中進行配置。', + answerNodeWarningDesc: '並行模式警告:反覆運算中的應答節點、對話變數賦值和持久讀/寫操作可能會導致異常。', + error_other: '{{count}}錯誤', + MaxParallelismDesc: '最大並行度用於控制在單個反覆運算中同時執行的任務數。', }, note: { editor: { diff --git a/web/models/common.ts b/web/models/common.ts index bb694385ef..48bdc8ae44 100644 --- a/web/models/common.ts +++ b/web/models/common.ts @@ -216,7 +216,7 @@ export interface FileUploadConfigResponse { file_size_limit: number // default is 15MB audio_file_size_limit?: number // default is 50MB video_file_size_limit?: number // default is 100MB - + workflow_file_upload_limit?: number // default is 10 } export type InvitationResult = { diff --git a/web/package.json b/web/package.json index 471a720fba..8d69bbc209 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.10.2", + "version": "0.11.0", "private": true, "engines": { "node": ">=18.17.0" diff --git a/web/service/base.ts b/web/service/base.ts index 8efb97cff3..e1a04217c7 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -1,4 +1,5 @@ import { API_PREFIX, IS_CE_EDITION, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config' +import { refreshAccessTokenOrRelogin } from './refresh-token' import Toast from '@/app/components/base/toast' import type { AnnotationReply, MessageEnd, MessageReplace, ThoughtItem } from '@/app/components/base/chat/chat/type' import type { VisionFile } from '@/types/app' @@ -368,42 +369,8 @@ const baseFetch = ( if (!/^(2|3)\d{2}$/.test(String(res.status))) { const bodyJson = res.json() switch (res.status) { - case 401: { - if (isMarketplaceAPI) - return - - if (isPublicAPI) { - return bodyJson.then((data: ResponseError) => { - if (data.code === 'web_sso_auth_required') - requiredWebSSOLogin() - - if (data.code === 'unauthorized') { - removeAccessToken() - globalThis.location.reload() - } - - return Promise.reject(data) - }) - } - const loginUrl = `${globalThis.location.origin}/signin` - bodyJson.then((data: ResponseError) => { - if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent) - Toast.notify({ type: 'error', message: data.message, duration: 4000 }) - else if (data.code === 'not_init_validated' && IS_CE_EDITION) - globalThis.location.href = `${globalThis.location.origin}/init` - else if (data.code === 'not_setup' && IS_CE_EDITION) - globalThis.location.href = `${globalThis.location.origin}/install` - else if (location.pathname !== '/signin' || !IS_CE_EDITION) - globalThis.location.href = loginUrl - else if (!silent) - Toast.notify({ type: 'error', message: data.message }) - }).catch(() => { - // Handle any other errors - globalThis.location.href = loginUrl - }) - - break - } + case 401: + return Promise.reject(resClone) case 403: bodyJson.then((data: ResponseError) => { if (!silent) @@ -499,7 +466,9 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search export const ssePost = ( url: string, fetchOptions: FetchOptionType, - { + otherOptions: IOtherOptions, +) => { + const { isPublicAPI = false, onData, onCompleted, @@ -522,8 +491,7 @@ export const ssePost = ( onTextReplace, onError, getAbortController, - }: IOtherOptions, -) => { + } = otherOptions const abortController = new AbortController() const options = Object.assign({}, baseOptions, { @@ -547,21 +515,29 @@ export const ssePost = ( globalThis.fetch(urlWithPrefix, options as RequestInit) .then((res) => { if (!/^(2|3)\d{2}$/.test(String(res.status))) { - res.json().then((data: any) => { - if (isPublicAPI) { - if (data.code === 'web_sso_auth_required') - requiredWebSSOLogin() + if (res.status === 401) { + refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + ssePost(url, fetchOptions, otherOptions) + }).catch(() => { + res.json().then((data: any) => { + if (isPublicAPI) { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() - if (data.code === 'unauthorized') { - removeAccessToken() - globalThis.location.reload() - } - if (res.status === 401) - return - } - Toast.notify({ type: 'error', message: data.message || 'Server Error' }) - }) - onError?.('Server Error') + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + } + }) + }) + } + else { + res.json().then((data) => { + Toast.notify({ type: 'error', message: data.message || 'Server Error' }) + }) + onError?.('Server Error') + } return } return handleStream(res, (str: string, isFirstMessage: boolean, moreInfo: IOnDataMoreInfo) => { @@ -584,7 +560,54 @@ export const ssePost = ( // base request export const request = (url: string, options = {}, otherOptions?: IOtherOptions) => { - return baseFetch(url, options, otherOptions || {}) + return new Promise((resolve, reject) => { + const otherOptionsForBaseFetch = otherOptions || {} + baseFetch(url, options, otherOptionsForBaseFetch).then(resolve).catch((errResp) => { + if (errResp?.status === 401) { + return refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + baseFetch(url, options, otherOptionsForBaseFetch).then(resolve).catch(reject) + }).catch(() => { + const { + isPublicAPI = false, + silent, + } = otherOptionsForBaseFetch + const bodyJson = errResp.json() + if (isPublicAPI) { + return bodyJson.then((data: ResponseError) => { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() + + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + + return Promise.reject(data) + }) + } + const loginUrl = `${globalThis.location.origin}/signin` + bodyJson.then((data: ResponseError) => { + if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent) + Toast.notify({ type: 'error', message: data.message, duration: 4000 }) + else if (data.code === 'not_init_validated' && IS_CE_EDITION) + globalThis.location.href = `${globalThis.location.origin}/init` + else if (data.code === 'not_setup' && IS_CE_EDITION) + globalThis.location.href = `${globalThis.location.origin}/install` + else if (location.pathname !== '/signin' || !IS_CE_EDITION) + globalThis.location.href = loginUrl + else if (!silent) + Toast.notify({ type: 'error', message: data.message }) + }).catch(() => { + // Handle any other errors + globalThis.location.href = loginUrl + }) + }) + } + else { + reject(errResp) + } + }) + }) } // request methods diff --git a/web/service/common.ts b/web/service/common.ts index 70586b6ff6..01b3a60991 100644 --- a/web/service/common.ts +++ b/web/service/common.ts @@ -320,9 +320,10 @@ export const verifyForgotPasswordToken: Fetcher = ({ url, body }) => post(url, { body }) -export const fetchRemoteFileInfo = (url: string) => { - return get<{ file_type: string; file_length: number }>(`/remote-files/${url}`) +export const uploadRemoteFileInfo = (url: string, isPublic?: boolean) => { + return post<{ id: string; name: string; size: number; mime_type: string; url: string }>('/remote-files/upload', { body: { url } }, { isPublicAPI: isPublic }) } + export const sendEMailLoginCode = (email: string, language = 'en-US') => post('/email-code-login', { body: { email, language } }) diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts new file mode 100644 index 0000000000..8bd2215041 --- /dev/null +++ b/web/service/refresh-token.ts @@ -0,0 +1,75 @@ +import { apiPrefix } from '@/config' +import { fetchWithRetry } from '@/utils' + +let isRefreshing = false +function waitUntilTokenRefreshed() { + return new Promise((resolve, reject) => { + function _check() { + const isRefreshingSign = localStorage.getItem('is_refreshing') + if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { + setTimeout(() => { + _check() + }, 1000) + } + else { + resolve() + } + } + _check() + }) +} + +// only one request can send +async function getNewAccessToken(): Promise { + try { + const isRefreshingSign = localStorage.getItem('is_refreshing') + if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { + await waitUntilTokenRefreshed() + } + else { + globalThis.localStorage.setItem('is_refreshing', '1') + isRefreshing = true + const refresh_token = globalThis.localStorage.getItem('refresh_token') + + // Do not use baseFetch to refresh tokens. + // If a 401 response occurs and baseFetch itself attempts to refresh the token, + // it can lead to an infinite loop if the refresh attempt also returns 401. + // To avoid this, handle token refresh separately in a dedicated function + // that does not call baseFetch and uses a single retry mechanism. + const [error, ret] = await fetchWithRetry(globalThis.fetch(`${apiPrefix}/refresh-token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json;utf-8', + }, + body: JSON.stringify({ refresh_token }), + })) + if (error) { + return Promise.reject(error) + } + else { + if (ret.status === 401) + return Promise.reject(ret) + + const { data } = await ret.json() + globalThis.localStorage.setItem('console_token', data.access_token) + globalThis.localStorage.setItem('refresh_token', data.refresh_token) + } + } + } + catch (error) { + console.error(error) + return Promise.reject(error) + } + finally { + isRefreshing = false + globalThis.localStorage.removeItem('is_refreshing') + } +} + +export async function refreshAccessTokenOrRelogin(timeout: number) { + return Promise.race([new Promise((resolve, reject) => setTimeout(() => { + isRefreshing = false + globalThis.localStorage.removeItem('is_refreshing') + reject(new Error('request timeout')) + }, timeout)), getNewAccessToken()]) +} diff --git a/web/themes/dark.css b/web/themes/dark.css index 3440a1a7a8..fad5c02559 100644 --- a/web/themes/dark.css +++ b/web/themes/dark.css @@ -244,6 +244,8 @@ html[data-theme="dark"] { --color-components-Avatar-default-avatar-bg: #222225; + --color-components-label-gray: #C8CEDA24; + --color-text-primary: #FBFBFC; --color-text-secondary: #D9D9DE; --color-text-tertiary: #C8CEDA99; diff --git a/web/themes/light.css b/web/themes/light.css index 717226e462..ecc1930360 100644 --- a/web/themes/light.css +++ b/web/themes/light.css @@ -244,6 +244,8 @@ html[data-theme="light"] { --color-components-Avatar-default-avatar-bg: #D0D5DC; + --color-components-label-gray: #F2F4F7; + --color-text-primary: #101828; --color-text-secondary: #354052; --color-text-tertiary: #676F83; diff --git a/web/themes/tailwind-theme-var-define.ts b/web/themes/tailwind-theme-var-define.ts index 643c96d1a1..9d17c361f8 100644 --- a/web/themes/tailwind-theme-var-define.ts +++ b/web/themes/tailwind-theme-var-define.ts @@ -244,6 +244,8 @@ const vars = { 'components-Avatar-default-avatar-bg': 'var(--color-components-Avatar-default-avatar-bg)', + 'components-label-gray': 'var(--color-components-label-gray)', + 'text-primary': 'var(--color-text-primary)', 'text-secondary': 'var(--color-text-secondary)', 'text-tertiary': 'var(--color-text-tertiary)', diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 28a8bd627e..8c0d81639d 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -19,6 +19,7 @@ export interface NodeTracing { process_data: any outputs?: any status: string + parallel_run_id?: string error?: string elapsed_time: number execution_metadata: { @@ -31,6 +32,7 @@ export interface NodeTracing { parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + parallel_mode_run_id?: string } metadata: { iterator_length: number @@ -121,6 +123,7 @@ export interface NodeStartedResponse { id: string node_id: string iteration_id?: string + parallel_run_id?: string node_type: string index: number predecessor_node_id?: string @@ -166,6 +169,7 @@ export interface NodeFinishedResponse { parallel_start_node_id?: string iteration_index?: number iteration_id?: string + parallel_mode_run_id: string } created_at: number files?: FileResponse[] @@ -200,6 +204,7 @@ export interface IterationNextResponse { output: any extras?: any created_at: number + parallel_mode_run_id: string execution_metadata: { parallel_id?: string }