Skip to content

Commit

Permalink
fix(ops_tracing): enhance error handle in celery tasks. (langgenius#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhouhaoJiang authored and JunXu01 committed Nov 9, 2024
1 parent b78f3c6 commit 5403817
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 12 deletions.
4 changes: 4 additions & 0 deletions api/core/ops/entities/config_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,7 @@ def set_value(cls, v, info: ValidationInfo):
raise ValueError("endpoint must start with https://")

return v


OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
11 changes: 11 additions & 0 deletions api/core/ops/entities/trace_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def ensure_type(cls, v):
return v
return ""

class Config:
json_encoders = {
datetime: lambda v: v.isoformat(),
}


class WorkflowTraceInfo(BaseTraceInfo):
workflow_data: Any
Expand Down Expand Up @@ -100,6 +105,12 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str


class TaskData(BaseModel):
app_id: str
trace_info_type: str
trace_info: Any


trace_info_info_map = {
"WorkflowTraceInfo": WorkflowTraceInfo,
"MessageTraceInfo": MessageTraceInfo,
Expand Down
20 changes: 15 additions & 5 deletions api/core/ops/ops_trace_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import time
from datetime import timedelta
from typing import Any, Optional, Union
from uuid import UUID
from uuid import UUID, uuid4

from flask import current_app

from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
TracingProviderEnum,
Expand All @@ -22,6 +23,7 @@
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
Expand All @@ -30,6 +32,7 @@
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks
Expand Down Expand Up @@ -740,10 +743,17 @@ def start_timer(self):
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
file_id = uuid4().hex
trace_info = task.execute()
task_data = {
task_data = TaskData(
app_id=task.app_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"trace_info_type": type(trace_info).__name__,
"trace_info": trace_info.model_dump() if trace_info else {},
}
process_trace_tasks.delay(task_data)
process_trace_tasks.delay(file_info)
24 changes: 17 additions & 7 deletions api/tasks/ops_trace_task.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import json
import logging
import time

from celery import shared_task
from flask import current_app

from core.ops.entities.config_entity import OPS_FILE_PATH, OPS_TRACE_FAILED_KEY
from core.ops.entities.trace_entity import trace_info_info_map
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.model import Message
from models.workflow import WorkflowRun


@shared_task(queue="ops_trace")
def process_trace_tasks(tasks_data):
def process_trace_tasks(file_info):
"""
Async process trace tasks
:param tasks_data: List of dictionaries containing task data
Expand All @@ -20,9 +23,12 @@ def process_trace_tasks(tasks_data):
"""
from core.ops.ops_trace_manager import OpsTraceManager

trace_info = tasks_data.get("trace_info")
app_id = tasks_data.get("app_id")
trace_info_type = tasks_data.get("trace_info_type")
app_id = file_info.get("app_id")
file_id = file_info.get("file_id")
file_path = f"{OPS_FILE_PATH}{app_id}/{file_id}.json"
file_data = json.loads(storage.load(file_path))
trace_info = file_data.get("trace_info")
trace_info_type = file_data.get("trace_info_type")
trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)

if trace_info.get("message_data"):
Expand All @@ -39,6 +45,10 @@ def process_trace_tasks(tasks_data):
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
end_at = time.perf_counter()
logging.info(f"Processing trace tasks success, app_id: {app_id}")
except Exception:
logging.exception("Processing trace tasks failed")
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logging.info(f"Processing trace tasks failed, app_id: {app_id}")
finally:
storage.delete(file_path)

0 comments on commit 5403817

Please sign in to comment.