diff --git a/aipipeline/projects/bio/run_strided_inference.py b/aipipeline/projects/bio/run_strided_inference.py index 49e6397..e2954fc 100644 --- a/aipipeline/projects/bio/run_strided_inference.py +++ b/aipipeline/projects/bio/run_strided_inference.py @@ -2,6 +2,8 @@ # Filename: projects/bio/run_strided_inference.py # Description: commands related to running inference on strided video with REDIS queue based load import argparse +import ast + import dotenv import io import json @@ -143,7 +145,7 @@ def video_to_frame(timestamp: str, video_path: Path, output_path: Path): "-i", video_path.as_posix(), "-frames:v", "1", "-qmin", "1", - "-q:v", "1", # Best quality JPG + "-q:v", "1", # Best quality JPG "-y", output_path.as_posix(), ] @@ -157,11 +159,13 @@ def video_to_frame(timestamp: str, video_path: Path, output_path: Path): except subprocess.CalledProcessError as e: logger.error(f"Error occurred: {e}") + def run_inference( video_file: str, stride: int, endpoint_url: str, - class_name: str, + allowed_class_names: [str], + remapped_class_names: dict, version_id: int = 0, min_confidence: float = 0.1, remove_vignette: bool = False, @@ -209,7 +213,7 @@ def seconds_to_timestamp(seconds): logger.error(f"Failed to get ancillary data for {dive}") return if ancillary_data_start["depthMeters"] < 200: - logger.info(f"{video_path.name}====>Depth {ancillary_data_start['depthMeters']} " + logger.info(f"{video_path.name}====>Depth {ancillary_data_start['depthMeters']} " f"is less than 200 meters, skipping") return @@ -228,7 +232,8 @@ def seconds_to_timestamp(seconds): if response.status_code == 200: break except Exception as e: - logger.error(f"{video_path.name}: error processing frame at {current_time_secs} seconds: {e} in {video_path}") + logger.error( + f"{video_path.name}: error processing frame at {current_time_secs} seconds: {e} in {video_path}") # delay to avoid overloading the server time.sleep(5) continue @@ -247,132 +252,132 @@ def seconds_to_timestamp(seconds): # Remove duplicates data = [dict(t) for t in {tuple(d.items()) for d in data}] - # Normalize the bounding box coordinates - for loc in data: - loc["x"] = loc["x"] / frame_width - loc["y"] = loc["y"] / frame_height - loc["xx"] = (loc["x"] + loc["width"]) / frame_width - loc["xy"] = (loc["y"] + loc["height"]) / frame_height - if remove_vignette: # Remove any detections in the corner 1% of the frame threshold = 0.01 # 1% threshold for loc in data: - if not ( - (0 <= loc['x'] <= threshold or 1 - threshold <= loc['x'] <= 1) or - (0 <= loc['y'] <= threshold or 1 - threshold <= loc['y'] <= 1) or - (0 <= loc['xx'] <= threshold or 1 - threshold <= loc['xx'] <= 1) or - (0 <= loc['xy'] <= threshold or 1 - threshold <= loc['xy'] <= 1) + x = loc["x"] / frame_width + y = loc["y"] / frame_height + xx = (loc["x"] + loc["width"]) / frame_width + xy = (loc["y"] + loc["height"]) / frame_height + if ( + (0 <= x <= threshold or 1 - threshold <= x <= 1) or + (0 <= y <= threshold or 1 - threshold <= y <= 1) or + (0 <= xx <= threshold or 1 - threshold <= xx <= 1) or + (0 <= xy <= threshold or 1 - threshold <= xy <= 1) ): - logger.info(f"{loc} is not in the corner") - else: data.remove(loc) for loc in data: - if loc["class_name"] == class_name and loc["confidence"] >= min_confidence: - if loc["confidence"] < 0.9: - if not skip_vss: - logger.info( - f"{video_path.name}: running VSS model on low confidence {class_name} detection {loc['confidence']}") - - # For low confidence detections, run through the vss model - # Crop the image to the bounding box - crop_path = output_path / f"{video_path.stem}_{index}_crop.jpg" - data = { - "image_path": output_frame.as_posix(), - "crop_path": crop_path.as_posix(), - "image_width": frame_width, - "image_height": frame_height, - "x": loc["x"], - "y": loc["y"], - "xx": loc["xx"], - "xy": loc["xy"], - } - s = pd.Series(data) - crop_square_image(s, 224) - images = [read_image(crop_path.as_posix())] - file_paths, best_predictions, best_scores = run_vss(images, config_dict, top_k=3) - crop_path.unlink() - if len(best_predictions) == 0: - logger.info(f"{video_path.name}: no predictions from VSS model. Skipping this detection.") - continue - if best_predictions[0] != class_name: - logger.info( - f"{video_path.name}: VSS model prediction {best_predictions[0]} does not match {class_name}. Skipping this detection.") - continue - logger.info(f"===>{video_path.name}: VSS model prediction {best_predictions[0]} matches {class_name}<====") - else: - logger.info(f"{video_path.name}: {class_name} detection {loc['confidence']}") - else: - logger.info(f"====>{video_path.name}: {class_name} detection {loc['confidence']}<====") - - if not queued_video: - queued_video = True - # Only queue the video if we have a valid localization to queue - # Video transcoding to gif for thumbnail generation is expensive - try: - logger.info(f"Queuing video {video_path.name}") - video_path = Path(video_file) - md = get_video_metadata(video_path.name) - if md is None: - logger.error(f"{video_path.name} failed to get video metadata") - return - iso_start = md["start_timestamp"] - # Convert the start time to a datetime object - iso_start_datetime = datetime.strptime(iso_start, "%Y-%m-%dT%H:%M:%SZ") - # Queue the video first - video_ref_uuid = md["video_reference_uuid"] - iso_start = md["start_timestamp"] - video_url = md["uri"] - logger.info(f"video_ref_uuid: {video_ref_uuid}") - redis_queue.hset( - f"video_refs_start:{video_ref_uuid}", - "start_timestamp", - iso_start, - ) - redis_queue.hset( - f"video_refs_load:{video_ref_uuid}", - "video_uri", - video_url, - ) - except Exception as e: - logger.info(f"Error: {e}") - # Remove the video reference from the queue - redis_queue.delete(f"video_refs_start:{video_ref_uuid}") - redis_queue.delete(f"video_refs_load:{video_ref_uuid}") - return - - loc_datetime = iso_start_datetime + timedelta(seconds=current_time_secs) - loc_datetime_str = loc_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") - logger.info(f"queuing loc: {loc} for {class_name} {dive} {loc_datetime}") - ancillary_data = get_ancillary_data(dive, config_dict, loc_datetime) - if ancillary_data is None or "depthMeters" not in ancillary_data: - logger.error(f"Failed to get ancillary data for {dive}") - continue - - new_loc = { - "x1": loc["x"], - "y1": loc["y"], - "x2": loc["x"] + loc["width"], - "y2": loc["y"] + loc["height"], - "width": frame_width, - "height": frame_height, - "frame": frame_number, - "version_id": version_id, - "score": loc["confidence"], - "cluster": -1, - "label": loc["class_name"], - "dive": dive, - "depth": ancillary_data["depthMeters"], - "iso_datetime": loc_datetime_str, - "latitude": ancillary_data["latitude"], - "longitude": ancillary_data["longitude"], - "temperature": ancillary_data["temperature"], - "oxygen": ancillary_data["oxygen"], + # Skip detections with low confidence or not the target class + if loc["confidence"] < min_confidence or loc["class_name"] not in allowed_class_names: + continue + + if not skip_vss: + logger.info( + f"{video_path.name}: running VSS model on detection {loc['confidence']}") + + # Crop the image to the bounding box + crop_path = output_path / f"{video_path.stem}_{index}_crop.jpg" + data = { + "image_path": output_frame.as_posix(), + "crop_path": crop_path.as_posix(), + "image_width": frame_width, + "image_height": frame_height, + "x": loc["x"], + "y": loc["y"], + "xx": loc["xx"], + "xy": loc["xy"], } - redis_queue.hset(f"locs:{video_ref_uuid}", str(idl), json.dumps(new_loc)) - logger.info(f"{video_path.name} found total possible {idl} localizations of {class_name}") - idl += 1 + s = pd.Series(data) + crop_square_image(s, 224) + images = [read_image(crop_path.as_posix())] + file_paths, best_predictions, best_scores = run_vss(images, config_dict, top_k=3) + crop_path.unlink() + if len(best_predictions) == 0: + logger.info(f"{video_path.name}: no predictions from VSS model. Skipping this detection.") + continue + if best_predictions[0] not in allowed_class_names: + logger.info( + f"{video_path.name}: VSS model prediction {best_predictions[0]} not in {allowed_class_names}. Skipping this detection.") + continue + logger.info(f"===>{video_path.name}: VSS model prediction {best_predictions[0]} in {allowed_class_names}<====") + loc["class_name"] = best_predictions[0] + else: + logger.info(f"{video_path.name}: {loc['class_name']} detection {loc['confidence']}") + + if not queued_video: + queued_video = True + # Only queue the video if we have a valid localization to queue + # Video transcoding to gif for thumbnail generation is expensive + try: + logger.info(f"Queuing video {video_path.name}") + video_path = Path(video_file) + md = get_video_metadata(video_path.name) + if md is None: + logger.error(f"{video_path.name} failed to get video metadata") + return + iso_start = md["start_timestamp"] + # Convert the start time to a datetime object + iso_start_datetime = datetime.strptime(iso_start, "%Y-%m-%dT%H:%M:%SZ") + # Queue the video first + video_ref_uuid = md["video_reference_uuid"] + iso_start = md["start_timestamp"] + video_url = md["uri"] + logger.info(f"video_ref_uuid: {video_ref_uuid}") + redis_queue.hset( + f"video_refs_start:{video_ref_uuid}", + "start_timestamp", + iso_start, + ) + redis_queue.hset( + f"video_refs_load:{video_ref_uuid}", + "video_uri", + video_url, + ) + except Exception as e: + logger.info(f"Error: {e}") + # Remove the video reference from the queue + redis_queue.delete(f"video_refs_start:{video_ref_uuid}") + redis_queue.delete(f"video_refs_load:{video_ref_uuid}") + return + + loc_datetime = iso_start_datetime + timedelta(seconds=current_time_secs) + loc_datetime_str = loc_datetime.strftime("%Y-%m-%dT%H:%M:%SZ") + logger.info(f"queuing loc: {loc} {dive} {loc_datetime}") + ancillary_data = get_ancillary_data(dive, config_dict, loc_datetime) + if ancillary_data is None or "depthMeters" not in ancillary_data: + logger.error(f"Failed to get ancillary data for {dive}") + continue + + if remapped_class_names: + label = remapped_class_names[loc["class_name"]] + else: + label = loc["class_name"] + + new_loc = { + "x1": loc["x"], + "y1": loc["y"], + "x2": loc["x"] + loc["width"], + "y2": loc["y"] + loc["height"], + "width": frame_width, + "height": frame_height, + "frame": frame_number, + "version_id": version_id, + "score": loc["confidence"], + "cluster": -1, + "label": label, + "dive": dive, + "depth": ancillary_data["depthMeters"], + "iso_datetime": loc_datetime_str, + "latitude": ancillary_data["latitude"], + "longitude": ancillary_data["longitude"], + "temperature": ancillary_data["temperature"], + "oxygen": ancillary_data["oxygen"], + } + redis_queue.hset(f"locs:{video_ref_uuid}", str(idl), json.dumps(new_loc)) + logger.info(f"{video_path.name} found total possible {idl} localizations") + idl += 1 else: logger.error(f"Error processing frame at {current_time_secs} seconds: {response.text}") @@ -383,12 +388,13 @@ def seconds_to_timestamp(seconds): def process_videos(video_files, stride, endpoint_url, class_name, version_id, min_confidence, - remove_vignette=False, skip_vss=False): + remove_vignette=False, skip_vss=False): num_cpus = multiprocessing.cpu_count() pool = multiprocessing.Pool(processes=num_cpus) pool.starmap( run_inference, - [(v, stride, endpoint_url, class_name, version_id, min_confidence, remove_vignette, skip_vss) for v in video_files], + [(v, stride, endpoint_url, class_name, version_id, min_confidence, remove_vignette, skip_vss) for v in + video_files], ) pool.close() pool.join() @@ -431,6 +437,17 @@ def parse_args(): parser.add_argument("--min-confidence", help="Minimum confidence for detections.", default=0.1, type=float) parser.add_argument("--flush", help="Flush the REDIS database.", action="store_true") parser.add_argument("--remove-vignette", help="Remove vignette detection.", action="store_true") + parser.add_argument( + '--allowed-classes', + type=str, + nargs='+', # Accepts multiple values + help='List of allowed classes.' + ) + parser.add_argument( + '--class-remap', + type=str, + help='Dictionary of class remapping, formatted as a string.' + ) return parser.parse_args() @@ -450,7 +467,7 @@ def parse_args(): # Get the version id from the database project = config_dict["tator"]["project"] host = config_dict["tator"]["host"] - if args.version: # Override the version in the config file + if args.version: # Override the version in the config file config_dict["data"]["version"] = args.version version = config_dict["data"]["version"] api, project = init_api_project(host=host, token=TATOR_TOKEN, project=project) @@ -459,6 +476,10 @@ def parse_args(): logger.error(f"Failed to get version id for {version}") exit(1) + # Convert the remapped class names to a dictionary + if args.class_remap: + args.class_remap = ast.literal_eval(args.class_remap) + # Need to have a video or TSV file with video paths to process if not args.video and not args.tsv: logger.error("Must provide either a video or TSV file with video paths") @@ -498,7 +519,8 @@ def parse_args(): video_path.as_posix(), args.stride, args.endpoint_url, - args.class_name, + args.allowed_classes, + args.class_remap, version_id, args.min_confidence, remove_vignette=args.remove_vignette, @@ -512,7 +534,8 @@ def parse_args(): video_files, args.stride, args.endpoint_url, - args.class_name, + args.allowed_classes, + args.class_remap, version_id, args.min_confidence, remove_vignette=args.remove_vignette, @@ -535,9 +558,10 @@ def parse_args(): video_files, args.stride, args.endpoint_url, - args.class_name, + args.allowed_classes, + args.class_remap, version_id, skip_vss=args.skip_vss, ) - logger.info("Finished processing videos") \ No newline at end of file + logger.info("Finished processing videos")