Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Feb 15, 2018
0 parents commit 51fbb9a
Show file tree
Hide file tree
Showing 31 changed files with 95,585 additions and 0 deletions.
126 changes: 126 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Data
*.png
*.gif
*.tar.gz
data/cifar-10-batches-py

# ipython checkpoints
.ipynb_checkpoints

# Log
logs

# ETC
.vscode

# Created by https://www.gitignore.io/api/python,vim

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# dotenv
.env

# virtualenv
.venv/
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject


### Vim ###
# swap
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-v][a-z]
[._]sw[a-p]
# session
Session.vim
# temporary
.netrwhist
*~
# auto-generated tag files
tags

# End of https://www.gitignore.io/api/python,vim
108 changes: 108 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Efficient Neural Architecture Search (ENAS) in PyTorch

PyTorch implementation of [Efficient Neural Architecture Search via Parameters Sharing](https://arxiv.org/abs/1802.03268).

<p align="center"><img src="assets/ENAS_rnn.png" alt="ENAS_rnn" width="60%"></p>

**ENAS** reduce the computational requirement (GPU-hours) of [Neural Architecture Search](https://arxiv.org/abs/1611.01578) (**NAS**) by 1000x via parameter sharing between models that are subgraphs within a large computational graph. SOTA on `Penn Treebank` language modeling.


## Prerequisites

- Python 3.6+
- [PyTorch](http://pytorch.org/)
- tqdm, scipy, imageio, graphviz, tensorboardX

## Usage

Install prerequisites with:

conda install graphviz
pip install -r requirements.txt

To train **ENAS** to discover a recurrent cell for RNN:

python main.py --network_type rnn --dataset ptb --controller_optim adam --controller_lr 0.00035 \
--shared_optim sgd --shared_lr 20.0 --entropy_coeff 0.0001

python main.py --network_type rnn --dataset wikitext

To train **ENAS** to discover CNN architecture (in progress):

python main.py --network_type cnn --dataset cifar --controller_optim momentum --controller_lr cosine \
--controller_lr_max 0.05 --controller_lr_min 0.0001 --entropy_coeff 0.1

or you can use your own dataset by placing images like:

data
├── YOUR_TEXT_DATASET
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── YOUR_IMAGE_DATASET
│ ├── test
│ │ ├── xxx.jpg (name doesn't matter)
│ │ ├── yyy.jpg (name doesn't matter)
│ │ └── ...
│ ├── train
│ │ ├── xxx.jpg
│ │ └── ...
│ └── valid
│ ├── xxx.jpg
│ └── ...
├── image.py
└── text.py

To generate `gif` image of generated samples:

python generate_gif.py --model_name=ptb_2018-02-15_11-20-02 --output=sample.gif

More configurations can be found [here](config.py).


## Results

Efficient Neural Architecture Search (**ENAS**) is composed of two sets of learnable parameters, controller LSTM *θ* and the shared parameters *ω*. These two parameters are alternatively trained and only trained controller is used to derive novel architectures.

### 1. Discovering Recurrent Cells

![rnn](./assets/rnn.png)

Controller LSTM decide 1) what activation function to use and 2) which previous node to connect.

The RNN cell **ENAS** discovered for `Penn Treebank` and `WikiText-2` dataset:

<img src="assets/ptb.gif" alt="ptb" width="45%"> <img src="assets/wikitext.gif" alt="wikitext" width="45%">

You can see the details of training (e.g. `reward`, `entropy`, `loss`) with:

tensorboard --logdir=logs --port=6006

![training](assets/training.png)


### 2. Discovering Convolutional Neural Networks

![cnn](./assets/cnn.png)

Controller LSTM samples 1) what computation operation to use and 2) which previous node to connect.

The CNN network **ENAS** discovered for `CIFAR-10` dataset:

(in progress)


### 3. Designing Convolutional Cells

(in progress)


## Reference

- [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578)
- [Neural Optimizer Search with Reinforcement Learning](https://arxiv.org/abs/1709.07417)


## Author

Taehoon Kim / [@carpedm20](http://carpedm20.github.io/)
Binary file added assets/ENAS_cnn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/ENAS_rnn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/arial.ttf
Binary file not shown.
Binary file added assets/cnn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/cnn_cell.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/ptb.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/rnn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/training.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/wikitext.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
106 changes: 106 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import argparse
from utils import get_logger

logger = get_logger()


arg_lists = []
parser = argparse.ArgumentParser()

def str2bool(v):
return v.lower() in ('true')

def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg

# Network
net_arg = add_argument_group('Network')
net_arg.add_argument('--network_type', type=str, choices=['rnn', 'cnn'], default='rnn')

# Controller
net_arg.add_argument('--num_blocks', type=int, default=12)
net_arg.add_argument('--tie_weights', type=str2bool, default=True)
net_arg.add_argument('--controller_hid', type=int, default=100)

# Shared parameters for PTB
net_arg.add_argument('--shared_dropout', type=float, default=0.4) # TODO
net_arg.add_argument('--shared_dropoute', type=float, default=0.1) # TODO
net_arg.add_argument('--shared_dropouti', type=float, default=0.65) # TODO
net_arg.add_argument('--shared_embed', type=int, default=1000)
net_arg.add_argument('--shared_hid', type=int, default=1000)
net_arg.add_argument('--shared_rnn_max_length', type=int, default=35)
net_arg.add_argument('--shared_rnn_activations', type=eval,
default="['tanh', 'ReLU', 'identity', 'sigmoid']")
net_arg.add_argument('--shared_cnn_types', type=eval,
default="['3x3', '5x5', 'sep 3x3', 'sep 5x5', 'max 3x3', 'max 5x5']")

# Shared parameters for CIFAR
net_arg.add_argument('--cnn_hid', type=int, default=64)


# Data
data_arg = add_argument_group('Data')
data_arg.add_argument('--dataset', type=str, default='ptb')


# Training / test parameters
learn_arg = add_argument_group('Learning')
learn_arg.add_argument('--mode', type=str, default='train',
choices=['train', 'derive', 'test'],
help='train: Training ENAS, derive: Deriving Architectures')
learn_arg.add_argument('--batch_size', type=int, default=64)
learn_arg.add_argument('--test_batch_size', type=int, default=1)
learn_arg.add_argument('--max_epoch', type=int, default=150)

# Controller
learn_arg.add_argument('--reward_c', type=int, default=80,
help="WE DON'T KNOW WHAT THIS VALUE SHOULD BE") # TODO
learn_arg.add_argument('--ema_baseline_decay', type=float, default=0.9) # TODO
learn_arg.add_argument('--discount', type=float, default=0.95) # TODO
learn_arg.add_argument('--controller_max_step', type=int, default=2000,
help='step for controller parameters')
learn_arg.add_argument('--controller_optim', type=str, default='adam')
learn_arg.add_argument('--controller_lr', type=float, default=3.5e-4)
learn_arg.add_argument('--tanh_c', type=float, default=2.5)
learn_arg.add_argument('--softmax_temperature', type=float, default=5.0)
learn_arg.add_argument('--entropy_coeff', type=float, default=1e-4)

# Shared parameters
learn_arg.add_argument('--shared_max_step', type=int, default=400,
help='step for shared parameters')
learn_arg.add_argument('--shared_num_sample', type=int, default=1,
help='# of Monte Carlo samples')
learn_arg.add_argument('--shared_optim', type=str, default='sgd')
learn_arg.add_argument('--shared_lr', type=float, default=20.0)
learn_arg.add_argument('--shared_decay', type=float, default=0.96)
learn_arg.add_argument('--shared_decay_after', type=float, default=15)
learn_arg.add_argument('--shared_l2_reg', type=float, default=1e-7)
learn_arg.add_argument('--shared_grad_clip', type=float, default=0.25)

# Deriving Architectures
learn_arg.add_argument('--derive_num_sample', type=int, default=100)


# Misc
misc_arg = add_argument_group('Misc')
misc_arg.add_argument('--load_path', type=str, default='')
misc_arg.add_argument('--log_step', type=int, default=50)
misc_arg.add_argument('--save_epoch', type=int, default=1)
misc_arg.add_argument('--max_save_num', type=int, default=5)
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
misc_arg.add_argument('--log_dir', type=str, default='logs')
misc_arg.add_argument('--data_dir', type=str, default='data')
misc_arg.add_argument('--num_gpu', type=int, default=1)
misc_arg.add_argument('--random_seed', type=int, default=12345)
misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True)


def get_args():
args, unparsed = parser.parse_known_args()
if args.num_gpu > 0:
setattr(args, 'cuda', True)
if len(unparsed) > 1:
logger.info(f"Unparsed args: {unparsed}")
return args, unparsed
2 changes: 2 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import data.text
import data.image
32 changes: 32 additions & 0 deletions data/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch as t
import torchvision.datasets as datasets
import torchvision.transforms as transforms


class Image(object):
def __init__(self, args):
if args.datset == 'cifar10':
Dataset = datasets.CIFAR10
elif args.datset == 'MNIST':
Dataset = datasets.MNIST
else:
raise NotImplemented(f"Unknown dataset: {args.dataset}")

self.train = t.utils.data.DataLoader(
Dataset(root='./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers, pin_memory=True)

self.valid = t.utils.data.DataLoader(
Dataset(root='./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, pin_memory=True)

Loading

0 comments on commit 51fbb9a

Please sign in to comment.