-
Notifications
You must be signed in to change notification settings - Fork 492
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 51fbb9a
Showing
31 changed files
with
95,585 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Data | ||
*.png | ||
*.gif | ||
*.tar.gz | ||
data/cifar-10-batches-py | ||
|
||
# ipython checkpoints | ||
.ipynb_checkpoints | ||
|
||
# Log | ||
logs | ||
|
||
# ETC | ||
.vscode | ||
|
||
# Created by https://www.gitignore.io/api/python,vim | ||
|
||
### Python ### | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv/ | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
|
||
### Vim ### | ||
# swap | ||
[._]*.s[a-v][a-z] | ||
[._]*.sw[a-p] | ||
[._]s[a-v][a-z] | ||
[._]sw[a-p] | ||
# session | ||
Session.vim | ||
# temporary | ||
.netrwhist | ||
*~ | ||
# auto-generated tag files | ||
tags | ||
|
||
# End of https://www.gitignore.io/api/python,vim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Efficient Neural Architecture Search (ENAS) in PyTorch | ||
|
||
PyTorch implementation of [Efficient Neural Architecture Search via Parameters Sharing](https://arxiv.org/abs/1802.03268). | ||
|
||
<p align="center"><img src="assets/ENAS_rnn.png" alt="ENAS_rnn" width="60%"></p> | ||
|
||
**ENAS** reduce the computational requirement (GPU-hours) of [Neural Architecture Search](https://arxiv.org/abs/1611.01578) (**NAS**) by 1000x via parameter sharing between models that are subgraphs within a large computational graph. SOTA on `Penn Treebank` language modeling. | ||
|
||
|
||
## Prerequisites | ||
|
||
- Python 3.6+ | ||
- [PyTorch](http://pytorch.org/) | ||
- tqdm, scipy, imageio, graphviz, tensorboardX | ||
|
||
## Usage | ||
|
||
Install prerequisites with: | ||
|
||
conda install graphviz | ||
pip install -r requirements.txt | ||
|
||
To train **ENAS** to discover a recurrent cell for RNN: | ||
|
||
python main.py --network_type rnn --dataset ptb --controller_optim adam --controller_lr 0.00035 \ | ||
--shared_optim sgd --shared_lr 20.0 --entropy_coeff 0.0001 | ||
|
||
python main.py --network_type rnn --dataset wikitext | ||
|
||
To train **ENAS** to discover CNN architecture (in progress): | ||
|
||
python main.py --network_type cnn --dataset cifar --controller_optim momentum --controller_lr cosine \ | ||
--controller_lr_max 0.05 --controller_lr_min 0.0001 --entropy_coeff 0.1 | ||
|
||
or you can use your own dataset by placing images like: | ||
|
||
data | ||
├── YOUR_TEXT_DATASET | ||
│ ├── test.txt | ||
│ ├── train.txt | ||
│ └── valid.txt | ||
├── YOUR_IMAGE_DATASET | ||
│ ├── test | ||
│ │ ├── xxx.jpg (name doesn't matter) | ||
│ │ ├── yyy.jpg (name doesn't matter) | ||
│ │ └── ... | ||
│ ├── train | ||
│ │ ├── xxx.jpg | ||
│ │ └── ... | ||
│ └── valid | ||
│ ├── xxx.jpg | ||
│ └── ... | ||
├── image.py | ||
└── text.py | ||
|
||
To generate `gif` image of generated samples: | ||
|
||
python generate_gif.py --model_name=ptb_2018-02-15_11-20-02 --output=sample.gif | ||
|
||
More configurations can be found [here](config.py). | ||
|
||
|
||
## Results | ||
|
||
Efficient Neural Architecture Search (**ENAS**) is composed of two sets of learnable parameters, controller LSTM *θ* and the shared parameters *ω*. These two parameters are alternatively trained and only trained controller is used to derive novel architectures. | ||
|
||
### 1. Discovering Recurrent Cells | ||
|
||
![rnn](./assets/rnn.png) | ||
|
||
Controller LSTM decide 1) what activation function to use and 2) which previous node to connect. | ||
|
||
The RNN cell **ENAS** discovered for `Penn Treebank` and `WikiText-2` dataset: | ||
|
||
<img src="assets/ptb.gif" alt="ptb" width="45%"> <img src="assets/wikitext.gif" alt="wikitext" width="45%"> | ||
|
||
You can see the details of training (e.g. `reward`, `entropy`, `loss`) with: | ||
|
||
tensorboard --logdir=logs --port=6006 | ||
|
||
![training](assets/training.png) | ||
|
||
|
||
### 2. Discovering Convolutional Neural Networks | ||
|
||
![cnn](./assets/cnn.png) | ||
|
||
Controller LSTM samples 1) what computation operation to use and 2) which previous node to connect. | ||
|
||
The CNN network **ENAS** discovered for `CIFAR-10` dataset: | ||
|
||
(in progress) | ||
|
||
|
||
### 3. Designing Convolutional Cells | ||
|
||
(in progress) | ||
|
||
|
||
## Reference | ||
|
||
- [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578) | ||
- [Neural Optimizer Search with Reinforcement Learning](https://arxiv.org/abs/1709.07417) | ||
|
||
|
||
## Author | ||
|
||
Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
from utils import get_logger | ||
|
||
logger = get_logger() | ||
|
||
|
||
arg_lists = [] | ||
parser = argparse.ArgumentParser() | ||
|
||
def str2bool(v): | ||
return v.lower() in ('true') | ||
|
||
def add_argument_group(name): | ||
arg = parser.add_argument_group(name) | ||
arg_lists.append(arg) | ||
return arg | ||
|
||
# Network | ||
net_arg = add_argument_group('Network') | ||
net_arg.add_argument('--network_type', type=str, choices=['rnn', 'cnn'], default='rnn') | ||
|
||
# Controller | ||
net_arg.add_argument('--num_blocks', type=int, default=12) | ||
net_arg.add_argument('--tie_weights', type=str2bool, default=True) | ||
net_arg.add_argument('--controller_hid', type=int, default=100) | ||
|
||
# Shared parameters for PTB | ||
net_arg.add_argument('--shared_dropout', type=float, default=0.4) # TODO | ||
net_arg.add_argument('--shared_dropoute', type=float, default=0.1) # TODO | ||
net_arg.add_argument('--shared_dropouti', type=float, default=0.65) # TODO | ||
net_arg.add_argument('--shared_embed', type=int, default=1000) | ||
net_arg.add_argument('--shared_hid', type=int, default=1000) | ||
net_arg.add_argument('--shared_rnn_max_length', type=int, default=35) | ||
net_arg.add_argument('--shared_rnn_activations', type=eval, | ||
default="['tanh', 'ReLU', 'identity', 'sigmoid']") | ||
net_arg.add_argument('--shared_cnn_types', type=eval, | ||
default="['3x3', '5x5', 'sep 3x3', 'sep 5x5', 'max 3x3', 'max 5x5']") | ||
|
||
# Shared parameters for CIFAR | ||
net_arg.add_argument('--cnn_hid', type=int, default=64) | ||
|
||
|
||
# Data | ||
data_arg = add_argument_group('Data') | ||
data_arg.add_argument('--dataset', type=str, default='ptb') | ||
|
||
|
||
# Training / test parameters | ||
learn_arg = add_argument_group('Learning') | ||
learn_arg.add_argument('--mode', type=str, default='train', | ||
choices=['train', 'derive', 'test'], | ||
help='train: Training ENAS, derive: Deriving Architectures') | ||
learn_arg.add_argument('--batch_size', type=int, default=64) | ||
learn_arg.add_argument('--test_batch_size', type=int, default=1) | ||
learn_arg.add_argument('--max_epoch', type=int, default=150) | ||
|
||
# Controller | ||
learn_arg.add_argument('--reward_c', type=int, default=80, | ||
help="WE DON'T KNOW WHAT THIS VALUE SHOULD BE") # TODO | ||
learn_arg.add_argument('--ema_baseline_decay', type=float, default=0.9) # TODO | ||
learn_arg.add_argument('--discount', type=float, default=0.95) # TODO | ||
learn_arg.add_argument('--controller_max_step', type=int, default=2000, | ||
help='step for controller parameters') | ||
learn_arg.add_argument('--controller_optim', type=str, default='adam') | ||
learn_arg.add_argument('--controller_lr', type=float, default=3.5e-4) | ||
learn_arg.add_argument('--tanh_c', type=float, default=2.5) | ||
learn_arg.add_argument('--softmax_temperature', type=float, default=5.0) | ||
learn_arg.add_argument('--entropy_coeff', type=float, default=1e-4) | ||
|
||
# Shared parameters | ||
learn_arg.add_argument('--shared_max_step', type=int, default=400, | ||
help='step for shared parameters') | ||
learn_arg.add_argument('--shared_num_sample', type=int, default=1, | ||
help='# of Monte Carlo samples') | ||
learn_arg.add_argument('--shared_optim', type=str, default='sgd') | ||
learn_arg.add_argument('--shared_lr', type=float, default=20.0) | ||
learn_arg.add_argument('--shared_decay', type=float, default=0.96) | ||
learn_arg.add_argument('--shared_decay_after', type=float, default=15) | ||
learn_arg.add_argument('--shared_l2_reg', type=float, default=1e-7) | ||
learn_arg.add_argument('--shared_grad_clip', type=float, default=0.25) | ||
|
||
# Deriving Architectures | ||
learn_arg.add_argument('--derive_num_sample', type=int, default=100) | ||
|
||
|
||
# Misc | ||
misc_arg = add_argument_group('Misc') | ||
misc_arg.add_argument('--load_path', type=str, default='') | ||
misc_arg.add_argument('--log_step', type=int, default=50) | ||
misc_arg.add_argument('--save_epoch', type=int, default=1) | ||
misc_arg.add_argument('--max_save_num', type=int, default=5) | ||
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) | ||
misc_arg.add_argument('--log_dir', type=str, default='logs') | ||
misc_arg.add_argument('--data_dir', type=str, default='data') | ||
misc_arg.add_argument('--num_gpu', type=int, default=1) | ||
misc_arg.add_argument('--random_seed', type=int, default=12345) | ||
misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True) | ||
|
||
|
||
def get_args(): | ||
args, unparsed = parser.parse_known_args() | ||
if args.num_gpu > 0: | ||
setattr(args, 'cuda', True) | ||
if len(unparsed) > 1: | ||
logger.info(f"Unparsed args: {unparsed}") | ||
return args, unparsed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
import data.text | ||
import data.image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch as t | ||
import torchvision.datasets as datasets | ||
import torchvision.transforms as transforms | ||
|
||
|
||
class Image(object): | ||
def __init__(self, args): | ||
if args.datset == 'cifar10': | ||
Dataset = datasets.CIFAR10 | ||
elif args.datset == 'MNIST': | ||
Dataset = datasets.MNIST | ||
else: | ||
raise NotImplemented(f"Unknown dataset: {args.dataset}") | ||
|
||
self.train = t.utils.data.DataLoader( | ||
Dataset(root='./data', train=True, transform=transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomCrop(32, 4), | ||
transforms.ToTensor(), | ||
normalize, | ||
]), download=True), | ||
batch_size=args.batch_size, shuffle=True, | ||
num_workers=args.num_workers, pin_memory=True) | ||
|
||
self.valid = t.utils.data.DataLoader( | ||
Dataset(root='./data', train=False, transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize, | ||
])), | ||
batch_size=args.batch_size, shuffle=False, | ||
num_workers=args.num_workers, pin_memory=True) | ||
|
Oops, something went wrong.