Skip to content

Commit

Permalink
Model Runtime (#1858)
Browse files Browse the repository at this point in the history
Co-authored-by: StyleZhang <[email protected]>
Co-authored-by: Garfield Dai <[email protected]>
Co-authored-by: chenhe <[email protected]>
Co-authored-by: jyong <[email protected]>
Co-authored-by: Joel <[email protected]>
Co-authored-by: Yeuoly <[email protected]>
  • Loading branch information
7 people authored Jan 2, 2024
1 parent e91dd28 commit d069c66
Show file tree
Hide file tree
Showing 807 changed files with 171,309 additions and 23,805 deletions.
58 changes: 58 additions & 0 deletions .github/workflows/api-model-runtime-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: Run Pytest

on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
- feat/model-runtime

jobs:
test:
runs-on: ubuntu-latest

env:
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
CHATGLM_API_BASE: http://a.abc.com:11451
XINFERENCE_SERVER_URL: http://a.abc.com:11451
XINFERENCE_GENERATION_MODEL_UID: generate
XINFERENCE_CHAT_MODEL_UID: chat
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
XINFERENCE_RERANK_MODEL_UID: rerank
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
MOCK_SWITCH: true


steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'

- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
38 changes: 0 additions & 38 deletions .github/workflows/api-unit-tests.yml

This file was deleted.

5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r

Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!


### Provider Integrations
If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR.


### i18n (Internationalization) Support

We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
Expand Down
15 changes: 15 additions & 0 deletions api/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Celery",
"type": "python",
"request": "launch",
"module": "celery",
"justMyCode": true,
"args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
"envFile": "${workspaceFolder}/.env",
"env": {
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"console": "integratedTerminal"
},
{
"name": "Python: Flask",
"type": "python",
Expand Down
3 changes: 0 additions & 3 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ RUN apt-get update \
COPY --from=base /pkg /usr/local
COPY . /app/api/

RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true

COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh

Expand Down
34 changes: 20 additions & 14 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
# if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()

import langchain
langchain.verbose = True

import time
import logging
Expand All @@ -18,9 +21,8 @@
from flask import Flask, request, Response
from flask_cors import CORS

from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_code_based_extension
ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider
from extensions.ext_database import db
from extensions.ext_login import login_manager

Expand Down Expand Up @@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)

hosted.init_app(app)

return app


Expand All @@ -95,6 +95,7 @@ def initialize_extensions(app):
ext_celery.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_hosting_provider.init_app(app)
ext_sentry.init_app(app)


Expand All @@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login):
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')

if not auth_header:
auth_token = request.args.get('_token')
if not auth_token:
raise Unauthorized('Invalid Authorization token.')
else:
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')

decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')

Expand Down
28 changes: 15 additions & 13 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound

from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
Expand Down Expand Up @@ -327,26 +323,32 @@ def create_qdrant_indexes():
except NotFound:
break

model_manager = ModelManager()

page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model

)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
embedding_model = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
except Exception:

provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
Expand Down
2 changes: 1 addition & 1 deletion api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self):
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.34"
self.CURRENT_VERSION = "0.4.0"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source

# Import workspace controllers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models
from .workspace import workspace, members, model_providers, account, tool_providers, models

# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio
Expand Down
Loading

0 comments on commit d069c66

Please sign in to comment.