From f414d241c1a30777cda97072851f1343da5f1f56 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Mon, 11 Nov 2024 14:47:52 +0800 Subject: [PATCH 01/53] Feat/iteration single run time (#10512) --- api/core/app/apps/workflow_app_runner.py | 1 + api/core/app/entities/queue_entities.py | 3 ++ api/core/app/entities/task_entities.py | 1 + .../task_pipeline/workflow_cycle_manage.py | 1 + api/core/workflow/entities/node_entities.py | 1 + .../workflow/graph_engine/entities/event.py | 2 + .../nodes/iteration/iteration_node.py | 22 +++++++- .../workflow/hooks/use-workflow-run.ts | 3 ++ .../components/workflow/nodes/_base/node.tsx | 2 +- .../workflow/panel/workflow-preview.tsx | 8 ++- web/app/components/workflow/run/index.tsx | 9 ++-- .../workflow/run/iteration-result-panel.tsx | 54 ++++++++++++++----- web/app/components/workflow/run/node.tsx | 6 +-- .../components/workflow/run/tracing-panel.tsx | 4 +- web/i18n/en-US/workflow.ts | 6 +-- web/types/workflow.ts | 7 +++ 16 files changed, 101 insertions(+), 29 deletions(-) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 9a01e8a253f97b..2872390d4662db 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -361,6 +361,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, output=event.pre_iteration_output, parallel_mode_run_id=event.parallel_mode_run_id, + duration=event.duration, ) ) elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index f1542ec5d8c578..69bc0d7f9ec102 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -111,6 +111,7 @@ class QueueIterationNextEvent(AppQueueEvent): """iteratoin run in parallel mode run id""" node_run_index: int output: Optional[Any] = None # output for the current iteration + duration: Optional[float] = None @field_validator("output", mode="before") @classmethod @@ -307,6 +308,8 @@ class QueueNodeSucceededEvent(AppQueueEvent): execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: Optional[str] = None + """single iteration duration map""" + iteration_duration_map: Optional[dict[str, float]] = None class QueueNodeInIterationFailedEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7e9aad54be57e4..03cc6941a84623 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -434,6 +434,7 @@ class Data(BaseModel): parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None parallel_mode_run_id: Optional[str] = None + duration: Optional[float] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index b89edf9079f043..042339969fb8ad 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -624,6 +624,7 @@ def _workflow_iteration_next_to_stream_response( parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parallel_mode_run_id=event.parallel_mode_run_id, + duration=event.duration, ), ) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 7e10cddc712baa..a7472666614fab 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -24,6 +24,7 @@ class NodeRunMetadataKey(str, Enum): PARENT_PARALLEL_ID = "parent_parallel_id" PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" + ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index bacea191dd866c..3736e632c3f1eb 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -148,6 +148,7 @@ class IterationRunStartedEvent(BaseIterationEvent): class IterationRunNextEvent(BaseIterationEvent): index: int = Field(..., description="index") pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output") + duration: Optional[float] = Field(None, description="duration") class IterationRunSucceededEvent(BaseIterationEvent): @@ -156,6 +157,7 @@ class IterationRunSucceededEvent(BaseIterationEvent): outputs: Optional[dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None steps: int = 0 + iteration_duration_map: Optional[dict[str, float]] = None class IterationRunFailedEvent(BaseIterationEvent): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e5863d771b0431..941ebde7a9c572 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -156,6 +156,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: index=0, pre_iteration_output=None, ) + iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: if self.node_data.is_parallel: @@ -175,6 +176,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: iteration_graph, index, item, + iter_run_map, ) future.add_done_callback(thread_pool.task_done_callback) futures.append(future) @@ -213,6 +215,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: start_at, graph_engine, iteration_graph, + iter_run_map, ) if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] @@ -230,7 +233,9 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: yield RunCompletedEvent( run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"output": jsonable_encoder(outputs)}, + metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map}, ) ) except IterationNodeError as e: @@ -356,15 +361,19 @@ def _run_single_iter( start_at: datetime, graph_engine: "GraphEngine", iteration_graph: Graph, + iter_run_map: dict[str, float], parallel_mode_run_id: Optional[str] = None, ) -> Generator[NodeEvent | InNodeEvent, None, None]: """ run single iteration """ + iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None) + try: rst = graph_engine.run() # get current iteration index current_index = variable_pool.get([self.node_id, "index"]).value + iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" next_index = int(current_index) + 1 if current_index is None: @@ -431,6 +440,8 @@ def _run_single_iter( variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -439,6 +450,7 @@ def _run_single_iter( index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, + duration=duration, ) return elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: @@ -449,6 +461,8 @@ def _run_single_iter( if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -457,6 +471,7 @@ def _run_single_iter( index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, + duration=duration, ) return elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: @@ -485,6 +500,8 @@ def _run_single_iter( if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -493,6 +510,7 @@ def _run_single_iter( index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + duration=duration, ) except IterationNodeError as e: @@ -528,6 +546,7 @@ def _run_single_iter_parallel( iteration_graph: Graph, index: int, item: Any, + iter_run_map: dict[str, float], ) -> Generator[NodeEvent | InNodeEvent, None, None]: """ run single iteration in parallel mode @@ -546,6 +565,7 @@ def _run_single_iter_parallel( start_at=start_at, graph_engine=graph_engine_copy, iteration_graph=iteration_graph, + iter_run_map=iter_run_map, parallel_mode_run_id=parallel_mode_run_id, ): q.put(event) diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 26654ef71e804c..eab353550510b1 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -445,6 +445,7 @@ export const useWorkflowRun = () => { ...data, status: NodeRunningStatus.Running, details: [], + iterDurationMap: {}, } as any) })) @@ -496,6 +497,8 @@ export const useWorkflowRun = () => { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { const iteration = draft.tracing!.find(trace => trace.node_id === data.node_id) if (iteration) { + if (iteration.iterDurationMap && data.duration) + iteration.iterDurationMap[data.parallel_mode_run_id ?? `${data.index - 1}`] = data.duration if (iteration.details!.length >= iteration.metadata.iterator_length!) return } diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index e864c419e22042..c5b78c5c2140f9 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -193,7 +193,7 @@ const BaseNode: FC = ({ { data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && (
- {data._iterationIndex}/{data._iterationLength} + {data._iterationIndex > data._iterationLength ? data._iterationLength : data._iterationIndex}/{data._iterationLength}
) } diff --git a/web/app/components/workflow/panel/workflow-preview.tsx b/web/app/components/workflow/panel/workflow-preview.tsx index 361f9d6bf44fa1..d560c0b2cb6c10 100644 --- a/web/app/components/workflow/panel/workflow-preview.tsx +++ b/web/app/components/workflow/panel/workflow-preview.tsx @@ -28,7 +28,7 @@ import IterationResultPanel from '../run/iteration-result-panel' import InputsPanel from './inputs-panel' import cn from '@/utils/classnames' import Loading from '@/app/components/base/loading' -import type { NodeTracing } from '@/types/workflow' +import type { IterationDurationMap, NodeTracing } from '@/types/workflow' const WorkflowPreview = () => { const { t } = useTranslation() @@ -53,12 +53,14 @@ const WorkflowPreview = () => { }, [workflowRunningData]) const [iterationRunResult, setIterationRunResult] = useState([]) + const [iterDurationMap, setIterDurationMap] = useState({}) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) - const handleShowIterationDetail = useCallback((detail: NodeTracing[][]) => { + const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterationDurationMap: IterationDurationMap) => { + setIterDurationMap(iterationDurationMap) setIterationRunResult(detail) doShowIterationDetail() }, [doShowIterationDetail]) @@ -72,6 +74,7 @@ const WorkflowPreview = () => { list={iterationRunResult} onHide={doHideIterationDetail} onBack={doHideIterationDetail} + iterDurationMap={iterDurationMap} /> ) @@ -94,6 +97,7 @@ const WorkflowPreview = () => { list={iterationRunResult} onHide={doHideIterationDetail} onBack={doHideIterationDetail} + iterDurationMap={iterDurationMap} /> ) : ( diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 89db43fa35e857..5267cf257db640 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -13,7 +13,7 @@ import cn from '@/utils/classnames' import { ToastContext } from '@/app/components/base/toast' import Loading from '@/app/components/base/loading' import { fetchRunDetail, fetchTracingList } from '@/service/log' -import type { NodeTracing } from '@/types/workflow' +import type { IterationDurationMap, NodeTracing } from '@/types/workflow' import type { WorkflowRunDetailResponse } from '@/models/log' import { useStore as useAppStore } from '@/app/components/app/store' @@ -172,15 +172,17 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe }, [loading]) const [iterationRunResult, setIterationRunResult] = useState([]) + const [iterDurationMap, setIterDurationMap] = useState({}) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) - const handleShowIterationDetail = useCallback((detail: NodeTracing[][]) => { + const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => { setIterationRunResult(detail) doShowIterationDetail() - }, [doShowIterationDetail]) + setIterDurationMap(iterDurationMap) + }, [doShowIterationDetail, setIterationRunResult, setIterDurationMap]) if (isShowIterationDetail) { return ( @@ -189,6 +191,7 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe list={iterationRunResult} onHide={doHideIterationDetail} onBack={doHideIterationDetail} + iterDurationMap={iterDurationMap} /> ) diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index c4cd909f2ed9ae..b13eadec99bcb2 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -6,12 +6,14 @@ import { RiArrowRightSLine, RiCloseLine, RiErrorWarningLine, + RiLoader2Line, } from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' +import { NodeRunningStatus } from '../types' import TracingPanel from './tracing-panel' import { Iteration } from '@/app/components/base/icons/src/vender/workflow' import cn from '@/utils/classnames' -import type { NodeTracing } from '@/types/workflow' +import type { IterationDurationMap, NodeTracing } from '@/types/workflow' const i18nPrefix = 'workflow.singleRun' type Props = { @@ -19,6 +21,7 @@ type Props = { onHide: () => void onBack: () => void noWrap?: boolean + iterDurationMap?: IterationDurationMap } const IterationResultPanel: FC = ({ @@ -26,6 +29,7 @@ const IterationResultPanel: FC = ({ onHide, onBack, noWrap, + iterDurationMap, }) => { const { t } = useTranslation() const [expandedIterations, setExpandedIterations] = useState>({}) @@ -36,6 +40,40 @@ const IterationResultPanel: FC = ({ [index]: !prev[index], })) }, []) + const countIterDuration = (iteration: NodeTracing[], iterDurationMap: IterationDurationMap): string => { + const IterRunIndex = iteration[0].execution_metadata.iteration_index as number + const iterRunId = iteration[0].execution_metadata.parallel_mode_run_id + const iterItem = iterDurationMap[iterRunId || IterRunIndex] + const duration = iterItem + return `${(duration && duration > 0.01) ? duration.toFixed(2) : 0.01}s` + } + const iterationStatusShow = (index: number, iteration: NodeTracing[], iterDurationMap?: IterationDurationMap) => { + const hasFailed = iteration.some(item => item.status === NodeRunningStatus.Failed) + const isRunning = iteration.some(item => item.status === NodeRunningStatus.Running) + const hasDurationMap = iterDurationMap && Object.keys(iterDurationMap).length !== 0 + + if (hasFailed) + return + + if (isRunning) + return + + return ( + <> + {hasDurationMap && ( +
+ {countIterDuration(iteration, iterDurationMap)} +
+ )} + + + ) + } const main = ( <> @@ -72,19 +110,7 @@ const IterationResultPanel: FC = ({ {t(`${i18nPrefix}.iteration`)} {index + 1} - { - iteration.some(item => item.status === 'failed') - ? ( - - ) - : (< RiArrowRightSLine className={ - cn( - 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', - expandedIterations[index] && 'transform rotate-90', - )} /> - ) - } - + {iterationStatusShow(index, iteration, iterDurationMap)} {expandedIterations[index] &&
void + onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void notShowIterationNav?: boolean justShowIterationNavArrow?: boolean } @@ -90,7 +90,7 @@ const NodePanel: FC = ({ const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() - onShowIterationDetail?.(nodeInfo.details || []) + onShowIterationDetail?.(nodeInfo.details || [], nodeInfo?.iterDurationMap || nodeInfo.execution_metadata?.iteration_duration_map || {}) } return (
diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index 613c10198de884..57b3a5cf5f145d 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -16,11 +16,11 @@ import NodePanel from './node' import { BlockEnum, } from '@/app/components/workflow/types' -import type { NodeTracing } from '@/types/workflow' +import type { IterationDurationMap, NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] - onShowIterationDetail?: (detail: NodeTracing[][]) => void + onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void className?: string hideNodeInfo?: boolean hideNodeProcessDetail?: boolean diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 7bfad01f2320a1..b3fece702a3b2b 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -569,9 +569,9 @@ const translation = { MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.', errorResponseMethod: 'Error response method', ErrorMethod: { - operationTerminated: 'terminated', - continueOnError: 'continue on error', - removeAbnormalOutput: 'remove abnormal output', + operationTerminated: 'Terminated', + continueOnError: 'Continue on Error', + removeAbnormalOutput: 'Remove Abnormal Output', }, answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.', }, diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 3c0675b60572a3..34b08e878e855d 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -33,6 +33,7 @@ export type NodeTracing = { parent_parallel_id?: string parent_parallel_start_node_id?: string parallel_mode_run_id?: string + iteration_duration_map?: IterationDurationMap } metadata: { iterator_length: number @@ -44,6 +45,7 @@ export type NodeTracing = { name: string email: string } + iterDurationMap?: IterationDurationMap finished_at: number extras?: any expand?: boolean // for UI @@ -207,7 +209,10 @@ export type IterationNextResponse = { parallel_mode_run_id: string execution_metadata: { parallel_id?: string + iteration_index: number + parallel_mode_run_id?: string } + duration?: number } } @@ -323,3 +328,5 @@ export type ConversationVariableResponse = { total: number page: number } + +export type IterationDurationMap = Record From 508f84893fb0e30d1cddf9c2f12e07e077849cf8 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Mon, 11 Nov 2024 14:57:28 +0800 Subject: [PATCH 02/53] fix: workflow start node form optional value (#10529) --- web/app/components/base/chat/chat-with-history/hooks.tsx | 2 +- web/app/components/base/chat/embedded-chatbot/hooks.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index d4fa170e4c60d0..a67cc3cd885387 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -173,7 +173,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const conversationInputs: Record = {} inputsForms.forEach((item: any) => { - conversationInputs[item.variable] = item.default || '' + conversationInputs[item.variable] = item.default || null }) handleNewConversationInputsChange(conversationInputs) }, [handleNewConversationInputsChange, inputsForms]) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 631d3b56bc0232..0a8bc0993f997c 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -159,7 +159,7 @@ export const useEmbeddedChatbot = () => { const conversationInputs: Record = {} inputsForms.forEach((item: any) => { - conversationInputs[item.variable] = item.default || '' + conversationInputs[item.variable] = item.default || null }) handleNewConversationInputsChange(conversationInputs) }, [handleNewConversationInputsChange, inputsForms]) From 9018ef30feabd0d95ab63db1c80631a9aa29b0ae Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Mon, 11 Nov 2024 15:02:33 +0800 Subject: [PATCH 03/53] chore: (dockerfile) upgrade perl version (#10534) --- api/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/Dockerfile b/api/Dockerfile index 51e2a10506474e..ed981e46d6f56c 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ + && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-7 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ From 867bf70f1a62463ff5fec38edd4c0f190503ea5b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 11 Nov 2024 16:06:53 +0800 Subject: [PATCH 04/53] fix(model_runtime): ensure compatibility with O1 models by adjusting token parameters (#10537) --- api/core/model_runtime/model_providers/openai/llm/llm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 922e5e131417ee..68317d71796189 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -617,6 +617,10 @@ def _chat_generate( # o1 compatibility block_as_stream = False if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] + if stream: block_as_stream = True stream = False From be33875199c4c292b757ec340f0ec19bd0f67a53 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 11 Nov 2024 16:23:11 +0800 Subject: [PATCH 05/53] fix(gitee_ai): update English description for clarity and accuracy (#10540) --- api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml index 2e18f8a7fca56a..d0475665dd7ac7 100644 --- a/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml +++ b/api/core/tools/provider/builtin/gitee_ai/gitee_ai.yaml @@ -5,7 +5,7 @@ identity: en_US: Gitee AI zh_Hans: Gitee AI description: - en_US: 快速体验大模型,领先探索 AI 开源世界 + en_US: Quickly experience large models and explore the leading AI open source world zh_Hans: 快速体验大模型,领先探索 AI 开源世界 icon: icon.svg tags: From 90087160c6bbb509bfe2010c7e7e184e09ef646a Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 11 Nov 2024 16:41:47 +0800 Subject: [PATCH 06/53] chore (vanna): update form parameter from 'form' to 'llm' in vanna.yaml (#10488) --- api/core/tools/provider/builtin/vanna/tools/vanna.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml index 12ca8a862e966f..3520ba55709775 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml @@ -32,7 +32,7 @@ parameters: en_US: RAG Model for your database DDL zh_Hans: 存储数据库训练数据的RAG模型 llm_description: RAG Model for generating SQL - form: form + form: llm - name: db_type type: select required: true From a1543b7da053657041489f34540942a312f89ffa Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 11 Nov 2024 17:31:27 +0800 Subject: [PATCH 07/53] fix(extractor): temporary file (#10543) --- api/core/rag/extractor/word_extractor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index b59e7f94fd5013..8e084ab4ff1679 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -50,9 +50,9 @@ def __init__(self, file_path: str, tenant_id: str, user_id: str): self.web_path = self.file_path # TODO: use a better way to handle the file - with tempfile.NamedTemporaryFile(delete=False) as self.temp_file: - self.temp_file.write(r.content) - self.file_path = self.temp_file.name + self.temp_file = tempfile.NamedTemporaryFile() # noqa: SIM115 + self.temp_file.write(r.content) + self.file_path = self.temp_file.name elif not os.path.isfile(self.file_path): raise ValueError(f"File path {self.file_path} is not a valid file or url") From 4b45ef62ed7ac95bdc6c40724a30a1d4b8336781 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Mon, 11 Nov 2024 17:34:48 +0800 Subject: [PATCH 08/53] fix: iteration invalid output selector doesn't throw an error (#10544) --- api/core/workflow/nodes/iteration/iteration_node.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 941ebde7a9c572..d5428f02868b76 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -489,7 +489,10 @@ def _run_single_iter( ) yield metadata_event - current_iteration_output = variable_pool.get(self.node_data.output_selector).value + current_output_segment = variable_pool.get(self.node_data.output_selector) + if current_output_segment is None: + raise IterationNodeError("iteration output selector not found") + current_iteration_output = current_output_segment.value outputs[current_index] = current_iteration_output # remove all nodes outputs from variable pool for node_id in iteration_graph.node_ids: From 9550b884f71450ac380a71cdf13ab67816777e0f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 11 Nov 2024 18:32:28 +0800 Subject: [PATCH 09/53] chore: update version to 0.11.1 across all configurations and Docker images (#10539) --- api/app.py | 4 ++++ api/configs/packaging/__init__.py | 2 +- docker-legacy/docker-compose.yaml | 6 +++--- docker/docker-compose.yaml | 6 +++--- web/package.json | 2 +- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/api/app.py b/api/app.py index 60cd622ef4d0a8..ead60e98d7dcc0 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,5 @@ import os +import sys from configs import dify_config @@ -29,6 +30,9 @@ # DO NOT REMOVE ABOVE +if sys.version_info[:2] == (3, 10): + print("Warning: Python 3.10 will not be supported in the next version.") + warnings.simplefilter("ignore", ResourceWarning) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index b5cb1f06d951f0..65065efbc09329 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.11.0", + default="0.11.1", ) COMMIT_SHA: str = Field( diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 90110f49a23abe..9c2a1fe980a2f9 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.11.0 + image: langgenius/dify-api:0.11.1 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.11.0 + image: langgenius/dify-api:0.11.1 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.11.0 + image: langgenius/dify-web:0.11.1 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index fcc0c562168e7a..d9ff9654732ad0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -280,7 +280,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.11.0 + image: langgenius/dify-api:0.11.1 restart: always environment: # Use the shared environment variables. @@ -300,7 +300,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.11.0 + image: langgenius/dify-api:0.11.1 restart: always environment: # Use the shared environment variables. @@ -319,7 +319,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.11.0 + image: langgenius/dify-web:0.11.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/package.json b/web/package.json index de01eb4d489df1..d863ba13d34867 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.11.0", + "version": "0.11.1", "private": true, "engines": { "node": ">=18.17.0" From 570f10d91ce3df5a38ff4284fac99bd24bb3cf0c Mon Sep 17 00:00:00 2001 From: liuhaoran <75237518+liuhaoran1212@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:43:37 +0800 Subject: [PATCH 10/53] fix issues:Image file not deleted when a doc is removed #9541 (#10465) Signed-off-by: root Co-authored-by: root --- api/core/tools/utils/web_reader_tool.py | 16 ++++++++++++++++ api/tasks/clean_dataset_task.py | 11 +++++++++++ api/tasks/clean_document_task.py | 11 +++++++++++ 3 files changed, 38 insertions(+) diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 5807d61b9409a6..3aae31e93a1304 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -356,3 +356,19 @@ def content_digest(element): digest.update(child.encode("utf-8")) digest = digest.hexdigest() return digest + + +def get_image_upload_file_ids(content): + pattern = r"!\[image\]\((http?://.*?(file-preview|image-preview))\)" + matches = re.findall(pattern, content) + image_upload_file_ids = [] + for match in matches: + if match[1] == "file-preview": + content_pattern = r"files/([^/]+)/file-preview" + else: + content_pattern = r"files/([^/]+)/image-preview" + content_match = re.search(content_pattern, match[0]) + if content_match: + image_upload_file_id = content_match.group(1) + image_upload_file_ids.append(image_upload_file_id) + return image_upload_file_ids diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 36249038011747..4d45df4d2a87e8 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -5,6 +5,7 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import ( @@ -67,6 +68,16 @@ def clean_dataset_task( db.session.delete(document) for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) db.session.delete(segment) db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index ae2855aa2ebc4d..54c89450c91419 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -6,6 +6,7 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage from models.dataset import Dataset, DocumentSegment @@ -40,6 +41,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i index_processor.clean(dataset, index_node_ids) for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) db.session.delete(segment) db.session.commit() From f19c18dc1412a9026ed8babacc71016309badf4f Mon Sep 17 00:00:00 2001 From: smyhw Date: Mon, 11 Nov 2024 21:50:32 +0800 Subject: [PATCH 11/53] Fixes `you have not added provider None` (#10501) --- api/core/tools/tool_manager.py | 3 ++- api/services/tools/api_tools_manage_service.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index bf2ad13620b629..d2723df7b2aec0 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -555,6 +555,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: """ get tool provider """ + provider_name = provider provider: ApiToolProvider = ( db.session.query(ApiToolProvider) .filter( @@ -565,7 +566,7 @@ def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: ) if provider is None: - raise ValueError(f"you have not added provider {provider}") + raise ValueError(f"you have not added provider {provider_name}") try: credentials = json.loads(provider.credentials_str) or {} diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 4a938918550ab8..ed0cebf460b47f 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -113,6 +113,8 @@ def create_api_tool_provider( if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema}") + provider_name = provider_name.strip() + # check if the provider exists provider: ApiToolProvider = ( db.session.query(ApiToolProvider) @@ -203,6 +205,7 @@ def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> """ list api tool provider tools """ + provider_name = provider provider: ApiToolProvider = ( db.session.query(ApiToolProvider) .filter( @@ -213,7 +216,7 @@ def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> ) if provider is None: - raise ValueError(f"you have not added provider {provider}") + raise ValueError(f"you have not added provider {provider_name}") controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) @@ -246,6 +249,8 @@ def update_api_tool_provider( if schema_type not in [member.value for member in ApiProviderSchemaType]: raise ValueError(f"invalid schema type {schema}") + provider_name = provider_name.strip() + # check if the provider exists provider: ApiToolProvider = ( db.session.query(ApiToolProvider) From bd4a61addd5729af094b6ea7233b3018bf31d086 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 11 Nov 2024 23:32:40 +0800 Subject: [PATCH 12/53] fix: set default factory for extract_by in ListOperatorNodeData (#10561) --- api/core/workflow/nodes/list_operator/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/core/workflow/nodes/list_operator/entities.py index 6a27de40fd4b13..75df784a922226 100644 --- a/api/core/workflow/nodes/list_operator/entities.py +++ b/api/core/workflow/nodes/list_operator/entities.py @@ -59,4 +59,4 @@ class ListOperatorNodeData(BaseNodeData): filter_by: FilterBy order_by: OrderBy limit: Limit - extract_by: ExtractConfig + extract_by: ExtractConfig = Field(default_factory=ExtractConfig) From 16db2c4e573e4a8d24b70d15688b92f7c39581ae Mon Sep 17 00:00:00 2001 From: fdb02983rhy <91766386+fdb02983rhy@users.noreply.github.com> Date: Tue, 12 Nov 2024 00:53:12 +0900 Subject: [PATCH 13/53] Fix: Set Celery LOG_File only when available, always log to console (#10563) --- api/extensions/ext_celery.py | 6 +++++- api/extensions/ext_logging.py | 12 +++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index c5de7395b8cd29..42012eee8e6b12 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -46,7 +46,6 @@ def __call__(self, *args: object, **kwargs: object) -> object: broker_connection_retry_on_startup=True, worker_log_format=dify_config.LOG_FORMAT, worker_task_log_format=dify_config.LOG_FORMAT, - worker_logfile=dify_config.LOG_FILE, worker_hijack_root_logger=False, timezone=pytz.timezone(dify_config.LOG_TZ), ) @@ -56,6 +55,11 @@ def __call__(self, *args: object, **kwargs: object) -> object: broker_use_ssl=ssl_options, # Add the SSL options to the broker configuration ) + if dify_config.LOG_FILE: + celery_app.conf.update( + worker_logfile=dify_config.LOG_FILE, + ) + celery_app.set_default() app.extensions["celery"] = celery_app diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 56b1d6bd28ba90..a15c73bd71786d 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -9,19 +9,21 @@ def init_app(app: Flask): - log_handlers = None + log_handlers = [] log_file = dify_config.LOG_FILE if log_file: log_dir = os.path.dirname(log_file) os.makedirs(log_dir, exist_ok=True) - log_handlers = [ + log_handlers.append( RotatingFileHandler( filename=log_file, maxBytes=dify_config.LOG_FILE_MAX_SIZE * 1024 * 1024, backupCount=dify_config.LOG_FILE_BACKUP_COUNT, - ), - logging.StreamHandler(sys.stdout), - ] + ) + ) + + # Always add StreamHandler to log to console + log_handlers.append(logging.StreamHandler(sys.stdout)) logging.basicConfig( level=dify_config.LOG_LEVEL, From e63c0e3cbb477c7ddf22b8b80a77459b067baffa Mon Sep 17 00:00:00 2001 From: Hiroshi Fujita Date: Tue, 12 Nov 2024 00:53:43 +0900 Subject: [PATCH 14/53] feat(settings): add chat color theme inverted toggle in settings modal (#10558) --- web/app/components/app/overview/settings/index.tsx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index a8ab456f43cb11..e7cc4148efb099 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -261,6 +261,10 @@ const SettingsModal: FC = ({ onChange={onChange('chatColorTheme')} placeholder='E.g #A020F0' /> +
+

{t(`${prefixSettings}.chatColorThemeInverted`)}

+ setInputInfo({ ...inputInfo, chatColorThemeInverted: v })}> +
} {systemFeatures.enable_web_sso_switch_component &&

{t(`${prefixSettings}.sso.label`)}

From b7238caea51753112ec8765399f240dafa34cd69 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Nov 2024 00:00:27 +0800 Subject: [PATCH 15/53] chore(vanna): update form parameter from 'form' to 'llm' in vanna.yaml (#10548) --- api/core/tools/provider/builtin/vanna/tools/vanna.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml index 3520ba55709775..309681321b1f3f 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.yaml +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.yaml @@ -136,7 +136,7 @@ parameters: human_description: en_US: DDL statements for training data zh_Hans: 用于训练RAG Model的建表语句 - form: form + form: llm - name: question type: string required: false @@ -146,7 +146,7 @@ parameters: human_description: en_US: Question-SQL Pairs zh_Hans: Question-SQL中的问题 - form: form + form: llm - name: sql type: string required: false @@ -156,7 +156,7 @@ parameters: human_description: en_US: SQL queries to your training data zh_Hans: 用于训练RAG Model的SQL语句 - form: form + form: llm - name: memos type: string required: false @@ -166,7 +166,7 @@ parameters: human_description: en_US: Sometimes you may want to add documentation about your business terminology or definitions zh_Hans: 添加更多关于数据库的业务说明 - form: form + form: llm - name: enable_training type: boolean required: false From 16b9665033b38da4975f6e8cd4793a83a675b468 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 12 Nov 2024 00:08:04 +0800 Subject: [PATCH 16/53] refactor(api): improve handling of `tools` field and cleanup variable usage (#10553) --- api/core/tools/entities/api_entities.py | 9 +++++++-- api/services/tools/api_tools_manage_service.py | 15 +++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b1db5594414470..ddb1481276df67 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,6 +1,6 @@ from typing import Literal, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject @@ -32,9 +32,14 @@ class UserToolProvider(BaseModel): original_credentials: Optional[dict] = None is_team_authorization: bool = False allow_delete: bool = True - tools: list[UserTool] | None = None + tools: list[UserTool] = Field(default_factory=list) labels: list[str] | None = None + @field_validator("tools", mode="before") + @classmethod + def convert_none_to_empty_list(cls, v): + return v if v is not None else [] + def to_dict(self) -> dict: # ------------- # overwrite tool parameter types for temp fix diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ed0cebf460b47f..b6b0143facf38c 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -116,7 +116,7 @@ def create_api_tool_provider( provider_name = provider_name.strip() # check if the provider exists - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -201,16 +201,15 @@ def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): return {"schema": schema} @staticmethod - def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]: """ list api tool provider tools """ - provider_name = provider - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, + ApiToolProvider.name == provider_name, ) .first() ) @@ -252,7 +251,7 @@ def update_api_tool_provider( provider_name = provider_name.strip() # check if the provider exists - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -319,7 +318,7 @@ def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ delete tool provider """ - provider: ApiToolProvider = ( + provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, @@ -369,7 +368,7 @@ def test_api_tool_preview( if tool_bundle is None: raise ValueError(f"invalid tool name {tool_name}") - db_provider: ApiToolProvider = ( + db_provider = ( db.session.query(ApiToolProvider) .filter( ApiToolProvider.tenant_id == tenant_id, From e4d175780e2eab98fafcc82aa6a42f73b265cba2 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Tue, 12 Nov 2024 14:38:24 +0800 Subject: [PATCH 17/53] fix: retrieval setting validate (#10454) --- .../configuration/dataset-config/index.tsx | 6 +- .../params-config/config-content.tsx | 2 +- .../dataset-config/params-config/index.tsx | 6 +- .../components/app/configuration/index.tsx | 11 +- .../nodes/knowledge-retrieval/default.ts | 13 +- .../nodes/knowledge-retrieval/types.ts | 2 + .../nodes/knowledge-retrieval/use-config.ts | 24 +++- .../nodes/knowledge-retrieval/utils.ts | 113 +++++++++++++----- 8 files changed, 129 insertions(+), 48 deletions(-) diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 0d9d575c1eb022..78b49f81d00824 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -47,12 +47,16 @@ const DatasetConfig: FC = () => { const { currentModel: currentRerankModel, + currentProvider: currentRerankProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const onRemove = (id: string) => { const filteredDataSets = dataSet.filter(item => item.id !== id) setDataSet(filteredDataSets) - const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) + const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) setDatasetConfigs({ ...(datasetConfigs as any), ...retrievalConfig, diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 5bd748382ed905..dcb2b1a3fd5e46 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -172,7 +172,7 @@ const ConfigContent: FC = ({ return false return datasetConfigs.reranking_enable - }, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) + }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid]) const handleDisabledSwitchClick = useCallback(() => { if (!currentRerankModel && !showRerankModel) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 94920fbd39ce21..7f7a4799d12217 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -43,6 +43,7 @@ const ParamsConfig = ({ const { defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelValid, + currentProvider: rerankDefaultProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const isValid = () => { @@ -91,7 +92,10 @@ const ParamsConfig = ({ reranking_mode: restConfigs.reranking_mode, weights: restConfigs.weights, reranking_enable: restConfigs.reranking_enable, - }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) + }, selectedDatasets, selectedDatasets, { + provider: rerankDefaultProvider?.provider, + model: isRerankDefaultModelValid?.model, + }) setTempDataSetConfigs({ ...retrievalConfig, diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 2bb11a870c2746..b5b7e98d4337e6 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -226,6 +226,7 @@ const Configuration: FC = () => { const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) const { currentModel: currentRerankModel, + currentProvider: currentRerankProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const handleSelect = (data: DataSet[]) => { if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { @@ -279,7 +280,10 @@ const Configuration: FC = () => { reranking_mode: restConfigs.reranking_mode, weights: restConfigs.weights, reranking_enable: restConfigs.reranking_enable, - }, newDatasets, dataSets, !!currentRerankModel) + }, newDatasets, dataSets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) setDatasetConfigs({ ...retrievalConfig, @@ -620,7 +624,10 @@ const Configuration: FC = () => { syncToPublishedConfig(config) setPublishedConfig(config) - const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) + const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) setDatasetConfigs({ retrieval_model: RETRIEVE_TYPE.multiWay, ...modelConfig.dataset_configs, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/default.ts b/web/app/components/workflow/nodes/knowledge-retrieval/default.ts index 03591dd527aaa1..e902d29b963e80 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/default.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/default.ts @@ -1,7 +1,7 @@ import { BlockEnum } from '../../types' import type { NodeDefault } from '../../types' import type { KnowledgeRetrievalNodeType } from './types' -import { RerankingModeEnum } from '@/models/datasets' +import { checkoutRerankModelConfigedInRetrievalSettings } from './utils' import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' import { DATASET_DEFAULT } from '@/config' import { RETRIEVE_TYPE } from '@/types/app' @@ -36,12 +36,17 @@ const nodeDefault: NodeDefault = { if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0)) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) }) - if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) - if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') }) + const { _datasets, multiple_retrieval_config, retrieval_mode } = payload + if (retrieval_mode === RETRIEVE_TYPE.multiWay) { + const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config) + + if (!errorMessages && !checked) + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) + } + return { isValid: !errorMessages, errorMessage: errorMessages, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts index da9373962b39fe..1b85bfc0b51b46 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts @@ -1,6 +1,7 @@ import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types' import type { RETRIEVE_TYPE } from '@/types/app' import type { + DataSet, RerankingModeEnum, } from '@/models/datasets' @@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & { retrieval_mode: RETRIEVE_TYPE multiple_retrieval_config?: MultipleRetrievalConfig single_retrieval_config?: SingleRetrievalConfig + _datasets?: DataSet[] } diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index 288a718aa25e87..e90fe2c2ff26b2 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const { currentModel: currentRerankModel, + currentProvider: currentRerankProvider, } = useCurrentProviderAndModel( rerankModelList, rerankDefaultModel @@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { draft.retrieval_mode = newMode if (newMode === RETRIEVE_TYPE.multiWay) { const multipleRetrievalConfig = draft.multiple_retrieval_config - draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) } else { const hasSetModel = draft.single_retrieval_config?.model?.provider @@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } }) setInputs(newInputs) - }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) + }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const newInputs = produce(inputs, (draft) => { - draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) }) setInputs(newInputs) - }, [inputs, setInputs, selectedDatasets, currentRerankModel]) + }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) // datasets useEffect(() => { @@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } const newInputs = produce(inputs, (draft) => { draft.dataset_ids = datasetIds + draft._datasets = selectedDatasets }) setInputs(newInputs) })() @@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } = getSelectedDatasetsMode(newDatasets) const newInputs = produce(inputs, (draft) => { draft.dataset_ids = newDatasets.map(d => d.id) + draft._datasets = newDatasets if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { const multipleRetrievalConfig = draft.multiple_retrieval_config - draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) } }) setInputs(newInputs) @@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { || allExternal ) setRerankModelOpen(true) - }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) + }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider]) const filterVar = useCallback((varPayload: Var) => { return varPayload.type === VarType.string diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index fd3d3ebab9dd9d..e9da9accccf783 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = ( multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[], originalDatasets: DataSet[], - isValidRerankModel?: boolean, + validRerankModel?: { provider?: string; model?: string }, ) => { const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 + const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model const { allHighQuality, @@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = ( reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } - if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) - result.reranking_mode = RerankingModeEnum.RerankingModel - - if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal) - result.reranking_mode = RerankingModeEnum.WeightedScore - - if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { - if (!isValidRerankModel) - result.reranking_mode = RerankingModeEnum.WeightedScore - else - result.reranking_mode = RerankingModeEnum.RerankingModel + if (!rerankModelIsValid) + result.reranking_model = undefined + const setDefaultWeights = () => { result.weights = { vector_setting: { vector_weight: allHighQualityVectorSearch @@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = ( } } - if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { - if (!isValidRerankModel) - result.reranking_mode = RerankingModeEnum.WeightedScore - else + if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) { + result.reranking_mode = RerankingModeEnum.RerankingModel + + if (rerankModelIsValid) { result.reranking_mode = RerankingModeEnum.RerankingModel + result.reranking_model = { + provider: validRerankModel?.provider || '', + model: validRerankModel?.model || '', + } + } + else { + result.reranking_model = undefined + } + } - result.weights = { - vector_setting: { - vector_weight: allHighQualityVectorSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic - : allHighQualityFullTextSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic - : DEFAULT_WEIGHTED_SCORE.other.semantic, - embedding_provider_name: selectedDatasets[0].embedding_model_provider, - embedding_model_name: selectedDatasets[0].embedding_model, - }, - keyword_setting: { - keyword_weight: allHighQualityVectorSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword - : allHighQualityFullTextSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword - : DEFAULT_WEIGHTED_SCORE.other.keyword, - }, + if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { + if (!reranking_mode) { + if (validRerankModel?.provider && validRerankModel?.model) { + result.reranking_mode = RerankingModeEnum.RerankingModel + result.reranking_model = { + provider: validRerankModel.provider, + model: validRerankModel.model, + } + } + else { + result.reranking_mode = RerankingModeEnum.WeightedScore + setDefaultWeights() + } + } + + if (reranking_mode === RerankingModeEnum.WeightedScore && !weights) + setDefaultWeights() + + if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) { + if (rerankModelIsValid) { + result.reranking_mode = RerankingModeEnum.RerankingModel + result.reranking_model = { + provider: validRerankModel.provider || '', + model: validRerankModel.model || '', + } + } + else { + setDefaultWeights() + } + } + + if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { + result.reranking_mode = RerankingModeEnum.WeightedScore + setDefaultWeights() } } return result } + +export const checkoutRerankModelConfigedInRetrievalSettings = ( + datasets: DataSet[], + multipleRetrievalConfig?: MultipleRetrievalConfig, +) => { + if (!multipleRetrievalConfig) + return true + + const { + allEconomic, + allExternal, + } = getSelectedDatasetsMode(datasets) + + const { + reranking_enable, + reranking_mode, + reranking_model, + } = multipleRetrievalConfig + + if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) { + if ((allEconomic || allExternal) && !reranking_enable) + return true + + return false + } + + return true +} From 40c5e6d67a2da47e71136e9c3a860222cda002f8 Mon Sep 17 00:00:00 2001 From: NFish Date: Tue, 12 Nov 2024 15:18:19 +0800 Subject: [PATCH 18/53] fix: Page may lock if user close the page when refresh access_token (#10550) --- web/service/refresh-token.ts | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts index 8bd22150414cea..b193779629e26e 100644 --- a/web/service/refresh-token.ts +++ b/web/service/refresh-token.ts @@ -1,11 +1,13 @@ import { apiPrefix } from '@/config' import { fetchWithRetry } from '@/utils' +const LOCAL_STORAGE_KEY = 'is_other_tab_refreshing' + let isRefreshing = false function waitUntilTokenRefreshed() { return new Promise((resolve, reject) => { function _check() { - const isRefreshingSign = localStorage.getItem('is_refreshing') + const isRefreshingSign = globalThis.localStorage.getItem(LOCAL_STORAGE_KEY) if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { setTimeout(() => { _check() @@ -22,13 +24,14 @@ function waitUntilTokenRefreshed() { // only one request can send async function getNewAccessToken(): Promise { try { - const isRefreshingSign = localStorage.getItem('is_refreshing') + const isRefreshingSign = globalThis.localStorage.getItem(LOCAL_STORAGE_KEY) if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { await waitUntilTokenRefreshed() } else { - globalThis.localStorage.setItem('is_refreshing', '1') isRefreshing = true + globalThis.localStorage.setItem(LOCAL_STORAGE_KEY, '1') + globalThis.addEventListener('beforeunload', releaseRefreshLock) const refresh_token = globalThis.localStorage.getItem('refresh_token') // Do not use baseFetch to refresh tokens. @@ -61,15 +64,21 @@ async function getNewAccessToken(): Promise { return Promise.reject(error) } finally { + releaseRefreshLock() + } +} + +function releaseRefreshLock() { + if (isRefreshing) { isRefreshing = false - globalThis.localStorage.removeItem('is_refreshing') + globalThis.localStorage.removeItem(LOCAL_STORAGE_KEY) + globalThis.removeEventListener('beforeunload', releaseRefreshLock) } } export async function refreshAccessTokenOrRelogin(timeout: number) { return Promise.race([new Promise((resolve, reject) => setTimeout(() => { - isRefreshing = false - globalThis.localStorage.removeItem('is_refreshing') + releaseRefreshLock() reject(new Error('request timeout')) }, timeout)), getNewAccessToken()]) } From b77628c45863e72ff8fe0ef5a24538a9c0e69574 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Tue, 12 Nov 2024 15:35:12 +0800 Subject: [PATCH 19/53] fix: text-generation webapp file form (#10578) --- .../share/text-generation/index.tsx | 2 ++ .../share/text-generation/run-once/index.tsx | 21 ++++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/web/app/components/share/text-generation/index.tsx b/web/app/components/share/text-generation/index.tsx index 0860560e7c0fa5..b853100b69dab3 100644 --- a/web/app/components/share/text-generation/index.tsx +++ b/web/app/components/share/text-generation/index.tsx @@ -94,6 +94,7 @@ const TextGeneration: FC = ({ const [isCallBatchAPI, setIsCallBatchAPI] = useState(false) const isInBatchTab = currentTab === 'batch' const [inputs, setInputs] = useState>({}) + const inputsRef = useRef(inputs) const [appId, setAppId] = useState('') const [siteInfo, setSiteInfo] = useState(null) const [canReplaceLogo, setCanReplaceLogo] = useState(false) @@ -604,6 +605,7 @@ const TextGeneration: FC = ({ + inputsRef: React.MutableRefObject> onInputsChange: (inputs: Record) => void onSend: () => void visionConfig: VisionSettings @@ -27,6 +28,7 @@ export type IRunOnceProps = { const RunOnce: FC = ({ promptConfig, inputs, + inputsRef, onInputsChange, onSend, visionConfig, @@ -47,6 +49,11 @@ const RunOnce: FC = ({ onSend() } + const handleInputsChange = useCallback((newInputs: Record) => { + onInputsChange(newInputs) + inputsRef.current = newInputs + }, [onInputsChange, inputsRef]) + return (
@@ -60,7 +67,7 @@ const RunOnce: FC = ({