Skip to content

Commit

Permalink
fix: missing system prompt, missing arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Dec 28, 2023
1 parent 7118f31 commit 4f3ff71
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions api/core/model_runtime/model_providers/zhipuai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def _generate(self, model: str, credentials_kwargs: dict,
api_key=credentials_kwargs['api_key']
)

if len(prompt_messages) == 0:
raise ValueError('At least one message is required')

if prompt_messages[0].role.value == 'system':
if not prompt_messages[0].content:
prompt_messages = prompt_messages[1:]

params = {
'model': model,
'prompt': [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages],
Expand All @@ -116,12 +123,14 @@ def _generate(self, model: str, credentials_kwargs: dict,

if stream:
response = client.sse_invoke(incremental=True, **params).events()
return self._handle_generate_stream_response(model, response, prompt_messages)
return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)

response = client.invoke(**params)
return self._handle_generate_response(model, response, prompt_messages)
return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)

def _handle_generate_response(self, model: str, response: Dict[str, Any],
def _handle_generate_response(self, model: str,
credentials: dict,
response: Dict[str, Any],
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
Expand All @@ -144,18 +153,21 @@ def _handle_generate_response(self, model: str, response: Dict[str, Any],
token_usage['completion_tokens'] = token_usage['total_tokens']

# transform usage
usage = self._calc_response_usage(model, token_usage['prompt_tokens'], token_usage['completion_tokens'])
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])

# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage,
)

return result

def _handle_generate_stream_response(self, model: str, responses: list[Generator],
def _handle_generate_stream_response(self, model: str,
credentials: dict,
responses: list[Generator],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
Expand All @@ -168,6 +180,7 @@ def _handle_generate_stream_response(self, model: str, responses: list[Generator
for index, event in enumerate(responses):
if event.event == "add":
yield LLMResultChunk(
prompt_messages=prompt_messages,
model=model,
delta=LLMResultChunkDelta(
index=index,
Expand All @@ -187,10 +200,11 @@ def _handle_generate_stream_response(self, model: str, responses: list[Generator
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']

usage = self._calc_response_usage(model, token_usage['prompt_tokens'], token_usage['completion_tokens'])
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=event.data),
Expand Down

0 comments on commit 4f3ff71

Please sign in to comment.