Skip to content

Commit

Permalink
Merge branch 'fix/chore-fix' into dev/plugin-deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Dec 4, 2024
2 parents 44989ae + 0af9c4f commit 3d3a429
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 24 deletions.
3 changes: 2 additions & 1 deletion api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from libs.external_api import ExternalApi

from .app.app_import import AppImportApi, AppImportConfirmApi
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi

Expand All @@ -21,6 +21,7 @@
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")

# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
Expand Down
21 changes: 20 additions & 1 deletion api/controllers/console/app/app_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus


Expand Down Expand Up @@ -88,3 +90,20 @@ def post(self, import_id):
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200


class AppImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_fields)
def get(self, app_model: App):
if not current_user.is_editor:
raise Forbidden()

with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

return result.model_dump(mode="json"), 200
3 changes: 3 additions & 0 deletions api/fields/app_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,8 @@
"current_dsl_version": fields.String,
"imported_dsl_version": fields.String,
"error": fields.String,
}

app_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
}
8 changes: 5 additions & 3 deletions api/libs/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Generator
from datetime import datetime
from hashlib import sha256
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from flask import Response, stream_with_context
from flask_restful import fields
Expand All @@ -19,7 +19,9 @@
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers
from extensions.ext_redis import redis_client
from models.account import Account

if TYPE_CHECKING:
from models.account import Account


def run(script):
Expand Down Expand Up @@ -196,7 +198,7 @@ class TokenManager:
def generate_token(
cls,
token_type: str,
account: Optional[Account] = None,
account: Optional["Account"] = None,
email: Optional[str] = None,
additional_data: Optional[dict] = None,
) -> str:
Expand Down
68 changes: 49 additions & 19 deletions api/services/app_dsl_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
logger = logging.getLogger(__name__)

IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
IMPORT_INFO_REDIS_EXPIRY = 2 * 60 * 60 # 2 hours
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
CURRENT_DSL_VERSION = "0.1.4"
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB

Expand All @@ -54,10 +55,13 @@ class Import(BaseModel):
app_id: Optional[str] = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
error: str = ""


class CheckDependenciesResult(BaseModel):
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)


def _check_version_compatibility(imported_version: str) -> ImportStatus:
"""Determine import status based on version comparison"""
try:
Expand Down Expand Up @@ -87,6 +91,11 @@ class PendingData(BaseModel):
app_id: str | None


class CheckDependenciesPendingData(BaseModel):
dependencies: list[PluginDependency]
app_id: str | None


class AppDslService:
def __init__(self, session: Session):
self._session = session
Expand Down Expand Up @@ -243,23 +252,11 @@ def import_app(
imported_dsl_version=imported_version,
)

try:
dependencies = self.get_leaked_dependencies(account.current_tenant_id, data.get("dependencies", []))
except Exception as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=str(e),
)

if len(dependencies) > 0:
return Import(
id=import_id,
status=ImportStatus.PENDING,
app_id=app_id,
imported_dsl_version=imported_version,
leaked_dependencies=dependencies,
)
# Extract dependencies
dependencies = data.get("dependencies", [])
check_dependencies_pending_data = None
if dependencies:
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]

# Create or update app
app = self._create_or_update_app(
Expand All @@ -271,6 +268,7 @@ def import_app(
icon_type=icon_type,
icon=icon,
icon_background=icon_background,
dependencies=check_dependencies_pending_data,
)

return Import(
Expand Down Expand Up @@ -355,6 +353,29 @@ def confirm_import(self, *, import_id: str, account: Account) -> Import:
error=str(e),
)

def check_dependencies(
self,
*,
app_model: App,
) -> CheckDependenciesResult:
"""Check dependencies"""
# Get dependencies from Redis
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
dependencies = redis_client.get(redis_key)
if not dependencies:
return CheckDependenciesResult()

# Extract dependencies
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)

# Get leaked dependencies
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
)
return CheckDependenciesResult(
leaked_dependencies=leaked_dependencies,
)

def _create_or_update_app(
self,
*,
Expand All @@ -366,6 +387,7 @@ def _create_or_update_app(
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
Expand Down Expand Up @@ -408,6 +430,14 @@ def _create_or_update_app(
self._session.commit()
app_was_created.send(app, account=account)

# save dependencies
if dependencies:
redis_client.setex(
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
IMPORT_INFO_REDIS_EXPIRY,
CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
)

# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")
Expand Down

0 comments on commit 3d3a429

Please sign in to comment.