Merge branch 'main' into feat/knowledge-dark-mode
This commit is contained in:
commit
49674507c6
2
.github/actions/setup-poetry/action.yml
vendored
2
.github/actions/setup-poetry/action.yml
vendored
@ -8,7 +8,7 @@ inputs:
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
default: '1.8.4'
|
||||
default: '2.0.1'
|
||||
poetry-lockfile:
|
||||
description: Path to the Poetry lockfile to restore cache from
|
||||
required: true
|
||||
|
16
.github/workflows/api-tests.yml
vendored
16
.github/workflows/api-tests.yml
vendored
@ -42,25 +42,23 @@ jobs:
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Check dependencies in pyproject.toml
|
||||
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
pushd api
|
||||
poetry run python -m mypy --install-types --non-interactive .
|
||||
popd
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
@ -80,4 +78,4 @@ jobs:
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
|
||||
|
6
.github/workflows/style.yml
vendored
6
.github/workflows/style.yml
vendored
@ -38,12 +38,12 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
poetry run -C api ruff check ./
|
||||
poetry run -C api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
|
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -70,4 +70,4 @@ jobs:
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
|
@ -53,10 +53,12 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
|
@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.4
|
||||
ENV POETRY_VERSION=2.0.1
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
@ -79,5 +79,5 @@
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
poetry run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
@ -146,7 +146,7 @@ class EndpointConfig(BaseSettings):
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
|
||||
description="Base URL for the console web interface,used for frontend references and CORS configuration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
@ -181,7 +181,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
|
||||
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.15.0",
|
||||
default="0.15.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
|
@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
@ -98,9 +98,13 @@ class ChatMessageTextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
@ -10,12 +10,7 @@ from controllers.console import api
|
||||
from controllers.console.apikey import api_key_fields, api_key_list
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -98,7 +93,6 @@ class DatasetListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
@ -213,7 +207,6 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -317,7 +310,6 @@ class DatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
@ -465,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -627,8 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
@ -653,6 +644,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
@ -27,7 +27,6 @@ from controllers.console.datasets.error import (
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
@ -231,7 +230,6 @@ class DatasetDocumentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
@ -287,7 +285,6 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -311,7 +308,6 @@ class DatasetInitApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
@ -354,8 +350,7 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -530,8 +525,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -684,7 +678,6 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
@ -721,7 +714,6 @@ class DocumentDeleteApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
@ -790,7 +782,6 @@ class DocumentStatusApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -886,7 +877,6 @@ class DocumentPauseApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -919,7 +909,6 @@ class DocumentRecoverApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -949,7 +938,6 @@ class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
|
@ -19,7 +19,6 @@ from controllers.console.datasets.error import (
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
@ -107,7 +106,6 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -139,7 +137,6 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -171,8 +168,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -195,7 +191,6 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -221,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -246,7 +240,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -272,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -307,7 +299,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -345,7 +336,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -412,7 +402,6 @@ class ChildChunkAddApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -445,8 +434,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -511,7 +499,6 @@ class ChildChunkAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -555,7 +542,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@ -600,7 +586,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
|
@ -2,11 +2,7 @@ from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
|
||||
|
||||
@ -14,7 +10,6 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import wraps
|
||||
|
||||
from flask import abort, request
|
||||
@ -8,7 +7,6 @@ from flask_login import current_user # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
@ -68,9 +66,7 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
|
||||
)
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
@ -115,33 +111,6 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{current_user.current_tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
|
||||
request_count = redis_client.zcard(key)
|
||||
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
abort(
|
||||
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
@ -53,8 +53,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -95,8 +94,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -175,8 +173,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
@ -14,7 +13,6 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.model import ApiToken, App, EndUser
|
||||
@ -141,35 +139,6 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
||||
if resource == "knowledge":
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{api_token.tenant_id}"
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
|
||||
request_count = redis_client.zcard(key)
|
||||
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
raise Forbidden(
|
||||
"Sorry, you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
|
||||
def validate_dataset_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
@ -226,7 +195,11 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
|
@ -172,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
|
@ -167,8 +167,7 @@ class AppQueueManager:
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
|
@ -89,6 +89,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == "normal",
|
||||
Conversation.is_deleted.is_(False),
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
|
@ -145,7 +145,7 @@ class MessageCycleManage:
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
extension = f".{message_file.url.split('.')[-1]}"
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
|
@ -62,8 +62,9 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid".format(self.variable)
|
||||
"[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
|
||||
self.variable
|
||||
)
|
||||
)
|
||||
|
||||
# decrypt api_key
|
||||
|
@ -90,7 +90,7 @@ class File(BaseModel):
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
text = f""
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
|
@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
@ -108,7 +108,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
|
||||
|
||||
if not ai_model_entity:
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
|
||||
|
||||
try:
|
||||
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
|
||||
|
@ -130,7 +130,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
raise CredentialsValidateFailedError("Base Model Name is required")
|
||||
|
||||
if not self._get_ai_model_entity(credentials["base_model_name"], model):
|
||||
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
|
||||
raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
|
||||
|
||||
try:
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
@ -70,7 +70,7 @@ class BedrockRerankModel(RerankModel):
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
"bedrockRerankingConfiguration": {
|
||||
"numberOfResults": top_n,
|
||||
"numberOfResults": min(top_n, len(text_sources)),
|
||||
"modelConfiguration": {
|
||||
"modelArn": model_package_arn,
|
||||
},
|
||||
|
@ -1,2 +1,3 @@
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
- deepseek-reasoner
|
||||
|
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
@ -10,7 +10,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
@ -0,0 +1,21 @@
|
||||
model: deepseek-reasoner
|
||||
label:
|
||||
zh_Hans: deepseek-reasoner
|
||||
en_US: deepseek-reasoner
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 8192
|
||||
default: 4096
|
||||
pricing:
|
||||
input: "4"
|
||||
output: "16"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
@ -24,9 +24,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
@ -1,5 +1,6 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
|
@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-01-21
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 01-21
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
@ -162,9 +162,9 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
|
||||
@staticmethod
|
||||
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
|
||||
try:
|
||||
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
|
||||
url = f"{HUGGINGFACE_ENDPOINT_API}{credentials['huggingface_namespace']}"
|
||||
headers = {
|
||||
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}',
|
||||
"Authorization": f"Bearer {credentials['huggingfacehub_api_token']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
@ -34,6 +34,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
|
||||
|
||||
class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
model_apis = {
|
||||
"minimax-text-01": MinimaxChatCompletionPro,
|
||||
"abab7-chat-preview": MinimaxChatCompletionPro,
|
||||
"abab6.5t-chat": MinimaxChatCompletionPro,
|
||||
"abab6.5s-chat": MinimaxChatCompletionPro,
|
||||
|
@ -0,0 +1,46 @@
|
||||
model: minimax-text-01
|
||||
label:
|
||||
en_US: Minimax-Text-01
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1000192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0.01
|
||||
max: 1
|
||||
default: 0.95
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
min: 1
|
||||
max: 1000192
|
||||
- name: mask_sensitive_info
|
||||
type: boolean
|
||||
default: true
|
||||
label:
|
||||
zh_Hans: 隐私保护
|
||||
en_US: Moderate
|
||||
help:
|
||||
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码,目前包括但不限于邮箱、域名、链接、证件号、家庭住址等,默认true,即开启打码
|
||||
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.008'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -44,9 +44,6 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
self._add_custom_parameters(credentials)
|
||||
self._add_function_call(model, credentials)
|
||||
user = user[:32] if user else None
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
|
||||
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
|
||||
if stream:
|
||||
block_as_stream = True
|
||||
stream = False
|
||||
if "stream_options" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stream_options"]
|
||||
|
||||
if "stop" in extra_model_kwargs:
|
||||
del extra_model_kwargs["stop"]
|
||||
|
||||
@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
||||
|
||||
if block_as_stream:
|
||||
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
||||
|
||||
return block_result
|
||||
|
||||
def _handle_chat_block_as_stream_response(
|
||||
self,
|
||||
block_result: LLMResult,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Handle llm chat response
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
text = block_result.message.content
|
||||
text = cast(str, text)
|
||||
|
||||
if stop:
|
||||
text = self.enforce_stop_tokens(text, stop)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=block_result.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=block_result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=block_result.message,
|
||||
finish_reason="stop",
|
||||
usage=block_result.usage,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
|
@ -29,9 +29,6 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
|
||||
if "response_format" in model_parameters:
|
||||
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
@ -21,7 +21,7 @@ class SparkLLMClient:
|
||||
domain = api_domain
|
||||
|
||||
model_api_configs = {
|
||||
"spark-lite": {"version": "v1.1", "chat_domain": "general"},
|
||||
"spark-lite": {"version": "v1.1", "chat_domain": "lite"},
|
||||
"spark-pro": {"version": "v3.1", "chat_domain": "generalv3"},
|
||||
"spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"},
|
||||
"spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"},
|
||||
|
@ -257,8 +257,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
||||
for index, response in enumerate(responses):
|
||||
if response.status_code not in {200, HTTPStatus.OK}:
|
||||
raise ServiceUnavailableError(
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, "
|
||||
f"message: {response.message}"
|
||||
f"Failed to invoke model {model}, status code: {response.status_code}, message: {response.message}"
|
||||
)
|
||||
|
||||
resp_finish_reason = response.output.choices[0].finish_reason
|
||||
|
@ -146,7 +146,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
|
@ -41,15 +41,15 @@ class BaiduAccessToken:
|
||||
resp = response.json()
|
||||
if "error" in resp:
|
||||
if resp["error"] == "invalid_client":
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
raise InvalidAPIKeyError(f"Invalid API key or secret key: {resp['error_description']}")
|
||||
elif resp["error"] == "unknown_error":
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
raise InternalServerError(f"Internal server error: {resp['error_description']}")
|
||||
elif resp["error"] == "invalid_request":
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
raise BadRequestError(f"Bad request: {resp['error_description']}")
|
||||
elif resp["error"] == "rate_limit_exceeded":
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
raise RateLimitReachedError(f"Rate limit reached: {resp['error_description']}")
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
raise Exception(f"Unknown error: {resp['error_description']}")
|
||||
|
||||
return resp["access_token"]
|
||||
|
||||
|
@ -406,7 +406,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif credentials["completion_type"] == "completion":
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
|
||||
raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
|
||||
else:
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(
|
||||
server_url=credentials["server_url"],
|
||||
@ -472,7 +472,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
api_key = credentials.get("api_key") or "abc"
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
base_url=f"{credentials['server_url']}/v1",
|
||||
api_key=api_key,
|
||||
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
|
||||
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),
|
||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
class TracingProviderEnum(Enum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
@ -56,5 +57,36 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
return v
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "Default Project"
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://www.comet.com/opik/api/"
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("url must start with https:// or http://")
|
||||
if not v.endswith("/api/"):
|
||||
raise ValueError("url should ends with /api/")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
0
api/core/ops/opik_trace/__init__.py
Normal file
0
api/core/ops/opik_trace/__init__.py
Normal file
469
api/core/ops/opik_trace/opik_trace.py
Normal file
469
api/core/ops/opik_trace/opik_trace.py
Normal file
@ -0,0 +1,469 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_dict(key_name, data):
|
||||
"""Make sure that the input data is a dict"""
|
||||
if not isinstance(data, dict):
|
||||
return {key_name: data}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def wrap_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all Traces and Spans"""
|
||||
metadata["created_from"] = "dify"
|
||||
|
||||
metadata.update(kwargs)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
objects start_time and message_id could be null which means we cannot map
|
||||
it to a UUIDv7. Given that we have no way to identify that object
|
||||
uniquely, generate a new random one UUIDv7 in that case.
|
||||
"""
|
||||
|
||||
if user_datetime is None:
|
||||
user_datetime = datetime.now()
|
||||
|
||||
if user_uuid is None:
|
||||
user_uuid = str(uuid.uuid4())
|
||||
|
||||
return uuid4_to_uuid7(user_datetime, user_uuid)
|
||||
|
||||
|
||||
class OpikDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
opik_config: OpikConfig,
|
||||
):
|
||||
super().__init__(opik_config)
|
||||
self.opik_client = Opik(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
self.project = opik_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
dify_trace_id = trace_info.workflow_run_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
)
|
||||
root_span_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = trace_info.message_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["message", "workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
metadata = execution_metadata.copy()
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
total_tokens = 0
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = "llm"
|
||||
provider = process_data.get("model_provider", None)
|
||||
model = process_data.get("model_name", "")
|
||||
metadata.update(
|
||||
{
|
||||
"ls_provider": provider,
|
||||
"ls_model_name": model,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
total_tokens = outputs["usage"].get("total_tokens", 0)
|
||||
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
|
||||
completion_tokens = outputs["usage"].get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
else:
|
||||
run_type = "tool"
|
||||
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
||||
"name": node_type,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
"end_time": finished_at,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": wrap_dict("input", inputs),
|
||||
"output": wrap_dict("output", outputs),
|
||||
"tags": ["node_execution"],
|
||||
"project_name": self.project,
|
||||
"usage": {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
metadata = trace_info.metadata
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, message_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": message_data.answer,
|
||||
"tags": ["message", str(trace_info.conversation_mode)],
|
||||
"project_name": self.project,
|
||||
}
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": "llm",
|
||||
"type": "llm",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": {"input": trace_info.inputs},
|
||||
"output": {"output": message_data.answer},
|
||||
"tags": ["llm", str(trace_info.conversation_mode)],
|
||||
"usage": {
|
||||
"completion_tokens": trace_info.answer_tokens,
|
||||
"prompt_tokens": trace_info.message_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
},
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
"tags": ["moderation"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.suggested_question),
|
||||
"tags": ["suggested_question"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {"documents": trace_info.documents},
|
||||
"tags": ["dataset_retrieval"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": trace_info.tool_name,
|
||||
"type": "tool",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.tool_inputs),
|
||||
"output": wrap_dict("output", trace_info.tool_outputs),
|
||||
"tags": ["tool", trace_info.tool_name],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": trace_info.outputs,
|
||||
"tags": ["generate_name"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.outputs),
|
||||
"tags": ["generate_name"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
return trace
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create span: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
self.opik_client.auth_check()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.opik_client.get_project_url(project_name=self.project)
|
||||
except Exception as e:
|
||||
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik get run url failed: {str(e)}")
|
@ -17,6 +17,7 @@ from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
@ -32,6 +33,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@ -52,6 +54,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
},
|
||||
TracingProviderEnum.OPIK.value: {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
@ -22,7 +22,12 @@ from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@ -835,11 +840,18 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
else:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
model_settings: list[ModelSettings] = []
|
||||
if not provider_model_settings:
|
||||
|
@ -31,7 +31,7 @@ class FirecrawlApp:
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
|
||||
raise Exception(f"Failed to scrape URL. Error: {response_data['error']}")
|
||||
|
||||
elif response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
|
@ -358,8 +358,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} "
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
@ -127,7 +127,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to create task: {response.get("msg")}')
|
||||
raise Exception(f"Failed to create task: {response.get('msg')}")
|
||||
|
||||
return response.get("data", {}).get("id")
|
||||
|
||||
@ -222,7 +222,7 @@ class AIPPTGenerateToolAdapter:
|
||||
elif model == "wenxin":
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate content: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate content: {response.get('msg')}")
|
||||
|
||||
return response.get("data", "")
|
||||
|
||||
@ -254,7 +254,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
id = response.get("data", {}).get("id")
|
||||
cover_url = response.get("data", {}).get("cover_url")
|
||||
@ -270,7 +270,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
export_code = response.get("data")
|
||||
if not export_code:
|
||||
@ -290,7 +290,7 @@ class AIPPTGenerateToolAdapter:
|
||||
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to generate ppt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to generate ppt: {response.get('msg')}")
|
||||
|
||||
if response.get("msg") == "导出中":
|
||||
current_iteration += 1
|
||||
@ -343,7 +343,7 @@ class AIPPTGenerateToolAdapter:
|
||||
raise Exception(f"Failed to connect to aippt: {response.text}")
|
||||
response = response.json()
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
token = response.get("data", {}).get("token")
|
||||
expire = response.get("data", {}).get("time_expire")
|
||||
@ -379,7 +379,7 @@ class AIPPTGenerateToolAdapter:
|
||||
if cls._style_cache[key]["expire"] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
key = f"{credentials['aippt_access_key']}#@#{user_id}"
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
|
||||
|
||||
@ -396,11 +396,11 @@ class AIPPTGenerateToolAdapter:
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
colors = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"id": f"id-{item.get('id')}",
|
||||
"name": item.get("name"),
|
||||
"en_name": item.get("en_name", item.get("name")),
|
||||
}
|
||||
@ -408,7 +408,7 @@ class AIPPTGenerateToolAdapter:
|
||||
]
|
||||
styles = [
|
||||
{
|
||||
"id": f'id-{item.get("id")}',
|
||||
"id": f"id-{item.get('id')}",
|
||||
"name": item.get("title"),
|
||||
}
|
||||
for item in response.get("data", {}).get("suit_style") or []
|
||||
@ -454,7 +454,7 @@ class AIPPTGenerateToolAdapter:
|
||||
response = response.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise Exception(f'Failed to connect to aippt: {response.get("msg")}')
|
||||
raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
|
||||
|
||||
if len(response.get("data", {}).get("list") or []) > 0:
|
||||
return response.get("data", {}).get("list")[0].get("id")
|
||||
|
@ -229,8 +229,7 @@ class NovaReelTool(BuiltinTool):
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\nVideo will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
@ -65,7 +65,7 @@ class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
if "trans_result" in result:
|
||||
result_text = result["trans_result"][0]["dst"]
|
||||
else:
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
|
||||
return self.create_text_message(str(result_text))
|
||||
except requests.RequestException as e:
|
||||
|
@ -52,7 +52,7 @@ class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
|
||||
result_text = ""
|
||||
if result["error_code"] != 0:
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
else:
|
||||
result_text = result["data"]["src"]
|
||||
result_text = self.mapping_result(description_language, result_text)
|
||||
|
@ -58,7 +58,7 @@ class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase):
|
||||
if "trans_result" in result:
|
||||
result_text = result["trans_result"][0]["dst"]
|
||||
else:
|
||||
result_text = f'{result["error_code"]}: {result["error_msg"]}'
|
||||
result_text = f"{result['error_code']}: {result['error_msg']}"
|
||||
|
||||
return self.create_text_message(str(result_text))
|
||||
except requests.RequestException as e:
|
||||
|
@ -30,7 +30,7 @@ class BingSearchTool(BuiltinTool):
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
|
||||
|
||||
query = quote(query)
|
||||
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}'
|
||||
server_url = f"{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={','.join(filters)}"
|
||||
response = get(server_url, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
@ -47,23 +47,23 @@ class BingSearchTool(BuiltinTool):
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(text=f'{result["name"]}{url}'))
|
||||
url = f": {result['url']}" if "url" in result else ""
|
||||
results.append(self.create_text_message(text=f"{result['name']}{url}"))
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}'))
|
||||
url = f": {entity['url']}" if "url" in entity else ""
|
||||
results.append(self.create_text_message(text=f"{entity.get('name', '')}{url}"))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}'))
|
||||
url = f": {news_item['url']}" if "url" in news_item else ""
|
||||
results.append(self.create_text_message(text=f"{news_item.get('name', '')}{url}"))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}'))
|
||||
url = f": {related['displayText']}" if "displayText" in related else ""
|
||||
results.append(self.create_text_message(text=f"{related.get('displayText', '')}{url}"))
|
||||
|
||||
return results
|
||||
elif result_type == "json":
|
||||
@ -106,29 +106,29 @@ class BingSearchTool(BuiltinTool):
|
||||
text = ""
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
text += f"{i + 1}: {result.get('name', '')} - {result.get('snippet', '')}\n"
|
||||
|
||||
if computation and "expression" in computation and "value" in computation:
|
||||
text += "\nComputation:\n"
|
||||
text += f'{computation["expression"]} = {computation["value"]}\n'
|
||||
text += f"{computation['expression']} = {computation['value']}\n"
|
||||
|
||||
if entities:
|
||||
text += "\nEntities:\n"
|
||||
for entity in entities:
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
url = f"- {entity['url']}" if "url" in entity else ""
|
||||
text += f"{entity.get('name', '')}{url}\n"
|
||||
|
||||
if news:
|
||||
text += "\nNews:\n"
|
||||
for news_item in news:
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
url = f"- {news_item['url']}" if "url" in news_item else ""
|
||||
text += f"{news_item.get('name', '')}{url}\n"
|
||||
|
||||
if related_searches:
|
||||
text += "\n\nRelated Searches:\n"
|
||||
for related in related_searches:
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
url = f"- {related['webSearchUrl']}" if "webSearchUrl" in related else ""
|
||||
text += f"{related.get('displayText', '')}{url}\n"
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
@ -83,5 +83,5 @@ class DIDApp:
|
||||
if status["status"] == "done":
|
||||
return status
|
||||
elif status["status"] == "error" or status["status"] == "rejected":
|
||||
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}')
|
||||
raise HTTPError(f"Talks {id} failed: {status['status']} {status.get('error', {}).get('description')}")
|
||||
time.sleep(poll_interval)
|
||||
|
@ -20,33 +20,33 @@ class SendEmailToolParameters(BaseModel):
|
||||
encrypt_method: str
|
||||
|
||||
|
||||
def send_mail(parmas: SendEmailToolParameters):
|
||||
def send_mail(params: SendEmailToolParameters):
|
||||
timeout = 60
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = parmas.email_account
|
||||
msg["To"] = parmas.sender_to
|
||||
msg["Subject"] = parmas.subject
|
||||
msg.attach(MIMEText(parmas.email_content, "plain"))
|
||||
msg.attach(MIMEText(parmas.email_content, "html"))
|
||||
msg["From"] = params.email_account
|
||||
msg["To"] = params.sender_to
|
||||
msg["Subject"] = params.subject
|
||||
msg.attach(MIMEText(params.email_content, "plain"))
|
||||
msg.attach(MIMEText(params.email_content, "html"))
|
||||
|
||||
ctx = ssl.create_default_context()
|
||||
|
||||
if parmas.encrypt_method.upper() == "SSL":
|
||||
if params.encrypt_method.upper() == "SSL":
|
||||
try:
|
||||
with smtplib.SMTP_SSL(parmas.smtp_server, parmas.smtp_port, context=ctx, timeout=timeout) as server:
|
||||
server.login(parmas.email_account, parmas.email_password)
|
||||
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
|
||||
with smtplib.SMTP_SSL(params.smtp_server, params.smtp_port, context=ctx, timeout=timeout) as server:
|
||||
server.login(params.email_account, params.email_password)
|
||||
server.sendmail(params.email_account, params.sender_to, msg.as_string())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception("send email failed")
|
||||
return False
|
||||
else: # NONE or TLS
|
||||
try:
|
||||
with smtplib.SMTP(parmas.smtp_server, parmas.smtp_port, timeout=timeout) as server:
|
||||
if parmas.encrypt_method.upper() == "TLS":
|
||||
with smtplib.SMTP(params.smtp_server, params.smtp_port, timeout=timeout) as server:
|
||||
if params.encrypt_method.upper() == "TLS":
|
||||
server.starttls(context=ctx)
|
||||
server.login(parmas.email_account, parmas.email_password)
|
||||
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string())
|
||||
server.login(params.email_account, params.email_password)
|
||||
server.sendmail(params.email_account, params.sender_to, msg.as_string())
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.exception("send email failed")
|
||||
|
@ -74,7 +74,7 @@ class FirecrawlApp:
|
||||
if response is None:
|
||||
raise HTTPError("Failed to initiate crawl after multiple retries")
|
||||
elif response.get("success") == False:
|
||||
raise HTTPError(f'Failed to crawl: {response.get("error")}')
|
||||
raise HTTPError(f"Failed to crawl: {response.get('error')}")
|
||||
job_id: str = response["id"]
|
||||
if wait:
|
||||
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
|
||||
@ -100,7 +100,7 @@ class FirecrawlApp:
|
||||
if status["status"] == "completed":
|
||||
return status
|
||||
elif status["status"] == "failed":
|
||||
raise HTTPError(f'Job {job_id} failed: {status["error"]}')
|
||||
raise HTTPError(f"Job {job_id} failed: {status['error']}")
|
||||
time.sleep(poll_interval)
|
||||
|
||||
|
||||
|
@ -37,8 +37,9 @@ class GaodeRepositoriesTool(BuiltinTool):
|
||||
CityCode = City_data["districts"][0]["adcode"]
|
||||
weatherInfo_response = s.request(
|
||||
method="GET",
|
||||
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json"
|
||||
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")),
|
||||
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json".format(
|
||||
url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")
|
||||
),
|
||||
)
|
||||
weatherInfo_data = weatherInfo_response.json()
|
||||
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":
|
||||
|
@ -11,19 +11,21 @@ class GitlabFilesTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
project = tool_parameters.get("project", "")
|
||||
repository = tool_parameters.get("repository", "")
|
||||
project = tool_parameters.get("project", "")
|
||||
branch = tool_parameters.get("branch", "")
|
||||
path = tool_parameters.get("path", "")
|
||||
file_path = tool_parameters.get("file_path", "")
|
||||
|
||||
if not project and not repository:
|
||||
return self.create_text_message("Either project or repository is required")
|
||||
if not repository and not project:
|
||||
return self.create_text_message("Either repository or project is required")
|
||||
if not branch:
|
||||
return self.create_text_message("Branch is required")
|
||||
if not path:
|
||||
return self.create_text_message("Path is required")
|
||||
if not path and not file_path:
|
||||
return self.create_text_message("Either path or file_path is required")
|
||||
|
||||
access_token = self.runtime.credentials.get("access_tokens")
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
site_url = self.runtime.credentials.get("site_url")
|
||||
|
||||
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
|
||||
@ -31,33 +33,45 @@ class GitlabFilesTool(BuiltinTool):
|
||||
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
|
||||
site_url = "https://gitlab.com"
|
||||
|
||||
# Get file content
|
||||
if repository:
|
||||
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True)
|
||||
# URL encode the repository path
|
||||
identifier = urllib.parse.quote(repository, safe="")
|
||||
else:
|
||||
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False)
|
||||
identifier = self.get_project_id(site_url, access_token, project)
|
||||
if not identifier:
|
||||
raise Exception(f"Project '{project}' not found.)")
|
||||
|
||||
return [self.create_json_message(item) for item in result]
|
||||
# Get file content
|
||||
if path:
|
||||
results = self.fetch_files(site_url, headers, identifier, branch, path)
|
||||
return [self.create_json_message(item) for item in results]
|
||||
else:
|
||||
result = self.fetch_file(site_url, headers, identifier, branch, file_path)
|
||||
return [self.create_json_message(result)]
|
||||
|
||||
@staticmethod
|
||||
def fetch_file(
|
||||
site_url: str,
|
||||
headers: dict[str, str],
|
||||
identifier: str,
|
||||
branch: str,
|
||||
path: str,
|
||||
) -> dict[str, Any]:
|
||||
encoded_file_path = urllib.parse.quote(path, safe="")
|
||||
file_url = f"{site_url}/api/v4/projects/{identifier}/repository/files/{encoded_file_path}/raw?ref={branch}"
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
return {"path": path, "branch": branch, "content": file_content}
|
||||
|
||||
def fetch_files(
|
||||
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool
|
||||
self, site_url: str, headers: dict[str, str], identifier: str, branch: str, path: str
|
||||
) -> list[dict[str, Any]]:
|
||||
domain = site_url
|
||||
headers = {"PRIVATE-TOKEN": access_token}
|
||||
results = []
|
||||
|
||||
try:
|
||||
if is_repository:
|
||||
# URL encode the repository path
|
||||
encoded_identifier = urllib.parse.quote(identifier, safe="")
|
||||
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
|
||||
else:
|
||||
# Get project ID from project name
|
||||
project_id = self.get_project_id(site_url, access_token, identifier)
|
||||
if not project_id:
|
||||
return self.create_text_message(f"Project '{identifier}' not found.")
|
||||
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
|
||||
|
||||
tree_url = f"{site_url}/api/v4/projects/{identifier}/repository/tree?path={path}&ref={branch}"
|
||||
response = requests.get(tree_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
items = response.json()
|
||||
@ -65,26 +79,10 @@ class GitlabFilesTool(BuiltinTool):
|
||||
for item in items:
|
||||
item_path = item["path"]
|
||||
if item["type"] == "tree": # It's a directory
|
||||
results.extend(
|
||||
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
|
||||
)
|
||||
results.extend(self.fetch_files(site_url, headers, identifier, branch, item_path))
|
||||
else: # It's a file
|
||||
encoded_item_path = urllib.parse.quote(item_path, safe="")
|
||||
if is_repository:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
|
||||
f"/{encoded_item_path}/raw?ref={branch}"
|
||||
)
|
||||
else:
|
||||
file_url = (
|
||||
f"{domain}/api/v4/projects/{project_id}/repository/files"
|
||||
f"{encoded_item_path}/raw?ref={branch}"
|
||||
)
|
||||
|
||||
file_response = requests.get(file_url, headers=headers)
|
||||
file_response.raise_for_status()
|
||||
file_content = file_response.text
|
||||
results.append({"path": item_path, "branch": branch, "content": file_content})
|
||||
result = self.fetch_file(site_url, headers, identifier, branch, item_path)
|
||||
results.append(result)
|
||||
except requests.RequestException as e:
|
||||
print(f"Error fetching data from GitLab: {e}")
|
||||
|
||||
|
@ -29,7 +29,7 @@ parameters:
|
||||
zh_Hans: 项目
|
||||
human_description:
|
||||
en_US: project
|
||||
zh_Hans: 项目
|
||||
zh_Hans: 项目(和仓库路径二选一,都填写以仓库路径优先)
|
||||
llm_description: Project for GitLab
|
||||
form: llm
|
||||
- name: branch
|
||||
@ -45,12 +45,21 @@ parameters:
|
||||
form: llm
|
||||
- name: path
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: path
|
||||
zh_Hans: 文件路径
|
||||
zh_Hans: 文件夹
|
||||
human_description:
|
||||
en_US: path
|
||||
zh_Hans: 文件夹
|
||||
llm_description: Dir path for GitLab
|
||||
form: llm
|
||||
- name: file_path
|
||||
type: string
|
||||
label:
|
||||
en_US: file_path
|
||||
zh_Hans: 文件路径
|
||||
human_description:
|
||||
en_US: file_path
|
||||
zh_Hans: 文件路径(和文件夹二选一,都填写以文件夹优先)
|
||||
llm_description: File path for GitLab
|
||||
form: llm
|
||||
|
@ -110,7 +110,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
|
||||
result["rows"].append(self.get_row_field_value(row, schema))
|
||||
return self.create_text_message(json.dumps(result, ensure_ascii=False))
|
||||
else:
|
||||
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
|
||||
result_text = f'Found {result["total"]} rows in worksheet "{worksheet_name}".'
|
||||
if result["total"] > 0:
|
||||
result_text += (
|
||||
f" The following are {min(limit, result['total'])}"
|
||||
|
@ -28,4 +28,4 @@ class BaseStabilityAuthorization:
|
||||
"""
|
||||
This method is responsible for generating the authorization headers.
|
||||
"""
|
||||
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'}
|
||||
return {"Authorization": f"Bearer {credentials.get('api_key', '')}"}
|
||||
|
@ -38,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
|
||||
tool_parameters={
|
||||
"model": "chinook",
|
||||
"db_type": "SQLite",
|
||||
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite',
|
||||
"url": f"{self._get_protocol_and_main_domain(credentials['base_url'])}/Chinook.sqlite",
|
||||
"query": "What are the top 10 customers by sales?",
|
||||
},
|
||||
)
|
||||
|
@ -43,7 +43,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply Job Search."""
|
||||
jobs = res.get("jobs", [])
|
||||
if not jobs:
|
||||
if not res or "jobs" not in res:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
@ -43,7 +43,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply News Search."""
|
||||
news = res.get("entries", [])
|
||||
if not news:
|
||||
if not res or "entries" not in res:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
@ -43,7 +43,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply News Search."""
|
||||
articles = res.get("articles", [])
|
||||
if not articles:
|
||||
if not res or "articles" not in res:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
@ -42,7 +42,7 @@ class SerplyApi:
|
||||
def parse_results(res: dict) -> str:
|
||||
"""Process response from Serply Web Search."""
|
||||
results = res.get("results", [])
|
||||
if not results:
|
||||
if not res or "results" not in res:
|
||||
raise ValueError(f"Got error from Serply: {res}")
|
||||
|
||||
string = []
|
||||
|
@ -84,9 +84,9 @@ class ApiTool(Tool):
|
||||
if "api_key_header_prefix" in credentials:
|
||||
api_key_header_prefix = credentials["api_key_header_prefix"]
|
||||
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
|
||||
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
|
||||
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
|
||||
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
|
||||
elif api_key_header_prefix == "custom":
|
||||
pass
|
||||
|
||||
|
@ -29,7 +29,7 @@ class ToolFileMessageTransformer:
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
@ -122,4 +122,4 @@ class ToolFileMessageTransformer:
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
||||
|
@ -5,6 +5,7 @@ from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
@ -29,6 +30,10 @@ class ApiBasedToolSchemaParser:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
@ -144,7 +149,7 @@ class ApiBasedToolSchemaParser:
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f'{path}_{interface["method"]}'
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
|
@ -134,6 +134,10 @@ class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[str]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return json.dumps(self.value)
|
||||
|
||||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||
@ -48,25 +49,35 @@ class StreamProcessor(ABC):
|
||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
|
||||
if "answer" in ids:
|
||||
continue
|
||||
else:
|
||||
reachable_node_ids.extend(ids)
|
||||
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
|
||||
# if "answer" in ids:
|
||||
# continue
|
||||
# else:
|
||||
# reachable_node_ids.extend(ids)
|
||||
|
||||
# The branch_identify parameter is added to ensure that
|
||||
# only nodes in the correct logical branch are included.
|
||||
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
|
||||
reachable_node_ids.extend(ids)
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
# Only follow edges that match the branch_identify or have no run_condition
|
||||
if edge.run_condition and edge.run_condition.branch_identify:
|
||||
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
|
@ -253,9 +253,9 @@ class Executor:
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ResponseSizeError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
|
||||
f' but current size is {executor_response.readable_size}.'
|
||||
f"{'File' if executor_response.is_file else 'Text'} size is too large,"
|
||||
f" max size is {threshold_size / 1024 / 1024:.2f} MB,"
|
||||
f" but current size is {executor_response.readable_size}."
|
||||
)
|
||||
|
||||
return executor_response
|
||||
@ -338,7 +338,7 @@ class Executor:
|
||||
if self.auth.config and self.auth.config.header:
|
||||
authorization_header = self.auth.config.header
|
||||
if k.lower() == authorization_header.lower():
|
||||
raw += f'{k}: {"*" * len(v)}\r\n'
|
||||
raw += f"{k}: {'*' * len(v)}\r\n"
|
||||
continue
|
||||
raw += f"{k}: {v}\r\n"
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
@ -20,10 +19,8 @@ from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
@ -64,23 +61,6 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
||||
)
|
||||
# check rate limit
|
||||
if self.tenant_id:
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
try:
|
||||
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
@ -31,7 +32,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
inputs = self.node_data.model_dump()
|
||||
process_data: dict[str, Any] = {}
|
||||
# NOTE: This node has no outputs
|
||||
updated_variables: list[Variable] = []
|
||||
updated_variable_selectors: list[Sequence[str]] = []
|
||||
|
||||
try:
|
||||
for item in self.node_data.items:
|
||||
@ -98,7 +99,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
value=item.value,
|
||||
)
|
||||
variable = variable.model_copy(update={"value": updated_value})
|
||||
updated_variables.append(variable)
|
||||
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
|
||||
updated_variable_selectors.append(variable.selector)
|
||||
except VariableOperatorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -107,9 +109,15 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# The `updated_variable_selectors` is a list contains list[str] which not hashable,
|
||||
# remove the duplicated items first.
|
||||
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
|
||||
|
||||
# Update variables
|
||||
for variable in updated_variables:
|
||||
self.graph_runtime_state.variable_pool.add(variable.selector, variable)
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:
|
||||
|
@ -26,7 +26,7 @@ def handle(sender, **kwargs):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=tool_entity.provider_name,
|
||||
provider_type=tool_entity.provider_type,
|
||||
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}',
|
||||
identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
except:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from flask_restful import fields # type: ignore
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String}
|
||||
|
||||
@ -8,6 +8,7 @@ account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"is_password_set": fields.Boolean,
|
||||
"interface_language": fields.String,
|
||||
@ -22,6 +23,7 @@ account_with_role_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"avatar": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
"email": fields.String,
|
||||
"last_login_at": TimestampField,
|
||||
"last_active_at": TimestampField,
|
||||
|
@ -41,6 +41,18 @@ class AppIconUrlField(fields.Raw):
|
||||
return None
|
||||
|
||||
|
||||
class AvatarUrlField(fields.Raw):
|
||||
def output(self, key, obj):
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
from models.account import Account
|
||||
|
||||
if isinstance(obj, Account) and obj.avatar is not None:
|
||||
return file_helpers.get_signed_file_url(obj.avatar)
|
||||
return None
|
||||
|
||||
|
||||
class TimestampField(fields.Raw):
|
||||
def format(self, value) -> int:
|
||||
return int(value.timestamp())
|
||||
|
@ -13,6 +13,7 @@ from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -515,7 +516,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
position: Mapped[int]
|
||||
content = db.Column(db.Text, nullable=False)
|
||||
answer = db.Column(db.Text, nullable=True)
|
||||
word_count = db.Column(db.Integer, nullable=False)
|
||||
|
1074
api/poetry.lock
generated
1074
api/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,9 +1,10 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
dynamic = [ "dependencies" ]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
requires = ["poetry-core>=2.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
@ -59,6 +60,7 @@ numpy = "~1.26.4"
|
||||
oci = "~2.135.1"
|
||||
openai = "~1.52.0"
|
||||
openpyxl = "~3.1.5"
|
||||
opik = "~1.3.4"
|
||||
pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
|
||||
pandas-stubs = "~2.2.3.241009"
|
||||
psycogreen = "~1.0.2"
|
||||
@ -190,4 +192,4 @@ pytest-mock = "~3.14.0"
|
||||
optional = true
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
dotenv-linter = "~0.5.0"
|
||||
ruff = "~0.8.1"
|
||||
ruff = "~0.9.2"
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -139,15 +139,6 @@ class AppDslService:
|
||||
status=ImportStatus.FAILED,
|
||||
error="Empty content from url",
|
||||
)
|
||||
|
||||
try:
|
||||
content = cast(bytes, content).decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
return Import(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Error decoding content: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
return Import(
|
||||
id=import_id,
|
||||
|
@ -82,7 +82,7 @@ class AudioService:
|
||||
from app import app
|
||||
from extensions.ext_database import db
|
||||
|
||||
def invoke_tts(text_content: str, app_model, voice: Optional[str] = None):
|
||||
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None):
|
||||
with app.app_context():
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
@ -95,6 +95,8 @@ class AudioService:
|
||||
|
||||
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
|
||||
else:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
|
@ -19,14 +19,6 @@ class BillingService:
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
|
||||
|
||||
return knowledge_rate_limit.get("limit", 10)
|
||||
|
||||
@classmethod
|
||||
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
|
||||
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
@ -221,8 +222,7 @@ class DatasetService:
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
|
||||
@ -859,7 +859,7 @@ class DocumentService:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
@ -901,7 +901,7 @@ class DocumentService:
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
@ -916,8 +916,8 @@ class DocumentService:
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
@ -956,7 +956,7 @@ class DocumentService:
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
@ -976,8 +976,8 @@ class DocumentService:
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
@ -996,7 +996,7 @@ class DocumentService:
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
data_source_info,
|
||||
@ -1195,20 +1195,20 @@ class DocumentService:
|
||||
|
||||
if features.billing.enabled:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
||||
upload_file_list = (
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
if knowledge_config.data_source.info_list.file_info_list
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
if knowledge_config.data_source.info_list.file_info_list # type: ignore
|
||||
else []
|
||||
)
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||
if notion_info_list:
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
@ -1239,7 +1239,7 @@ class DocumentService:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name="",
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
||||
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
||||
indexing_technique=knowledge_config.indexing_technique,
|
||||
created_by=account.id,
|
||||
embedding_model=knowledge_config.embedding_model,
|
||||
@ -1611,8 +1611,11 @@ class SegmentService:
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
keyword_changed = False
|
||||
if args.keywords:
|
||||
segment.keywords = args.keywords
|
||||
if Counter(segment.keywords) != Counter(args.keywords):
|
||||
segment.keywords = args.keywords
|
||||
keyword_changed = True
|
||||
segment.enabled = True
|
||||
segment.disabled_at = None
|
||||
segment.disabled_by = None
|
||||
@ -1623,13 +1626,6 @@ class SegmentService:
|
||||
document.word_count = max(0, document.word_count + word_count_change)
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if args.enabled:
|
||||
VectorService.create_segments_vector(
|
||||
[args.keywords] if args.keywords else None,
|
||||
[segment],
|
||||
dataset,
|
||||
document.doc_form,
|
||||
)
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
@ -1662,6 +1658,14 @@ class SegmentService:
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
|
||||
if args.enabled or keyword_changed:
|
||||
VectorService.create_segments_vector(
|
||||
[args.keywords] if args.keywords else None,
|
||||
[segment],
|
||||
dataset,
|
||||
document.doc_form,
|
||||
)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
|
@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
|
||||
original_document_id: Optional[str] = None
|
||||
duplicate: bool = True
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
data_source: DataSource
|
||||
data_source: Optional[DataSource] = None
|
||||
process_rule: Optional[ProcessRule] = None
|
||||
retrieval_model: Optional[RetrievalModel] = None
|
||||
doc_form: str = "text_model"
|
||||
|
@ -155,7 +155,7 @@ class ExternalDatasetService:
|
||||
if custom_parameters:
|
||||
for parameter in custom_parameters:
|
||||
if parameter.get("required", False) and not process_parameter.get(parameter.get("name")):
|
||||
raise ValueError(f'{parameter.get("name")} is required')
|
||||
raise ValueError(f"{parameter.get('name')} is required")
|
||||
|
||||
@staticmethod
|
||||
def process_external_api(
|
||||
|
@ -41,7 +41,6 @@ class FeatureModel(BaseModel):
|
||||
members: LimitationModel = LimitationModel(size=0, limit=1)
|
||||
apps: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
|
||||
knowledge_rate_limit: int = 10
|
||||
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
|
||||
docs_processing: str = "standard"
|
||||
@ -53,11 +52,6 @@ class FeatureModel(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class KnowledgeRateLimitModel(BaseModel):
|
||||
enabled: bool = False
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
@ -85,14 +79,6 @@ class FeatureService:
|
||||
|
||||
return features
|
||||
|
||||
@classmethod
|
||||
def get_knowledge_rate_limit(cls, tenant_id: str):
|
||||
knowledge_rate_limit = KnowledgeRateLimitModel()
|
||||
if dify_config.BILLING_ENABLED and tenant_id:
|
||||
knowledge_rate_limit.enabled = True
|
||||
knowledge_rate_limit.limit = BillingService.get_knowledge_rate_limit(tenant_id)
|
||||
return knowledge_rate_limit
|
||||
|
||||
@classmethod
|
||||
def get_system_features(cls) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
@ -158,9 +144,6 @@ class FeatureService:
|
||||
if "model_load_balancing_enabled" in billing_info:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
if "knowledge_rate_limit" in billing_info:
|
||||
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
@ -59,6 +59,15 @@ class OpsService:
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"})
|
||||
|
||||
if tracing_provider == "opik" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
|
||||
|
||||
trace_config_data.tracing_config = new_decrypt_tracing_config
|
||||
return trace_config_data.to_dict()
|
||||
|
||||
@ -92,7 +101,7 @@ class OpsService:
|
||||
if tracing_provider == "langfuse":
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key)
|
||||
elif tracing_provider == "langsmith":
|
||||
elif tracing_provider in ("langsmith", "opik"):
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
|
||||
else:
|
||||
project_url = None
|
||||
|
@ -5,7 +5,8 @@ import uuid
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
@ -18,7 +19,12 @@ from services.vector_service import VectorService
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_create_segment_to_index_task(
|
||||
job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str
|
||||
job_id: str,
|
||||
content: list,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Async batch create segment to index
|
||||
@ -37,71 +43,80 @@ def batch_create_segment_to_index_task(
|
||||
indexing_cache_key = "segment_batch_import_{}".format(job_id)
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not exist.")
|
||||
|
||||
dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
dataset_document = session.get(Document, document_id)
|
||||
if not dataset_document:
|
||||
raise ValueError("Document not exist.")
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
raise ValueError("Document is not available.")
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
if (
|
||||
not dataset_document.enabled
|
||||
or dataset_document.archived
|
||||
or dataset_document.indexing_status != "completed"
|
||||
):
|
||||
raise ValueError("Document is not available.")
|
||||
document_segments = []
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
)
|
||||
word_count_change = 0
|
||||
segments_to_insert: list[str] = []
|
||||
max_position_stmt = select(func.max(DocumentSegment.position)).where(
|
||||
DocumentSegment.document_id == dataset_document.id
|
||||
)
|
||||
word_count_change = 0
|
||||
segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str]
|
||||
for segment in content:
|
||||
content_str = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content_str)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
|
||||
max_position = (
|
||||
db.session.query(func.max(DocumentSegment.position))
|
||||
.filter(DocumentSegment.document_id == dataset_document.id)
|
||||
.scalar()
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position + 1 if max_position else 1,
|
||||
content=content_str,
|
||||
word_count=len(content_str),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
status="completed",
|
||||
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
)
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
db.session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
segments_to_insert.append(str(segment)) # Cast to string if needed
|
||||
# update document word count
|
||||
dataset_document.word_count += word_count_change
|
||||
db.session.add(dataset_document)
|
||||
# add index to db
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
db.session.commit()
|
||||
max_position = session.scalar(max_position_stmt) or 1
|
||||
for segment in content:
|
||||
content_str = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content_str)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
index_node_id=doc_id,
|
||||
index_node_hash=segment_hash,
|
||||
position=max_position,
|
||||
content=content_str,
|
||||
word_count=len(content_str),
|
||||
tokens=tokens,
|
||||
created_by=user_id,
|
||||
indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
status="completed",
|
||||
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||
)
|
||||
max_position += 1
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
segment_document.answer = segment["answer"]
|
||||
segment_document.word_count += len(segment["answer"])
|
||||
word_count_change += segment_document.word_count
|
||||
session.add(segment_document)
|
||||
document_segments.append(segment_document)
|
||||
segments_to_insert.append(str(segment)) # Cast to string if needed
|
||||
# update document word count
|
||||
dataset_document.word_count += word_count_change
|
||||
session.add(dataset_document)
|
||||
# add index to db
|
||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
||||
session.commit()
|
||||
|
||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green")
|
||||
click.style(
|
||||
"Segment batch created job: {} latency: {}".format(job_id, end_at - start_at),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("Segments batch created index failed")
|
||||
|
@ -44,6 +44,6 @@ def test_duplicated_dependency_crossing_groups() -> None:
|
||||
dependency_names = list(dependencies.keys())
|
||||
all_dependency_names.extend(dependency_names)
|
||||
expected_all_dependency_names = set(all_dependency_names)
|
||||
assert sorted(expected_all_dependency_names) == sorted(
|
||||
all_dependency_names
|
||||
), "Duplicated dependencies crossing groups are found"
|
||||
assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), (
|
||||
"Duplicated dependencies crossing groups are found"
|
||||
)
|
||||
|
@ -89,9 +89,9 @@ class TestOpenSearchVector:
|
||||
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
|
||||
|
||||
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
|
||||
assert (
|
||||
hits_by_vector[0].metadata["document_id"] == self.example_doc_id
|
||||
), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
|
||||
assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, (
|
||||
f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
|
||||
)
|
||||
|
||||
def test_get_ids_by_metadata_field(self):
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
|
@ -438,9 +438,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
|
||||
# Verify the result
|
||||
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||
assert (
|
||||
prompt_messages == scenario.expected_messages
|
||||
), f"Message content mismatch in scenario: {scenario.description}"
|
||||
assert prompt_messages == scenario.expected_messages, (
|
||||
f"Message content mismatch in scenario: {scenario.description}"
|
||||
)
|
||||
|
||||
|
||||
def test_handle_list_messages_basic(llm_node):
|
||||
|
@ -401,8 +401,7 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n"
|
||||
"Human: hi\nAssistant: ",
|
||||
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
|
||||
),
|
||||
)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user