diff --git a/sdcat/cluster/cluster.py b/sdcat/cluster/cluster.py index 317f655..dbb26fb 100755 --- a/sdcat/cluster/cluster.py +++ b/sdcat/cluster/cluster.py @@ -86,7 +86,7 @@ def run_vss(image_t: tuple[str, np.array], vss_url: str, project: str, vss_thres :param vss_url: url for vss service :param project: project name in vss :param top_k: number of vss to use for prediction; 1, 3, 5 etc. - :return: + :return: best prediction and score """ url_vss = f"{vss_url}/{top_k}/{project}" debug(f"URL: {url_vss} threshold: {vss_threshold}") @@ -96,13 +96,13 @@ def run_vss(image_t: tuple[str, np.array], vss_url: str, project: str, vss_thres if response.status_code != 200: err(f"Error processing images: {response.text}") - return None, None + return "", 0.0 predictions = response.json()["predictions"] debug(f"Predictions: {predictions}") if len(predictions) == 0: - return None, None + return "", 0.0 scores = response.json()["scores"][0] # Scores are 1 - score, so we need to invert them @@ -112,7 +112,7 @@ def run_vss(image_t: tuple[str, np.array], vss_url: str, project: str, vss_thres if best_pred is None: err(f"No majority prediction for {image_t[0]}") - return None, None + return "", 0.0 return best_pred, best_score @@ -514,7 +514,7 @@ def cluster_vits( # Run the VSS service to assign the cluster to a class image_t = read_image(exemplar['image_path']) best_prediction, best_score = run_vss(image_t, vss_url=vss_url, vss_threshold=.1, project='901103-biodiversity', top_k=1) - if best_prediction is None: + if len(best_prediction) == 0: warn(f'No predictions found for {exemplar["image_path"]}') continue # Assign the class to the cluster in df_dets @@ -522,6 +522,19 @@ def cluster_vits( df_dets.loc[df_dets['cluster'] == cluster_id, 'class'] = best_prediction df_dets.loc[df_dets['cluster'] == cluster_id, 'score'] = best_score + # Try to assign everything not in a cluster to a class + unknowns = df_dets[df_dets['cluster'] == -1] + for idx, row in unknowns.iterrows(): + image_t = read_image(row['crop_path']) + best_prediction, best_score = run_vss(image_t, vss_url=vss_url, vss_threshold=.1, project='901103-biodiversity', top_k=1) + if len(best_prediction) == 0: + warn(f'No predictions found for {row["crop_path"]}') + continue + # Assign the class to the cluster in df_dets + info(f'Assigning {row["crop_path"]} to class {best_prediction} with score {best_score}') + df_dets.loc[df_dets['crop_path'] == row['crop_path'], 'class'] = best_prediction + df_dets.loc[df_dets['crop_path'] == row['crop_path'], 'score'] = best_score + # Save the exemplar embeddings with the model type exemplar_df['model'] = model exemplar_df.to_csv(output_path / f'{prefix}_exemplars.csv', index=False)