From c5fe72523fdfb1eb7b674565e5724686a4ec65d1 Mon Sep 17 00:00:00 2001 From: danellecline Date: Mon, 29 Jul 2024 09:56:48 -0700 Subject: [PATCH] perf: migrated to transformers library with batch size of 8, moved some imports to only where needed for some speed-up, and removed unused activation maps. --- requirements.txt | 2 +- sdcat/cluster/cluster.py | 72 ++++++++--------- sdcat/cluster/commands.py | 5 +- sdcat/cluster/embedding.py | 162 +++++++++++++++---------------------- sdcat/cluster/utils.py | 136 +++++++++++++------------------ sdcat/config/config.ini | 18 ++--- sdcat/detect/commands.py | 6 +- 7 files changed, 172 insertions(+), 229 deletions(-) diff --git a/requirements.txt b/requirements.txt index 19a30dc..f291522 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ torch==2.3.1 piexif yolov5==7.0.13 torchvision==0.18.1 -transformers +transformers[torch] timm pandas>=1.2.4 ultralytics diff --git a/sdcat/cluster/cluster.py b/sdcat/cluster/cluster.py index daaca84..290f673 100644 --- a/sdcat/cluster/cluster.py +++ b/sdcat/cluster/cluster.py @@ -74,7 +74,7 @@ def _run_hdbscan_assign( if not numerical.empty: numerical = numerical.fillna(0) - # Normalize the numerical data from 0 to 1 + # Normalize the numerical data from 0 to 1 and add it to the dataframe numerical = (numerical - numerical.min()) / (numerical.max() - numerical.min()) df = pd.merge(df, numerical, left_index=True, right_index=True, how='left') @@ -107,30 +107,15 @@ def _run_hdbscan_assign( labels = scan.fit_predict(x) else: scan = HDBSCAN( - metric='l2', - allow_single_cluster=True, - min_cluster_size=min_cluster_size, - min_samples=min_samples, - alpha=alpha, - cluster_selection_epsilon=cluster_selection_epsilon, - cluster_selection_method='leaf') + metric='l2', + allow_single_cluster=True, + min_cluster_size=min_cluster_size, + min_samples=min_samples, + alpha=alpha, + cluster_selection_epsilon=cluster_selection_epsilon, + cluster_selection_method='leaf') labels = scan.fit_predict(x) -# title_tree = f'HDBSCAN Tree Distances {cluster_selection_epsilon} min_cluster_size {min_cluster_size} min_samples {min_samples} alpha {alpha}' -# title_linkage = title_tree.replace('Tree Distances', 'Linkage') - -# scan.condensed_tree_.plot(select_clusters=True, -# selection_palette=sns.color_palette('deep', 8)) -# plt.title(title_tree) -# plt.xlabel('Index') -# plt.savefig(f"{out_path}/{prefix}_condensed_tree.png") - -# plt.figure(figsize=(10, 6)) -# scan.single_linkage_tree_.plot(cmap='viridis', colorbar=True) -# plt.title(title_linkage) -# plt.xlabel('Index') -# plt.savefig(f"{out_path}/{prefix}_tree.png") - # Get the unique clusters and sort them; -1 are unassigned clusters cluster_df = pd.DataFrame(labels, columns=['cluster']) unique_clusters = cluster_df['cluster'].unique().tolist() @@ -149,7 +134,7 @@ def _run_hdbscan_assign( if len(unique_clusters) == 1 and unique_clusters[0] == -1: avg_sim_scores = [] exemplar_df = pd.DataFrame() - exemplar_df['cluster'] = len(x)*['Unknown'] + exemplar_df['cluster'] = len(x) * ['Unknown'] exemplar_df['embedding'] = x.tolist() exemplar_df['image_path'] = ancillary_df['image_path'].tolist() clusters = [] @@ -191,6 +176,9 @@ def _run_hdbscan_assign( avg_sim_scores = [] for i, c in enumerate(clusters): debug(f'Computing similarity for cluster {i} with {len(c)} samples') + if len(c) == 0: + avg_sim_scores.append(0) + continue cosine_sim_matrix = cosine_similarity(image_emb[c]) avg_sim_scores.append(np.mean(cosine_sim_matrix)) @@ -223,7 +211,7 @@ def _run_hdbscan_assign( else: init = 'spectral' - # Reduce the dimensionality of the embeddings using UMAP to 2 dimensions for visualization + # Reduce the dimensionality of the embeddings using UMAP to 2 dimensions to visualize the clusters if have_gpu: xx = cuUMAP(init=init, n_components=2, @@ -233,8 +221,6 @@ def _run_hdbscan_assign( else: xx = UMAP(init=init, n_components=2, - n_neighbors=3, - min_dist=0.1, metric='cosine', low_memory=True).fit_transform(df.values) @@ -285,14 +271,14 @@ def cluster_vits( # Skip cropping if all the crops are already done if num_crop != len(df_dets): num_processes = min(multiprocessing.cpu_count(), len(df_dets)) - if roi == True: - info(f'ROI crops already exist. Creating square crops in parallel using {multiprocessing.cpu_count()} processes...') + if roi is True: + info(f'ROI crops already exist. Creating square crops in parallel using {num_processes} processes...') with multiprocessing.Pool(num_processes) as pool: args = [(row, 224) for index, row in df_dets.iterrows()] pool.starmap(square_image, args) else: # Crop and squaring the images in parallel using multiprocessing to speed up the processing - info(f'Cropping {len(df_dets)} detections in parallel using {multiprocessing.cpu_count()} processes...') + info(f'Cropping {len(df_dets)} detections in parallel using {num_processes} processes...') with multiprocessing.Pool(num_processes) as pool: args = [(row, 224) for index, row in df_dets.iterrows()] pool.starmap(crop_square_image, args) @@ -317,9 +303,17 @@ def cluster_vits( for filename in images: emb = fetch_embedding(model, filename) if len(emb) == 0: + # If the embeddings are zero, then the extraction failed; add a zero array image_emb.append(np.zeros(384, dtype=np.float32)) else: image_emb.append(emb) + + # If the embeddings are zero, then the extraction failed + num_failed = [i for i, e in enumerate(image_emb) if np.all(e == 0)] + if len(num_failed) == len(images): + warn('Failed to extract embeddings from all images') + return pd.DataFrame() + image_emb = np.array(image_emb) if not (output_path / prefix).exists(): @@ -338,15 +332,15 @@ def cluster_vits( # Cluster the images cluster_sim, exemplar_df, unique_clusters, cluster_means, coverage = _run_hdbscan_assign(prefix, - image_emb, - alpha, - cluster_selection_epsilon, - min_similarity, - min_cluster_size, - min_samples, - use_tsne, - ancillary_df, - output_path / prefix) + image_emb, + alpha, + cluster_selection_epsilon, + min_similarity, + min_cluster_size, + min_samples, + use_tsne, + ancillary_df, + output_path / prefix) # Get the average similarity across all clusters avg_similarity = np.mean(cluster_sim) diff --git a/sdcat/cluster/commands.py b/sdcat/cluster/commands.py index dce1e6b..547fbe5 100644 --- a/sdcat/cluster/commands.py +++ b/sdcat/cluster/commands.py @@ -246,8 +246,11 @@ def is_day(utc_dt): info(df.head(5)) if len(df) > 0: + # Replace / with _ in the model name + model_machine_friendly = model.replace('/', '_') + # A prefix for the output files to make sure the output is unique for each execution - prefix = f'{model}_{datetime.now().strftime("%Y%m%d_%H%M%S")}' + prefix = f'{model_machine_friendly}_{datetime.now().strftime("%Y%m%d_%H%M%S")}' # Cluster the detections df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, min_similarity, diff --git a/sdcat/cluster/embedding.py b/sdcat/cluster/embedding.py index 5776078..31ea682 100644 --- a/sdcat/cluster/embedding.py +++ b/sdcat/cluster/embedding.py @@ -13,45 +13,49 @@ from sahi.utils.torch import torch from torchvision import transforms as pth_transforms import cv2 +from transformers import ViTModel, ViTImageProcessor + from sdcat.logger import info, err -def cache_embedding(embedding, model_name: str, filename: str): - # save numpy array as npy file - save(f'{filename}_{model_name}.npy', embedding) +class ViTWrapper: + MODEL_NAME = "google/vit-base-patch16-224" + VECTOR_DIMENSIONS = 768 + + def __init__(self, device: str = "cpu", reset: bool = False, batch_size: int = 32): + self.batch_size = batch_size + self.model = ViTModel.from_pretrained(self.MODEL_NAME) + self.processor = ViTImageProcessor.from_pretrained(self.MODEL_NAME) -def cache_attention(attention, model_name: str, filename: str): + # Load the model and processor + if 'cuda' in device and torch.cuda.is_available(): + device_num = int(device.split(":")[-1]) + info(f"Using GPU device {device_num}") + torch.cuda.set_device(device_num) + self.device = "cuda" + self.model.to("cuda") + else: + self.device = "cpu" + + +def cache_embedding(embedding, model_name: str, filename: str): + model_machine_friendly_name = model_name.replace("/", "_") # save numpy array as npy file - save(f'{filename}_{model_name}_a.npy', attention) + save(f'{filename}_{model_machine_friendly_name}.npy', embedding) def fetch_embedding(model_name: str, filename: str) -> np.array: + model_machine_friendly_name = model_name.replace("/", "_") # if the npy file exists, return it - if os.path.exists(f'{filename}_{model_name}.npy'): - data = load(f'{filename}_{model_name}.npy') + if os.path.exists(f'{filename}_{model_machine_friendly_name}.npy'): + data = load(f'{filename}_{model_machine_friendly_name}.npy') return data else: info(f'No embedding found for {filename}') return [] -def fetch_attention(model_name: str, filename: str) -> np.array: - """ - Fetch the attention map for the given filename and model name - :param model_name: Name of the model - :param filename: Name of the file - :return: Numpy array of the attention map - """ - # if the npy file exists, return it - if os.path.exists(f'{filename}_{model_name}_a.npy'): - data = load(f'{filename}_{model_name}_a.npy') - return data - else: - info(f'No attention map found for {filename}') - return [] - - def has_cached_embedding(model_name: str, filename: str) -> int: """ Check if the given filename has a cached embedding @@ -59,7 +63,8 @@ def has_cached_embedding(model_name: str, filename: str) -> int: :param filename: Name of the file :return: 1 if the image has a cached embedding, otherwise 0 """ - if os.path.exists(f'{filename}_{model_name}.npy'): + model_machine_friendly_name = model_name.replace("/", "_") + if os.path.exists(f'{filename}_{model_machine_friendly_name}.npy'): return 1 return 0 @@ -71,89 +76,48 @@ def encode_image(filename): return keep -def compute_embedding(images: list, model_name: str): +def compute_embedding_vits(images: list, model_name: str, device: str = "cpu"): """ Compute the embedding for the given images using the given model :param images: List of image filenames - :param model_name: Name of the model + :param model_name: Name of the model (i.e. google/vit-base-patch16-224, dinov2_vits16, etc.) + :param device: Device to use for the computation (cpu or cuda:0, cuda:1, etc.) """ - - # Load the model - if 'dinov2' in model_name: - info(f'Loading model {model_name} from facebookresearch/dinov2...') - model = torch.hub.load('facebookresearch/dinov2', model_name) - elif 'dino' in model_name: - info(f'Loading model {model_name} from facebookresearch/dino:main...') - model = torch.hub.load('facebookresearch/dino:main', model_name) - else: - # TODO: Add more models - err(f'Unknown model {model_name}!') - return - - # The patch size is in the model name, e.g. dino_vits16 is a 16x16 patch size, dino_vits8 is a 8x8 patch size - res = re.findall(r'\d+$', model_name) - if len(res) > 0: - patch_size = int(res[0]) + batch_size = 8 + vit_model = ViTModel.from_pretrained(model_name) + processor = ViTImageProcessor.from_pretrained(model_name) + + if 'cuda' in device and torch.cuda.is_available(): + device_num = int(device.split(":")[-1]) + info(f"Using GPU device {device_num}") + torch.cuda.set_device(device_num) + vit_model.to("cuda") + device = "cuda" else: - raise ValueError(f'Could not find patch size in model name {model_name}') - info(f'Using patch size {patch_size} for model {model_name}') - - # Load images and generate embeddings - device = 'cuda' if torch.cuda.is_available() else 'cpu' - with torch.no_grad(): - # Set the cuda device - if torch.cuda.is_available(): - model = model.to(device) - - for filename in images: - # Skip if the embedding already exists - if Path(f'{filename}_{model_name}.npy').exists(): + device = "cpu" + + # Batch process the images + batches = [images[i:i + batch_size] for i in range(0, len(images), batch_size)] + for batch in batches: + try: + # Skip running the model if the embeddings already exist + if all([has_cached_embedding(model_name, filename) for filename in batch]): continue - try: - # Load the image - square_img = Image.open(filename) - - # Do some image processing to reduce the noise in the image - # Gaussian blur - square_img = square_img.filter(ImageFilter.GaussianBlur(radius=1)) - - image = np.array(square_img) - - norm_transform = pth_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 - # Noramlize the tensor with the mean and std of the ImageNet dataset - img_tensor = norm_transform(img_tensor) - img_tensor = img_tensor.unsqueeze(0) # Add batch dimension - if 'cuda' in device: - img_tensor = img_tensor.to(device) - features = model(img_tensor) - - # TODO: add attention map cach as optional - # attentions = model.get_last_selfattention(img_tensor) - - # nh = attentions.shape[1] # number of head - - # w_featmap = 224 // patch_size - # h_featmap = 224 // patch_size + images = [Image.open(filename).convert("RGB") for filename in batch] + inputs = processor(images=images, return_tensors="pt").to(device) - # Keep only the output patch attention - # attentions = attentions[0, :, 0, 1:].reshape(nh, -1) - # attentions = attentions.reshape(nh, w_featmap, h_featmap) - # attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[ - # 0].cpu().numpy() - # - # # Resize the attention map to the original image size - # attentions = np.uint8(255 * attentions[0]) + with torch.no_grad(): + embeddings = vit_model(**inputs) - # Get the feature embeddings - embeddings = features.squeeze(dim=0) # Remove batch dimension - embeddings = embeddings.cpu().numpy() # Convert to numpy array + batch_embeddings = embeddings.last_hidden_state[:, 0, :].cpu().numpy() - cache_embedding(embeddings, model_name, filename) # save the embedding to disk - #cache_attention(attentions, model_name, filename) # save the attention map to disk - except Exception as e: - err(f'Error processing {filename}: {e}') + # Save the embeddings + for emb, filename in zip(batch_embeddings, batch): + emb = emb.astype(np.float32) + cache_embedding(emb, model_name, filename) + except Exception as e: + err(f'Error processing {batch}: {e}') def compute_norm_embedding(model_name: str, images: list): @@ -172,7 +136,7 @@ def compute_norm_embedding(model_name: str, images: list): # If using a GPU, set then skip the parallel CPU processing if torch.cuda.is_available(): - compute_embedding(images, model_name) + compute_embedding_vits(images, model_name) else: # Use a pool of processes to speed up the embedding generation 20 images at a time on each process num_processes = min(multiprocessing.cpu_count(), len(images) // 20) @@ -180,7 +144,7 @@ def compute_norm_embedding(model_name: str, images: list): info(f'Using {num_processes} processes to compute {len(images)} embeddings 20 at a time ...') with multiprocessing.Pool(num_processes) as pool: args = [(images[i:i + 20], model_name) for i in range(0, len(images), 20)] - pool.starmap(compute_embedding, args) + pool.starmap(compute_embedding_vits, args) def calc_mean_std(image_files: list) -> tuple: diff --git a/sdcat/cluster/utils.py b/sdcat/cluster/utils.py index a5c5cc7..d5695e0 100644 --- a/sdcat/cluster/utils.py +++ b/sdcat/cluster/utils.py @@ -9,14 +9,13 @@ from mpl_toolkits.axes_grid1 import ImageGrid from pathlib import Path -from sdcat.cluster.embedding import fetch_attention from sdcat.logger import debug, warn, exception def cluster_grid(prefix: str, cluster_sim: float, cluster_id: int, cluster_size: int, nb_images_display: int, images: list, output_path: Path): """ - Cluster visualization; create a grid of images both with and without attention map + Cluster visualization; create a grid of images :param cluster_sim: Cluster similarity :param cluster_size: Size of the cluster :param cluster_id: Cluster ID @@ -26,74 +25,60 @@ def cluster_grid(prefix: str, cluster_sim: float, cluster_id: int, cluster_size: """ debug(f'Cluster number {cluster_id} size {len(cluster_size)} similarity {cluster_sim}\n') - def gen_grid(with_attention: bool): - # Plot a grid for each group of images nb_images_display at a time (e.g. 8x8) - for i in range(0, len(images), nb_images_display * nb_images_display): - fig = plt.figure(figsize=(10., 10.)) - grid = ImageGrid(fig, 111, # similar to subplot(111) - nrows_ncols=(nb_images_display, nb_images_display), - # creates nb_images_display x nb_images_display grid of axes - axes_pad=0.025, - share_all=True, - cbar_pad=0.025) - images_display = images[i:i + nb_images_display * nb_images_display] - page = i // (nb_images_display * nb_images_display) - - # If we have more than 3 pages, then only display the first 3 pages - # There can be a large number of pages for detections in common classes - if page > 3: - break - - total_pages = len(images) // (nb_images_display * nb_images_display) - # debug(f"{i} Image filename:", images[j]) - for j, image in enumerate(images_display): - try: - image_square = Image.open(image) - grid[j].imshow(image_square) - except Exception as e: - exception(f'Error opening {image} {e}') - continue - - if with_attention: - # Get the attention map - # TODO: remove this or refactor with pass through of model name - attention = fetch_attention('dino_vitb8', image) - - # Overlay the attention map on top of the original image - grid[j].imshow(attention, cmap='jet', alpha=0.125) - - grid[j].axis('off') - # If the verified is in the image name, then add a label to the image in the top center corner - if 'verified' in image: - n = Path(image) - title = f"{n.stem.split('_')[0]}" - grid[j].text(30, 10, title, fontsize=8, color='white', ha='center', va='center') - # clear the x and y-axis - grid[j].set_xticklabels([]) - - # Add a title to the figure - if total_pages > 1: - fig.suptitle( - f"{prefix} Cluster {cluster_id}, Size: {len(cluster_size)}, Similarity: {cluster_sim:.2f}, Page: {page} of {total_pages}", - fontsize=16) - else: - fig.suptitle(f"{prefix} Cluster {cluster_id}, Size: {len(cluster_size)}, Similarity: {cluster_sim:.2f}", - fontsize=16) - - # Set the background color of the grid to white - fig.set_facecolor('white') - - # Write the figure to a file - if with_attention: - out = output_path / f'{prefix}_cluster_{cluster_id}_p{page}_attention.png' - else: - out = output_path / f'{prefix}_cluster_{cluster_id}_p{page}.png' - debug(f'Writing {out}') - fig.savefig(out.as_posix()) - plt.close(fig) - - gen_grid(with_attention=False) - # gen_grid(with_attention=True) + # Plot a grid for each group of images nb_images_display at a time (e.g. 8x8) + for i in range(0, len(images), nb_images_display * nb_images_display): + fig = plt.figure(figsize=(10., 10.)) + grid = ImageGrid(fig, 111, # similar to subplot(111) + nrows_ncols=(nb_images_display, nb_images_display), + # creates nb_images_display x nb_images_display grid of axes + axes_pad=0.025, + share_all=True, + cbar_pad=0.025) + images_display = images[i:i + nb_images_display * nb_images_display] + page = i // (nb_images_display * nb_images_display) + + # If we have more than 3 pages, then only display the first 3 pages + # There can be a large number of pages for detections in common classes + if page > 3: + break + + total_pages = len(images) // (nb_images_display * nb_images_display) + # debug(f"{i} Image filename:", images[j]) + for j, image in enumerate(images_display): + try: + image_square = Image.open(image) + grid[j].imshow(image_square) + except Exception as e: + exception(f'Error opening {image} {e}') + continue + + grid[j].axis('off') + # If the verified is in the image name, then add a label to the image in the top center corner + if 'verified' in image: + n = Path(image) + title = f"{n.stem.split('_')[0]}" + grid[j].text(30, 10, title, fontsize=8, color='white', ha='center', va='center') + # clear the x and y-axis + grid[j].set_xticklabels([]) + + # Add a title to the figure + if total_pages > 1: + fig.suptitle( + f"{prefix} Cluster {cluster_id}, Size: {len(cluster_size)}, Similarity: {cluster_sim:.2f}, Page: {page} of {total_pages}", + fontsize=16) + else: + fig.suptitle(f"{prefix} Cluster {cluster_id}, Size: {len(cluster_size)}, Similarity: {cluster_sim:.2f}", + fontsize=16) + + # Set the background color of the grid to white + fig.set_facecolor('white') + + # Write the figure to a file + out = output_path / f'{prefix}_cluster_{cluster_id}_p{page}.png' + debug(f'Writing {out}') + fig.savefig(out.as_posix()) + plt.close(fig) + def square_image(row, square_dim: int): @@ -132,6 +117,7 @@ def square_image(row, square_dim: int): exception(f'Error cropping {row.image_path} {e}') raise e + def crop_square_image(row, square_dim: int): """ Crop the image to a square padding the shortest dimension, then resize it to square_dim x square_dim @@ -203,14 +189,8 @@ def crop_square_image(row, square_dim: int): img = img.resize((square_dim, square_dim), Image.LANCZOS) # Save the image - # img.save(row.crop_path) - - # Every 10th index, Create a zero byte file to indicate that the crop was successful - if Path(row.image_path).stem is 'e1f5e2b8-9e3c-5904-a896-acb3c7a9cbf6': - Path(row.crop_path).touch() - else: - img.save(row.crop_path) - img.close() + img.save(row.crop_path) + img.close() except Exception as e: exception(f'Error cropping {row.image_path} {e}') diff --git a/sdcat/config/config.ini b/sdcat/config/config.ini index ef3a1a6..5b4959e 100644 --- a/sdcat/config/config.ini +++ b/sdcat/config/config.ini @@ -28,18 +28,18 @@ min_cluster_size = 2 min_samples = 1 max_area = 4375000 min_area = 100 +# Detections not assigned with hdbscan are assigned to the nearest cluster with a similarity > min_similarity +# This is useful for merging examples not assigned to clusters; set to 0 to disable +# A value of .9 would be very conservative, while a value of .5 would be very aggressive (merging only somewhat similar detections) # min_similarity must be in the range [0, 1] -# Clusters not assigned with hdbscan are assigned to the nearest cluster with a similarity > min_similarity min_similarity = 0.70 -# Examples: dinov2_vits14, dino_vits8, dino_vits16 -# dinov2 models were pretrained on a dataset of 142 M images without any labels -# dino models were pretrained on ImageNet which contains 1.3 M images with labels -# dino_vits8 has block_size=8 which can be good for very small objects -# dino_vits14 has block_size=14 +# google/vit-base-patch16-224 is a model trained on ImageNet21k with 21k classes good for general detection +# dino models were pretrained on ImageNet which contains 1.3 M images with labels from 1000 classes # Smaller block_size means more patches and more accurate fine-grained clustering on smaller objects -model = dino_vits8 -;model = dinov2_vits14 -;model = dinov2_vitb14 +# Larger block_size means fewer patches and faster processing +model = google/vit-base-patch16-224 +;model = facebook/dino-vits8 +;model = facebook/dino-vits16 [detect] ######################################################################## diff --git a/sdcat/detect/commands.py b/sdcat/detect/commands.py index fb8bf35..97b20a9 100644 --- a/sdcat/detect/commands.py +++ b/sdcat/detect/commands.py @@ -8,8 +8,6 @@ import cv2 import pandas as pd import torch -from huggingface_hub import hf_hub_download -from sahi import AutoDetectionModel from sahi.postprocess.combine import nms from sdcat import common_args @@ -67,6 +65,7 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, create_logger_file('detect') if not skip_sahi: + from sahi import AutoDetectionModel if model == 'yolov8s': detection_model = AutoDetectionModel.from_pretrained( model_type='yolov8', @@ -101,6 +100,7 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, ) elif model == 'MBARI/megamidwater': # Download model path + from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="MBARI-org/megamidwater", filename="best.pt") detection_model = AutoDetectionModel.from_pretrained( model_type='yolov5', @@ -111,6 +111,7 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, ) elif model == 'MBARI/uav-yolov5': # Download model path + from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="MBARI-org/uav-yolov5", filename="best.pt") detection_model = AutoDetectionModel.from_pretrained( model_type='yolov5', @@ -121,6 +122,7 @@ def run_detect(show: bool, image_dir: str, save_dir: str, model: str, ) elif model == 'FathomNet/MBARI-315k-yolov5': # Download model path + from huggingface_hub import hf_hub_download model_path = hf_hub_download(repo_id="FathomNet/MBARI-315k-yolov5", filename="mbari_315k_yolov5.pt") detection_model = AutoDetectionModel.from_pretrained( model_type='yolov5',