Skip to content

Commit

Permalink
fix: correct check for exemplar files with v0.38.2 perf improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 19, 2024
1 parent c05b37f commit 86f585a
Showing 1 changed file with 10 additions and 27 deletions.
37 changes: 10 additions & 27 deletions aipipeline/prediction/vss_init_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# aipipeline, Apache-2.0 license
# Filename: aipiipeline/prediction/vss-_init_pipeline.py
# Filename: aipiipeline/prediction/vss_init_pipeline.py
# Description: Run the VSS initialization pipeline
import time
from datetime import datetime
Expand Down Expand Up @@ -46,7 +46,7 @@


# Load exemplars into Vector Search Server
def load_exemplars(data, config_dict=Dict, conf_files=Dict, min_exemplars:int = 10, min_detections: int = 2000) -> str:
def load_exemplars(data, config_dict=Dict, conf_files=Dict) -> str:
project = str(config_dict["tator"]["project"])
short_name = get_short_name(project)

Expand All @@ -55,33 +55,17 @@ def load_exemplars(data, config_dict=Dict, conf_files=Dict, min_exemplars:int =
for label, save_dir in data:
machine_friendly_label = gen_machine_friendly_label(label)
# Grab the most recent file
all_exemplars = list(Path(save_dir).rglob("*_exemplars.csv"))
all_exemplars = list(Path(save_dir).rglob("*exemplars.csv"))
logger.info(f"Found {len(all_exemplars)} exemplar files for {label}")
exemplar_file = sorted(all_exemplars, key=os.path.getmtime, reverse=True)[0] if all_exemplars else None

# Grab the most recent detections file
all_detections = list(Path(save_dir).rglob("*cluster_detections.csv"))
detection_file = sorted(all_detections, key=os.path.getmtime, reverse=True)[0] if all_detections else None

if exemplar_file is None:
logger.info(f"No exemplar file found for {label}")
exemplar_count = 0
else:
with open(exemplar_file, "r") as f:
exemplar_count = len(f.readlines())
if detection_file is None:
logger.info(f"No detection file found for {label}")
detection_count = 0
else:
with open(detection_file, "r") as f:
detection_count = len(f.readlines())

if exemplar_count == 0 or detection_count == 0:
logger.info(f"No exemplars or detections found for {label}")
continue

if exemplar_count < min_exemplars or detection_count < min_detections:
logger.info(f"Too few exemplars, using detections file {detection_file} instead")
exemplar_file = detection_file
return f"No exemplar file found for {label}"

exemplar_count = 0
with open(exemplar_file, "r") as f:
exemplar_count = len(f.readlines())

logger.info(f"Loading {exemplar_count} exemplars for {label} as {label} from {exemplar_file}")
args = [
Expand Down Expand Up @@ -143,7 +127,6 @@ def run_pipeline(argv=None):
parser.add_argument("--batch-size", required=False, type=int, default=1, help="Batch size")
args, beam_args = parser.parse_known_args(argv)

MIN_EXEMPLARS = 10
MIN_DETECTIONS = 2000
conf_files, config_dict = setup_config(args.config)
batch_size = int(args.batch_size)
Expand Down Expand Up @@ -178,7 +161,7 @@ def run_pipeline(argv=None):
| "Clean dark blurry examples" >> beam.Map(clean_images)
| 'Batch cluster ROI elements' >> beam.FlatMap(lambda x: batch_elements(x, batch_size=batch_size))
| 'Process cluster ROI batches' >> beam.ParDo(ProcessClusterBatch(config_dict=config_dict, min_detections=MIN_DETECTIONS))
| "Load exemplars" >> beam.Map(load_exemplars, config_dict=config_dict, conf_files=conf_files,min_exemplars=MIN_EXEMPLARS, min_detections=MIN_DETECTIONS)
| "Load exemplars" >> beam.Map(load_exemplars, config_dict=config_dict, conf_files=conf_files)
| "Log results" >> beam.Map(logger.info)
)

Expand Down

0 comments on commit 86f585a

Please sign in to comment.