Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Jan 2, 2024
2 parents 7552314 + 0a08553 commit b30397c
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 5 deletions.
2 changes: 1 addition & 1 deletion api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def generate() -> Generator:
except ModelCurrentlyNotSupportError:
yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
except InvokeError as e:
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n"
yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
except ValueError as e:
yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
except Exception:
Expand Down
5 changes: 3 additions & 2 deletions api/core/entities/provider_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,12 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c
except JSONDecodeError:
original_credentials = {}

# encrypt credentials
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
if value == '[__HIDDEN__]' and key in original_credentials:
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])

model_provider_factory.model_credentials_validate(
provider=self.provider.provider,
Expand Down
5 changes: 4 additions & 1 deletion api/core/model_runtime/model_providers/openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
base_model = model.split(':')[1]

# check if model exists
remote_models = self.remote_models()
remote_models = self.remote_models(credentials)
remote_model_map = {model.model: model for model in remote_models}
if model not in remote_model_map:
raise CredentialsValidateFailedError(f'Fine-tuned model {model} not found')
Expand Down Expand Up @@ -687,6 +687,9 @@ def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage],
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model.startswith('ft:'):
model = model.split(':')[1]

try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""modify provider model name length
Revision ID: 187385f442fc
Revises: 88072f0caa04
Create Date: 2024-01-02 07:18:43.887428
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '187385f442fc'
down_revision = '88072f0caa04'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.alter_column('model_name',
existing_type=sa.VARCHAR(length=40),
type_=sa.String(length=255),
existing_nullable=False)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.alter_column('model_name',
existing_type=sa.String(length=255),
type_=sa.VARCHAR(length=40),
existing_nullable=False)

# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion api/models/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ProviderModel(db.Model):
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
tenant_id = db.Column(UUID, nullable=False)
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
Expand Down

0 comments on commit b30397c

Please sign in to comment.