Skip to content

Commit

Permalink
Model Runtime (langgenius#1858)
Browse files Browse the repository at this point in the history
Co-authored-by: StyleZhang <[email protected]>
Co-authored-by: Garfield Dai <[email protected]>
Co-authored-by: chenhe <[email protected]>
Co-authored-by: jyong <[email protected]>
Co-authored-by: Joel <[email protected]>
Co-authored-by: Yeuoly <[email protected]>
  • Loading branch information
7 people authored Jan 2, 2024
1 parent 11f204e commit 200e502
Show file tree
Hide file tree
Showing 663 changed files with 166,604 additions and 19,131 deletions.
15 changes: 15 additions & 0 deletions api/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Celery",
"type": "python",
"request": "launch",
"module": "celery",
"justMyCode": true,
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
"envFile": "${workspaceFolder}/.env",
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"console": "integratedTerminal"
},
{
"name": "Python: Flask",
"type": "python",
Expand Down
3 changes: 0 additions & 3 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ RUN apt-get update \
COPY --from=base /pkg /usr/local
COPY . /app/api/

RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true

COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh

Expand Down
34 changes: 20 additions & 14 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
# if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()

import langchain
langchain.verbose = True

import time
import logging
Expand All @@ -18,9 +21,8 @@
from flask import Flask, request, Response
from flask_cors import CORS

from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
from extensions.ext_database import db
from extensions.ext_login import login_manager

Expand Down Expand Up @@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)

hosted.init_app(app)

return app


Expand All @@ -95,6 +95,7 @@ def initialize_extensions(app):
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)


Expand All @@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')

if not auth_header:
auth_token = request.args.get('_token')
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')

decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')

Expand Down
28 changes: 15 additions & 13 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound

from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
Expand Down Expand Up @@ -327,26 +323,32 @@ def create_qdrant_indexes():
except NotFound:
break

model_manager = ModelManager()

page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model

)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:

provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
Expand Down
2 changes: 1 addition & 1 deletion api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self):
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.34"
self.CURRENT_VERSION = "0.4.0"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source

# Import workspace controllers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
from .workspace import workspace, members, model_providers, account, tool_providers, models

# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
Expand Down
78 changes: 44 additions & 34 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from datetime import datetime

from flask_login import current_user

from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
Expand All @@ -13,9 +17,7 @@
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError
from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
Expand Down Expand Up @@ -73,39 +75,41 @@ def post(self):
raise Forbidden()

try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
provider_manager = ProviderManager()
default_model_entity = provider_manager.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
default_model_entity = None
except Exception as e:
logging.exception(e)
default_model = None
default_model_entity = None

if args['model_config'] is not None:
# validate config
model_config_dict = args['model_config']

# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)

if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = model_instance.provider
model_config_dict["model"]["name"] = model_instance.model

model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=model_config_dict,
mode=args['mode']
app_mode=args['mode']
)

app = App(
Expand All @@ -129,21 +133,27 @@ def post(self):
app_model_config = AppModelConfig(**model_config_template['model_config'])

# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
app_model_config.model_dict["provider"]
)

if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.model_provider.provider_name
model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
model_manager = ModelManager()

try:
model_instance = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.LLM
)
except ProviderTokenNotInitError:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")

if not model_instance:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_dict = app_model_config.model_dict
model_dict['provider'] = model_instance.provider
model_dict['name'] = model_instance.model
app_model_config.model = json.dumps(model_dict)

app.name = args['name']
app.mode = args['mode']
Expand Down
8 changes: 4 additions & 4 deletions api/controllers/console/app/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging

from flask import request

from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from werkzeug.exceptions import InternalServerError

Expand All @@ -14,8 +16,7 @@
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \
Expand Down Expand Up @@ -56,8 +57,7 @@ def post(self, app_id):
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
except InvokeError as e:
raise CompletionRequestError(str(e))
except ValueError as e:
raise e
Expand Down
Loading

0 comments on commit 200e502

Please sign in to comment.