Skip to content

Commit

Permalink
Merge branch 'feat/continue-on-error' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
Novice Lee authored and Novice Lee committed Dec 8, 2024
2 parents 4dfe73a + 1d118a0 commit 46e104d
Show file tree
Hide file tree
Showing 16 changed files with 250 additions and 155 deletions.
14 changes: 1 addition & 13 deletions api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum, StrEnum
from typing import Any, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
Expand Down Expand Up @@ -115,18 +115,6 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None

@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")


class QueueIterationCompletedEvent(AppQueueEvent):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD
16 changes: 9 additions & 7 deletions api/core/model_runtime/model_providers/ollama/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,11 @@ def _generate(
# prepare the payload for a simple ping to the model
data = {"model": model, "stream": stream}

if "format" in model_parameters:
data["format"] = model_parameters["format"]
del model_parameters["format"]
if format_schema := model_parameters.pop("format", None):
try:
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema)
except json.JSONDecodeError as e:
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")

if "keep_alive" in model_parameters:
data["keep_alive"] = model_parameters["keep_alive"]
Expand Down Expand Up @@ -733,12 +735,12 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
ParameterRule(
name="format",
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
type=ParameterType.STRING,
type=ParameterType.TEXT,
default="json",
help=I18nObject(
en_US="the format to return a response in. Currently the only accepted value is json.",
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
en_US="the format to return a response in. Format can be `json` or a JSON schema.",
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.",
),
options=["json"],
),
],
pricing=PriceConfig(
Expand Down
10 changes: 6 additions & 4 deletions api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ def _generate_anthropic(
"""
# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
token = ""

# get access token from service account credential
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
Expand Down Expand Up @@ -478,10 +479,11 @@ def _generate(
if stop:
config_kwargs["stop_sequences"] = stop

service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def _invoke(
:param input_type: input type
:return: embeddings result
"""
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
Expand Down Expand Up @@ -100,10 +101,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:return:
"""
try:
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
model: glm-4v-flash
label:
en_US: glm-4v-flash
model_type: llm
model_properties:
mode: chat
context_size: 2048
features:
- vision
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.6
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 1024
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: RMB
6 changes: 4 additions & 2 deletions api/core/model_runtime/model_providers/zhipuai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _generate(
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v'
if model not in {"glm-4v", "glm-4v-plus"}:
if not model.startswith("glm-4v"):
# not support list message
continue
# get image and
Expand Down Expand Up @@ -188,7 +188,7 @@ def _generate(
else:
model_parameters["tools"] = [web_search_params]

if model in {"glm-4v", "glm-4v-plus"}:
if model.startswith("glm-4v"):
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:
params = {"model": model, "messages": [], **model_parameters}
Expand Down Expand Up @@ -412,6 +412,8 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)

if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def batch_update_tidb_serverless_cluster_status(
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "FULL"}
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)
Expand Down
26 changes: 13 additions & 13 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def _run(
next_node_id = edge.target_node_id
else:
final_node_id = None

if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings = {}
Expand Down Expand Up @@ -701,19 +702,18 @@ def _run_node(
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
event_args = {
"id": node_instance.id,
"node_id": node_instance.node_id,
"node_type": node_instance.node_type,
"node_data": node_instance.node_data,
"route_node_state": route_node_state,
"parallel_id": parallel_id,
"parallel_start_node_id": parallel_start_node_id,
"parent_parallel_id": parent_parallel_id,
"parent_parallel_start_node_id": parent_parallel_start_node_id,
}
event = NodeRunSucceededEvent(**event_args)
yield event

yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)

break
elif isinstance(item, RunStreamChunkEvent):
Expand Down
24 changes: 16 additions & 8 deletions api/core/workflow/nodes/document_extractor/node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import csv
import io
import json
import os
import tempfile

import docx
import pandas as pd
Expand Down Expand Up @@ -264,14 +266,20 @@ def _extract_text_from_ppt(file_content: bytes) -> str:

def _extract_text_from_pptx(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
elements = partition_via_api(
file=file,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
)
else:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
)
os.unlink(temp_file.name)
else:
with io.BytesIO(file_content) as file:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
Expand Down
Loading

0 comments on commit 46e104d

Please sign in to comment.