diff --git a/aipipeline/prediction/library.py b/aipipeline/prediction/library.py index fd37132..7ca7c47 100644 --- a/aipipeline/prediction/library.py +++ b/aipipeline/prediction/library.py @@ -4,7 +4,6 @@ import multiprocessing import os import shutil -import time from datetime import datetime import numpy as np from pathlib import Path @@ -79,6 +78,9 @@ def generate_multicrop_views2(image) -> List[tuple]: def clean_bad_images(element) -> tuple: count, crop_path, save_path = element num_removed = 0 + # Check if any images exist + if count == 0: + return count, crop_path, save_path imagelab = Imagelab(data_path=crop_path) imagelab.find_issues() imagelab.report() @@ -139,8 +141,8 @@ def generate_multicrop_views(elements) -> List[tuple]: return data -def cluster(data, config_dict: Dict) -> List[tuple]: - logger.info(f'Clustering {data}') +def cluster(data, config_dict: Dict, min_detections: int) -> List[tuple]: + logger.info(f'Clustering {data} with min_detections {min_detections}') num_images, crop_dir, cluster_dir = data project = config_dict["tator"]["project"] sdcat_config = config_dict["sdcat"]["ini"] @@ -152,11 +154,6 @@ def cluster(data, config_dict: Dict) -> List[tuple]: short_name = get_short_name(project) logger.info(data) - # If there are less than 500 images, skip clustering - if num_images < 500: - logger.info(f"Skipping clustering for {num_images} images in {crop_dir}") - return [(Path(crop_dir).name, cluster_dir)] - logger.info(f"Clustering {num_images} images in {crop_dir} ....") min_cluster_size = 2 @@ -179,6 +176,19 @@ def cluster(data, config_dict: Dict) -> List[tuple]: label = Path(crop_dir).name machine_friendly_label = gen_machine_friendly_label(label) try: + # Skip clustering if there are too few images, but generate a detection file for the next step + if num_images < min_detections: + logger.info(f"Skipping clustering for {label} with {num_images} images") + Path(cluster_dir).mkdir(parents=True, exist_ok=True) + cluster_results.append((Path(crop_dir).name, cluster_dir)) + images = [f"{crop_dir}/{f}" for f in os.listdir(crop_dir) if f.endswith(".jpg")] + with open(f"{cluster_dir}/no_cluster_exemplars.csv", "w") as f: + # Add the header image_path,image_width,image_height,crop_path,cluster + f.write("image_path,image_width,image_height,crop_path,cluster\n") + for image in images: + f.write(f"{image},224,224,{image},-1\n") + return cluster_results + container = run_docker( image=config_dict["docker"]["sdcat"], name=f"{short_name}-sdcat-clu-{machine_friendly_label}", @@ -203,8 +213,9 @@ def cluster(data, config_dict: Dict) -> List[tuple]: class ProcessClusterBatch(beam.DoFn): - def __init__(self, config_dict: Dict): + def __init__(self, config_dict: Dict, min_detections: int): self.config_dict = config_dict + self.min_detections = min_detections def process(self, batch): if len(batch) > 1: @@ -212,7 +223,7 @@ def process(self, batch): else: num_processes = 1 with multiprocessing.Pool(num_processes) as pool: - args = [(data, self.config_dict) for data in batch] + args = [(data, self.config_dict, self.min_detections) for data in batch] results = pool.starmap(cluster, args) return results diff --git a/aipipeline/prediction/vss_init_pipeline.py b/aipipeline/prediction/vss_init_pipeline.py index 9f1d4aa..3557eea 100755 --- a/aipipeline/prediction/vss_init_pipeline.py +++ b/aipipeline/prediction/vss_init_pipeline.py @@ -46,7 +46,7 @@ # Load exemplars into Vector Search Server -def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str: +def load_exemplars(data, config_dict=Dict, conf_files=Dict, min_exemplars:int = 10, min_detections: int = 2000) -> str: project = str(config_dict["tator"]["project"]) short_name = get_short_name(project) @@ -79,7 +79,7 @@ def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str: logger.info(f"No exemplars or detections found for {label}") continue - if exemplar_count < 10 or detection_count < 100: + if exemplar_count < min_exemplars or detection_count < min_detections: logger.info(f"Too few exemplars, using detections file {detection_file} instead") exemplar_file = detection_file @@ -140,9 +140,11 @@ def run_pipeline(argv=None): parser.add_argument("--config", required=True, help=f"Config file path, e.g. {example_project}") parser.add_argument("--skip-clean", required=False, default=False, help="Skip cleaning of previously downloaded data") parser.add_argument("--skip-download", required=False, default=False, help="Skip downloading data") - parser.add_argument("--batch-size", required=False, type=int, default=2, help="Batch size") + parser.add_argument("--batch-size", required=False, type=int, default=1, help="Batch size") args, beam_args = parser.parse_known_args(argv) + MIN_EXEMPLARS = 10 + MIN_DETECTIONS = 2000 conf_files, config_dict = setup_config(args.config) batch_size = int(args.batch_size) download_args = config_dict["data"]["download_args"] @@ -175,8 +177,8 @@ def run_pipeline(argv=None): | "Generate views" >> beam.Map(generate_multicrop_views) | "Clean dark blurry examples" >> beam.Map(clean_images) | 'Batch cluster ROI elements' >> beam.FlatMap(lambda x: batch_elements(x, batch_size=batch_size)) - | 'Process cluster ROI batches' >> beam.ParDo(ProcessClusterBatch(config_dict=config_dict)) - | "Load exemplars" >> beam.Map(load_exemplars, config_dict=config_dict, conf_files=conf_files) + | 'Process cluster ROI batches' >> beam.ParDo(ProcessClusterBatch(config_dict=config_dict, min_detections=MIN_DETECTIONS)) + | "Load exemplars" >> beam.Map(load_exemplars, config_dict=config_dict, conf_files=conf_files,min_exemplars=MIN_EXEMPLARS, min_detections=MIN_DETECTIONS) | "Log results" >> beam.Map(logger.info) )