Skip to content

Commit

Permalink
fix update dataset failed when embedding model is not exist (#6920)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong authored Aug 2, 2024
1 parent 4d0a6cc commit 048bc4c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
9 changes: 7 additions & 2 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def patch(self, dataset_id):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)

parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False,
Expand All @@ -215,6 +213,13 @@ def patch(self, dataset_id):
args = parser.parse_args()
data = request.get_json()

# check embedding model setting
if data.get('indexing_technique') == 'high_quality':
DatasetService.check_embedding_model_setting(dataset.tenant_id,
data.get('embedding_model_provider'),
data.get('embedding_model')
)

# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get('permission'), data.get('partial_member_list')
Expand Down
22 changes: 22 additions & 0 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,28 @@ def check_dataset_model_setting(dataset):
f"{ex.description}"
)

@staticmethod
def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model:str):
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=tenant_id,
provider=embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=embedding_model
)
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(
f"The dataset in unavailable, due to: "
f"{ex.description}"
)


@staticmethod
def update_dataset(dataset_id, data, user):
data.pop('partial_member_list', None)
Expand Down

0 comments on commit 048bc4c

Please sign in to comment.