-
Notifications
You must be signed in to change notification settings - Fork 8.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(ext_login): Move the login_manager outside.
Signed-off-by: -LAN- <[email protected]>
- Loading branch information
Showing
1 changed file
with
57 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,62 @@ | ||
from dify_app import DifyApp | ||
|
||
import json | ||
|
||
def init_app(app: DifyApp): | ||
import json | ||
import flask_login | ||
from flask import Response, request | ||
from flask_login import user_loaded_from_request, user_logged_in | ||
from werkzeug.exceptions import Unauthorized | ||
|
||
import flask_login | ||
from flask import Response, request | ||
from flask_login import user_loaded_from_request, user_logged_in | ||
from werkzeug.exceptions import Unauthorized | ||
import contexts | ||
from dify_app import DifyApp | ||
from libs.passport import PassportService | ||
from services.account_service import AccountService | ||
|
||
login_manager = flask_login.LoginManager() | ||
|
||
|
||
# Flask-Login configuration | ||
@login_manager.request_loader | ||
def load_user_from_request(request_from_flask_login): | ||
"""Load user based on the request.""" | ||
if request.blueprint not in {"console", "inner_api"}: | ||
return None | ||
# Check if the user_id contains a dot, indicating the old format | ||
auth_header = request.headers.get("Authorization", "") | ||
if not auth_header: | ||
auth_token = request.args.get("_token") | ||
if not auth_token: | ||
raise Unauthorized("Invalid Authorization token.") | ||
else: | ||
if " " not in auth_header: | ||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||
auth_scheme, auth_token = auth_header.split(None, 1) | ||
auth_scheme = auth_scheme.lower() | ||
if auth_scheme != "bearer": | ||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||
|
||
decoded = PassportService().verify(auth_token) | ||
user_id = decoded.get("user_id") | ||
|
||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | ||
return logged_in_account | ||
|
||
|
||
@user_logged_in.connect | ||
@user_loaded_from_request.connect | ||
def on_user_logged_in(_sender, user): | ||
"""Called when a user logged in.""" | ||
if user: | ||
contexts.tenant_id.set(user.current_tenant_id) | ||
|
||
|
||
@login_manager.unauthorized_handler | ||
def unauthorized_handler(): | ||
"""Handle unauthorized requests.""" | ||
return Response( | ||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}), | ||
status=401, | ||
content_type="application/json", | ||
) | ||
|
||
import contexts | ||
from libs.passport import PassportService | ||
from services.account_service import AccountService | ||
|
||
login_manager = flask_login.LoginManager() | ||
def init_app(app: DifyApp): | ||
login_manager.init_app(app) | ||
|
||
# Flask-Login configuration | ||
@login_manager.request_loader | ||
def load_user_from_request(request_from_flask_login): | ||
"""Load user based on the request.""" | ||
if request.blueprint not in {"console", "inner_api"}: | ||
return None | ||
# Check if the user_id contains a dot, indicating the old format | ||
auth_header = request.headers.get("Authorization", "") | ||
if not auth_header: | ||
auth_token = request.args.get("_token") | ||
if not auth_token: | ||
raise Unauthorized("Invalid Authorization token.") | ||
else: | ||
if " " not in auth_header: | ||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||
auth_scheme, auth_token = auth_header.split(None, 1) | ||
auth_scheme = auth_scheme.lower() | ||
if auth_scheme != "bearer": | ||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | ||
|
||
decoded = PassportService().verify(auth_token) | ||
user_id = decoded.get("user_id") | ||
|
||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id) | ||
return logged_in_account | ||
|
||
@user_logged_in.connect | ||
@user_loaded_from_request.connect | ||
def on_user_logged_in(_sender, user): | ||
"""Called when a user logged in.""" | ||
if user: | ||
contexts.tenant_id.set(user.current_tenant_id) | ||
|
||
@login_manager.unauthorized_handler | ||
def unauthorized_handler(): | ||
"""Handle unauthorized requests.""" | ||
return Response( | ||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}), | ||
status=401, | ||
content_type="application/json", | ||
) |