This is a PyTorch/GPU re-implementation of the paper MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis (to appear in CVPR 2023):
@article{li2022mage,
title={MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis},
author={Li, Tianhong and Chang, Huiwen and Mishra, Shlok Kumar and Zhang, Han and Katabi, Dina and Krishnan, Dilip},
journal={arXiv preprint arXiv:2211.09117},
year={2022}
}
MAGE is a unified framework for both generative modeling and representation learning, achieving SOTA results in both class-unconditional image generation and linear probing on ImageNet-1K.
A large portion of codes in this repo is based on MAE and VQGAN. The original implementation was in JAX/TPU.
Download ImageNet dataset, and place it in your IMAGENET_DIR
.
A suitable conda environment named mage
can be created and activated with:
conda env create -f environment.yaml
conda activate mage
Download the code
git clone https://github.com/LTH14/mage.git
cd mage
Use this link to download the pre-trained VQGAN tokenzier and put it in the mage directory.
To pre-train a MAGE ViT-B model with 4096 batch size using 8 servers with 8 V100 GPUs per server:
python -m torch.distributed.launch --node_rank=0 --nproc_per_node=8 --nnodes=8 \
--master_addr="${MASTER_SERVER_ADDRESS}" --master_port=12344 \
main_pretrain.py \
--batch_size 64 \
--model mage_vit_base_patch16 \
--mask_ratio_min 0.5 --mask_ratio_max 1.0 \
--mask_ratio_mu 0.55 --mask_ratio_std 0.25 \
--epochs 1600 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214
The following table provides the performance and weights of the pre-trained checkpoints used in the paper, converted from JAX/TPU to PT/GPU:
ViT-Base | ViT-Large | |
---|---|---|
Checkpoint | Google Drive | Google Drive |
Class-unconditional Generation FID | 11.1 | 9.10 |
Class-unconditional Generation IS | 81.2 | 105.1 |
Linear Probing Top-1 Accuracy | 74.7% | 78.9% |
Fine-tuning Top-1 Accuracy | 82.5% Checkpoint | 83.9% Checkpoint |
To perform linear probing on pre-trained MAGE model using 4 servers with 8 V100 GPUs per server:
python -m torch.distributed.launch --node_rank=0 --nproc_per_node=8 --nnodes=4 \
--master_addr="${MASTER_SERVER_ADDRESS}" --master_port=12344 \
main_linprobe.py \
--batch_size 128 \
--model vit_base_patch16 \
--global_pool \
--finetune ${PRETRAIN_CHKPT} \
--epochs 90 \
--blr 0.1 \
--weight_decay 0.0 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_eval --dist_url tcp://${MASTER_SERVER_ADDRESS}:6311
For ViT-L, set --blr 0.05
.
To perform fine-tuning with pre-trained ViT-B model using 4 servers with 8 V100 GPUs per server:
python -m torch.distributed.launch --node_rank=0 --nproc_per_node=8 --nnodes=4 \
--master_addr="${MASTER_SERVER_ADDRESS}" --master_port=12344 \
main_finetune.py \
--batch_size 32 \
--model vit_base_patch16 \
--global_pool \
--finetune ${PRETRAIN_CHKPT} \
--epochs 100 \
--blr 2.5e-4 --layer_decay 0.65 --interpolation bicubic \
--weight_decay 0.05 --drop_path 0.1 --reprob 0 --mixup 0.8 --cutmix 1.0 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_eval --dist_url tcp://${MASTER_SERVER_ADDRESS}:6311
For ViT-L, set --epochs 50 --layer_decay 0.75 --drop_path 0.2
.
To perform class unconditional generation with pre-trained MAGE model using a single V100 GPU:
python gen_img_uncond.py --temp 6.0 --num_iter 20 \
--ckpt ${PRETRAIN_CHKPT} --batch_size 32 --num_images 50000 \
--model mage_vit_base_patch16 --output_dir ${OUTPUT_DIR}
To quantitatively evaluate FID/IS, please first generate 256x256 ImageNet validation images using
python prepare_imgnet_val.py --data_path ${IMAGENET_DIR} --output_dir ${OUTPUT_DIR}
Then install the torch-fidelity package by
pip install torch-fidelity
Then use the above package to evaluate FID/IS of the images generated by our models against 256x256 ImageNet validation images by
fidelity --gpu 0 --isc --fid --input1 ${GENERATED_IMAGES_DIR} --input2 ${IMAGENET256X256_DIR}
Here are some examples of our class-unconditional generation:
Here we provide the pre-trained MAGE-C checkpoints converted from JAX/TPU to PT/GPU: ViT-B, ViT-L. PyTorch training script coming soon.
If you have any questions, feel free to contact me through email ([email protected]). Enjoy!