From 425174e82f9046c10386433c39261176f99f89e5 Mon Sep 17 00:00:00 2001 From: Joe <79627742+ZhouhaoJiang@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:22:16 +0800 Subject: [PATCH] feat: update ops trace (#7102) --- .../app/apps/advanced_chat/app_generator.py | 3 +- .../easy_ui_based_generate_task_pipeline.py | 3 +- .../task_pipeline/workflow_cycle_manage.py | 3 +- .../agent_tool_callback_handler.py | 3 +- api/core/llm_generator/llm_generator.py | 3 +- api/core/moderation/input_moderation.py | 3 +- api/core/ops/entities/trace_entity.py | 14 ++- .../entities/langfuse_trace_entity.py | 42 ++++--- api/core/ops/langfuse_trace/langfuse_trace.py | 103 ++++++++---------- .../ops/langsmith_trace/langsmith_trace.py | 51 ++++----- api/core/ops/ops_trace_manager.py | 13 +-- api/core/rag/retrieval/dataset_retrieval.py | 3 +- api/services/message_service.py | 3 +- 13 files changed, 118 insertions(+), 129 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 0141dbec58de6e..e854ea18b099b8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -89,7 +89,8 @@ def generate( ) # get tracing instance - trace_manager = TraceQueueManager(app_id=app_model.id) + user_id = user.id if isinstance(user, Account) else user.session_id + trace_manager = TraceQueueManager(app_model.id, user_id) if invoke_from == InvokeFrom.DEBUGGER: # always enable retriever resource in debugger mode diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index c9644c7d4cf1c0..8d91a507a9e8ee 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -48,7 +48,8 @@ ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from events.message_event import message_was_created diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 69951e9371f2f4..4935c43ac437e4 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -22,7 +22,8 @@ from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage from core.file.file_obj import FileVar from core.model_runtime.utils.encoders import jsonable_encoder -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType from core.workflow.nodes.tool.entities import ToolNodeData diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 03f8244bab212e..578996574739a8 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,7 +4,8 @@ from pydantic import BaseModel -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage _TEXT_COLOR_MAPPING = { diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 0b5029460a03ff..8c13b4a45cbe6c 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -14,7 +14,8 @@ from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser diff --git a/api/core/moderation/input_moderation.py b/api/core/moderation/input_moderation.py index c5dd88fb2458b1..8157b300b1f6c7 100644 --- a/api/core/moderation/input_moderation.py +++ b/api/core/moderation/input_moderation.py @@ -4,7 +4,8 @@ from core.app.app_config.entities import AppConfig from core.moderation.base import ModerationAction, ModerationException from core.moderation.factory import ModerationFactory -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time logger = logging.getLogger(__name__) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index db7e0806ee8d74..a1443f0691233b 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import Enum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -105,4 +106,15 @@ class GenerateNameTraceInfo(BaseTraceInfo): 'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo, 'ToolTraceInfo': ToolTraceInfo, 'GenerateNameTraceInfo': GenerateNameTraceInfo, -} \ No newline at end of file +} + + +class TraceTaskName(str, Enum): + CONVERSATION_TRACE = 'conversation' + WORKFLOW_TRACE = 'workflow' + MESSAGE_TRACE = 'message' + MODERATION_TRACE = 'moderation' + SUGGESTED_QUESTION_TRACE = 'suggested_question' + DATASET_RETRIEVAL_TRACE = 'dataset_retrieval' + TOOL_TRACE = 'tool' + GENERATE_NAME_TRACE = 'generate_conversation_name' diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index b90c05f4cbc605..af7661f0afc9c8 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -50,10 +50,11 @@ class LangfuseTrace(BaseModel): """ Langfuse trace model """ + id: Optional[str] = Field( default=None, description="The id of the trace can be set, defaults to a random id. Used to link traces to external systems " - "or when creating a distributed trace. Traces are upserted on id.", + "or when creating a distributed trace. Traces are upserted on id.", ) name: Optional[str] = Field( default=None, @@ -68,7 +69,7 @@ class LangfuseTrace(BaseModel): metadata: Optional[dict[str, Any]] = Field( default=None, description="Additional metadata of the trace. Can be any JSON object. Metadata is merged when being updated " - "via the API.", + "via the API.", ) user_id: Optional[str] = Field( default=None, @@ -81,22 +82,22 @@ class LangfuseTrace(BaseModel): version: Optional[str] = Field( default=None, description="The version of the trace type. Used to understand how changes to the trace type affect metrics. " - "Useful in debugging.", + "Useful in debugging.", ) release: Optional[str] = Field( default=None, description="The release identifier of the current deployment. Used to understand how changes of different " - "deployments affect metrics. Useful in debugging.", + "deployments affect metrics. Useful in debugging.", ) tags: Optional[list[str]] = Field( default=None, description="Tags are used to categorize or label traces. Traces can be filtered by tags in the UI and GET " - "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", + "API. Tags can also be changed in the UI. Tags are merged and never deleted via the API.", ) public: Optional[bool] = Field( default=None, description="You can make a trace public to share it via a public link. This allows others to view the trace " - "without needing to log in or be members of your Langfuse project.", + "without needing to log in or be members of your Langfuse project.", ) @field_validator("input", "output") @@ -109,6 +110,7 @@ class LangfuseSpan(BaseModel): """ Langfuse span model """ + id: Optional[str] = Field( default=None, description="The id of the span can be set, otherwise a random id is generated. Spans are upserted on id.", @@ -140,17 +142,17 @@ class LangfuseSpan(BaseModel): metadata: Optional[dict[str, Any]] = Field( default=None, description="Additional metadata of the span. Can be any JSON object. Metadata is merged when being updated " - "via the API.", + "via the API.", ) level: Optional[str] = Field( default=None, description="The level of the span. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering of " - "traces with elevated error levels and for highlighting in the UI.", + "traces with elevated error levels and for highlighting in the UI.", ) status_message: Optional[str] = Field( default=None, description="The status message of the span. Additional field for context of the event. E.g. the error " - "message of an error event.", + "message of an error event.", ) input: Optional[Union[str, dict[str, Any], list, None]] = Field( default=None, description="The input of the span. Can be any JSON object." @@ -161,7 +163,7 @@ class LangfuseSpan(BaseModel): version: Optional[str] = Field( default=None, description="The version of the span type. Used to understand how changes to the span type affect metrics. " - "Useful in debugging.", + "Useful in debugging.", ) parent_observation_id: Optional[str] = Field( default=None, @@ -185,10 +187,9 @@ class UnitEnum(str, Enum): class GenerationUsage(BaseModel): promptTokens: Optional[int] = None completionTokens: Optional[int] = None - totalTokens: Optional[int] = None + total: Optional[int] = None input: Optional[int] = None output: Optional[int] = None - total: Optional[int] = None unit: Optional[UnitEnum] = None inputCost: Optional[float] = None outputCost: Optional[float] = None @@ -224,15 +225,13 @@ class LangfuseGeneration(BaseModel): completion_start_time: Optional[datetime | str] = Field( default=None, description="The time at which the completion started (streaming). Set it to get latency analytics broken " - "down into time until completion started and completion duration.", + "down into time until completion started and completion duration.", ) end_time: Optional[datetime | str] = Field( default=None, description="The time at which the generation ended. Automatically set by generation.end().", ) - model: Optional[str] = Field( - default=None, description="The name of the model used for the generation." - ) + model: Optional[str] = Field(default=None, description="The name of the model used for the generation.") model_parameters: Optional[dict[str, Any]] = Field( default=None, description="The parameters of the model used for the generation; can be any key-value pairs.", @@ -248,27 +247,27 @@ class LangfuseGeneration(BaseModel): usage: Optional[GenerationUsage] = Field( default=None, description="The usage object supports the OpenAi structure with tokens and a more generic version with " - "detailed costs and units.", + "detailed costs and units.", ) metadata: Optional[dict[str, Any]] = Field( default=None, description="Additional metadata of the generation. Can be any JSON object. Metadata is merged when being " - "updated via the API.", + "updated via the API.", ) level: Optional[LevelEnum] = Field( default=None, description="The level of the generation. Can be DEBUG, DEFAULT, WARNING or ERROR. Used for sorting/filtering " - "of traces with elevated error levels and for highlighting in the UI.", + "of traces with elevated error levels and for highlighting in the UI.", ) status_message: Optional[str] = Field( default=None, description="The status message of the generation. Additional field for context of the event. E.g. the error " - "message of an error event.", + "message of an error event.", ) version: Optional[str] = Field( default=None, description="The version of the generation type. Used to understand how changes to the span type affect " - "metrics. Useful in debugging.", + "metrics. Useful in debugging.", ) model_config = ConfigDict(protected_namespaces=()) @@ -277,4 +276,3 @@ class LangfuseGeneration(BaseModel): def ensure_dict(cls, v, info: ValidationInfo): field_name = info.field_name return validate_input_output(v, field_name) - diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index c520fe2aa9c089..698398e0cb8c16 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -16,6 +16,7 @@ ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, WorkflowTraceInfo, ) from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( @@ -68,9 +69,9 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): user_id = trace_info.metadata.get("user_id") if trace_info.message_id: trace_id = trace_info.message_id - name = f"message_{trace_info.message_id}" + name = TraceTaskName.MESSAGE_TRACE.value trace_data = LangfuseTrace( - id=trace_info.message_id, + id=trace_id, user_id=user_id, name=name, input=trace_info.workflow_run_inputs, @@ -78,11 +79,13 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): metadata=trace_info.metadata, session_id=trace_info.conversation_id, tags=["message", "workflow"], + created_at=trace_info.start_time, + updated_at=trace_info.end_time, ) self.add_trace(langfuse_trace_data=trace_data) workflow_span_data = LangfuseSpan( - id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, - name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}", + id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), + name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, trace_id=trace_id, @@ -97,7 +100,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): trace_data = LangfuseTrace( id=trace_id, user_id=user_id, - name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}", + name=TraceTaskName.WORKFLOW_TRACE.value, input=trace_info.workflow_run_inputs, output=trace_info.workflow_run_outputs, metadata=trace_info.metadata, @@ -134,14 +137,12 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): node_type = node_execution.node_type status = node_execution.status if node_type == "llm": - inputs = json.loads(node_execution.process_data).get( - "prompts", {} - ) if node_execution.process_data else {} + inputs = ( + json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} + ) else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = ( - json.loads(node_execution.outputs) if node_execution.outputs else {} - ) + outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} created_at = node_execution.created_at if node_execution.created_at else datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) @@ -163,28 +164,30 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): if trace_info.message_id: span_data = LangfuseSpan( id=node_execution_id, - name=f"{node_name}_{node_execution_id}", + name=node_type, input=inputs, output=outputs, trace_id=trace_id, start_time=created_at, end_time=finished_at, metadata=metadata, - level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error if trace_info.error else "", - parent_observation_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + parent_observation_id=( + trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id + ), ) else: span_data = LangfuseSpan( id=node_execution_id, - name=f"{node_name}_{node_execution_id}", + name=node_type, input=inputs, output=outputs, trace_id=trace_id, start_time=created_at, end_time=finished_at, metadata=metadata, - level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error if trace_info.error else "", ) @@ -195,11 +198,11 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): total_token = metadata.get("total_tokens", 0) # add generation generation_usage = GenerationUsage( - totalTokens=total_token, + total=total_token, ) node_generation_data = LangfuseGeneration( - name=f"generation_{node_execution_id}", + name="llm", trace_id=trace_id, parent_observation_id=node_execution_id, start_time=created_at, @@ -207,16 +210,14 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): input=inputs, output=outputs, metadata=metadata, - level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), status_message=trace_info.error if trace_info.error else "", usage=generation_usage, ) self.add_generation(langfuse_generation_data=node_generation_data) - def message_trace( - self, trace_info: MessageTraceInfo, **kwargs - ): + def message_trace(self, trace_info: MessageTraceInfo, **kwargs): # get message file data file_list = trace_info.file_list metadata = trace_info.metadata @@ -225,9 +226,9 @@ def message_trace( user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser = db.session.query(EndUser).filter( - EndUser.id == message_data.from_end_user_id - ).first() + end_user_data: EndUser = ( + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + ) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id @@ -235,7 +236,7 @@ def message_trace( trace_data = LangfuseTrace( id=message_id, user_id=user_id, - name=f"message_{message_id}", + name=TraceTaskName.MESSAGE_TRACE.value, input={ "message": trace_info.inputs, "files": file_list, @@ -258,7 +259,6 @@ def message_trace( # start add span generation_usage = GenerationUsage( - totalTokens=trace_info.total_tokens, input=trace_info.message_tokens, output=trace_info.answer_tokens, total=trace_info.total_tokens, @@ -267,7 +267,7 @@ def message_trace( ) langfuse_generation_data = LangfuseGeneration( - name=f"generation_{message_id}", + name="llm", trace_id=message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, @@ -275,7 +275,7 @@ def message_trace( input=trace_info.inputs, output=message_data.answer, metadata=metadata, - level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), status_message=message_data.error if message_data.error else "", usage=generation_usage, ) @@ -284,7 +284,7 @@ def message_trace( def moderation_trace(self, trace_info: ModerationTraceInfo): span_data = LangfuseSpan( - name="moderation", + name=TraceTaskName.MODERATION_TRACE.value, input=trace_info.inputs, output={ "action": trace_info.action, @@ -303,22 +303,21 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data generation_usage = GenerationUsage( - totalTokens=len(str(trace_info.suggested_question)), + total=len(str(trace_info.suggested_question)), input=len(trace_info.inputs), output=len(trace_info.suggested_question), - total=len(trace_info.suggested_question), unit=UnitEnum.CHARACTERS, ) generation_data = LangfuseGeneration( - name="suggested_question", + name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, input=trace_info.inputs, output=str(trace_info.suggested_question), trace_id=trace_info.message_id, start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, - level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), status_message=message_data.error if message_data.error else "", usage=generation_usage, ) @@ -327,7 +326,7 @@ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): dataset_retrieval_span_data = LangfuseSpan( - name="dataset_retrieval", + name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, input=trace_info.inputs, output={"documents": trace_info.documents}, trace_id=trace_info.message_id, @@ -347,7 +346,7 @@ def tool_trace(self, trace_info: ToolTraceInfo): start_time=trace_info.start_time, end_time=trace_info.end_time, metadata=trace_info.metadata, - level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR, + level=(LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR), status_message=trace_info.error, ) @@ -355,7 +354,7 @@ def tool_trace(self, trace_info: ToolTraceInfo): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_generation_trace_data = LangfuseTrace( - name="generate_name", + name=TraceTaskName.GENERATE_NAME_TRACE.value, input=trace_info.inputs, output=trace_info.outputs, user_id=trace_info.tenant_id, @@ -366,7 +365,7 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): self.add_trace(langfuse_trace_data=name_generation_trace_data) name_generation_span_data = LangfuseSpan( - name="generate_name", + name=TraceTaskName.GENERATE_NAME_TRACE.value, input=trace_info.inputs, output=trace_info.outputs, trace_id=trace_info.conversation_id, @@ -377,9 +376,7 @@ def generate_name_trace(self, trace_info: GenerateNameTraceInfo): self.add_span(langfuse_span_data=name_generation_span_data) def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): - format_trace_data = ( - filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} - ) + format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {} try: self.langfuse_client.trace(**format_trace_data) logger.debug("LangFuse Trace created successfully") @@ -387,9 +384,7 @@ def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None): raise ValueError(f"LangFuse Failed to create trace: {str(e)}") def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): - format_span_data = ( - filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} - ) + format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} try: self.langfuse_client.span(**format_span_data) logger.debug("LangFuse Span created successfully") @@ -397,19 +392,13 @@ def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None): raise ValueError(f"LangFuse Failed to create span: {str(e)}") def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None): - format_span_data = ( - filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} - ) + format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {} span.end(**format_span_data) - def add_generation( - self, langfuse_generation_data: Optional[LangfuseGeneration] = None - ): + def add_generation(self, langfuse_generation_data: Optional[LangfuseGeneration] = None): format_generation_data = ( - filter_none_values(langfuse_generation_data.model_dump()) - if langfuse_generation_data - else {} + filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) try: self.langfuse_client.generation(**format_generation_data) @@ -417,13 +406,9 @@ def add_generation( except Exception as e: raise ValueError(f"LangFuse Failed to create generation: {str(e)}") - def update_generation( - self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None - ): + def update_generation(self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None): format_generation_data = ( - filter_none_values(langfuse_generation_data.model_dump()) - if langfuse_generation_data - else {} + filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {} ) generation.end(**format_generation_data) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 0ce91db335cd94..fde8a06c612dd9 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -15,6 +15,7 @@ ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, WorkflowTraceInfo, ) from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( @@ -39,9 +40,7 @@ def __init__( self.langsmith_key = langsmith_config.api_key self.project_name = langsmith_config.project self.project_id = None - self.langsmith_client = Client( - api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint - ) + self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint) self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") def trace(self, trace_info: BaseTraceInfo): @@ -64,7 +63,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): if trace_info.message_id: message_run = LangSmithRunModel( id=trace_info.message_id, - name=f"message_{trace_info.message_id}", + name=TraceTaskName.MESSAGE_TRACE.value, inputs=trace_info.workflow_run_inputs, outputs=trace_info.workflow_run_outputs, run_type=LangSmithRunType.chain, @@ -73,8 +72,8 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): extra={ "metadata": trace_info.metadata, }, - tags=["message"], - error=trace_info.error + tags=["message", "workflow"], + error=trace_info.error, ) self.add_run(message_run) @@ -82,7 +81,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): file_list=trace_info.file_list, total_tokens=trace_info.total_tokens, id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, - name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}", + name=TraceTaskName.WORKFLOW_TRACE.value, inputs=trace_info.workflow_run_inputs, run_type=LangSmithRunType.tool, start_time=trace_info.workflow_data.created_at, @@ -126,22 +125,18 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): node_type = node_execution.node_type status = node_execution.status if node_type == "llm": - inputs = json.loads(node_execution.process_data).get( - "prompts", {} - ) if node_execution.process_data else {} + inputs = ( + json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} + ) else: inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = ( - json.loads(node_execution.outputs) if node_execution.outputs else {} - ) + outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} created_at = node_execution.created_at if node_execution.created_at else datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = ( - json.loads(node_execution.execution_metadata) - if node_execution.execution_metadata - else {} + json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} ) node_total_tokens = execution_metadata.get("total_tokens", 0) @@ -168,7 +163,7 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): langsmith_run = LangSmithRunModel( total_tokens=node_total_tokens, - name=f"{node_name}_{node_execution_id}", + name=node_type, inputs=inputs, run_type=run_type, start_time=created_at, @@ -178,7 +173,9 @@ def workflow_trace(self, trace_info: WorkflowTraceInfo): extra={ "metadata": metadata, }, - parent_run_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, + parent_run_id=trace_info.workflow_app_log_id + if trace_info.workflow_app_log_id + else trace_info.workflow_run_id, tags=["node_execution"], ) @@ -198,9 +195,9 @@ def message_trace(self, trace_info: MessageTraceInfo): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser = db.session.query(EndUser).filter( - EndUser.id == message_data.from_end_user_id - ).first() + end_user_data: EndUser = ( + db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() + ) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id @@ -210,7 +207,7 @@ def message_trace(self, trace_info: MessageTraceInfo): output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, id=message_id, - name=f"message_{message_id}", + name=TraceTaskName.MESSAGE_TRACE.value, inputs=trace_info.inputs, run_type=LangSmithRunType.chain, start_time=trace_info.start_time, @@ -230,7 +227,7 @@ def message_trace(self, trace_info: MessageTraceInfo): input_tokens=trace_info.message_tokens, output_tokens=trace_info.answer_tokens, total_tokens=trace_info.total_tokens, - name=f"llm_{message_id}", + name="llm", inputs=trace_info.inputs, run_type=LangSmithRunType.llm, start_time=trace_info.start_time, @@ -248,7 +245,7 @@ def message_trace(self, trace_info: MessageTraceInfo): def moderation_trace(self, trace_info: ModerationTraceInfo): langsmith_run = LangSmithRunModel( - name="moderation", + name=TraceTaskName.MODERATION_TRACE.value, inputs=trace_info.inputs, outputs={ "action": trace_info.action, @@ -271,7 +268,7 @@ def moderation_trace(self, trace_info: ModerationTraceInfo): def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): message_data = trace_info.message_data suggested_question_run = LangSmithRunModel( - name="suggested_question", + name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.suggested_question, run_type=LangSmithRunType.tool, @@ -288,7 +285,7 @@ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo): def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo): dataset_retrieval_run = LangSmithRunModel( - name="dataset_retrieval", + name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value, inputs=trace_info.inputs, outputs={"documents": trace_info.documents}, run_type=LangSmithRunType.retriever, @@ -323,7 +320,7 @@ def tool_trace(self, trace_info: ToolTraceInfo): def generate_name_trace(self, trace_info: GenerateNameTraceInfo): name_run = LangSmithRunModel( - name="generate_name", + name=TraceTaskName.GENERATE_NAME_TRACE.value, inputs=trace_info.inputs, outputs=trace_info.outputs, run_type=LangSmithRunType.tool, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 61279e3f5f29c5..068b490ec887bd 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -5,7 +5,6 @@ import threading import time from datetime import timedelta -from enum import Enum from typing import Any, Optional, Union from uuid import UUID @@ -24,6 +23,7 @@ ModerationTraceInfo, SuggestedQuestionTraceInfo, ToolTraceInfo, + TraceTaskName, WorkflowTraceInfo, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -253,17 +253,6 @@ def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str) return trace_instance(tracing_config).api_check() -class TraceTaskName(str, Enum): - CONVERSATION_TRACE = 'conversation_trace' - WORKFLOW_TRACE = 'workflow_trace' - MESSAGE_TRACE = 'message_trace' - MODERATION_TRACE = 'moderation_trace' - SUGGESTED_QUESTION_TRACE = 'suggested_question_trace' - DATASET_RETRIEVAL_TRACE = 'dataset_retrieval_trace' - TOOL_TRACE = 'tool_trace' - GENERATE_NAME_TRACE = 'generate_name_trace' - - class TraceTask: def __init__( self, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index dc1f1ada115453..e9453647969a97 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -14,7 +14,8 @@ from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler diff --git a/api/services/message_service.py b/api/services/message_service.py index e310d70d5314e7..491a914c776387 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -7,7 +7,8 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination