Skip to content

Commit

Permalink
perf: skip cluster in vss pipeline for < 2000 detections
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 14, 2024
1 parent 8b3aa53 commit 2591116
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
31 changes: 21 additions & 10 deletions aipipeline/prediction/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import multiprocessing
import os
import shutil
import time
from datetime import datetime
import numpy as np
from pathlib import Path
Expand Down Expand Up @@ -79,6 +78,9 @@ def generate_multicrop_views2(image) -> List[tuple]:
def clean_bad_images(element) -> tuple:
count, crop_path, save_path = element
num_removed = 0
# Check if any images exist
if count == 0:
return count, crop_path, save_path
imagelab = Imagelab(data_path=crop_path)
imagelab.find_issues()
imagelab.report()
Expand Down Expand Up @@ -139,8 +141,8 @@ def generate_multicrop_views(elements) -> List[tuple]:
return data


def cluster(data, config_dict: Dict) -> List[tuple]:
logger.info(f'Clustering {data}')
def cluster(data, config_dict: Dict, min_detections: int) -> List[tuple]:
logger.info(f'Clustering {data} with min_detections {min_detections}')
num_images, crop_dir, cluster_dir = data
project = config_dict["tator"]["project"]
sdcat_config = config_dict["sdcat"]["ini"]
Expand All @@ -152,11 +154,6 @@ def cluster(data, config_dict: Dict) -> List[tuple]:
short_name = get_short_name(project)
logger.info(data)

# If there are less than 500 images, skip clustering
if num_images < 500:
logger.info(f"Skipping clustering for {num_images} images in {crop_dir}")
return [(Path(crop_dir).name, cluster_dir)]

logger.info(f"Clustering {num_images} images in {crop_dir} ....")
min_cluster_size = 2

Expand All @@ -179,6 +176,19 @@ def cluster(data, config_dict: Dict) -> List[tuple]:
label = Path(crop_dir).name
machine_friendly_label = gen_machine_friendly_label(label)
try:
# Skip clustering if there are too few images, but generate a detection file for the next step
if num_images < min_detections:
logger.info(f"Skipping clustering for {label} with {num_images} images")
Path(cluster_dir).mkdir(parents=True, exist_ok=True)
cluster_results.append((Path(crop_dir).name, cluster_dir))
images = [f"{crop_dir}/{f}" for f in os.listdir(crop_dir) if f.endswith(".jpg")]
with open(f"{cluster_dir}/no_cluster_exemplars.csv", "w") as f:
# Add the header image_path,image_width,image_height,crop_path,cluster
f.write("image_path,image_width,image_height,crop_path,cluster\n")
for image in images:
f.write(f"{image},224,224,{image},-1\n")
return cluster_results

container = run_docker(
image=config_dict["docker"]["sdcat"],
name=f"{short_name}-sdcat-clu-{machine_friendly_label}",
Expand All @@ -203,16 +213,17 @@ def cluster(data, config_dict: Dict) -> List[tuple]:

class ProcessClusterBatch(beam.DoFn):

def __init__(self, config_dict: Dict):
def __init__(self, config_dict: Dict, min_detections: int):
self.config_dict = config_dict
self.min_detections = min_detections

def process(self, batch):
if len(batch) > 1:
num_processes = min(1, len(batch))
else:
num_processes = 1
with multiprocessing.Pool(num_processes) as pool:
args = [(data, self.config_dict) for data in batch]
args = [(data, self.config_dict, self.min_detections) for data in batch]
results = pool.starmap(cluster, args)
return results

Expand Down
12 changes: 7 additions & 5 deletions aipipeline/prediction/vss_init_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


# Load exemplars into Vector Search Server
def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str:
def load_exemplars(data, config_dict=Dict, conf_files=Dict, min_exemplars:int = 10, min_detections: int = 2000) -> str:
project = str(config_dict["tator"]["project"])
short_name = get_short_name(project)

Expand Down Expand Up @@ -79,7 +79,7 @@ def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str:
logger.info(f"No exemplars or detections found for {label}")
continue

if exemplar_count < 10 or detection_count < 100:
if exemplar_count < min_exemplars or detection_count < min_detections:
logger.info(f"Too few exemplars, using detections file {detection_file} instead")
exemplar_file = detection_file

Expand Down Expand Up @@ -140,9 +140,11 @@ def run_pipeline(argv=None):
parser.add_argument("--config", required=True, help=f"Config file path, e.g. {example_project}")
parser.add_argument("--skip-clean", required=False, default=False, help="Skip cleaning of previously downloaded data")
parser.add_argument("--skip-download", required=False, default=False, help="Skip downloading data")
parser.add_argument("--batch-size", required=False, type=int, default=2, help="Batch size")
parser.add_argument("--batch-size", required=False, type=int, default=1, help="Batch size")
args, beam_args = parser.parse_known_args(argv)

MIN_EXEMPLARS = 10
MIN_DETECTIONS = 2000
conf_files, config_dict = setup_config(args.config)
batch_size = int(args.batch_size)
download_args = config_dict["data"]["download_args"]
Expand Down Expand Up @@ -175,8 +177,8 @@ def run_pipeline(argv=None):
| "Generate views" >> beam.Map(generate_multicrop_views)
| "Clean dark blurry examples" >> beam.Map(clean_images)
| 'Batch cluster ROI elements' >> beam.FlatMap(lambda x: batch_elements(x, batch_size=batch_size))
| 'Process cluster ROI batches' >> beam.ParDo(ProcessClusterBatch(config_dict=config_dict))
| "Load exemplars" >> beam.Map(load_exemplars, config_dict=config_dict, conf_files=conf_files)
| 'Process cluster ROI batches' >> beam.ParDo(ProcessClusterBatch(config_dict=config_dict, min_detections=MIN_DETECTIONS))
| "Load exemplars" >> beam.Map(load_exemplars, config_dict=config_dict, conf_files=conf_files,min_exemplars=MIN_EXEMPLARS, min_detections=MIN_DETECTIONS)
| "Log results" >> beam.Map(logger.info)
)

Expand Down

0 comments on commit 2591116

Please sign in to comment.