From 2c30d19cbe19f9491f16098b5d91f6d818e411bb Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 24 Aug 2023 10:22:36 +0800 Subject: [PATCH] feat: add baichuan prompt (#985) --- api/core/completion.py | 130 ++-------------- api/core/model_providers/models/llm/base.py | 142 ++++++++++++++++-- .../models/llm/huggingface_hub_model.py | 9 ++ .../models/llm/openllm_model.py | 9 ++ .../models/llm/xinference_model.py | 9 ++ .../generate_prompts/baichuan_chat.json | 13 ++ .../generate_prompts/baichuan_completion.json | 9 ++ .../prompt/generate_prompts/common_chat.json | 13 ++ .../generate_prompts/common_completion.json | 9 ++ 9 files changed, 213 insertions(+), 130 deletions(-) create mode 100644 api/core/prompt/generate_prompts/baichuan_chat.json create mode 100644 api/core/prompt/generate_prompts/baichuan_completion.json create mode 100644 api/core/prompt/generate_prompts/common_chat.json create mode 100644 api/core/prompt/generate_prompts/common_completion.json diff --git a/api/core/completion.py b/api/core/completion.py index 4614f8084d..192b16f4f9 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -130,13 +130,12 @@ class Completion: fake_response = agent_execute_result.output # get llm prompt - prompt_messages, stop_words = cls.get_main_llm_prompt( + prompt_messages, stop_words = model_instance.get_prompt( mode=mode, - model=app_model_config.model_dict, pre_prompt=app_model_config.pre_prompt, - query=query, inputs=inputs, - agent_execute_result=agent_execute_result, + query=query, + context=agent_execute_result.output if agent_execute_result else None, memory=memory ) @@ -154,113 +153,6 @@ class Completion: return response - @classmethod - def get_main_llm_prompt(cls, mode: str, model: dict, - pre_prompt: str, query: str, inputs: dict, - agent_execute_result: Optional[AgentExecuteResult], - memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ - Tuple[List[PromptMessage], Optional[List[str]]]: - if mode == 'completion': - prompt_template = JinjaPromptTemplate.from_template( - template=("""Use the following context as your learned knowledge, inside XML tags. - - -{{context}} - - -When answer to user: -- If you don't know, just say that you don't know. -- If you don't know when you are not sure, ask for clarification. -Avoid mentioning that you obtained the information from the context. -And answer according to the language of the user's question. -""" if agent_execute_result else "") - + (pre_prompt + "\n" if pre_prompt else "") - + "{{query}}\n" - ) - - if agent_execute_result: - inputs['context'] = agent_execute_result.output - - prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} - prompt_content = prompt_template.format( - query=query, - **prompt_inputs - ) - - return [PromptMessage(content=prompt_content)], None - else: - messages: List[BaseMessage] = [] - - human_inputs = { - "query": query - } - - human_message_prompt = "" - - if pre_prompt: - pre_prompt_inputs = {k: inputs[k] for k in - JinjaPromptTemplate.from_template(template=pre_prompt).input_variables - if k in inputs} - - if pre_prompt_inputs: - human_inputs.update(pre_prompt_inputs) - - if agent_execute_result: - human_inputs['context'] = agent_execute_result.output - human_message_prompt += """Use the following context as your learned knowledge, inside XML tags. - - -{{context}} - - -When answer to user: -- If you don't know, just say that you don't know. -- If you don't know when you are not sure, ask for clarification. -Avoid mentioning that you obtained the information from the context. -And answer according to the language of the user's question. -""" - - if pre_prompt: - human_message_prompt += pre_prompt - - query_prompt = "\n\nHuman: {{query}}\n\nAssistant: " - - if memory: - # append chat histories - tmp_human_message = PromptBuilder.to_human_message( - prompt_content=human_message_prompt + query_prompt, - inputs=human_inputs - ) - - if memory.model_instance.model_rules.max_tokens.max: - curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message])) - max_tokens = model.get("completion_params").get('max_tokens') - rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - else: - rest_tokens = 2000 - - histories = cls.get_history_messages_from_memory(memory, rest_tokens) - human_message_prompt += "\n\n" if human_message_prompt else "" - human_message_prompt += "Here is the chat histories between human and assistant, " \ - "inside XML tags.\n\n\n" - human_message_prompt += histories + "\n" - - human_message_prompt += query_prompt - - # construct main prompt - human_message = PromptBuilder.to_human_message( - prompt_content=human_message_prompt, - inputs=human_inputs - ) - - messages.append(human_message) - - for message in messages: - message.content = re.sub(r'<\|.*?\|>', '', message.content) - - return to_prompt_messages(messages), ['\nHuman:', ''] - @classmethod def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, max_token_limit: int) -> str: @@ -307,13 +199,12 @@ And answer according to the language of the user's question. max_tokens = 0 # get prompt without memory and context - prompt_messages, _ = cls.get_main_llm_prompt( + prompt_messages, _ = model_instance.get_prompt( mode=mode, - model=app_model_config.model_dict, pre_prompt=app_model_config.pre_prompt, - query=query, inputs=inputs, - agent_execute_result=None, + query=query, + context=None, memory=None ) @@ -358,13 +249,12 @@ And answer according to the language of the user's question. ) # get llm prompt - old_prompt_messages, _ = cls.get_main_llm_prompt( - mode="completion", - model=app_model_config.model_dict, + old_prompt_messages, _ = final_model_instance.get_prompt( + mode='completion', pre_prompt=pre_prompt, - query=message.query, inputs=message.inputs, - agent_execute_result=None, + query=message.query, + context=None, memory=None ) diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index fe92407069..8662e73275 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -1,17 +1,24 @@ +import json +import os +import re from abc import abstractmethod -from typing import List, Optional, Any, Union +from typing import List, Optional, Any, Union, Tuple import decimal from langchain.callbacks.manager import Callbacks +from langchain.memory.chat_memory import BaseChatMemory from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult +from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.providers.base import BaseModelProvider +from core.prompt.prompt_builder import PromptBuilder +from core.prompt.prompt_template import JinjaPromptTemplate from core.third_party.langchain.llms.fake import FakeLLM import logging + logger = logging.getLogger(__name__) @@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel): def price_config(self) -> dict: def get_or_default(): default_price_config = { - 'prompt': decimal.Decimal('0'), - 'completion': decimal.Decimal('0'), - 'unit': decimal.Decimal('0'), - 'currency': 'USD' - } + 'prompt': decimal.Decimal('0'), + 'completion': decimal.Decimal('0'), + 'unit': decimal.Decimal('0'), + 'currency': 'USD' + } rules = self.model_provider.get_rules() - price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config + price_config = rules['price_config'][ + self.base_model_name] if 'price_config' in rules else default_price_config price_config = { 'prompt': decimal.Decimal(price_config['prompt']), 'completion': decimal.Decimal(price_config['completion']), @@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel): 'currency': price_config['currency'] } return price_config - + self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default() logger.debug(f"model: {self.name} price_config: {self._price_config}") @@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel): total_tokens = result.llm_output['token_usage']['total_tokens'] else: prompt_tokens = self.get_num_tokens(messages) - completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) + completion_tokens = self.get_num_tokens( + [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) total_tokens = prompt_tokens + completion_tokens self.model_provider.update_last_used() @@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel): def support_streaming(cls): return False + def get_prompt(self, mode: str, + pre_prompt: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory]) -> \ + Tuple[List[PromptMessage], Optional[List[str]]]: + prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode)) + prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory) + return [PromptMessage(content=prompt)], stops + + def prompt_file_name(self, mode: str) -> str: + if mode == 'completion': + return 'common_completion' + else: + return 'common_chat' + + def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]: + context_prompt_content = '' + if context and 'context_prompt' in prompt_rules: + prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt']) + context_prompt_content = prompt_template.format( + context=context + ) + + pre_prompt_content = '' + if pre_prompt: + prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} + pre_prompt_content = prompt_template.format( + **prompt_inputs + ) + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt': + prompt += context_prompt_content + elif order == 'pre_prompt': + prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else '' + + query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' + + if memory and 'histories_prompt' in prompt_rules: + # append chat histories + tmp_human_message = PromptBuilder.to_human_message( + prompt_content=prompt + query_prompt, + inputs={ + 'query': query + } + ) + + if self.model_rules.max_tokens.max: + curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message])) + max_tokens = self.model_kwargs.max_tokens + rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + else: + rest_tokens = 2000 + + memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human' + memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + + histories = self._get_history_messages_from_memory(memory, rest_tokens) + prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt']) + histories_prompt_content = prompt_template.format( + histories=histories + ) + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt': + prompt += context_prompt_content + elif order == 'pre_prompt': + prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' + elif order == 'histories_prompt': + prompt += histories_prompt_content + + prompt_template = JinjaPromptTemplate.from_template(template=query_prompt) + query_prompt_content = prompt_template.format( + query=query + ) + + prompt += query_prompt_content + + prompt = re.sub(r'<\|.*?\|>', '', prompt) + + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + return prompt, stops + + def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: + # Get the absolute path of the subdirectory + prompt_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), + 'prompt/generate_prompts') + + json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') + # Open the JSON file and read its content + with open(json_file_path, 'r') as json_file: + return json.load(json_file) + + def _get_history_messages_from_memory(self, memory: BaseChatMemory, + max_token_limit: int) -> str: + """Get memory messages.""" + memory.max_token_limit = max_token_limit + memory_key = memory.memory_variables[0] + external_context = memory.load_memory_variables({}) + return external_context[memory_key] + def _get_prompt_from_messages(self, messages: List[PromptMessage], model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: if not model_mode: diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py index 7e800e3fea..fb381bf64d 100644 --- a/api/core/model_providers/models/llm/huggingface_hub_model.py +++ b/api/core/model_providers/models/llm/huggingface_hub_model.py @@ -60,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return self._client.get_num_tokens(prompts) + def prompt_file_name(self, mode: str) -> str: + if 'baichuan' in self.name.lower(): + if mode == 'completion': + return 'baichuan_completion' + else: + return 'baichuan_chat' + else: + return super().prompt_file_name(mode) + def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) self.client.model_kwargs = provider_model_kwargs diff --git a/api/core/model_providers/models/llm/openllm_model.py b/api/core/model_providers/models/llm/openllm_model.py index 2f9876a92b..217d893c48 100644 --- a/api/core/model_providers/models/llm/openllm_model.py +++ b/api/core/model_providers/models/llm/openllm_model.py @@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) + def prompt_file_name(self, mode: str) -> str: + if 'baichuan' in self.name.lower(): + if mode == 'completion': + return 'baichuan_completion' + else: + return 'baichuan_chat' + else: + return super().prompt_file_name(mode) + def _set_model_kwargs(self, model_kwargs: ModelKwargs): pass diff --git a/api/core/model_providers/models/llm/xinference_model.py b/api/core/model_providers/models/llm/xinference_model.py index 8af6356c99..a058a601b1 100644 --- a/api/core/model_providers/models/llm/xinference_model.py +++ b/api/core/model_providers/models/llm/xinference_model.py @@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM): prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) + def prompt_file_name(self, mode: str) -> str: + if 'baichuan' in self.name.lower(): + if mode == 'completion': + return 'baichuan_completion' + else: + return 'baichuan_chat' + else: + return super().prompt_file_name(mode) + def _set_model_kwargs(self, model_kwargs: ModelKwargs): pass diff --git a/api/core/prompt/generate_prompts/baichuan_chat.json b/api/core/prompt/generate_prompts/baichuan_chat.json new file mode 100644 index 0000000000..93b9b72544 --- /dev/null +++ b/api/core/prompt/generate_prompts/baichuan_chat.json @@ -0,0 +1,13 @@ +{ + "human_prefix": "用户", + "assistant_prefix": "助手", + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n\n", + "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n", + "system_prompt_orders": [ + "context_prompt", + "pre_prompt", + "histories_prompt" + ], + "query_prompt": "用户:{{query}}\n助手:", + "stops": ["用户:"] +} \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/baichuan_completion.json b/api/core/prompt/generate_prompts/baichuan_completion.json new file mode 100644 index 0000000000..bdc7cf976c --- /dev/null +++ b/api/core/prompt/generate_prompts/baichuan_completion.json @@ -0,0 +1,9 @@ +{ + "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n", + "system_prompt_orders": [ + "context_prompt", + "pre_prompt" + ], + "query_prompt": "{{query}}", + "stops": null +} \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/common_chat.json b/api/core/prompt/generate_prompts/common_chat.json new file mode 100644 index 0000000000..baa000a7a2 --- /dev/null +++ b/api/core/prompt/generate_prompts/common_chat.json @@ -0,0 +1,13 @@ +{ + "human_prefix": "Human", + "assistant_prefix": "Assistant", + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "histories_prompt": "Here is the chat histories between human and assistant, inside XML tags.\n\n\n{{histories}}\n\n\n", + "system_prompt_orders": [ + "context_prompt", + "pre_prompt", + "histories_prompt" + ], + "query_prompt": "Human: {{query}}\n\nAssistant: ", + "stops": ["\nHuman:", ""] +} \ No newline at end of file diff --git a/api/core/prompt/generate_prompts/common_completion.json b/api/core/prompt/generate_prompts/common_completion.json new file mode 100644 index 0000000000..9e7e8d68ef --- /dev/null +++ b/api/core/prompt/generate_prompts/common_completion.json @@ -0,0 +1,9 @@ +{ + "context_prompt": "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{context}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n", + "system_prompt_orders": [ + "context_prompt", + "pre_prompt" + ], + "query_prompt": "{{query}}", + "stops": null +} \ No newline at end of file