[paper
] [BibTeX
] [weights
] [data
]
This is the official PyTorch implementation of the WaSR-T network [1]. Contains scripts for training and running the network and weights pretrained on the MaSTr1325 [2] (and MaSTr1478) dataset.
Our work was presented at the IROS 2022 conference in Kyoto, Japan.
Comparison between WaSR (single-frame) and WaSR-T (temporal context) on hard examples.
WaSR-T is a temporal extension of the established WaSR model [3] for maritime obstacle detection. It harnesses the temporal context of recent image frames to reduce the ambiguity on reflections and improve the overall robustness of predictions.
The target and context (i.e. previous) frames are encoded using a shared encoder network. To extract the temporal context from the past frames, we apply a 3D convolution operation over the temporal dimension. The 3D convolution is able to extract discriminative information about local texture changes over the recent frames. The resulting temporal context is concatenated with the target frame features and passed to the decoder, which produces the final predictions.
Requirements: Python >= 3.6 (tested on Python 3.8), PyTorch 1.8.1, PyTorch Lightning 1.4.4 (for training)
The required Python libraries can be installed using the following pip command
pip install -r requirements.txt
The WaSR-T model used in our experiments (ResNet-101 backbone) can be initialized with the following code.
from wasr_t.wasr_t import wasr_temporal_resnet101
model = wasr_temporal_resnet101(num_classes=3)
WaSR-T model operates in two different modes:
- sequential: An online mode, useful for inference. Only one frame is processed at a time. Features of the previous frames are stored in a circular buffer. The frames must be processed one after the other, thus the batch size must be 1. The context buffer is initialized from copies of the first frame in the sequence.
- unrolled: An offline mode, used during training. Each sample consists of a target frame and the required previous frames. Supports batched processing.
You can switch between the two modes by calling sequential()
or unrolled()
on the model.
Example of sequential operation:
model = model.sequential()
model.clear_state() # Clear the temporal buffer of the model
for image in sequence:
# image is a [1,3,H,W] tensor
output = model({'image': image})
Note: If you run inference on multiple sequences you must call clear_state()
on the model to clear the buffer before moving to a new sequence. Otherwise the context of the last frames of the previous sequence will be used, which may lead to faulty predictions.
Example of unrolled operation:
model = model.unrolled()
# images is a batch of images: [B,3,H,W] tensor, where B is the batch size
# hist_images is a batch of context images: [B,T,3,H,W], where T is the number of context frames used by the network (default 5)
output = model({'image': images, 'hist_images': hist_images})
To run sequential WaSR-T inference on a sequence of image frames use the predict_sequential.py
script.
# export CUDA_VISIBLE_DEVICES=-1 # CPU only
export CUDA_VISIBLE_DEVICES=0 # GPU to use
python predict_sequential.py \
--sequence-dir examples/sequence \
--weights path/to/model/weights.pth \
--output-dir output/predictions
The script will loop over the images in the --sequence-dir
directory in alphabetical order. Predictions will be stored as color-coded masks to the specified output directory.
If you wish to run inference on a video file, first convert the file to a sequence of images. For example, using ffmpeg:
mkdir sequence_images
ffmpeg -i video.mp4 sequence_images/frame_%05d.jpg
Currently available pretrained model weights. All models are evaluated on the MODS benchmark [4]. F1 scores overall and inside the danger zone are reported in the table.
backbone | T | dataset | F1 | F1D | weights |
---|---|---|---|---|---|
ResNet-101 | 5 | MaSTr1325 | 93.7 | 87.3 | link |
ResNet-101 | 5 | MaSTr1478 | 94.4 | 93.6 | link |
To train your own models, use the train.py
script. For example, to reproduce the results of our experiments use the following steps:
- Download and prepare the MaSTr1325 dataset (images and GT masks). Also download the context frames for the MaSTr1325 images here.
- Edit the dataset configuration files (
configs/mastr_1325_train.yaml
,configs/mastr1325_extra
andconfigs/mastr1325_val.yaml
) so that they correctly point to the dataset directories. - Use the
train.py
to train the network.
export CUDA_VISIBLE_DEVICES=0,1,2,3 # GPUs to use
python train.py \
--train-config configs/mastr1325_train.yaml \
--val-config configs/mastr1325_val.yaml \
--validation \
--model-name my_wasr \
--batch-size 2 \
--epochs 100
Note: Model training requires a large amount of GPU memory (>11 GB per GPU). If you use smaller GPUs, you can reduce the memory consumption by decreasing the number of backbone backpropagation steps (--backbone-grad-steps
) or using a smaller context length (--hist-len
).
A log dir with the specified model name will be created inside the output
directory. Model checkpoints and training logs will be stored here. At the end of the training the model weights are also exported to a weights.pth
file inside this directory.
Logged metrics (loss, validation accuracy, validation IoU) can be inspected using tensorboard.
tensorboard --logdir output/logs/model_name
We extend the MaSTr1325 dataset by providing the context frames (5 preceding frames). We also extend the dataset with additional hard examples to form MaSTr1478.
If you use this code, please cite our paper:
@InProceedings{Zust2022Temporal,
title={Temporal Context for Robust Maritime Obstacle Detection},
author={{\v{Z}}ust, Lojze and Kristan, Matej},
booktitle={2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
year={2022}
}
[1] Žust, L., & Kristan, M. (2022). Temporal Context for Robust Maritime Obstacle Detection. 2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)
[2] Bovcon, B., Muhovič, J., Perš, J., & Kristan, M. (2019). The MaSTr1325 dataset for training deep USV obstacle detection models. 2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)
[3] Bovcon, B., & Kristan, M. (2021). WaSR--A Water Segmentation and Refinement Maritime Obstacle Detection Network. IEEE Transactions on Cybernetics
[4] Bovcon, B., Muhovič, J., Vranac, D., Mozetič, D., Perš, J., & Kristan, M. (2021). MODS -- A USV-oriented object detection and obstacle segmentation benchmark.
You should use this fork of the repository (https://github.com/playertr/WaSR-T) for compatibility with the MobileNetV3. I trained a network using an RTX2060 with the command
python train.py \
--train-config MaSTr1325/mastr1325_train.yaml \
--val-config MaSTr1325/mastr1325_val.yaml \
--validation --model-name mastr1478_mobilenetv3_e \
--batch-size 6 --epochs 500 --patience 50 \
--additional-train-config MaSTr153/mastr153_train.yaml
Training 27 epochs (starting from the DeepLabs COCO pretraining) took a couple of hours on my RTX 2060.
You can get the weights here: https://drive.google.com/file/d/19uASKkNV-IwsBNGR5WtAV_P--hyVPs03/view?usp=sharing
Inference took 156 MB of VRAM and was 48 FPS on my RTX2060, but was only 3.5 FPS on the Nano. It got 99.5% obstacle IOU on the validation set. The output on MaSTr1325 is below (it's wild to me that so little training data was needed! I wonder how well it will generalize to different lighting and weather conditions.)
Note: there might still be performance improvements to be had by retraining with different implementation. For instance, the designation of the
skip1
andskip2
intermediate variables, and the resulting tensor sizes within the decoder module, might be mistaken.
You need PyTorch and the other dependencies (not necessarily the specific listed versions) to run this network. The version of Pytorch that ships with the classic Jetson Nano (on JetPack 4.6) is pretty old, so I used version 1.10 from the pip wheel at this link. The latest versions (1.13 or 2.0dev as of this righting) are not easily available for ARM devices.