forked from lojzezust/WaSR-T
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_sequential.py
113 lines (86 loc) · 4.11 KB
/
predict_sequential.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import torch
from wasr_t.data.folder import FolderDataset
from wasr_t.data.transforms import PytorchHubNormalization
from wasr_t.inference import Predictor
from wasr_t.wasr_t import wasr_temporal_resnet101, wasr_temporal_mobilenetv3
from wasr_t.utils import load_weights, Option
# Colors corresponding to each segmentation class
SEGMENTATION_COLORS = np.array([
[247, 195, 37],
[41, 167, 224],
[90, 75, 164]
], np.uint8)
OUTPUT_DIR = 'output/predictions'
HIST_LEN = 5
RESIZE = (512,384)
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="WaSR Network Sequential Inference")
parser.add_argument("--sequence-dir", type=str, required=False,
help="Path to the directory containing frames of the input sequence.")
parser.add_argument("--hist-len", default=HIST_LEN, type=int,
help="Number of past frames to be considered in addition to the target frame (context length). Must match the value used in training.")
parser.add_argument("--weights", type=str, required=True,
help="Model weights file.")
parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
help="Directory where the predictions will be stored.")
parser.add_argument("--resize", type=Option(int), default=RESIZE, nargs='+',
help="Resize input images to a specified size. Use `none` for no resizing.")
parser.add_argument("--fp16", action='store_true',
help="Use half precision for inference.")
parser.add_argument("--gpus", default=-1,
help="Number of gpus (or GPU ids) used for training.")
return parser.parse_args()
def export_predictions(probs, batch, output_dir):
features, metadata = batch
# Class prediction
out_class = probs.argmax(1).astype(np.uint8)
for i, pred_mask in enumerate(out_class):
pred_mask = SEGMENTATION_COLORS[pred_mask]
mask_img = Image.fromarray(pred_mask)
out_path = output_dir / Path(metadata['image_path'][i]).with_suffix('.png')
if not out_path.parent.exists():
out_path.parent.mkdir(parents=True, exist_ok=True)
mask_img.save(str(out_path))
def predict_sequence(predictor, sequence_dir, output_dir, size):
"""Runs inference on a sequence of images. The frames are processed sequentially (stateful). The state is cleared at the start of the sequence."""
predictor.model.clear_state()
dataset = FolderDataset(sequence_dir, normalize_t=PytorchHubNormalization(), resize=size)
dl = DataLoader(dataset, batch_size=1, num_workers=1) # NOTE: Batch size must be 1 in sequential mode.
for batch in tqdm(dl, desc='Processing frames'):
features, metadata = batch
probs = predictor.predict_batch(features)
export_predictions(probs, batch, output_dir=output_dir)
def run_inference(args):
model = wasr_temporal_mobilenetv3(pretrained=False, hist_len=args.hist_len, sequential=True)
state_dict = load_weights(args.weights)
# if PyTorch 2.0's torch.compile() function generated these weights, then we need to remove
# the _orig_mod label from each parameter.
state_dict = {key.replace("_orig_mod.", "") : value for key, value in state_dict.items()}
model.load_state_dict(state_dict)
model = model.sequential() # Enable sequential mode
# model = model.unrolled()
model.eval()
# model = torch.compile(model, mode="max-autotune")
predictor = Predictor(model, half_precision=args.fp16, device=torch.device('cpu'))
output_dir = Path(args.output_dir)
size = None
if args.resize[0] is not None:
size = args.resize
with torch.inference_mode():
predict_sequence(predictor, args.sequence_dir, output_dir, size=size)
def main():
args = get_arguments()
print(args)
run_inference(args)
if __name__ == '__main__':
main()