diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 7498072e3e..48c57193b3 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource): 'enabled': v.enabled, 'min': v.min, 'max': v.max, - 'default': v.default + 'default': v.default, + 'precision': v.precision } for k, v in vars(parameter_rules).items() } @@ -290,10 +291,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): @login_required @account_initialization_required def get(self, provider_name: str): + parser = reqparse.RequestParser() + parser.add_argument('token', type=str, required=False, nullable=True, location='args') + args = parser.parse_args() + provider_service = ProviderService() result = provider_service.free_quota_qualification_verify( tenant_id=current_user.current_tenant_id, - provider_name=provider_name + provider_name=provider_name, + token=args['token'] ) return result diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index b1f3ec393e..b8eb99b2e5 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -63,7 +63,18 @@ class LLMCallbackHandler(BaseCallbackHandler): self.conversation_message_task.append_message_text(response.generations[0][0].text) self.llm_message.completion = response.generations[0][0].text - self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) + if response.llm_output and 'token_usage' in response.llm_output: + if 'prompt_tokens' in response.llm_output['token_usage']: + self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] + + if 'completion_tokens' in response.llm_output['token_usage']: + self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens'] + else: + self.llm_message.completion_tokens = self.model_instance.get_num_tokens( + [PromptMessage(content=self.llm_message.completion)]) + else: + self.llm_message.completion_tokens = self.model_instance.get_num_tokens( + [PromptMessage(content=self.llm_message.completion)]) self.conversation_message_task.save_message(self.llm_message) diff --git a/api/core/chain/sensitive_word_avoidance_chain.py b/api/core/chain/sensitive_word_avoidance_chain.py index 5fc20c5cea..62d5854275 100644 --- a/api/core/chain/sensitive_word_avoidance_chain.py +++ b/api/core/chain/sensitive_word_avoidance_chain.py @@ -2,13 +2,8 @@ import enum import logging from typing import List, Dict, Optional, Any -import openai -from flask import current_app from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from openai import InvalidRequestError -from openai.error import APIConnectionError, APIError, ServiceUnavailableError, Timeout, RateLimitError, \ - AuthenticationError, OpenAIError from pydantic import BaseModel from core.model_providers.error import LLMBadRequestError @@ -86,6 +81,12 @@ class SensitiveWordAvoidanceChain(Chain): result = self._check_moderation(text) if not result: - raise LLMBadRequestError(self.sensitive_word_avoidance_rule.canned_response) + raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response) return {self.output_key: text} + + +class SensitiveWordAvoidanceError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message diff --git a/api/core/completion.py b/api/core/completion.py index bb2da1e8ec..b66e965b33 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -7,6 +7,7 @@ from requests.exceptions import ChunkedEncodingError from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler +from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ @@ -76,28 +77,53 @@ class Completion: app_model_config=app_model_config ) - # parse sensitive_word_avoidance_chain - chain_callback = MainChainGatherCallbackHandler(conversation_message_task) - sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(final_model_instance, [chain_callback]) - if sensitive_word_avoidance_chain: - query = sensitive_word_avoidance_chain.run(query) - - # get agent executor - agent_executor = orchestrator_rule_parser.to_agent_executor( - conversation_message_task=conversation_message_task, - memory=memory, - rest_tokens=rest_tokens_for_context_and_memory, - chain_callback=chain_callback - ) - - # run agent executor - agent_execute_result = None - if agent_executor: - should_use_agent = agent_executor.should_use_agent(query) - if should_use_agent: - agent_execute_result = agent_executor.run(query) - # run the final llm try: + # parse sensitive_word_avoidance_chain + chain_callback = MainChainGatherCallbackHandler(conversation_message_task) + sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain( + final_model_instance, [chain_callback]) + if sensitive_word_avoidance_chain: + try: + query = sensitive_word_avoidance_chain.run(query) + except SensitiveWordAvoidanceError as ex: + cls.run_final_llm( + model_instance=final_model_instance, + mode=app.mode, + app_model_config=app_model_config, + query=query, + inputs=inputs, + agent_execute_result=None, + conversation_message_task=conversation_message_task, + memory=memory, + fake_response=ex.message + ) + return + + # get agent executor + agent_executor = orchestrator_rule_parser.to_agent_executor( + conversation_message_task=conversation_message_task, + memory=memory, + rest_tokens=rest_tokens_for_context_and_memory, + chain_callback=chain_callback, + retriever_from=retriever_from + ) + + # run agent executor + agent_execute_result = None + if agent_executor: + should_use_agent = agent_executor.should_use_agent(query) + if should_use_agent: + agent_execute_result = agent_executor.run(query) + + # When no extra pre prompt is specified, + # the output of the agent can be used directly as the main output content without calling LLM again + fake_response = None + if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ + and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, + PlanningStrategy.REACT_ROUTER]: + fake_response = agent_execute_result.output + + # run the final llm cls.run_final_llm( model_instance=final_model_instance, mode=app.mode, @@ -106,7 +132,8 @@ class Completion: inputs=inputs, agent_execute_result=agent_execute_result, conversation_message_task=conversation_message_task, - memory=memory + memory=memory, + fake_response=fake_response ) except ConversationTaskStoppedException: return @@ -121,14 +148,8 @@ class Completion: inputs: dict, agent_execute_result: Optional[AgentExecuteResult], conversation_message_task: ConversationMessageTask, - memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): - # When no extra pre prompt is specified, - # the output of the agent can be used directly as the main output content without calling LLM again - fake_response = None - if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ - and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]: - fake_response = agent_execute_result.output - + memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], + fake_response: Optional[str]): # get llm prompt prompt_messages, stop_words = model_instance.get_prompt( mode=mode, diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index e567e9ed22..63d968774e 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,32 +1,34 @@ import logging import openai -from flask import current_app from core.model_providers.error import LLMBadRequestError from core.model_providers.providers.base import BaseModelProvider +from core.model_providers.providers.hosted import hosted_config, hosted_model_providers from models.provider import ProviderType def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: - if current_app.config['HOSTED_MODERATION_ENABLED'] and current_app.config['HOSTED_MODERATION_PROVIDERS']: - moderation_providers = current_app.config['HOSTED_MODERATION_PROVIDERS'].split(',') - + if hosted_config.moderation.enabled is True and hosted_model_providers.openai: if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ - and model_provider.provider_name in moderation_providers: + and model_provider.provider_name in hosted_config.moderation.providers: # 2000 text per chunk length = 2000 - chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i:i + length] for i in range(0, len(text), length)] - try: - moderation_result = openai.Moderation.create(input=chunks, - api_key=current_app.config['HOSTED_OPENAI_API_KEY']) - except Exception as ex: - logging.exception(ex) - raise LLMBadRequestError('Rate limit exceeded, please try again later.') + max_text_chunks = 32 + chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] - for result in moderation_result.results: - if result['flagged'] is True: - return False + for text_chunk in chunks: + try: + moderation_result = openai.Moderation.create(input=text_chunk, + api_key=hosted_model_providers.openai.api_key) + except Exception as ex: + logging.exception(ex) + raise LLMBadRequestError('Rate limit exceeded, please try again later.') + + for result in moderation_result.results: + if result['flagged'] is True: + return False return True diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py index 45989d430f..0517676807 100644 --- a/api/core/model_providers/model_provider_factory.py +++ b/api/core/model_providers/model_provider_factory.py @@ -45,6 +45,9 @@ class ModelProviderFactory: elif provider_name == 'wenxin': from core.model_providers.providers.wenxin_provider import WenxinProvider return WenxinProvider + elif provider_name == 'zhipuai': + from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider + return ZhipuAIProvider elif provider_name == 'chatglm': from core.model_providers.providers.chatglm_provider import ChatGLMProvider return ChatGLMProvider diff --git a/api/core/model_providers/models/embedding/zhipuai_embedding.py b/api/core/model_providers/models/embedding/zhipuai_embedding.py new file mode 100644 index 0000000000..97d5056c37 --- /dev/null +++ b/api/core/model_providers/models/embedding/zhipuai_embedding.py @@ -0,0 +1,22 @@ +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import BaseModelProvider +from core.model_providers.models.embedding.base import BaseEmbedding +from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings + + +class ZhipuAIEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = ZhipuAIEmbeddings( + model=name, + **credentials, + ) + + super().__init__(model_provider, client, name) + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}") diff --git a/api/core/model_providers/models/entity/model_params.py b/api/core/model_providers/models/entity/model_params.py index 2a6a1bc510..225a5cc674 100644 --- a/api/core/model_providers/models/entity/model_params.py +++ b/api/core/model_providers/models/entity/model_params.py @@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel): max: Optional[T] = None default: Optional[T] = None alias: Optional[str] = None + precision: Optional[int] = None class ModelKwargsRules(BaseModel): diff --git a/api/core/model_providers/models/llm/zhipuai_model.py b/api/core/model_providers/models/llm/zhipuai_model.py new file mode 100644 index 0000000000..7f32c1dc70 --- /dev/null +++ b/api/core/model_providers/models/llm/zhipuai_model.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM + + +class ZhipuAIModel(BaseLLM): + model_mode: ModelMode = ModelMode.CHAT + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return ZhipuAIChatLLM( + streaming=self.streaming, + callbacks=self.callbacks, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens_from_messages(prompts), 0) + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"ZhipuAI: {str(ex)}") + + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/models/moderation/openai_moderation.py b/api/core/model_providers/models/moderation/openai_moderation.py index 9aeb6f0292..f3dedc542a 100644 --- a/api/core/model_providers/models/moderation/openai_moderation.py +++ b/api/core/model_providers/models/moderation/openai_moderation.py @@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration): # 2000 text per chunk length = 2000 - chunks = [text[i:i + length] for i in range(0, len(text), length)] + text_chunks = [text[i:i + length] for i in range(0, len(text), length)] - moderation_result = self._client.create(input=chunks, - api_key=credentials['openai_api_key']) + max_text_chunks = 32 + chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)] - for result in moderation_result.results: - if result['flagged'] is True: - return False + for text_chunk in chunks: + moderation_result = self._client.create(input=text_chunk, + api_key=credentials['openai_api_key']) + + for result in moderation_result.results: + if result['flagged'] is True: + return False return True diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py index e1ef06d140..35532b0ec4 100644 --- a/api/core/model_providers/providers/anthropic_provider.py +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider): :return: """ return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=1, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), + temperature=KwargRule[float](min=0, max=1, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256), + max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/azure_openai_provider.py b/api/core/model_providers/providers/azure_openai_provider.py index 2abedd014c..4f7c8b717c 100644 --- a/api/core/model_providers/providers/azure_openai_provider.py +++ b/api/core/model_providers/providers/azure_openai_provider.py @@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider): model_credentials = self.get_model_credentials(model_name, model_type) return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=1), - presence_penalty=KwargRule[float](min=-2, max=2, default=0), - frequency_penalty=KwargRule[float](min=-2, max=2, default=0), + temperature=KwargRule[float](min=0, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=1, precision=2), + presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get( model_credentials['base_model_name'], 4097 - ), default=16), + ), default=16, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/chatglm_provider.py b/api/core/model_providers/providers/chatglm_provider.py index f905da6f23..4b2a46ad42 100644 --- a/api/core/model_providers/providers/chatglm_provider.py +++ b/api/core/model_providers/providers/chatglm_provider.py @@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider): } return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), + temperature=KwargRule[float](min=0, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048), + max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/hosted.py b/api/core/model_providers/providers/hosted.py index a5f1ce83b6..d2dc39b73f 100644 --- a/api/core/model_providers/providers/hosted.py +++ b/api/core/model_providers/providers/hosted.py @@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel): hosted_model_providers = HostedModelProviders() +class HostedModerationConfig(BaseModel): + enabled: bool = False + providers: list[str] = [] + + +class HostedConfig(BaseModel): + moderation = HostedModerationConfig() + + +hosted_config = HostedConfig() + + def init_app(app: Flask): if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': langchain.verbose = True @@ -78,3 +90,9 @@ def init_app(app: Flask): paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"), paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"), ) + + if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"): + hosted_config.moderation = HostedModerationConfig( + enabled=app.config.get("HOSTED_MODERATION_ENABLED"), + providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',') + ) diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py index f033a28963..75fffac722 100644 --- a/api/core/model_providers/providers/huggingface_hub_provider.py +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -47,11 +47,11 @@ class HuggingfaceHubProvider(BaseModelProvider): :return: """ return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=1), - top_p=KwargRule[float](min=0.01, max=0.99, default=0.7), + temperature=KwargRule[float](min=0, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200), + max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/localai_provider.py b/api/core/model_providers/providers/localai_provider.py index 36cceff0ec..f5b07b1e6c 100644 --- a/api/core/model_providers/providers/localai_provider.py +++ b/api/core/model_providers/providers/localai_provider.py @@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider): :return: """ return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=0.7), - top_p=KwargRule[float](min=0, max=1, default=1), - max_tokens=KwargRule[int](min=10, max=4097, default=16), + temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2), + top_p=KwargRule[float](min=0, max=1, default=1, precision=2), + max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/minimax_provider.py b/api/core/model_providers/providers/minimax_provider.py index 46ec84a6d8..488e6438b4 100644 --- a/api/core/model_providers/providers/minimax_provider.py +++ b/api/core/model_providers/providers/minimax_provider.py @@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider): } return ModelKwargsRules( - temperature=KwargRule[float](min=0.01, max=1, default=0.9), - top_p=KwargRule[float](min=0, max=1, default=0.95), + temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024), + max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/openai_provider.py b/api/core/model_providers/providers/openai_provider.py index 0041d23ca6..1ac6af4429 100644 --- a/api/core/model_providers/providers/openai_provider.py +++ b/api/core/model_providers/providers/openai_provider.py @@ -133,11 +133,11 @@ class OpenAIProvider(BaseModelProvider): } return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=1), - presence_penalty=KwargRule[float](min=-2, max=2, default=0), - frequency_penalty=KwargRule[float](min=-2, max=2, default=0), - max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16), + temperature=KwargRule[float](min=0, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=1, precision=2), + presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/openllm_provider.py b/api/core/model_providers/providers/openllm_provider.py index 5abb5efa63..f1274a8082 100644 --- a/api/core/model_providers/providers/openllm_provider.py +++ b/api/core/model_providers/providers/openllm_provider.py @@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider): :return: """ return ModelKwargsRules( - temperature=KwargRule[float](min=0.01, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), - presence_penalty=KwargRule[float](min=-2, max=2, default=0), - frequency_penalty=KwargRule[float](min=-2, max=2, default=0), - max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128), + temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), + presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/replicate_provider.py b/api/core/model_providers/providers/replicate_provider.py index a5c62e77f1..9324d432a4 100644 --- a/api/core/model_providers/providers/replicate_provider.py +++ b/api/core/model_providers/providers/replicate_provider.py @@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider): min=float(value.get('minimum')) if value.get('minimum') is not None else None, max=float(value.get('maximum')) if value.get('maximum') is not None else None, default=float(value.get('default')) if value.get('default') is not None else None, + precision = 2 ) if key == 'temperature': model_kwargs_rules.temperature = kwarg_rule @@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider): min=int(value.get('minimum')) if value.get('minimum') is not None else 1, max=int(value.get('maximum')) if value.get('maximum') is not None else 8000, default=int(value.get('default')) if value.get('default') is not None else 500, + precision = 0 ) return model_kwargs_rules diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py index 9a3e3643a0..89ed5d30b7 100644 --- a/api/core/model_providers/providers/spark_provider.py +++ b/api/core/model_providers/providers/spark_provider.py @@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider): :return: """ return ModelKwargsRules( - temperature=KwargRule[float](min=0, max=1, default=0.5), + temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2), top_p=KwargRule[float](enabled=False), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](min=10, max=4096, default=2048), + max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/tongyi_provider.py b/api/core/model_providers/providers/tongyi_provider.py index ffa7c72db4..d3074b885c 100644 --- a/api/core/model_providers/providers/tongyi_provider.py +++ b/api/core/model_providers/providers/tongyi_provider.py @@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider): return ModelKwargsRules( temperature=KwargRule[float](enabled=False), - top_p=KwargRule[float](min=0, max=1, default=0.8), + top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024), + max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0), ) @classmethod diff --git a/api/core/model_providers/providers/wenxin_provider.py b/api/core/model_providers/providers/wenxin_provider.py index 1c62b72d95..0def5f15b0 100644 --- a/api/core/model_providers/providers/wenxin_provider.py +++ b/api/core/model_providers/providers/wenxin_provider.py @@ -63,8 +63,8 @@ class WenxinProvider(BaseModelProvider): """ if model_name in ['ernie-bot', 'ernie-bot-turbo']: return ModelKwargsRules( - temperature=KwargRule[float](min=0.01, max=1, default=0.95), - top_p=KwargRule[float](min=0.01, max=1, default=0.8), + temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), + top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), max_tokens=KwargRule[int](enabled=False), diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index 7c43804c7f..f56c5fb59d 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -53,27 +53,27 @@ class XinferenceProvider(BaseModelProvider): credentials = self.get_model_credentials(model_name, model_type) if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm": return ModelKwargsRules( - temperature=KwargRule[float](min=0.01, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), + temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](min=10, max=4000, default=256), + max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), ) elif credentials['model_format'] == "ggmlv3": return ModelKwargsRules( - temperature=KwargRule[float](min=0.01, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), - presence_penalty=KwargRule[float](min=-2, max=2, default=0), - frequency_penalty=KwargRule[float](min=-2, max=2, default=0), - max_tokens=KwargRule[int](min=10, max=4000, default=256), + temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), + presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2), + max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), ) else: return ModelKwargsRules( - temperature=KwargRule[float](min=0.01, max=2, default=1), - top_p=KwargRule[float](min=0, max=1, default=0.7), + temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2), + top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](min=10, max=4000, default=256), + max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0), ) diff --git a/api/core/model_providers/providers/zhipuai_provider.py b/api/core/model_providers/providers/zhipuai_provider.py new file mode 100644 index 0000000000..0f7dae5f4f --- /dev/null +++ b/api/core/model_providers/providers/zhipuai_provider.py @@ -0,0 +1,176 @@ +import json +from json import JSONDecodeError +from typing import Type + +from langchain.schema import HumanMessage + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM +from models.provider import ProviderType, ProviderQuotaType + + +class ZhipuAIProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'zhipuai' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'chatglm_pro', + 'name': 'chatglm_pro', + }, + { + 'id': 'chatglm_std', + 'name': 'chatglm_std', + }, + { + 'id': 'chatglm_lite', + 'name': 'chatglm_lite', + }, + { + 'id': 'chatglm_lite_32k', + 'name': 'chatglm_lite_32k', + } + ] + elif model_type == ModelType.EMBEDDINGS: + return [ + { + 'id': 'text_embedding', + 'name': 'text_embedding', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = ZhipuAIModel + elif model_type == ModelType.EMBEDDINGS: + model_class = ZhipuAIEmbedding + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), + top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](enabled=False), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'api_key' not in credentials: + raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.') + + try: + credential_kwargs = { + 'api_key': credentials['api_key'] + } + + llm = ZhipuAIChatLLM( + temperature=0.01, + **credential_kwargs + ) + + llm([HumanMessage(content='ping')]) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value \ + or (self.provider.provider_type == ProviderType.SYSTEM.value + and self.provider.quota_type == ProviderQuotaType.FREE.value): + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'api_key': None, + } + + if credentials['api_key']: + credentials['api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_key'] + ) + + if obfuscated: + credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) + + return credentials + else: + return {} + + def should_deduct_quota(self): + return True + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/rules/_providers.json b/api/core/model_providers/rules/_providers.json index 3bdf890aa6..f004dc5261 100644 --- a/api/core/model_providers/rules/_providers.json +++ b/api/core/model_providers/rules/_providers.json @@ -6,6 +6,7 @@ "tongyi", "spark", "wenxin", + "zhipuai", "chatglm", "replicate", "huggingface_hub", diff --git a/api/core/model_providers/rules/zhipuai.json b/api/core/model_providers/rules/zhipuai.json new file mode 100644 index 0000000000..4a258a8a6a --- /dev/null +++ b/api/core/model_providers/rules/zhipuai.json @@ -0,0 +1,44 @@ +{ + "support_provider_types": [ + "system", + "custom" + ], + "system_config": { + "supported_quota_types": [ + "free" + ], + "quota_unit": "tokens" + }, + "model_flexibility": "fixed", + "price_config": { + "chatglm_pro": { + "prompt": "0.01", + "completion": "0.01", + "unit": "0.001", + "currency": "RMB" + }, + "chatglm_std": { + "prompt": "0.005", + "completion": "0.005", + "unit": "0.001", + "currency": "RMB" + }, + "chatglm_lite": { + "prompt": "0.002", + "completion": "0.002", + "unit": "0.001", + "currency": "RMB" + }, + "chatglm_lite_32k": { + "prompt": "0.0004", + "completion": "0.0004", + "unit": "0.001", + "currency": "RMB" + }, + "text_embedding": { + "completion": "0", + "unit": "0.001", + "currency": "RMB" + } + } +} \ No newline at end of file diff --git a/api/core/third_party/langchain/embeddings/zhipuai_embedding.py b/api/core/third_party/langchain/embeddings/zhipuai_embedding.py new file mode 100644 index 0000000000..f4ec88fe6b --- /dev/null +++ b/api/core/third_party/langchain/embeddings/zhipuai_embedding.py @@ -0,0 +1,64 @@ +"""Wrapper around ZhipuAI embedding models.""" +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + +from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI + + +class ZhipuAIEmbeddings(BaseModel, Embeddings): + """Wrapper around ZhipuAI embedding models. + 1024 dimensions. + """ + + client: Any #: :meta private: + model: str + """Model name to use.""" + + base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" + api_key: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["api_key"] = get_from_dict_or_env( + values, "api_key", "ZHIPUAI_API_KEY" + ) + values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url']) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to ZhipuAI's embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings = [] + for text in texts: + response = self.client.invoke(model=self.model, prompt=text) + data = response["data"] + embeddings.append(data.get('embedding')) + + return [list(map(float, e)) for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Call out to ZhipuAI's embedding endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/api/core/third_party/langchain/llms/zhipuai_llm.py b/api/core/third_party/langchain/llms/zhipuai_llm.py new file mode 100644 index 0000000000..06016f8c1c --- /dev/null +++ b/api/core/third_party/langchain/llms/zhipuai_llm.py @@ -0,0 +1,315 @@ +"""Wrapper around ZhipuAI APIs.""" +from __future__ import annotations + +import json +import logging +import posixpath +from typing import ( + Any, + Dict, + List, + Optional, Iterator, Sequence, +) + +import zhipuai +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage +from langchain.schema.messages import AIMessageChunk +from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration +from pydantic import Extra, root_validator, BaseModel + +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) +from langchain.utils import get_from_dict_or_env +from zhipuai.model_api.api import InvokeType +from zhipuai.utils import jwt_token +from zhipuai.utils.http_client import post, stream +from zhipuai.utils.sse_client import SSEClient + +logger = logging.getLogger(__name__) + + +class ZhipuModelAPI(BaseModel): + base_url: str + api_key: str + api_timeout_seconds = 60 + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def invoke(self, **kwargs): + url = self._build_api_url(kwargs, InvokeType.SYNC) + response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds) + if not response['success']: + raise ValueError( + f"Error Code: {response['code']}, Message: {response['msg']} " + ) + return response + + def sse_invoke(self, **kwargs): + url = self._build_api_url(kwargs, InvokeType.SSE) + data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds) + return SSEClient(data) + + def _build_api_url(self, kwargs, *path): + if kwargs: + if "model" not in kwargs: + raise Exception("model param missed") + model = kwargs.pop("model") + else: + model = "-" + + return posixpath.join(self.base_url, model, *path) + + def _generate_token(self): + if not self.api_key: + raise Exception( + "api_key not provided, you could provide it." + ) + + try: + return jwt_token.generate_token(self.api_key) + except Exception: + raise ValueError( + f"Your api_key is invalid, please check it." + ) + + +class ZhipuAIChatLLM(BaseChatModel): + """Wrapper around ZhipuAI large language models. + To use, you should pass the api_key as a named parameter to the constructor. + Example: + .. code-block:: python + from core.third_party.langchain.llms.zhipuai import ZhipuAI + model = ZhipuAI(model="", api_key="my-api-key") + """ + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"api_key": "API_KEY"} + + @property + def lc_serializable(self) -> bool: + return True + + client: Any = None #: :meta private: + model: str = "chatglm_lite" + """Model name to use.""" + temperature: float = 0.95 + """A non-negative float that tunes the degree of randomness in generation.""" + top_p: float = 0.7 + """Total probability mass of tokens to consider at each step.""" + streaming: bool = False + """Whether to stream the response or return it all at once.""" + api_key: Optional[str] = None + + base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["api_key"] = get_from_dict_or_env( + values, "api_key", "ZHIPUAI_API_KEY" + ) + + if 'test' in values['base_url']: + values['model'] = 'chatglm_130b_test' + + values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url']) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return self._default_params + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "zhipuai" + + def _convert_message_to_dict(self, message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "user", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + return AIMessage(content=_dict["content"]) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + def _create_message_dicts( + self, messages: List[BaseMessage] + ) -> List[Dict[str, Any]]: + dict_messages = [] + for m in messages: + message = self._convert_message_to_dict(m) + if dict_messages: + previous_message = dict_messages[-1] + if previous_message['role'] == message['role']: + dict_messages[-1]['content'] += f"\n{message['content']}" + else: + dict_messages.append(message) + else: + dict_messages.append(message) + + return dict_messages + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + llm_output: Optional[Dict] = None + for chunk in self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if chunk.generation_info is not None \ + and 'token_usage' in chunk.generation_info: + llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} + continue + + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation], llm_output=llm_output) + else: + message_dicts = self._create_message_dicts(messages) + request = self._default_params + request["prompt"] = message_dicts + request.update(kwargs) + response = self.client.invoke(**request) + return self._create_chat_result(response) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts = self._create_message_dicts(messages) + request = self._default_params + request["prompt"] = message_dicts + request.update(kwargs) + + for event in self.client.sse_invoke(incremental=True, **request).events(): + if event.event == "add": + yield ChatGenerationChunk(message=AIMessageChunk(content=event.data)) + if run_manager: + run_manager.on_llm_new_token(event.data) + elif event.event == "error" or event.event == "interrupted": + raise ValueError( + f"{event.data}" + ) + elif event.event == "finish": + meta = json.loads(event.meta) + token_usage = meta['usage'] + if token_usage is not None: + if 'prompt_tokens' not in token_usage: + token_usage['prompt_tokens'] = 0 + if 'completion_tokens' not in token_usage: + token_usage['completion_tokens'] = token_usage['total_tokens'] + + yield ChatGenerationChunk( + message=AIMessageChunk(content=event.data), + generation_info=dict({'token_usage': token_usage}) + ) + + def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: + data = response["data"] + generations = [] + for res in data["choices"]: + message = self._convert_dict_to_message(res) + gen = ChatGeneration( + message=message + ) + generations.append(gen) + token_usage = data.get("usage") + if token_usage is not None: + if 'prompt_tokens' not in token_usage: + token_usage['prompt_tokens'] = 0 + if 'completion_tokens' not in token_usage: + token_usage['completion_tokens'] = token_usage['total_tokens'] + + llm_output = {"token_usage": token_usage, "model_name": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + # def get_token_ids(self, text: str) -> List[int]: + # """Return the ordered ids of the tokens in a text. + # + # Args: + # text: The string input to tokenize. + # + # Returns: + # A list of ids corresponding to the tokens in the text, in order they occur + # in the text. + # """ + # from core.third_party.transformers.Token import ChatGLMTokenizer + # + # tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b") + # return tokenizer.encode(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input will fit in a model's context window. + + Args: + messages: The message inputs to tokenize. + + Returns: + The sum of the number of tokens across the messages. + """ + return sum([self.get_num_tokens(m.content) for m in messages]) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model} diff --git a/api/requirements.txt b/api/requirements.txt index 26bdce61da..f2db4e14a8 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -50,4 +50,5 @@ transformers~=4.31.0 stripe~=5.5.0 pandas==1.5.3 xinference==0.4.2 -safetensors==0.3.2 \ No newline at end of file +safetensors==0.3.2 +zhipuai==1.0.7 diff --git a/api/services/provider_service.py b/api/services/provider_service.py index a0eed4d273..34064d0c33 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -548,7 +548,7 @@ class ProviderService: 'result': 'success' } - def free_quota_qualification_verify(self, tenant_id: str, provider_name: str): + def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") api_url = api_base_url + '/api/v1/providers/qualification-verify' @@ -557,8 +557,11 @@ class ProviderService: 'Content-Type': 'application/json', 'Authorization': f"Bearer {api_key}" } + json_data = {'workspace_id': tenant_id, 'provider_name': provider_name} + if token: + json_data['token'] = token response = requests.post(api_url, headers=headers, - json={'workspace_id': tenant_id, 'provider_name': provider_name}) + json=json_data) if not response.ok: logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 26dd38f472..3400cfaddb 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -31,6 +31,9 @@ TONGYI_DASHSCOPE_API_KEY= WENXIN_API_KEY= WENXIN_SECRET_KEY= +# ZhipuAI Credentials +ZHIPUAI_API_KEY= + # ChatGLM Credentials CHATGLM_API_BASE= diff --git a/api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py b/api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py new file mode 100644 index 0000000000..be2898402c --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_zhipuai_embedding.py @@ -0,0 +1,50 @@ +import json +import os +from unittest.mock import patch + +from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding +from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='zhipuai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'api_key': valid_api_key + }), + is_valid=True, + ) + + +def get_mock_embedding_model(): + model_name = 'text_embedding' + valid_api_key = os.environ['ZHIPUAI_API_KEY'] + provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key)) + return ZhipuAIEmbedding( + model_provider=provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embedding(mock_decrypt): + embedding_model = get_mock_embedding_model() + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 1024 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_doc_embedding(mock_decrypt): + embedding_model = get_mock_embedding_model() + rst = embedding_model.client.embed_documents(['test', 'test2']) + assert isinstance(rst, list) + assert len(rst[0]) == 1024 diff --git a/api/tests/integration_tests/models/llm/test_zhipuai_model.py b/api/tests/integration_tests/models/llm/test_zhipuai_model.py new file mode 100644 index 0000000000..4bc47bec9b --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_zhipuai_model.py @@ -0,0 +1,79 @@ +import json +import os +from unittest.mock import patch + + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel +from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='zhipuai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'api_key': valid_api_key + }), + is_valid=True, + ) + + +def get_mock_model(model_name: str, streaming: bool = False): + model_kwargs = ModelKwargs( + temperature=0.01, + ) + valid_api_key = os.environ['ZHIPUAI_API_KEY'] + model_provider = ZhipuAIProvider(provider=get_mock_provider(valid_api_key)) + return ZhipuAIModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs, + streaming=streaming + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_get_num_tokens(mock_decrypt): + model = get_mock_model('chatglm_lite') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst > 0 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + model = get_mock_model('chatglm_lite') + messages = [ + PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + ] + rst = model.run( + messages, + ) + assert len(rst.content) > 0 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_stream_run(mock_decrypt, mocker): + mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) + + model = get_mock_model('chatglm_lite', streaming=True) + messages = [ + PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + ] + rst = model.run( + messages + ) + assert len(rst.content) > 0 diff --git a/api/tests/unit_tests/model_providers/test_spark_provider.py b/api/tests/unit_tests/model_providers/test_spark_provider.py index 7193221f1d..c9e9c58321 100644 --- a/api/tests/unit_tests/model_providers/test_spark_provider.py +++ b/api/tests/unit_tests/model_providers/test_spark_provider.py @@ -39,7 +39,7 @@ def test_is_provider_credentials_valid_or_raise_invalid(): MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) credential = VALIDATE_CREDENTIAL.copy() - credential['api_key'] = 'invalid_key' + del credential['api_key'] # raise CredentialsValidateFailedError if api_key is invalid with pytest.raises(CredentialsValidateFailedError): diff --git a/api/tests/unit_tests/model_providers/test_zhipuai_provider.py b/api/tests/unit_tests/model_providers/test_zhipuai_provider.py new file mode 100644 index 0000000000..7f9a43a3d7 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_zhipuai_provider.py @@ -0,0 +1,88 @@ +import pytest +from unittest.mock import patch +import json + +from langchain.schema import ChatResult, ChatGeneration, AIMessage + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'zhipuai' +MODEL_PROVIDER_CLASS = ZhipuAIProvider +VALIDATE_CREDENTIAL = { + 'api_key': 'valid_key', +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('core.third_party.langchain.llms.zhipuai_llm.ZhipuAIChatLLM._generate', + return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['api_key'] = 'invalid_key' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['api_key'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['api_key'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) + assert all(char == '*' for char in middle_token)