Skip to content

Commit

Permalink
bump to most recent S4 repo
Browse files Browse the repository at this point in the history
  • Loading branch information
mlech26l committed Aug 10, 2022
1 parent afbe5de commit 8d07d19
Show file tree
Hide file tree
Showing 233 changed files with 14,693 additions and 3,770 deletions.
58 changes: 58 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
## Changelog

### 2022-08-03 - [V3.0]

#### Models and Features
- Updated version of S4 module, including new measures and theory from [[How to Train Your HiPPO](https://arxiv.org/abs/2206.12037)] (https://github.com/HazyResearch/state-spaces/issues/21, https://github.com/HazyResearch/state-spaces/issues/54)
- Complete version of S4D module from [[On the Parameterization and Initialization of Diagonal State Space Models](https://arxiv.org/abs/2206.11893)]
- [State forwarding](src/models/s4/README.md#state-forwarding) (https://github.com/HazyResearch/state-spaces/issues/49, https://github.com/HazyResearch/state-spaces/issues/56)
- Support for S4 variants including DSS and GSS ([documentation](src/models/s4/README.md#other-variants))

<!--
#### Compilation of additional resources
- Recommended resources for understanding S4-style models, including the [Simplifying S4 blog](https://hazyresearch.stanford.edu/blog/2022-06-11-simplifying-s4) ([code](https://github.com/HazyResearch/state-spaces/tree/simple/src/models/sequence/ss/s4_simple)) and a minimal pedagogical version of S4D ([code](src/models/s4/s4d.py))
- Tips & Tricks page for getting started with tuning S4
-->

#### Bug fixes and library compatibility issues
- PyTorch 1.11 had a [Dropout bug](https://github.com/pytorch/pytorch/issues/77081) which is now avoided with a custom Dropout implementation (https://github.com/HazyResearch/state-spaces/issues/42, https://github.com/HazyResearch/state-spaces/issues/22)
- Conjugated tensors API change in PyTorch 1.10 (https://github.com/HazyResearch/state-spaces/issues/35)

#### SaShiMi
- Release of Sashimi+DiffWave model (https://github.com/HazyResearch/state-spaces/issues/46). Can be found at [albertfgu/diffwave-sashimi](https://github.com/albertfgu/diffwave-sashimi)

#### Generation
- Improved generation script for any models trained using this repository (https://github.com/HazyResearch/state-spaces/issues/38)

#### Model Checkpoints
- Re-trained SaShiMi models with the latest version of S4 (https://github.com/HazyResearch/state-spaces/issues/37, https://github.com/HazyResearch/state-spaces/issues/32)
- New WikiText-103 checkpoint with generation functionality (https://github.com/HazyResearch/state-spaces/issues/5, https://github.com/HazyResearch/state-spaces/issues/19)

#### HiPPO
- Release of new [notebook](notebooks/hippo_function_approximation.ipynb) (and equivalent .py [file](src/models/hippo/visualizations.py)) visualizing HiPPO function reconstruction. Includes animations used in HTTYH, the Annotated S4D, and various S4 talks.

#### Experiments
- Improved configs for Long Range Arena reported in HTTYH and S4D papers
- New datasets and ablation experiments from the S4D paper

Note that there have been various refactors and miscellaneous changes which may affect results slightly, but results should be close and general trends should hold. Feel free to file an issue for any results which do not match the papers.

#### Documentation
- Reorganized the [README](README.md) and added much more [documentation](README.md#readmes) for using this codebase


### 2022-05-01 - [V2.1]
- Minor updates to S4 modules
- By default, S4 no longer requires installing Pykeops or a custom CUDA kernel.
- New S4D (S4-diagonal) standalone model found at `src/models/sequence/ss/standalone/s4d.py`. Simple variant using diagonal SSMs that recovers S4's performance on most tasks. Can be run with any existing experiment config with the additional flag `model/layer=s4d` on the command line.
- New [LRA configs](#long-range-arena-lra) for updated S4 code, with an average score of ~86

### 2022-02-27 - [V2]
Code release for SaShiMi audio model

### 2022-01-29 - [V1.1]
Added configs for time series datasets from the Informer paper (https://github.com/HazyResearch/state-spaces/issues/4)

### 2021-11-18 - [V1]
First release of this repository containing the S4 module and configs to reproduce sCIFAR, Speech Commands, Long Range Arena, and WikiText-103 results

154 changes: 154 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@

```
config.yaml Main config
model/ Instantiates a model backbone (see src/models/)
dataset/ Instantiates a datamodule (see src/dataloaders/)
loader/ Defines a PyTorch DataLoader
task/ Defines loss, metrics, optional encoder/decoder (see src/tasks/)
pipeline/ Combination of dataset/loader/task for convenience
optimizer/ Instantiates an optimizer
scheduler/ Instantiates a learning rate scheduler
trainer/ Flags for the PyTorch Lightning Trainer class
callbacks/ Misc options for the Trainer (see src/callbacks/)
experiment/ Defines a full experiment (combination of all of the above configs)
generate/ Additional flags used by the generate.py script
```

This README provides a brief overview of the organization of this configs folder. These configs are composed to define a full Hydra config for every experiment.

## Overview
The main config is found at `configs/config.yaml`, which is an example experiment for Permuted MNIST. Different combinations of flags can be overridden to define alternate experiments. The config files in this folder define useful combinations of flags that can be composed. Examples of full configs defining end-to-end experiments can be found in [experiment/](experiment/).

Flags can also be passed on the command line.

<!--
The end-to-end training pipeline can broken down into the following rough groups, where group XX is found under `configs/XX/`:
```
model: the sequence-to-sequence model backbone (e.g. a src.models.sequence.SequenceModel)
dataset: the raw dataset (data/target pairs) (e.g. a pytorch Dataset)
loader: how the data is loaded (e.g. a pytorch DataLoader)
encoder: defines a Module that interfaces between data and model backbone
decoder: defines a Module that interfaces between model backbone and targets
task: specifies loss and metrics
```
Default combinations of dataset+loader+encoder+decoder+task are further consolidated into groups called `pipelines`.
-->

## Helpful Tips

### Inspect the Config
- At the beginning of running `train.py`, the full Hydra config is printed. This is very useful for making sure all flags were passed in as intended. Try running `python -m train` and inspecting the full base config.

### Class Instantiation
- Generally, most dictionaries in the config correspond exactly to the arguments passed into a Python class. For example, the configs in `model/`, `dataset/`, `loader/`, `optimizer/`, `scheduler/`, `trainer/` define dictionaries which each instantiate exactly one object (a PyTorch `nn.Module`, `SequenceDataset`, PyTorch [DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), PyTorch optimizer, PyTorch scheduler, and [PyTorch LightningTrainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html)).

### Registries
- Instantiating objects is controlled by the very useful Hydra [instantiate](https://hydra.cc/docs/advanced/instantiate_objects/overview/) utility.
- In this codebase, instead of defining a `_target_=<path>.<to>.<module>`, we use shorthand names for each desired class (wherever a `_name_` attribute appears). The file `src/utils/registry.py` lists these shorthand names found in these configs to the full class path.

### Source Code Documentation
- Check READMEs for the source code. For example, the configs in [configs/model](model) correspond to classes in [src/models](../src/models), the configs in [configs/dataset](dataset) correspond to classes in [src/dataloaders](../src/dataloaders).

<!--
It is recommended to read the overview in `src/README.md` to fully understand how models, datasets, tasks, and pipelines are put together.
-->


## Example
```
configs/optimizer/adamw.yaml
_name_: adamw
lr: 0.001
weight_decay: 0.00
```

1. When composed into a larger config, this should define a dictionary under the corresponding sub-config name. For example, the config printed by `python -m train optimizer=adamw optimizer.weight_decay=0.1` includes the following dictionary, confirming that the flags were passed in correctly.
```
├── optimizer
│ └── _name_: adamw
│ lr: 0.001
│ weight_decay: 0.1
```

2. The file `src/utils/registry.py` includes an `optimizer` dictionary mapping `adamw: torch.optim.AdamW`.

3. The full optimizer config is equivalent to instantiating `torch.optim.AdamW(lr=0.001, weight_decay=0.1)`

## Models

The `model/` configs correspond to modules in `src/models/`.
See `model/README.md`.

## Datasets

The `dataset/` configs correspond to modules in `src/dataloaders/`.

## Loader

`loader/` configs are used to instantiate a dataloader such as PyTorch's `torch.utils.data.DataLoader`.
Other configs correspond to extensions of this found in the source file `src/dataloaders/base.py`, for example dataloaders that allow sampling the data at different resolutions.

## Tasks

A task is something like "multiclass classification" or "regression", and defines *how the model interfaces with the data*.
A task defines the loss function and additional metrics, and an optional encoder and decoder.
These configs correspond to modules in `src/tasks/`.

### Encoder/Decoder

The encoder is the interface between the input data and model backbone. It defines how the input data is transformed before being fed to the model.

The decoder is the interface between the model backbone and target data. It defines how the model's outputs are transformed so that the task's loss and metrics can be calculated on it.


## Optimizer/Scheduler

`optimizer/` and `scheduler/` configs are used to instantiate an optimizer class and scheduler class respectively.


## Pipeline
A pipeline consists of a dataset + loader + encoder + decoder + task (and optionally optimizer+scheduler).
This is sometimes what people refer to as a "task", such as the "CIFAR-10 classification task".
A pipeline fully defines a training scheme; combining a pipeline with a model specifies an end-to-end experiment.

Overally, a pipeline fully defines a training experiment aside from the model backbone. This means *any pipeline* can be flexibly combined with *any* model backbone to define an experiment, regardless of the dimensions of the data and model, e.g.: `python -m train pipeline=cifar model=transformer`.

### Example: sCIFAR

```
defaults:
- /trainer: default
- /loader: default
- /dataset: cifar
- /task: multiclass_classification
- /optimizer: adamw
- /scheduler: plateau
train:
monitor: val/accuracy # Needed for plateau scheduler
mode: max
encoder: linear
decoder:
_name_: sequence
mode: pool
```

1. The `trainer/default.yaml` and `loader/default.yaml` specify a basic PyTorch Lightning trainer and PyTorch DataLoader

2. The `dataset/cifar.yaml` defines a dataset object that specifies data and target pairs. In this case, the data has shape `(batch size, 1024, 1)` and the target has shape `(batch size,)` which are class IDs from 0-9.

3. The model is not part of the pipeline; any model can be combined with this pipeline as long as it maps shape `(batch size, 1024, d_input) -> (batch size, 1024, d_output)`

4. The task consists of an encoder, decoder, loss function, and metrics. The `encoder` interfaces between the input data and model backbone; this example specifies that the data will pass through an `nn.Linear` mapping dimension the data from `(batch size, 1024, 1) -> (batch size, 1024, d_input)`. The `decoder` will map the model's outputs from `(batch size, 1024, d_output) -> (batch size,)` by pooling over the sequence length and passing through another `nn.Linear`. Finally, the `multiclass_classification` task defines a cross entropy loss and Accuracy metric.

5. This pipeline also defines a target optimizer and scheduler, which are optional.


## Experiment

An experiment combines every type of config into a complete end-to-end experiment.
Generally, this consists of a pipeline and model, together with training details such as the optimizer and scheduler.
See [experiment/README.md](experiment/).
6 changes: 6 additions & 0 deletions configs/callbacks/progressive_resizing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
progressive_resizing:
stage_params:
- resolution: null
epochs: null
- resolution: null
epochs: null
50 changes: 32 additions & 18 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ defaults:
# - model: s4 # Model backbone
# - pipeline: cifar # Specifies collection of configs, equivalent to next 5 lines
# Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order)
# # - loader: torch # Dataloader (e.g. handles batches)
# # - loader: default # Dataloader (e.g. handles batches)
# # - dataset: cifar # Defines the data (x and y pairs)
# # - task: multiclass_classification # Defines loss and metrics
# # - encoder: null # Interface between data and model
# # - decoder: null # Interface between model and targets
- callbacks:
- base
- checkpoint
- callbacks: # Extra pytorch-lightning features
- base
- checkpoint

# Additional arguments used to configure the training loop
# Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer
Expand All @@ -27,26 +27,36 @@ train:
test: False # Test after training
debug: False # Special settings to make debugging more convenient
ignore_warnings: False # Disable python warnings
# These control state

# These control state passing between batches
state:
mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ]
chunk_len: null # [ None | int ] chunk length for tbptt (used by TBPTTDataLoader)
overlap_len: null # [ None | int ] overlap length for tbptt (used by TBPTTDataLoader)
n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context
n_context_eval: ${.n_context}
n_context_eval: ${.n_context} # Context at evaluation time
# Convenience keys to allow grouping runs
sweep: null
group: null

benchmark_step: False # Whether to benchmark the step function
benchmark_step_k: 1 # Multipler for loader.batch_size when benchmarking step function with large batch sizes than the dataset
benchmark_step_T: 1 # Number of additional repeats to benchmark the step function
checkpoint_path: null # Path to checkpoint file: only used for visualization at the moment
visualizer: 'filters' # Which visualizer to use: [ 'filters' | 'forecasting' ]

ckpt: null # Resume training

disable_dataset: False # Disable dataset loading
validate_at_start: false

pretrained_model_path: null # Path to pretrained model
pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible
pretrained_model_state_hook: # Hook called on the loaded model's state_dict
_name_: null
post_init_hook: # After initializing model, call method on model
_name_: null

# We primarily use wandb so this is moved to top level for convenience
# Set ~wandb or wandb=null or wandb.mode=disabled to disable logging
layer_decay: # Used for ImageNet finetuning
_name_: null
decay: 0.7

tolerance: # fault tolerance for training on preemptible machines
logdir: ./resume
id: null # must be set to resume training on preemption

# We primarily use wandb so this is moved to top level in the config for convenience
# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging
# If other loggers are added, it would make sense to put this one level lower under train/ or logger/
wandb:
project: hippo
Expand All @@ -61,3 +71,7 @@ wandb:
# prefix: ""
# job_type: "train"
# tags: []

hydra:
run:
dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}
2 changes: 1 addition & 1 deletion configs/dataset/beethoven.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ quantization: linear
drop_last: true
context_len: null
pad_len: null
__l_max: ${.sample_len}
__l_max: ${.sample_len}
2 changes: 1 addition & 1 deletion configs/dataset/cifar.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ cutout: False
random_erasing: False
val_split: 0.1
seed: 42 # For validation split
__l_max: 1024
# __l_max: 1024
12 changes: 8 additions & 4 deletions configs/dataset/copying.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ _name_: copying
l_noise: 100 # length
l_memorize: 10 # number of tokens to memorize
n_tokens: 10 # alphabet size
variable: False # Randomly distribute memorization tokens throughout sequence instead of frontloading them
n_samples: 50000
variable: false # Randomly distribute memorization tokens throughout sequence instead of frontloading them
n_train: 10000 # Training samples per epoch (random)
n_eval: 1000 # Evaluation samples per epoch (fixed)
one_hot: false
static: false
lag: false
# test_samples: 5000
val_split: 0.1
__l_max: ${eval:${.l_noise} + 2*${.l_memorize}}
# val_split: 0.1
__l_max: null # ${eval:${.l_noise} + 2*${.l_memorize}}
11 changes: 11 additions & 0 deletions configs/dataset/delay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_name_: delay
l_seq: 4000 # length of total sequence
n_lag: 1 # number of lags to copy
l_lag: 1000 # length of lag (default to l_seq / n_lag)
dt: 0.00025
freq: 1000.0
seed: 0
static: False # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch
n_train: 16384 # Training samples per epoch (random)
n_eval: 1024 # Evaluation samples per epoch (fixed)
__l_max: ${.l_seq}
7 changes: 7 additions & 0 deletions configs/dataset/ljspeech.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_name_: ljspeech
bits: 8
sample_len: 161792 # (10.1s @ 16kHz rounded to nearest multiple of 1024)
train_percentage: 0.88
quantization: mu-law
use_text: false
__l_max: ${.sample_len}
10 changes: 10 additions & 0 deletions configs/dataset/qautomusic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_name_: qautoaudio
path: music_data
bits: 8
sample_len: 131072
train_percentage: 0.88
quantization: mu-law
drop_last: true
context_len: null
pad_len: null
__l_max: ${.sample_len}
10 changes: 10 additions & 0 deletions configs/dataset/reconstruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_name_: reconstruct
l_seq: 4000 # length of total sequence
l_mem: 1000 # length to reconstruct
dt: 0.001
freq: 100.0
seed: 0
static: False # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch
n_train: 16384 # Training samples per epoch (random)
n_eval: 1024 # Evaluation samples per epoch (fixed)
__l_max: ${.l_seq}
2 changes: 1 addition & 1 deletion configs/dataset/sc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ _name_: sc
mfcc: False
dropped_rate: 0.
length: 16000
all_classes: false # Use original dataset or 10-way version
all_classes: true # Use original dataset or 10-way version
__l_max: ${.length}
5 changes: 5 additions & 0 deletions configs/dataset/sc10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defaults:
- sc

all_classes: false

Loading

0 comments on commit 8d07d19

Please sign in to comment.