From 7c9b585a47f9de0e08a9836582e439f9a5bac68f Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 18 Oct 2023 15:35:24 +0800 Subject: [PATCH 1/5] feat: support weixin ernie-bot-4 and chat mode (#1375) --- .../models/llm/wenxin_model.py | 19 +- .../providers/wenxin_provider.py | 18 +- api/core/model_providers/rules/wenxin.json | 6 + api/core/third_party/langchain/llms/wenxin.py | 198 ++++++++++++------ .../models/llm/test_wenxin_model.py | 5 +- .../model_providers/test_wenxin_provider.py | 5 +- 6 files changed, 174 insertions(+), 77 deletions(-) diff --git a/api/core/model_providers/models/llm/wenxin_model.py b/api/core/model_providers/models/llm/wenxin_model.py index 00ddbb82ddd91f..c912ffb5d32c74 100644 --- a/api/core/model_providers/models/llm/wenxin_model.py +++ b/api/core/model_providers/models/llm/wenxin_model.py @@ -6,17 +6,16 @@ from core.model_providers.error import LLMBadRequestError from core.model_providers.models.llm.base import BaseLLM -from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.third_party.langchain.llms.wenxin import Wenxin class WenxinModel(BaseLLM): - model_mode: ModelMode = ModelMode.COMPLETION + model_mode: ModelMode = ModelMode.CHAT def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) - # TODO load price_config from configs(db) return Wenxin( model=self.name, streaming=self.streaming, @@ -38,7 +37,13 @@ def _run(self, messages: List[PromptMessage], :return: """ prompts = self._get_prompt_from_messages(messages) - return self._client.generate([prompts], stop, callbacks) + + generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]} + + if 'functions' in kwargs: + generate_kwargs['functions'] = kwargs['functions'] + + return self._client.generate(**generate_kwargs) def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ @@ -48,7 +53,7 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int: :return: """ prompts = self._get_prompt_from_messages(messages) - return max(self._client.get_num_tokens(prompts), 0) + return max(self._client.get_num_tokens_from_messages(prompts), 0) def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) @@ -58,3 +63,7 @@ def _set_model_kwargs(self, model_kwargs: ModelKwargs): def handle_exceptions(self, ex: Exception) -> Exception: return LLMBadRequestError(f"Wenxin: {str(ex)}") + + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/providers/wenxin_provider.py b/api/core/model_providers/providers/wenxin_provider.py index e729358c0a99d7..650acffcc4567c 100644 --- a/api/core/model_providers/providers/wenxin_provider.py +++ b/api/core/model_providers/providers/wenxin_provider.py @@ -2,6 +2,8 @@ from json import JSONDecodeError from typing import Type +from langchain.schema import HumanMessage + from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode @@ -23,20 +25,25 @@ def provider_name(self): def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: if model_type == ModelType.TEXT_GENERATION: return [ + { + 'id': 'ernie-bot-4', + 'name': 'ERNIE-Bot-4', + 'mode': ModelMode.CHAT.value, + }, { 'id': 'ernie-bot', 'name': 'ERNIE-Bot', - 'mode': ModelMode.COMPLETION.value, + 'mode': ModelMode.CHAT.value, }, { 'id': 'ernie-bot-turbo', 'name': 'ERNIE-Bot-turbo', - 'mode': ModelMode.COMPLETION.value, + 'mode': ModelMode.CHAT.value, }, { 'id': 'bloomz-7b', 'name': 'BLOOMZ-7B', - 'mode': ModelMode.COMPLETION.value, + 'mode': ModelMode.CHAT.value, } ] else: @@ -68,11 +75,12 @@ def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> M :return: """ model_max_tokens = { + 'ernie-bot-4': 4800, 'ernie-bot': 4800, 'ernie-bot-turbo': 11200, } - if model_name in ['ernie-bot', 'ernie-bot-turbo']: + if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']: return ModelKwargsRules( temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2), top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2), @@ -111,7 +119,7 @@ def is_provider_credentials_valid_or_raise(cls, credentials: dict): **credential_kwargs ) - llm("ping") + llm([HumanMessage(content='ping')]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_providers/rules/wenxin.json b/api/core/model_providers/rules/wenxin.json index e5c136d32625d8..dbb692fb42c443 100644 --- a/api/core/model_providers/rules/wenxin.json +++ b/api/core/model_providers/rules/wenxin.json @@ -5,6 +5,12 @@ "system_config": null, "model_flexibility": "fixed", "price_config": { + "ernie-bot-4": { + "prompt": "0", + "completion": "0", + "unit": "0.001", + "currency": "RMB" + }, "ernie-bot": { "prompt": "0.012", "completion": "0.012", diff --git a/api/core/third_party/langchain/llms/wenxin.py b/api/core/third_party/langchain/llms/wenxin.py index 7b31b676241c31..d6aee116c7991d 100644 --- a/api/core/third_party/langchain/llms/wenxin.py +++ b/api/core/third_party/langchain/llms/wenxin.py @@ -8,12 +8,15 @@ Any, Dict, List, - Optional, Iterator, + Optional, Iterator, Tuple, ) import requests +from langchain.chat_models.base import BaseChatModel from langchain.llms.utils import enforce_stop_tokens -from langchain.schema.output import GenerationChunk +from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage +from langchain.schema.messages import AIMessageChunk +from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator from langchain.callbacks.manager import ( @@ -61,6 +64,7 @@ def post(self, request: dict) -> Any: raise ValueError(f"Wenxin Model name is required") model_url_map = { + 'ernie-bot-4': 'completions_pro', 'ernie-bot': 'completions', 'ernie-bot-turbo': 'eb-instant', 'bloomz-7b': 'bloomz_7b1', @@ -70,6 +74,7 @@ def post(self, request: dict) -> Any: access_token = self.get_access_token() api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}" + del request['model'] headers = {"Content-Type": "application/json"} response = requests.post(api_url, @@ -86,22 +91,21 @@ def post(self, request: dict) -> Any: f"Wenxin API {json_response['error_code']}" f" error: {json_response['error_msg']}" ) - return json_response["result"] + return json_response else: return response -class Wenxin(LLM): - """Wrapper around Wenxin large language models. - To use, you should have the environment variable - ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key, - or pass them as a named parameter to the constructor. - Example: - .. code-block:: python - from langchain.llms.wenxin import Wenxin - wenxin = Wenxin(model="", api_key="my-api-key", - secret_key="my-group-id") - """ +class Wenxin(BaseChatModel): + """Wrapper around Wenxin large language models.""" + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"} + + @property + def lc_serializable(self) -> bool: + return True _client: _WenxinEndpointClient = PrivateAttr() model: str = "ernie-bot" @@ -161,64 +165,89 @@ def __init__(self, **data: Any): secret_key=self.secret_key, ) - def _call( + def _convert_message_to_dict(self, message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + def _create_message_dicts( + self, messages: List[BaseMessage] + ) -> Tuple[List[Dict[str, Any]], str]: + dict_messages = [] + system = None + for m in messages: + message = self._convert_message_to_dict(m) + if message['role'] == 'system': + if not system: + system = message['content'] + else: + system += f"\n{message['content']}" + continue + + if dict_messages: + previous_message = dict_messages[-1] + if previous_message['role'] == message['role']: + dict_messages[-1]['content'] += f"\n{message['content']}" + else: + dict_messages.append(message) + else: + dict_messages.append(message) + + return dict_messages, system + + def _generate( self, - prompt: str, + messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - r"""Call out to Wenxin's completion endpoint to chat - Args: - prompt: The prompt to pass into the model. - Returns: - The string generated by the model. - Example: - .. code-block:: python - response = wenxin("Tell me a joke.") - """ + ) -> ChatResult: if self.streaming: - completion = "" + generation: Optional[ChatGenerationChunk] = None + llm_output: Optional[Dict] = None for chunk in self._stream( - prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + messages=messages, stop=stop, run_manager=run_manager, **kwargs ): - completion += chunk.text + if chunk.generation_info is not None \ + and 'token_usage' in chunk.generation_info: + llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} + + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return ChatResult(generations=[generation], llm_output=llm_output) else: + message_dicts, system = self._create_message_dicts(messages) request = self._default_params - request["messages"] = [{"role": "user", "content": prompt}] + request["messages"] = message_dicts + if system: + request["system"] = system request.update(kwargs) - completion = self._client.post(request) - - if stop is not None: - completion = enforce_stop_tokens(completion, stop) - - return completion + response = self._client.post(request) + return self._create_chat_result(response) def _stream( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[GenerationChunk]: - r"""Call wenxin completion_stream and return the resulting generator. - - Args: - prompt: The prompt to pass into the model. - stop: Optional list of stop words to use when generating. - Returns: - A generator representing the stream of tokens from Wenxin. - Example: - .. code-block:: python - - prompt = "Write a poem about a stream." - prompt = f"\n\nHuman: {prompt}\n\nAssistant:" - generator = wenxin.stream(prompt) - for token in generator: - yield token - """ + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, system = self._create_message_dicts(messages) request = self._default_params - request["messages"] = [{"role": "user", "content": prompt}] + request["messages"] = message_dicts + if system: + request["system"] = system request.update(kwargs) for token in self._client.post(request).iter_lines(): @@ -228,12 +257,18 @@ def _stream( if token.startswith('data:'): completion = json.loads(token[5:]) - yield GenerationChunk(text=completion['result']) - if run_manager: - run_manager.on_llm_new_token(completion['result']) + chunk_dict = { + 'message': AIMessageChunk(content=completion['result']), + } if completion['is_end']: - break + token_usage = completion['usage'] + token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens'] + chunk_dict['generation_info'] = dict({'token_usage': token_usage}) + + yield ChatGenerationChunk(**chunk_dict) + if run_manager: + run_manager.on_llm_new_token(completion['result']) else: try: json_response = json.loads(token) @@ -245,3 +280,40 @@ def _stream( f" error: {json_response['error_msg']}, " f"please confirm if the model you have chosen is already paid for." ) + + def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult: + generations = [ChatGeneration( + message=AIMessage(content=response['result']), + )] + token_usage = response.get("usage") + token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens'] + + llm_output = {"token_usage": token_usage, "model_name": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input will fit in a model's context window. + + Args: + messages: The message inputs to tokenize. + + Returns: + The sum of the number of tokens across the messages. + """ + return sum([self.get_num_tokens(m.content) for m in messages]) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage, "model_name": self.model} diff --git a/api/tests/integration_tests/models/llm/test_wenxin_model.py b/api/tests/integration_tests/models/llm/test_wenxin_model.py index 8cc4779160b660..9378c620d802e3 100644 --- a/api/tests/integration_tests/models/llm/test_wenxin_model.py +++ b/api/tests/integration_tests/models/llm/test_wenxin_model.py @@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker): mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None) model = get_mock_model('ernie-bot') - messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] + messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')] rst = model.run( - messages, - stop=['\nHuman:'], + messages ) assert len(rst.content) > 0 diff --git a/api/tests/unit_tests/model_providers/test_wenxin_provider.py b/api/tests/unit_tests/model_providers/test_wenxin_provider.py index 9f714bb6d39746..5a7f8dab1910d3 100644 --- a/api/tests/unit_tests/model_providers/test_wenxin_provider.py +++ b/api/tests/unit_tests/model_providers/test_wenxin_provider.py @@ -2,6 +2,8 @@ from unittest.mock import patch import json +from langchain.schema import AIMessage, ChatGeneration, ChatResult + from core.model_providers.providers.base import CredentialsValidateFailedError from core.model_providers.providers.wenxin_provider import WenxinProvider from models.provider import ProviderType, Provider @@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key): def test_is_provider_credentials_valid_or_raise_valid(mocker): - mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc") + mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate', + return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))])) MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) From d14f15863d3f15395484e6d19f0aded98be1324e Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Wed, 18 Oct 2023 16:00:56 +0800 Subject: [PATCH 2/5] fix: i18n runtime error (#1376) --- web/app/components/i18n-server.tsx | 2 -- web/app/components/i18n.tsx | 13 ++++++++----- web/i18n/i18next-config.ts | 5 +---- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/web/app/components/i18n-server.tsx b/web/app/components/i18n-server.tsx index 5707fa82c59a85..39a5f7ca1d1096 100644 --- a/web/app/components/i18n-server.tsx +++ b/web/app/components/i18n-server.tsx @@ -4,12 +4,10 @@ import { ToastProvider } from './base/toast' import { getDictionary, getLocaleOnServer } from '@/i18n/server' export type II18NServerProps = { - // locale: Locale children: React.ReactNode } const I18NServer = async ({ - // locale, children, }: II18NServerProps) => { const locale = getLocaleOnServer() diff --git a/web/app/components/i18n.tsx b/web/app/components/i18n.tsx index e64f958b309541..32e1c930731bab 100644 --- a/web/app/components/i18n.tsx +++ b/web/app/components/i18n.tsx @@ -1,23 +1,26 @@ 'use client' import type { FC } from 'react' -import React from 'react' -import '@/i18n/i18next-config' +import React, { useEffect } from 'react' +import { changeLanguage } from '@/i18n/i18next-config' import I18NContext from '@/context/i18n' import type { Locale } from '@/i18n' -import { getLocaleOnClient, setLocaleOnClient } from '@/i18n/client' +import { setLocaleOnClient } from '@/i18n/client' export type II18nProps = { locale: Locale dictionary: Record children: React.ReactNode - setLocaleOnClient: (locale: Locale) => void } const I18n: FC = ({ + locale, dictionary, children, }) => { - const locale = getLocaleOnClient() + useEffect(() => { + changeLanguage(locale) + }, [locale]) + return ( Date: Wed, 18 Oct 2023 18:07:36 +0800 Subject: [PATCH 3/5] =?UTF-8?q?fix:=20app=20config=20zhipu=20chatglm=5Fstd?= =?UTF-8?q?=20model,=20but=20it=20still=20use=20chatglm=5Flit=E2=80=A6=20(?= =?UTF-8?q?#1377)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: wayne.wang --- api/core/model_providers/models/llm/zhipuai_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/model_providers/models/llm/zhipuai_model.py b/api/core/model_providers/models/llm/zhipuai_model.py index 7f32c1dc70bcdf..f94eb4e7298ea0 100644 --- a/api/core/model_providers/models/llm/zhipuai_model.py +++ b/api/core/model_providers/models/llm/zhipuai_model.py @@ -16,6 +16,7 @@ class ZhipuAIModel(BaseLLM): def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) return ZhipuAIChatLLM( + model=self.name, streaming=self.streaming, callbacks=self.callbacks, **self.credentials, From fe14130b3ce909ff99e39d1afea870486e1a2074 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Wed, 18 Oct 2023 20:02:52 +0800 Subject: [PATCH 4/5] refactor advanced prompt core. (#1350) --- api/core/completion.py | 44 ++- .../models/llm/baichuan_model.py | 6 - api/core/model_providers/models/llm/base.py | 218 +---------- .../models/llm/huggingface_hub_model.py | 9 - .../models/llm/openllm_model.py | 9 - .../models/llm/xinference_model.py | 9 - api/core/prompt/prompt_transform.py | 344 ++++++++++++++++++ .../advanced_prompt_template_service.py | 28 +- api/services/app_model_config_service.py | 10 +- 9 files changed, 405 insertions(+), 272 deletions(-) create mode 100644 api/core/prompt/prompt_transform.py diff --git a/api/core/completion.py b/api/core/completion.py index 768231a53d941a..57e18199271ccb 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -16,6 +16,7 @@ from core.model_providers.models.llm.base import BaseLLM from core.orchestrator_rule_parser import OrchestratorRuleParser from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompt_transform import PromptTransform from models.model import App, AppModelConfig, Account, Conversation, EndUser @@ -156,24 +157,28 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App conversation_message_task: ConversationMessageTask, memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], fake_response: Optional[str]): + prompt_transform = PromptTransform() + # get llm prompt if app_model_config.prompt_type == 'simple': - prompt_messages, stop_words = model_instance.get_prompt( + prompt_messages, stop_words = prompt_transform.get_prompt( mode=mode, pre_prompt=app_model_config.pre_prompt, inputs=inputs, query=query, context=agent_execute_result.output if agent_execute_result else None, - memory=memory + memory=memory, + model_instance=model_instance ) else: - prompt_messages = model_instance.get_advanced_prompt( + prompt_messages = prompt_transform.get_advanced_prompt( app_mode=mode, app_model_config=app_model_config, inputs=inputs, query=query, context=agent_execute_result.output if agent_execute_result else None, - memory=memory + memory=memory, + model_instance=model_instance ) model_config = app_model_config.model_dict @@ -238,15 +243,30 @@ def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_ if max_tokens is None: max_tokens = 0 + prompt_transform = PromptTransform() + prompt_messages = [] + # get prompt without memory and context - prompt_messages, _ = model_instance.get_prompt( - mode=mode, - pre_prompt=app_model_config.pre_prompt, - inputs=inputs, - query=query, - context=None, - memory=None - ) + if app_model_config.prompt_type == 'simple': + prompt_messages, _ = prompt_transform.get_prompt( + mode=mode, + pre_prompt=app_model_config.pre_prompt, + inputs=inputs, + query=query, + context=None, + memory=None, + model_instance=model_instance + ) + else: + prompt_messages = prompt_transform.get_advanced_prompt( + app_mode=mode, + app_model_config=app_model_config, + inputs=inputs, + query=query, + context=None, + memory=None, + model_instance=model_instance + ) prompt_tokens = model_instance.get_num_tokens(prompt_messages) rest_tokens = model_limited_tokens - max_tokens - prompt_tokens diff --git a/api/core/model_providers/models/llm/baichuan_model.py b/api/core/model_providers/models/llm/baichuan_model.py index d2aea36ccaa135..e614547fa3d517 100644 --- a/api/core/model_providers/models/llm/baichuan_model.py +++ b/api/core/model_providers/models/llm/baichuan_model.py @@ -37,12 +37,6 @@ def _run(self, messages: List[PromptMessage], prompts = self._get_prompt_from_messages(messages) return self._client.generate([prompts], stop, callbacks) - def prompt_file_name(self, mode: str) -> str: - if mode == 'completion': - return 'baichuan_completion' - else: - return 'baichuan_chat' - def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ get num tokens of prompt messages. diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 3a6e8b41ca7a52..41724dd54bfae0 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -1,28 +1,18 @@ -import json -import os -import re -import time from abc import abstractmethod -from typing import List, Optional, Any, Union, Tuple +from typing import List, Optional, Any, Union import decimal +import logging from langchain.callbacks.manager import Callbacks -from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration +from langchain.schema import LLMResult, BaseMessage, ChatGeneration from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.helper import moderation from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \ - to_lc_messages +from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.providers.base import BaseModelProvider -from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import PromptTemplateParser from core.third_party.langchain.llms.fake import FakeLLM -import logging - -from extensions.ext_database import db logger = logging.getLogger(__name__) @@ -320,206 +310,8 @@ def add_callbacks(self, callbacks: Callbacks): def support_streaming(self): return False - def get_prompt(self, mode: str, - pre_prompt: str, inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory]) -> \ - Tuple[List[PromptMessage], Optional[List[str]]]: - prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode)) - prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory) - return [PromptMessage(content=prompt)], stops - - def get_advanced_prompt(self, app_mode: str, - app_model_config: str, inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory]) -> List[PromptMessage]: - - model_mode = app_model_config.model_dict['mode'] - conversation_histories_role = {} - - raw_prompt_list = [] - prompt_messages = [] - - if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value: - prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text'] - raw_prompt_list = [{ - 'role': MessageType.USER.value, - 'text': prompt_text - }] - conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role'] - elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value: - raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] - elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value: - raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] - elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value: - prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text'] - raw_prompt_list = [{ - 'role': MessageType.USER.value, - 'text': prompt_text - }] - else: - raise Exception("app_mode or model_mode not support") - - for prompt_item in raw_prompt_list: - prompt = prompt_item['text'] - - # set prompt template variables - prompt_template = PromptTemplateParser(template=prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - - if '#context#' in prompt: - if context: - prompt_inputs['#context#'] = context - else: - prompt_inputs['#context#'] = '' - - if '#query#' in prompt: - if query: - prompt_inputs['#query#'] = query - else: - prompt_inputs['#query#'] = '' - - if '#histories#' in prompt: - if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value: - memory.human_prefix = conversation_histories_role['user_prefix'] - memory.ai_prefix = conversation_histories_role['assistant_prefix'] - histories = self._get_history_messages_from_memory(memory, 2000) - prompt_inputs['#histories#'] = histories - else: - prompt_inputs['#histories#'] = '' - - prompt = prompt_template.format( - prompt_inputs - ) - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) - - if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value: - memory.human_prefix = MessageType.USER.value - memory.ai_prefix = MessageType.ASSISTANT.value - histories = self._get_history_messages_list_from_memory(memory, 2000) - prompt_messages.extend(histories) - - if app_mode == 'chat' and model_mode == ModelMode.CHAT.value: - prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query)) - - return prompt_messages - - def prompt_file_name(self, mode: str) -> str: - if mode == 'completion': - return 'common_completion' - else: - return 'common_chat' - - def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict, - query: str, - context: Optional[str], - memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]: - context_prompt_content = '' - if context and 'context_prompt' in prompt_rules: - prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) - context_prompt_content = prompt_template.format( - {'context': context} - ) - - pre_prompt_content = '' - if pre_prompt: - prompt_template = PromptTemplateParser(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} - pre_prompt_content = prompt_template.format( - prompt_inputs - ) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += pre_prompt_content - - query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' - - if memory and 'histories_prompt' in prompt_rules: - # append chat histories - tmp_human_message = PromptBuilder.to_human_message( - prompt_content=prompt + query_prompt, - inputs={ - 'query': query - } - ) - - if self.model_rules.max_tokens.max: - curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message])) - max_tokens = self.model_kwargs.max_tokens - rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - else: - rest_tokens = 2000 - - memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human' - memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' - - histories = self._get_history_messages_from_memory(memory, rest_tokens) - prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) - histories_prompt_content = prompt_template.format({'histories': histories}) - - prompt = '' - for order in prompt_rules['system_prompt_orders']: - if order == 'context_prompt': - prompt += context_prompt_content - elif order == 'pre_prompt': - prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' - elif order == 'histories_prompt': - prompt += histories_prompt_content - - prompt_template = PromptTemplateParser(template=query_prompt) - query_prompt_content = prompt_template.format({'query': query}) - - prompt += query_prompt_content - - prompt = re.sub(r'<\|.*?\|>', '', prompt) - - stops = prompt_rules.get('stops') - if stops is not None and len(stops) == 0: - stops = None - - return prompt, stops - - def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: - # Get the absolute path of the subdirectory - prompt_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))), - 'prompt/generate_prompts') - - json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') - # Open the JSON file and read its content - with open(json_file_path, 'r') as json_file: - return json.load(json_file) - - def _get_history_messages_from_memory(self, memory: BaseChatMemory, - max_token_limit: int) -> str: - """Get memory messages.""" - memory.max_token_limit = max_token_limit - memory_key = memory.memory_variables[0] - external_context = memory.load_memory_variables({}) - return external_context[memory_key] - - def _get_history_messages_list_from_memory(self, memory: BaseChatMemory, - max_token_limit: int) -> List[PromptMessage]: - """Get memory messages.""" - memory.max_token_limit = max_token_limit - memory.return_messages = True - memory_key = memory.memory_variables[0] - external_context = memory.load_memory_variables({}) - memory.return_messages = False - return to_prompt_messages(external_context[memory_key]) - def _get_prompt_from_messages(self, messages: List[PromptMessage], - model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: + model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]: if not model_mode: model_mode = self.model_mode diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py index 3eae369fe9e644..ca3f1d2cf72657 100644 --- a/api/core/model_providers/models/llm/huggingface_hub_model.py +++ b/api/core/model_providers/models/llm/huggingface_hub_model.py @@ -66,15 +66,6 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int: prompts = self._get_prompt_from_messages(messages) return self._client.get_num_tokens(prompts) - def prompt_file_name(self, mode: str) -> str: - if 'baichuan' in self.name.lower(): - if mode == 'completion': - return 'baichuan_completion' - else: - return 'baichuan_chat' - else: - return super().prompt_file_name(mode) - def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) self.client.model_kwargs = provider_model_kwargs diff --git a/api/core/model_providers/models/llm/openllm_model.py b/api/core/model_providers/models/llm/openllm_model.py index 0ee6ce0f6467af..c92877fd8b6ae6 100644 --- a/api/core/model_providers/models/llm/openllm_model.py +++ b/api/core/model_providers/models/llm/openllm_model.py @@ -49,15 +49,6 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int: prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) - def prompt_file_name(self, mode: str) -> str: - if 'baichuan' in self.name.lower(): - if mode == 'completion': - return 'baichuan_completion' - else: - return 'baichuan_chat' - else: - return super().prompt_file_name(mode) - def _set_model_kwargs(self, model_kwargs: ModelKwargs): pass diff --git a/api/core/model_providers/models/llm/xinference_model.py b/api/core/model_providers/models/llm/xinference_model.py index 551450bec38bc8..2239ef1336856c 100644 --- a/api/core/model_providers/models/llm/xinference_model.py +++ b/api/core/model_providers/models/llm/xinference_model.py @@ -59,15 +59,6 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int: prompts = self._get_prompt_from_messages(messages) return max(self._client.get_num_tokens(prompts), 0) - def prompt_file_name(self, mode: str) -> str: - if 'baichuan' in self.name.lower(): - if mode == 'completion': - return 'baichuan_completion' - else: - return 'baichuan_chat' - else: - return super().prompt_file_name(mode) - def _set_model_kwargs(self, model_kwargs: ModelKwargs): pass diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py new file mode 100644 index 00000000000000..62fd814678aea0 --- /dev/null +++ b/api/core/prompt/prompt_transform.py @@ -0,0 +1,344 @@ +import json +import os +import re +import enum +from typing import List, Optional, Tuple + +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema import BaseMessage + +from core.model_providers.models.entity.model_params import ModelMode +from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.llm.baichuan_model import BaichuanModel +from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel +from core.model_providers.models.llm.openllm_model import OpenLLMModel +from core.model_providers.models.llm.xinference_model import XinferenceModel +from core.prompt.prompt_builder import PromptBuilder +from core.prompt.prompt_template import PromptTemplateParser + +class AppMode(enum.Enum): + COMPLETION = 'completion' + CHAT = 'chat' + +class PromptTransform: + def get_prompt(self, mode: str, + pre_prompt: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> \ + Tuple[List[PromptMessage], Optional[List[str]]]: + prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance)) + prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance) + return [PromptMessage(content=prompt)], stops + + def get_advanced_prompt(self, + app_mode: str, + app_model_config: str, + inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: + + model_mode = app_model_config.model_dict['mode'] + + app_mode_enum = AppMode(app_mode) + model_mode_enum = ModelMode(model_mode) + + prompt_messages = [] + + if app_mode_enum == AppMode.CHAT: + if model_mode_enum == ModelMode.COMPLETION: + prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance) + elif model_mode_enum == ModelMode.CHAT: + prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance) + elif app_mode_enum == AppMode.COMPLETION: + if model_mode_enum == ModelMode.CHAT: + prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context) + elif model_mode_enum == ModelMode.COMPLETION: + prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context) + + return prompt_messages + + def _get_history_messages_from_memory(self, memory: BaseChatMemory, + max_token_limit: int) -> str: + """Get memory messages.""" + memory.max_token_limit = max_token_limit + memory_key = memory.memory_variables[0] + external_context = memory.load_memory_variables({}) + return external_context[memory_key] + + def _get_history_messages_list_from_memory(self, memory: BaseChatMemory, + max_token_limit: int) -> List[PromptMessage]: + """Get memory messages.""" + memory.max_token_limit = max_token_limit + memory.return_messages = True + memory_key = memory.memory_variables[0] + external_context = memory.load_memory_variables({}) + memory.return_messages = False + return to_prompt_messages(external_context[memory_key]) + + def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str: + # baichuan + if isinstance(model_instance, BaichuanModel): + return self._prompt_file_name_for_baichuan(mode) + + baichuan_model_hosted_platforms = (HuggingfaceHubModel, OpenLLMModel, XinferenceModel) + if isinstance(model_instance, baichuan_model_hosted_platforms) and 'baichuan' in model_instance.name.lower(): + return self._prompt_file_name_for_baichuan(mode) + + # common + if mode == 'completion': + return 'common_completion' + else: + return 'common_chat' + + def _prompt_file_name_for_baichuan(self, mode: str) -> str: + if mode == 'completion': + return 'baichuan_completion' + else: + return 'baichuan_chat' + + def _read_prompt_rules_from_file(self, prompt_name: str) -> dict: + # Get the absolute path of the subdirectory + prompt_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'generate_prompts') + + json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') + # Open the JSON file and read its content + with open(json_file_path, 'r') as json_file: + return json.load(json_file) + + def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> Tuple[str, Optional[list]]: + context_prompt_content = '' + if context and 'context_prompt' in prompt_rules: + prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) + context_prompt_content = prompt_template.format( + {'context': context} + ) + + pre_prompt_content = '' + if pre_prompt: + prompt_template = PromptTemplateParser(template=pre_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + pre_prompt_content = prompt_template.format( + prompt_inputs + ) + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt': + prompt += context_prompt_content + elif order == 'pre_prompt': + prompt += pre_prompt_content + + query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}' + + if memory and 'histories_prompt' in prompt_rules: + # append chat histories + tmp_human_message = PromptBuilder.to_human_message( + prompt_content=prompt + query_prompt, + inputs={ + 'query': query + } + ) + + rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance) + + memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human' + memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' + + histories = self._get_history_messages_from_memory(memory, rest_tokens) + prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) + histories_prompt_content = prompt_template.format({'histories': histories}) + + prompt = '' + for order in prompt_rules['system_prompt_orders']: + if order == 'context_prompt': + prompt += context_prompt_content + elif order == 'pre_prompt': + prompt += (pre_prompt_content + '\n') if pre_prompt_content else '' + elif order == 'histories_prompt': + prompt += histories_prompt_content + + prompt_template = PromptTemplateParser(template=query_prompt) + query_prompt_content = prompt_template.format({'query': query}) + + prompt += query_prompt_content + + prompt = re.sub(r'<\|.*?\|>', '', prompt) + + stops = prompt_rules.get('stops') + if stops is not None and len(stops) == 0: + stops = None + + return prompt, stops + + def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#context#' in prompt_template.variable_keys: + if context: + prompt_inputs['#context#'] = context + else: + prompt_inputs['#context#'] = '' + + def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None: + if '#query#' in prompt_template.variable_keys: + if query: + prompt_inputs['#query#'] = query + else: + prompt_inputs['#query#'] = '' + + def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict, + prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None: + if '#histories#' in prompt_template.variable_keys: + if memory: + tmp_human_message = PromptBuilder.to_human_message( + prompt_content=raw_prompt, + inputs={ '#histories#': '', **prompt_inputs } + ) + + rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance) + + memory.human_prefix = conversation_histories_role['user_prefix'] + memory.ai_prefix = conversation_histories_role['assistant_prefix'] + histories = self._get_history_messages_from_memory(memory, rest_tokens) + prompt_inputs['#histories#'] = histories + else: + prompt_inputs['#histories#'] = '' + + def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None: + if memory: + rest_tokens = self._calculate_rest_token(prompt_messages, model_instance) + + memory.human_prefix = MessageType.USER.value + memory.ai_prefix = MessageType.ASSISTANT.value + histories = self._get_history_messages_list_from_memory(memory, rest_tokens) + prompt_messages.extend(histories) + + def _calculate_rest_token(self, prompt_messages: BaseMessage, model_instance: BaseLLM) -> int: + rest_tokens = 2000 + + if model_instance.model_rules.max_tokens.max: + curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages)) + max_tokens = model_instance.model_kwargs.max_tokens + rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str: + prompt = prompt_template.format( + prompt_inputs + ) + + prompt = re.sub(r'<\|.*?\|>', '', prompt) + return prompt + + def _get_chat_app_completion_model_prompt_messages(self, + app_model_config: str, + inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: + + raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text'] + conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role'] + + prompt_messages = [] + prompt = '' + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + self._set_query_variable(query, prompt_template, prompt_inputs) + + self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance) + + prompt = self._format_prompt(prompt_template, prompt_inputs) + + prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt)) + + return prompt_messages + + def _get_chat_app_chat_model_prompt_messages(self, + app_model_config: str, + inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory], + model_instance: BaseLLM) -> List[PromptMessage]: + raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] + + prompt_messages = [] + + for prompt_item in raw_prompt_list: + raw_prompt = prompt_item['text'] + prompt = '' + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = self._format_prompt(prompt_template, prompt_inputs) + + prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) + + self._append_chat_histories(memory, prompt_messages, model_instance) + + prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query)) + + return prompt_messages + + def _get_completion_app_completion_model_prompt_messages(self, + app_model_config: str, + inputs: dict, + context: Optional[str]) -> List[PromptMessage]: + raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text'] + + prompt_messages = [] + prompt = '' + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = self._format_prompt(prompt_template, prompt_inputs) + + prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt)) + + return prompt_messages + + def _get_completion_app_chat_model_prompt_messages(self, + app_model_config: str, + inputs: dict, + context: Optional[str]) -> List[PromptMessage]: + raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] + + prompt_messages = [] + + for prompt_item in raw_prompt_list: + raw_prompt = prompt_item['text'] + prompt = '' + + prompt_template = PromptTemplateParser(template=raw_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + self._set_context_variable(context, prompt_template, prompt_inputs) + + prompt = self._format_prompt(prompt_template, prompt_inputs) + + prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) + + return prompt_messages \ No newline at end of file diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 58e2c658fb9c2e..bdbc2b82f807b0 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,6 +1,8 @@ import copy +from core.model_providers.models.entity.model_params import ModelMode +from core.prompt.prompt_transform import AppMode from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT @@ -13,7 +15,7 @@ def get_prompt(cls, args: dict) -> dict: model_name = args['model_name'] has_context = args['has_context'] - if 'baichuan' in model_name: + if 'baichuan' in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: return cls.get_common_prompt(app_mode, model_mode, has_context) @@ -22,15 +24,15 @@ def get_prompt(cls, args: dict) -> dict: def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) - if app_mode == 'chat': - if model_mode == 'completion': + if app_mode == AppMode.CHAT.value: + if model_mode == ModelMode.COMPLETION.value: return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) - elif model_mode == 'chat': + elif model_mode == ModelMode.CHAT.value: return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - elif app_mode == 'completion': - if model_mode == 'completion': + elif app_mode == AppMode.COMPLETION.value: + if model_mode == ModelMode.COMPLETION.value: return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) - elif model_mode == 'chat': + elif model_mode == ModelMode.CHAT.value: return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) @classmethod @@ -51,13 +53,13 @@ def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) - if app_mode == 'chat': - if model_mode == 'completion': + if app_mode == AppMode.CHAT.value: + if model_mode == ModelMode.COMPLETION.value: return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) - elif model_mode == 'chat': + elif model_mode == ModelMode.CHAT.value: return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) - elif app_mode == 'completion': - if model_mode == 'completion': + elif app_mode == AppMode.COMPLETION.value: + if model_mode == ModelMode.COMPLETION.value: return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) - elif model_mode == 'chat': + elif model_mode == ModelMode.CHAT.value: return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 4c49b43b887c6d..79c1ed0ad6f663 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,6 +1,7 @@ import re import uuid +from core.prompt.prompt_transform import AppMode from core.agent.agent_executor import PlanningStrategy from core.model_providers.model_provider_factory import ModelProviderFactory from core.model_providers.models.entity.model_params import ModelType, ModelMode @@ -418,7 +419,7 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: if config['model']["mode"] not in ['chat', 'completion']: raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") - if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value: + if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value: user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] @@ -427,3 +428,10 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: if not assistant_prefix: config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' + + + if config['model']["mode"] == ModelMode.CHAT.value: + prompt_list = config['chat_prompt_config']['prompt'] + + if len(prompt_list) > 10: + raise ValueError("prompt messages must be less than 10") \ No newline at end of file From 211b379c070c9497949148895e056b108a1aa8ac Mon Sep 17 00:00:00 2001 From: StyleZhang Date: Thu, 19 Oct 2023 11:32:15 +0800 Subject: [PATCH 5/5] fix: npm run start --- web/README.md | 12 ++++-------- web/package.json | 7 ++++--- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/web/README.md b/web/README.md index 1a24a135e795b0..f6d8e6cdb96193 100644 --- a/web/README.md +++ b/web/README.md @@ -64,24 +64,20 @@ Open [http://localhost:3000](http://localhost:3000) with your browser to see the ## Deploy ### Deploy on server First, build the app for production: - ```bash npm run build ``` -Then, move the static files to standalone folder: +Then, start the server: ```bash -mv .next/static .next/standalone/.next -cp -r ./public .next/standalone/.next/ +npm run start ``` -Finally, start the app: +If you want to customize the host and port: ```bash -node .next/standalone/server.js +npm run start --port=3001 --host=0.0.0.0 ``` -If your project needs alternative port or hostname for listening, you can define PORT and HOSTNAME environment variables, before running server.js. For example, `PORT=3000 HOSTNAME=localhost node .next/standalone/server.js`. - ## Lint Code If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. diff --git a/web/package.json b/web/package.json index 7a225d397f6597..88a0d1dc965410 100644 --- a/web/package.json +++ b/web/package.json @@ -5,7 +5,7 @@ "scripts": { "dev": "next dev", "build": "next build", - "start": "next dev", + "start": "cp -r .next/static .next/standalone/.next/static && cp -r public .next/standalone/public && cross-env PORT=$npm_config_port HOSTNAME=$npm_config_host node .next/standalone/server.js", "lint": "next lint", "fix": "next lint --fix", "eslint-fix": "eslint --fix", @@ -95,6 +95,8 @@ "@types/react-window-infinite-loader": "^1.0.6", "@types/recordrtc": "^5.6.11", "@types/sortablejs": "^1.15.1", + "autoprefixer": "^10.4.14", + "cross-env": "^7.0.3", "eslint": "8.36.0", "eslint-config-next": "^13.4.7", "husky": "^8.0.3", @@ -102,10 +104,9 @@ "miragejs": "^0.1.47", "postcss": "^8.4.21", "sass": "^1.61.0", - "uglify-js": "^3.17.4", "tailwindcss": "^3.3.3", "typescript": "4.9.5", - "autoprefixer": "^10.4.14" + "uglify-js": "^3.17.4" }, "lint-staged": { "**/*.js?(x)": [