diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c69f656 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,58 @@ +## Changelog + +### 2022-08-03 - [V3.0] + +#### Models and Features +- Updated version of S4 module, including new measures and theory from [[How to Train Your HiPPO](https://arxiv.org/abs/2206.12037)] (https://github.com/HazyResearch/state-spaces/issues/21, https://github.com/HazyResearch/state-spaces/issues/54) +- Complete version of S4D module from [[On the Parameterization and Initialization of Diagonal State Space Models](https://arxiv.org/abs/2206.11893)] +- [State forwarding](src/models/s4/README.md#state-forwarding) (https://github.com/HazyResearch/state-spaces/issues/49, https://github.com/HazyResearch/state-spaces/issues/56) +- Support for S4 variants including DSS and GSS ([documentation](src/models/s4/README.md#other-variants)) + + + +#### Bug fixes and library compatibility issues +- PyTorch 1.11 had a [Dropout bug](https://github.com/pytorch/pytorch/issues/77081) which is now avoided with a custom Dropout implementation (https://github.com/HazyResearch/state-spaces/issues/42, https://github.com/HazyResearch/state-spaces/issues/22) +- Conjugated tensors API change in PyTorch 1.10 (https://github.com/HazyResearch/state-spaces/issues/35) + +#### SaShiMi +- Release of Sashimi+DiffWave model (https://github.com/HazyResearch/state-spaces/issues/46). Can be found at [albertfgu/diffwave-sashimi](https://github.com/albertfgu/diffwave-sashimi) + +#### Generation +- Improved generation script for any models trained using this repository (https://github.com/HazyResearch/state-spaces/issues/38) + +#### Model Checkpoints +- Re-trained SaShiMi models with the latest version of S4 (https://github.com/HazyResearch/state-spaces/issues/37, https://github.com/HazyResearch/state-spaces/issues/32) +- New WikiText-103 checkpoint with generation functionality (https://github.com/HazyResearch/state-spaces/issues/5, https://github.com/HazyResearch/state-spaces/issues/19) + +#### HiPPO +- Release of new [notebook](notebooks/hippo_function_approximation.ipynb) (and equivalent .py [file](src/models/hippo/visualizations.py)) visualizing HiPPO function reconstruction. Includes animations used in HTTYH, the Annotated S4D, and various S4 talks. + +#### Experiments +- Improved configs for Long Range Arena reported in HTTYH and S4D papers +- New datasets and ablation experiments from the S4D paper + +Note that there have been various refactors and miscellaneous changes which may affect results slightly, but results should be close and general trends should hold. Feel free to file an issue for any results which do not match the papers. + +#### Documentation +- Reorganized the [README](README.md) and added much more [documentation](README.md#readmes) for using this codebase + + +### 2022-05-01 - [V2.1] +- Minor updates to S4 modules +- By default, S4 no longer requires installing Pykeops or a custom CUDA kernel. +- New S4D (S4-diagonal) standalone model found at `src/models/sequence/ss/standalone/s4d.py`. Simple variant using diagonal SSMs that recovers S4's performance on most tasks. Can be run with any existing experiment config with the additional flag `model/layer=s4d` on the command line. +- New [LRA configs](#long-range-arena-lra) for updated S4 code, with an average score of ~86 + +### 2022-02-27 - [V2] +Code release for SaShiMi audio model + +### 2022-01-29 - [V1.1] +Added configs for time series datasets from the Informer paper (https://github.com/HazyResearch/state-spaces/issues/4) + +### 2021-11-18 - [V1] +First release of this repository containing the S4 module and configs to reproduce sCIFAR, Speech Commands, Long Range Arena, and WikiText-103 results + diff --git a/configs/README.md b/configs/README.md new file mode 100644 index 0000000..e08b167 --- /dev/null +++ b/configs/README.md @@ -0,0 +1,154 @@ + +``` +config.yaml Main config +model/ Instantiates a model backbone (see src/models/) +dataset/ Instantiates a datamodule (see src/dataloaders/) +loader/ Defines a PyTorch DataLoader +task/ Defines loss, metrics, optional encoder/decoder (see src/tasks/) +pipeline/ Combination of dataset/loader/task for convenience +optimizer/ Instantiates an optimizer +scheduler/ Instantiates a learning rate scheduler +trainer/ Flags for the PyTorch Lightning Trainer class +callbacks/ Misc options for the Trainer (see src/callbacks/) +experiment/ Defines a full experiment (combination of all of the above configs) +generate/ Additional flags used by the generate.py script +``` + +This README provides a brief overview of the organization of this configs folder. These configs are composed to define a full Hydra config for every experiment. + +## Overview +The main config is found at `configs/config.yaml`, which is an example experiment for Permuted MNIST. Different combinations of flags can be overridden to define alternate experiments. The config files in this folder define useful combinations of flags that can be composed. Examples of full configs defining end-to-end experiments can be found in [experiment/](experiment/). + +Flags can also be passed on the command line. + + + +## Helpful Tips + +### Inspect the Config +- At the beginning of running `train.py`, the full Hydra config is printed. This is very useful for making sure all flags were passed in as intended. Try running `python -m train` and inspecting the full base config. + +### Class Instantiation +- Generally, most dictionaries in the config correspond exactly to the arguments passed into a Python class. For example, the configs in `model/`, `dataset/`, `loader/`, `optimizer/`, `scheduler/`, `trainer/` define dictionaries which each instantiate exactly one object (a PyTorch `nn.Module`, `SequenceDataset`, PyTorch [DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), PyTorch optimizer, PyTorch scheduler, and [PyTorch LightningTrainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html)). + +### Registries +- Instantiating objects is controlled by the very useful Hydra [instantiate](https://hydra.cc/docs/advanced/instantiate_objects/overview/) utility. +- In this codebase, instead of defining a `_target_=..`, we use shorthand names for each desired class (wherever a `_name_` attribute appears). The file `src/utils/registry.py` lists these shorthand names found in these configs to the full class path. + +### Source Code Documentation +- Check READMEs for the source code. For example, the configs in [configs/model](model) correspond to classes in [src/models](../src/models), the configs in [configs/dataset](dataset) correspond to classes in [src/dataloaders](../src/dataloaders). + + + + +## Example +``` +configs/optimizer/adamw.yaml + +_name_: adamw +lr: 0.001 +weight_decay: 0.00 +``` + +1. When composed into a larger config, this should define a dictionary under the corresponding sub-config name. For example, the config printed by `python -m train optimizer=adamw optimizer.weight_decay=0.1` includes the following dictionary, confirming that the flags were passed in correctly. +``` +├── optimizer +│ └── _name_: adamw +│ lr: 0.001 +│ weight_decay: 0.1 +``` + +2. The file `src/utils/registry.py` includes an `optimizer` dictionary mapping `adamw: torch.optim.AdamW`. + +3. The full optimizer config is equivalent to instantiating `torch.optim.AdamW(lr=0.001, weight_decay=0.1)` + +## Models + +The `model/` configs correspond to modules in `src/models/`. +See `model/README.md`. + +## Datasets + +The `dataset/` configs correspond to modules in `src/dataloaders/`. + +## Loader + +`loader/` configs are used to instantiate a dataloader such as PyTorch's `torch.utils.data.DataLoader`. +Other configs correspond to extensions of this found in the source file `src/dataloaders/base.py`, for example dataloaders that allow sampling the data at different resolutions. + +## Tasks + +A task is something like "multiclass classification" or "regression", and defines *how the model interfaces with the data*. +A task defines the loss function and additional metrics, and an optional encoder and decoder. +These configs correspond to modules in `src/tasks/`. + +### Encoder/Decoder + +The encoder is the interface between the input data and model backbone. It defines how the input data is transformed before being fed to the model. + +The decoder is the interface between the model backbone and target data. It defines how the model's outputs are transformed so that the task's loss and metrics can be calculated on it. + + +## Optimizer/Scheduler + +`optimizer/` and `scheduler/` configs are used to instantiate an optimizer class and scheduler class respectively. + + +## Pipeline +A pipeline consists of a dataset + loader + encoder + decoder + task (and optionally optimizer+scheduler). +This is sometimes what people refer to as a "task", such as the "CIFAR-10 classification task". +A pipeline fully defines a training scheme; combining a pipeline with a model specifies an end-to-end experiment. + +Overally, a pipeline fully defines a training experiment aside from the model backbone. This means *any pipeline* can be flexibly combined with *any* model backbone to define an experiment, regardless of the dimensions of the data and model, e.g.: `python -m train pipeline=cifar model=transformer`. + +### Example: sCIFAR + +``` +defaults: + - /trainer: default + - /loader: default + - /dataset: cifar + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + +train: + monitor: val/accuracy # Needed for plateau scheduler + mode: max + +encoder: linear + +decoder: + _name_: sequence + mode: pool +``` + +1. The `trainer/default.yaml` and `loader/default.yaml` specify a basic PyTorch Lightning trainer and PyTorch DataLoader + +2. The `dataset/cifar.yaml` defines a dataset object that specifies data and target pairs. In this case, the data has shape `(batch size, 1024, 1)` and the target has shape `(batch size,)` which are class IDs from 0-9. + +3. The model is not part of the pipeline; any model can be combined with this pipeline as long as it maps shape `(batch size, 1024, d_input) -> (batch size, 1024, d_output)` + +4. The task consists of an encoder, decoder, loss function, and metrics. The `encoder` interfaces between the input data and model backbone; this example specifies that the data will pass through an `nn.Linear` mapping dimension the data from `(batch size, 1024, 1) -> (batch size, 1024, d_input)`. The `decoder` will map the model's outputs from `(batch size, 1024, d_output) -> (batch size,)` by pooling over the sequence length and passing through another `nn.Linear`. Finally, the `multiclass_classification` task defines a cross entropy loss and Accuracy metric. + +5. This pipeline also defines a target optimizer and scheduler, which are optional. + + +## Experiment + +An experiment combines every type of config into a complete end-to-end experiment. +Generally, this consists of a pipeline and model, together with training details such as the optimizer and scheduler. +See [experiment/README.md](experiment/). diff --git a/configs/callbacks/progressive_resizing.yaml b/configs/callbacks/progressive_resizing.yaml new file mode 100644 index 0000000..6e06c6e --- /dev/null +++ b/configs/callbacks/progressive_resizing.yaml @@ -0,0 +1,6 @@ +progressive_resizing: + stage_params: + - resolution: null + epochs: null + - resolution: null + epochs: null diff --git a/configs/config.yaml b/configs/config.yaml index a1f615f..2da40ca 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -5,14 +5,14 @@ defaults: # - model: s4 # Model backbone # - pipeline: cifar # Specifies collection of configs, equivalent to next 5 lines # Pipelines should specify /loader, /dataset, /task, /encoder, /decoder (ideally in that order) - # # - loader: torch # Dataloader (e.g. handles batches) + # # - loader: default # Dataloader (e.g. handles batches) # # - dataset: cifar # Defines the data (x and y pairs) # # - task: multiclass_classification # Defines loss and metrics # # - encoder: null # Interface between data and model # # - decoder: null # Interface between model and targets - - callbacks: - - base - - checkpoint + - callbacks: # Extra pytorch-lightning features + - base + - checkpoint # Additional arguments used to configure the training loop # Most of these set combinations of options in the PL trainer, add callbacks, or add features to the optimizer @@ -27,26 +27,36 @@ train: test: False # Test after training debug: False # Special settings to make debugging more convenient ignore_warnings: False # Disable python warnings - # These control state + + # These control state passing between batches state: mode: null # [ None | 'none' | 'reset' | 'bptt' | 'tbptt' ] - chunk_len: null # [ None | int ] chunk length for tbptt (used by TBPTTDataLoader) - overlap_len: null # [ None | int ] overlap length for tbptt (used by TBPTTDataLoader) n_context: 0 # How many steps to use as memory context. Must be >= 0 or None (null), meaning infinite context - n_context_eval: ${.n_context} + n_context_eval: ${.n_context} # Context at evaluation time # Convenience keys to allow grouping runs - sweep: null - group: null - - benchmark_step: False # Whether to benchmark the step function - benchmark_step_k: 1 # Multipler for loader.batch_size when benchmarking step function with large batch sizes than the dataset - benchmark_step_T: 1 # Number of additional repeats to benchmark the step function - checkpoint_path: null # Path to checkpoint file: only used for visualization at the moment - visualizer: 'filters' # Which visualizer to use: [ 'filters' | 'forecasting' ] + + ckpt: null # Resume training + disable_dataset: False # Disable dataset loading + validate_at_start: false + + pretrained_model_path: null # Path to pretrained model + pretrained_model_strict_load: true # Whether to load the pretrained model even if the model is not compatible + pretrained_model_state_hook: # Hook called on the loaded model's state_dict + _name_: null + post_init_hook: # After initializing model, call method on model + _name_: null -# We primarily use wandb so this is moved to top level for convenience -# Set ~wandb or wandb=null or wandb.mode=disabled to disable logging + layer_decay: # Used for ImageNet finetuning + _name_: null + decay: 0.7 + +tolerance: # fault tolerance for training on preemptible machines + logdir: ./resume + id: null # must be set to resume training on preemption + +# We primarily use wandb so this is moved to top level in the config for convenience +# Set `~wandb` or `wandb=null` or `wandb.mode=disabled` to disable logging # If other loggers are added, it would make sense to put this one level lower under train/ or logger/ wandb: project: hippo @@ -61,3 +71,7 @@ wandb: # prefix: "" # job_type: "train" # tags: [] + +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f} diff --git a/configs/dataset/beethoven.yaml b/configs/dataset/beethoven.yaml index 409147f..0914b63 100644 --- a/configs/dataset/beethoven.yaml +++ b/configs/dataset/beethoven.yaml @@ -7,4 +7,4 @@ quantization: linear drop_last: true context_len: null pad_len: null -__l_max: ${.sample_len} \ No newline at end of file +__l_max: ${.sample_len} diff --git a/configs/dataset/cifar.yaml b/configs/dataset/cifar.yaml index 3bbc3a7..1911d5e 100644 --- a/configs/dataset/cifar.yaml +++ b/configs/dataset/cifar.yaml @@ -7,4 +7,4 @@ cutout: False random_erasing: False val_split: 0.1 seed: 42 # For validation split -__l_max: 1024 +# __l_max: 1024 diff --git a/configs/dataset/copying.yaml b/configs/dataset/copying.yaml index b3b029a..272917d 100644 --- a/configs/dataset/copying.yaml +++ b/configs/dataset/copying.yaml @@ -2,8 +2,12 @@ _name_: copying l_noise: 100 # length l_memorize: 10 # number of tokens to memorize n_tokens: 10 # alphabet size -variable: False # Randomly distribute memorization tokens throughout sequence instead of frontloading them -n_samples: 50000 +variable: false # Randomly distribute memorization tokens throughout sequence instead of frontloading them +n_train: 10000 # Training samples per epoch (random) +n_eval: 1000 # Evaluation samples per epoch (fixed) +one_hot: false +static: false +lag: false # test_samples: 5000 -val_split: 0.1 -__l_max: ${eval:${.l_noise} + 2*${.l_memorize}} +# val_split: 0.1 +__l_max: null # ${eval:${.l_noise} + 2*${.l_memorize}} diff --git a/configs/dataset/delay.yaml b/configs/dataset/delay.yaml new file mode 100644 index 0000000..7d43a5b --- /dev/null +++ b/configs/dataset/delay.yaml @@ -0,0 +1,11 @@ +_name_: delay +l_seq: 4000 # length of total sequence +n_lag: 1 # number of lags to copy +l_lag: 1000 # length of lag (default to l_seq / n_lag) +dt: 0.00025 +freq: 1000.0 +seed: 0 +static: False # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch +n_train: 16384 # Training samples per epoch (random) +n_eval: 1024 # Evaluation samples per epoch (fixed) +__l_max: ${.l_seq} diff --git a/configs/dataset/ljspeech.yaml b/configs/dataset/ljspeech.yaml new file mode 100644 index 0000000..a32c7d9 --- /dev/null +++ b/configs/dataset/ljspeech.yaml @@ -0,0 +1,7 @@ +_name_: ljspeech +bits: 8 +sample_len: 161792 # (10.1s @ 16kHz rounded to nearest multiple of 1024) +train_percentage: 0.88 +quantization: mu-law +use_text: false +__l_max: ${.sample_len} \ No newline at end of file diff --git a/configs/dataset/qautomusic.yaml b/configs/dataset/qautomusic.yaml new file mode 100644 index 0000000..551bf3c --- /dev/null +++ b/configs/dataset/qautomusic.yaml @@ -0,0 +1,10 @@ +_name_: qautoaudio +path: music_data +bits: 8 +sample_len: 131072 +train_percentage: 0.88 +quantization: mu-law +drop_last: true +context_len: null +pad_len: null +__l_max: ${.sample_len} diff --git a/configs/dataset/reconstruct.yaml b/configs/dataset/reconstruct.yaml new file mode 100644 index 0000000..11994ca --- /dev/null +++ b/configs/dataset/reconstruct.yaml @@ -0,0 +1,10 @@ +_name_: reconstruct +l_seq: 4000 # length of total sequence +l_mem: 1000 # length to reconstruct +dt: 0.001 +freq: 100.0 +seed: 0 +static: False # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch +n_train: 16384 # Training samples per epoch (random) +n_eval: 1024 # Evaluation samples per epoch (fixed) +__l_max: ${.l_seq} diff --git a/configs/dataset/sc.yaml b/configs/dataset/sc.yaml index 542bd6b..26b4f1d 100644 --- a/configs/dataset/sc.yaml +++ b/configs/dataset/sc.yaml @@ -2,5 +2,5 @@ _name_: sc mfcc: False dropped_rate: 0. length: 16000 -all_classes: false # Use original dataset or 10-way version +all_classes: true # Use original dataset or 10-way version __l_max: ${.length} diff --git a/configs/dataset/sc10.yaml b/configs/dataset/sc10.yaml new file mode 100644 index 0000000..839089c --- /dev/null +++ b/configs/dataset/sc10.yaml @@ -0,0 +1,5 @@ +defaults: + - sc + +all_classes: false + diff --git a/configs/dataset/weather.yaml b/configs/dataset/weather.yaml new file mode 100644 index 0000000..96386ec --- /dev/null +++ b/configs/dataset/weather.yaml @@ -0,0 +1,15 @@ +_name_: weather +size: + - 384 + - 96 + - 96 +features: 'S' +target: 'WetBulbCelsius' +variant: 0 +scale: True +inverse: False +timeenc: 0 +eval_stamp: False +eval_mask: False +# freq: 'h' +__l_max: ${eval:${.size.0}+${.size.2}} diff --git a/configs/dataset/youtubemix.yaml b/configs/dataset/youtubemix.yaml index e914450..41f5e93 100644 --- a/configs/dataset/youtubemix.yaml +++ b/configs/dataset/youtubemix.yaml @@ -7,4 +7,4 @@ quantization: mu-law drop_last: true context_len: null pad_len: null -__l_max: ${.sample_len} \ No newline at end of file +__l_max: ${.sample_len} diff --git a/configs/experiment/README.md b/configs/experiment/README.md new file mode 100644 index 0000000..e653dfd --- /dev/null +++ b/configs/experiment/README.md @@ -0,0 +1,17 @@ +``` +audio/ Audio datasets (all Sashimi experiments are here) +bidmc/ BIDMC time series regression datasets +cifar/ CIFAR-10, both sequential and 2d +convnext/ ImageNet + ConvNext variants (S4ND paper) +forecasting/ Monash, ARIMA synthetics, datasets from the Informer paper +lra/ Long Range Arena datasets +old/ Deprecated experiment configs (e.g. from original S4 paper that have been improved) +progres/ Progressive Resizing (S4ND paper) +sc/ Speech Commands variants +segmentation/ Segmentation experiments (preliminary) +synthetic/ Synthetic experiments (Copying, Delay, Reconstruct) +ts/ Time Series (EEG, impedance) +vision/ Assorted other image/video datasets (CelebA, ImageNet, HMDB51) +vit/ ViT experiments +wt103/ WikiText-103 experiments +``` diff --git a/configs/experiment/audio/samplernn-beethoven.yaml b/configs/experiment/audio/samplernn-beethoven.yaml new file mode 100644 index 0000000..6caed13 --- /dev/null +++ b/configs/experiment/audio/samplernn-beethoven.yaml @@ -0,0 +1,45 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: tbptt + - /dataset: beethoven + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: samplernn + +model: + bits: 8 + quantization: linear + n_rnn: 1 + frame_sizes: + - 8 + - 2 + - 2 + +train: + monitor: val/loss # Needed for plateau scheduler + mode: min + state: + mode: tbptt + +loader: + chunk_len: 1024 + overlap_len: 32 # this is model dependent (product of model.frame_sizes here) + batch_size: 128 + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + + +encoder: id +decoder: id + +trainer: + gradient_clip_val: 1.0 + gradient_clip_algorithm: value diff --git a/configs/experiment/audio/samplernn-qautomusic.yaml b/configs/experiment/audio/samplernn-qautomusic.yaml new file mode 100644 index 0000000..702952d --- /dev/null +++ b/configs/experiment/audio/samplernn-qautomusic.yaml @@ -0,0 +1,31 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: tbptt + - /dataset: qautomusic + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/samplernn + +train: + monitor: val/loss # Needed for plateau scheduler + mode: min + state: + mode: tbptt + +loader: + chunk_len: 1024 + overlap_len: 64 # this is model dependent + batch_size: 128 + +task: + metrics: bpb + +encoder: id +decoder: id + + +trainer: + gradient_clip_val: 1.0 + gradient_clip_algorithm: value diff --git a/configs/experiment/audio/samplernn-sc09.yaml b/configs/experiment/audio/samplernn-sc09.yaml new file mode 100644 index 0000000..25a112c --- /dev/null +++ b/configs/experiment/audio/samplernn-sc09.yaml @@ -0,0 +1,45 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: tbptt + - /dataset: sc09 + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/samplernn + +model: + bits: 8 + quantization: mu-law + n_rnn: 1 + frame_sizes: + - 8 + - 2 + - 2 + +train: + monitor: val/loss # Needed for plateau scheduler + mode: min + state: + mode: tbptt + +loader: + chunk_len: 1024 + overlap_len: 32 # this is model dependent (product of model.frame_sizes here) + batch_size: 128 + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: id +decoder: id + + +trainer: + gradient_clip_val: 1.0 + gradient_clip_algorithm: value diff --git a/configs/experiment/audio/samplernn-scg.yaml b/configs/experiment/audio/samplernn-scg.yaml new file mode 100644 index 0000000..e9d95cf --- /dev/null +++ b/configs/experiment/audio/samplernn-scg.yaml @@ -0,0 +1,38 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: tbptt + - /dataset: scg + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/samplernn + +dataset: + discrete_input: true + +model: + bits: 8 + quantization: mu-law + +train: + monitor: val/loss # Needed for plateau scheduler + mode: min + state: + mode: tbptt + +loader: + chunk_len: 1024 + overlap_len: 64 # this is model dependent + batch_size: 128 + +task: + metrics: bpb + +encoder: id +decoder: id + + +trainer: + gradient_clip_val: 1.0 + gradient_clip_algorithm: value diff --git a/configs/experiment/audio/samplernn-youtubemix.yaml b/configs/experiment/audio/samplernn-youtubemix.yaml new file mode 100644 index 0000000..e943006 --- /dev/null +++ b/configs/experiment/audio/samplernn-youtubemix.yaml @@ -0,0 +1,47 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: tbptt + - /dataset: youtubemix + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: samplernn + + +model: + bits: 8 + quantization: mu-law + n_rnn: 1 + frame_sizes: + - 8 + - 2 + - 2 + +train: + monitor: val/loss # Needed for plateau scheduler + mode: min + state: + mode: tbptt + +loader: + chunk_len: 1024 + overlap_len: 32 # this is model dependent (product of model.frame_sizes here) + batch_size: 128 + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + + +encoder: id +decoder: id + + +trainer: + gradient_clip_val: 1.0 + gradient_clip_algorithm: value diff --git a/configs/experiment/audio/sashimi-beethoven.yaml b/configs/experiment/audio/sashimi-beethoven.yaml new file mode 100644 index 0000000..efe76e7 --- /dev/null +++ b/configs/experiment/audio/sashimi-beethoven.yaml @@ -0,0 +1,47 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: beethoven + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: sashimi + +model: + n_layers: 8 + pool: + - 4 + - 4 + dropout: 0.0 + prenorm: True + +train: + monitor: val/loss + mode: min + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding + +decoder: + _name_: sequence + mode: last + +loader: + batch_size: 1 + +trainer: + max_epochs: 1000 + +optimizer: + lr: 0.004 + +scheduler: + patience: 20 diff --git a/configs/experiment/audio/sashimi-sc09.yaml b/configs/experiment/audio/sashimi-sc09.yaml new file mode 100644 index 0000000..71a619b --- /dev/null +++ b/configs/experiment/audio/sashimi-sc09.yaml @@ -0,0 +1,46 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: sc09 + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: sashimi + +model: + n_layers: 8 + expand: 2 + ff: 2 + pool: + - 4 + - 4 + dropout: 0.0 + prenorm: True + + layer: + n_ssm: 1 + +train: + monitor: val/loss + mode: min + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding + +decoder: + _name_: sequence + mode: last + +loader: + batch_size: 32 + +scheduler: + patience: 20 diff --git a/configs/experiment/audio/sashimi-youtubemix.yaml b/configs/experiment/audio/sashimi-youtubemix.yaml new file mode 100644 index 0000000..cb0baba --- /dev/null +++ b/configs/experiment/audio/sashimi-youtubemix.yaml @@ -0,0 +1,47 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: youtubemix + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: sashimi + +model: + n_layers: 8 + pool: + - 4 + - 4 + dropout: 0.0 + prenorm: True + +train: + monitor: val/loss + mode: min + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding + +decoder: + _name_: sequence + mode: last + +loader: + batch_size: 1 + +trainer: + max_epochs: 1000 + +optimizer: + lr: 0.004 + +scheduler: + patience: 20 diff --git a/configs/experiment/audio/wavenet-beethoven.yaml b/configs/experiment/audio/wavenet-beethoven.yaml new file mode 100644 index 0000000..dd886d8 --- /dev/null +++ b/configs/experiment/audio/wavenet-beethoven.yaml @@ -0,0 +1,38 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: beethoven + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/wavenet + +train: + monitor: val/loss + mode: min + +dataset: + pad_len: 4093 + +model: + skip_channels: 1024 + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding +decoder: + _name_: sequence + mode: last + +scheduler: + patience: 5 + +loader: + batch_size: 1 diff --git a/configs/experiment/audio/wavenet-qautomusic.yaml b/configs/experiment/audio/wavenet-qautomusic.yaml new file mode 100644 index 0000000..8f8ec3b --- /dev/null +++ b/configs/experiment/audio/wavenet-qautomusic.yaml @@ -0,0 +1,32 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: qautomusic + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/wavenet + +train: + monitor: val/loss + mode: min + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding +decoder: + _name_: sequence + mode: last + +scheduler: + patience: 5 + +loader: + batch_size: 1 diff --git a/configs/experiment/audio/wavenet-sc09.yaml b/configs/experiment/audio/wavenet-sc09.yaml new file mode 100644 index 0000000..2c877c7 --- /dev/null +++ b/configs/experiment/audio/wavenet-sc09.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: sc09 + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/wavenet + +train: + monitor: val/loss + mode: min + +dataset: + quantization: mu-law + pad_len: 4093 + +model: + skip_channels: 1024 + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding +decoder: + _name_: sequence + mode: last + +scheduler: + patience: 5 + +loader: + batch_size: 8 diff --git a/configs/experiment/audio/wavenet-youtubemix.yaml b/configs/experiment/audio/wavenet-youtubemix.yaml new file mode 100644 index 0000000..9ecc309 --- /dev/null +++ b/configs/experiment/audio/wavenet-youtubemix.yaml @@ -0,0 +1,38 @@ +# @package _global_ +defaults: + - /trainer: default + - /loader: default + - /dataset: youtubemix + - /task: multiclass_classification + - /optimizer: adamw + - /scheduler: plateau + - /model: baseline/wavenet + +train: + monitor: val/loss + mode: min + +dataset: + pad_len: 4093 + +model: + skip_channels: 1024 + +task: + metrics: + - bpb + - accuracy + - accuracy@3 + - accuracy@5 + - accuracy@10 + +encoder: embedding +decoder: + _name_: sequence + mode: last + +scheduler: + patience: 5 + +loader: + batch_size: 1 diff --git a/configs/experiment/base.yaml b/configs/experiment/base.yaml index f9f4846..6ef496d 100644 --- a/configs/experiment/base.yaml +++ b/configs/experiment/base.yaml @@ -1,5 +1,6 @@ # @package _global_ defaults: - /pipeline: mnist - - /model: base -# This file is a bare bones config for an experiment, consisting of a pipeline and model backbone + - /model: s4 + +# This file is a bare bones config for an experiment for illustration, consisting of a pipeline and model backbone diff --git a/configs/experiment/bidmc/ckconv-bidmc.yaml b/configs/experiment/bidmc/ckconv-bidmc.yaml new file mode 100644 index 0000000..03d87b2 --- /dev/null +++ b/configs/experiment/bidmc/ckconv-bidmc.yaml @@ -0,0 +1,34 @@ +# @package _global_ +defaults: + - /pipeline: adding + - /model: baseline/ckconv + - override /dataset: bidmc + - override /scheduler: multistep + +dataset: + target: SpO2 # 'RR' | 'HR' | 'SpO2' + +model: + d_input: 2 + d_output: 1 + dropout: 0.0 + +encoder: id +decoder: id + +loader: + batch_size: 32 + +optimizer: + lr: 0.01 + weight_decay: 0.00 + +trainer: + max_epochs: 500 + +scheduler: + milestones: [100,200,300,400,500] + gamma: 0.5 + +train: + seed: 1112 diff --git a/configs/experiment/bidmc/resnet-bidmc.yaml b/configs/experiment/bidmc/resnet-bidmc.yaml new file mode 100644 index 0000000..745ddfb --- /dev/null +++ b/configs/experiment/bidmc/resnet-bidmc.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - /pipeline: adding + - /model: nonaka/resnet + - override /dataset: bidmc + - override /scheduler: timm_cosine + +dataset: + target: SpO2 # 'RR' | 'HR' | 'SpO2' + +model: + input_channels: 2 + num_classes: 1 + +encoder: id +decoder: id + +loader: + batch_size: 32 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +trainer: + max_epochs: 500 + +scheduler: + # milestones: [100,200,300,400,500] + # gamma: 0.5 + t_initial: ${trainer.max_epochs} + warmup_t: 5 + +train: + seed: 1112 diff --git a/configs/experiment/bidmc/s4-bidmc-ablation.yaml b/configs/experiment/bidmc/s4-bidmc-ablation.yaml new file mode 100644 index 0000000..b82060c --- /dev/null +++ b/configs/experiment/bidmc/s4-bidmc-ablation.yaml @@ -0,0 +1,42 @@ +# @package _global_ +defaults: + - /pipeline: adding + - /model: s4 + - override /dataset: bidmc + - override /scheduler: multistep + +dataset: + target: RR # 'RR' | 'HR' | 'SpO2' + +model: + dropout: 0.0 + n_layers: 4 + d_model: 128 + prenorm: true + layer: + rank: 1 + measure: legs + deterministic: false + d_state: 256 + bidirectional: true + postact: null + +decoder: + mode: pool + +loader: + batch_size: 32 + +optimizer: + lr: 0.004 + weight_decay: 0.01 + +trainer: + max_epochs: 200 + +scheduler: + milestones: [50, 100, 150, 200] + gamma: 0.5 + +train: + seed: 1111 diff --git a/configs/experiment/bidmc/s4-bidmc.yaml b/configs/experiment/bidmc/s4-bidmc.yaml new file mode 100644 index 0000000..af82617 --- /dev/null +++ b/configs/experiment/bidmc/s4-bidmc.yaml @@ -0,0 +1,44 @@ +# @package _global_ +defaults: + - /pipeline: adding + - /model: s4 + - override /dataset: bidmc + - override /scheduler: multistep + +dataset: + target: RR # 'RR' | 'HR' | 'SpO2' + +model: + dropout: 0.0 + n_layers: 6 + d_model: 128 + prenorm: true + layer: + rank: 1 + measure: legs + deterministic: false + d_state: 256 + lr: 0.001 + bidirectional: true + postact: glu + n_ssm: 2 + +decoder: + mode: pool + +loader: + batch_size: 32 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +trainer: + max_epochs: 500 + +scheduler: + milestones: [100,200,300,400,500] + gamma: 0.5 + +train: + seed: 1112 diff --git a/configs/experiment/cifar/cnn-cifar-2d.yaml b/configs/experiment/cifar/cnn-cifar-2d.yaml new file mode 100644 index 0000000..0a0d601 --- /dev/null +++ b/configs/experiment/cifar/cnn-cifar-2d.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + - override /model/layer: conv2d + - override /scheduler: cosine_warmup + +dataset: + permute: 2d + augment: true + +model: + dropout: 0.1 + n_layers: 6 + d_model: 512 + prenorm: true + tie_dropout: true + +loader: + batch_size: 50 + eval_resolutions: [1, 2] + img_size: 32 + +optimizer: + lr: 0.01 + weight_decay: 0.03 + +trainer: + max_epochs: 100 + +scheduler: + num_training_steps: 100000 + +train: + seed: 2222 diff --git a/configs/experiment/cifar/resnet-cifar.yaml b/configs/experiment/cifar/resnet-cifar.yaml new file mode 100644 index 0000000..0f5bc69 --- /dev/null +++ b/configs/experiment/cifar/resnet-cifar.yaml @@ -0,0 +1,28 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: baseline/resnet2d + +model: + variant: resnet18 + +encoder: null +decoder: nd + +dataset: + permute: 2d + +train: + seed: 1111 + +optimizer: + lr: 0.01 + +loader: + batch_size: 50 + +trainer: + max_epochs: 100 + +scheduler: + patience: 10 diff --git a/configs/experiment/cifar/s4-cifar-ablation.yaml b/configs/experiment/cifar/s4-cifar-ablation.yaml new file mode 100644 index 0000000..f4ad026 --- /dev/null +++ b/configs/experiment/cifar/s4-cifar-ablation.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + - override /scheduler: cosine_warmup + +model: + dropout: 0.1 + tie_dropout: true + n_layers: 4 + d_model: 128 + prenorm: false + layer: + rank: 1 + measure: legs + deterministic: false + d_state: 64 + bidirectional: true + postact: null + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.01 + +trainer: + max_epochs: 100 + +scheduler: + num_training_steps: 100000 + +train: + seed: 1111 diff --git a/configs/experiment/cifar/s4-cifar.yaml b/configs/experiment/cifar/s4-cifar.yaml new file mode 100644 index 0000000..5f80cd8 --- /dev/null +++ b/configs/experiment/cifar/s4-cifar.yaml @@ -0,0 +1,32 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + - override /scheduler: cosine_warmup + +model: + dropout: 0.1 + tie_dropout: true + n_layers: 6 + d_model: 512 + prenorm: false + layer: + bidirectional: true + postact: glu + n_ssm: 2 + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +trainer: + max_epochs: 200 + +scheduler: + num_training_steps: 200000 + +train: + seed: 2222 diff --git a/configs/experiment/cifar/s4d-minimal-cifar.yaml b/configs/experiment/cifar/s4d-minimal-cifar.yaml new file mode 100644 index 0000000..ee84c9e --- /dev/null +++ b/configs/experiment/cifar/s4d-minimal-cifar.yaml @@ -0,0 +1,35 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + - override /model/layer: s4d_minimal + - override /scheduler: cosine_warmup + +model: + dropout: 0.1 + tie_dropout: true + n_layers: 4 + d_model: 128 + prenorm: false + layer: + # scaling: linear + d_state: 64 + lr: 0.001 + # postact: glu + # bidirectional: false + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.01 + +trainer: + max_epochs: 100 + +scheduler: + num_training_steps: 100000 + +train: + seed: 1111 diff --git a/configs/experiment/forecasting/s4-informer-ecl.yaml b/configs/experiment/forecasting/s4-informer-ecl.yaml new file mode 100644 index 0000000..e56e9d2 --- /dev/null +++ b/configs/experiment/forecasting/s4-informer-ecl.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - /pipeline: informer + - /model: s4 + - override /dataset: ecl + +trainer: + max_epochs: 10 + +loader: + batch_size: 50 + +model: + dropout: 0.25 + n_layers: 2 + d_model: 128 + +optimizer: + lr: 0.01 diff --git a/configs/experiment/forecasting/s4-informer-etth.yaml b/configs/experiment/forecasting/s4-informer-etth.yaml new file mode 100644 index 0000000..96527dd --- /dev/null +++ b/configs/experiment/forecasting/s4-informer-etth.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - /pipeline: informer + - /model: s4 + - override /dataset: etth + +trainer: + max_epochs: 10 + +loader: + batch_size: 50 + +model: + dropout: 0.25 + n_layers: 2 + d_model: 128 + +optimizer: + lr: 0.01 diff --git a/configs/experiment/forecasting/s4-informer-ettm.yaml b/configs/experiment/forecasting/s4-informer-ettm.yaml new file mode 100644 index 0000000..3b48ccc --- /dev/null +++ b/configs/experiment/forecasting/s4-informer-ettm.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - /pipeline: informer + - /model: s4 + - override /dataset: ettm + +trainer: + max_epochs: 10 + +loader: + batch_size: 50 + +model: + dropout: 0.25 + n_layers: 2 + d_model: 128 + +optimizer: + lr: 0.01 diff --git a/configs/experiment/forecasting/s4-informer-weather.yaml b/configs/experiment/forecasting/s4-informer-weather.yaml new file mode 100644 index 0000000..3b60a25 --- /dev/null +++ b/configs/experiment/forecasting/s4-informer-weather.yaml @@ -0,0 +1,19 @@ +# @package _global_ +defaults: + - /pipeline: informer + - /model: s4 + - override /dataset: weather + +trainer: + max_epochs: 10 + +loader: + batch_size: 50 + +model: + dropout: 0.25 + n_layers: 2 + d_model: 128 + +optimizer: + lr: 0.01 diff --git a/configs/experiment/lm/s4-wt103.yaml b/configs/experiment/lm/s4-wt103.yaml new file mode 100644 index 0000000..3ee42d0 --- /dev/null +++ b/configs/experiment/lm/s4-wt103.yaml @@ -0,0 +1,47 @@ +# @package _global_ +defaults: + - /pipeline: wt103 + - /model: s4 + - override /model/layer: s4s4ff + +# Dataset +dataset: + test_split: True +loader: + batch_size: 1 + l_max: 8192 + n_context: 1 + eval: + batch_size: null + l_max: null + +task: + div_val: 4 + dropemb: 0.25 + dropsoft: 0.25 + +# Model +model: + dropinp: 0.0 + dropout: 0.25 + prenorm: True + n_layers: 16 + d_model: 1024 + transposed: false # Saves memory + tie_dropout: false # More standard + +# Optimizer (adamw) +optimizer: + lr: 5e-4 + weight_decay: 0.1 + +# Scheduler (cosine) +trainer: + max_epochs: 1000 + +scheduler: + num_warmup_steps: 1000 + num_training_steps: 800000 + +train: + seed: 1111 diff --git a/configs/experiment/lm/transformer-wt103.yaml b/configs/experiment/lm/transformer-wt103.yaml new file mode 100644 index 0000000..dcc1d65 --- /dev/null +++ b/configs/experiment/lm/transformer-wt103.yaml @@ -0,0 +1,49 @@ +# @package _global_ +defaults: + - /pipeline: wt103 + - /model: transformer + +# Dataset +dataset: + test_split: True +loader: + batch_size: 16 + l_max: 512 + n_context: 1 + eval: + batch_size: null + l_max: null + +task: + div_val: 4 + dropemb: 0.1 + dropsoft: 0.1 + +# Model +model: + dropinp: 0.0 + dropout: 0.1 + prenorm: true + n_layers: 16 + d_model: 512 + residual: R + prenorm: False + +# Optimizer +# optimizer: adamw +optimizer: + lr: 5e-4 + weight_decay: 0.0 + +# Scheduler +trainer: + max_epochs: 40 + gradient_clip_val: 0.25 + accumulate_grad_batches: 2 + +scheduler: # cosine_warmup + num_warmup_steps: 1000 + num_training_steps: 40000 + +train: + seed: 1111 diff --git a/configs/experiment/lra/old/s4-lra-aan.yaml b/configs/experiment/lra/old/s4-lra-aan.yaml new file mode 100644 index 0000000..c53d17c --- /dev/null +++ b/configs/experiment/lra/old/s4-lra-aan.yaml @@ -0,0 +1,26 @@ +# @package _global_ +defaults: + - /pipeline: aan + - /model: s4 + +model: + dropout: 0. + n_layers: 6 + prenorm: true + d_model: 256 + norm: batch + +loader: + batch_size: 64 + +optimizer: + lr: 0.002 + +scheduler: + patience: 20 + +trainer: + max_epochs: 25 + +train: + seed: 1112 diff --git a/configs/experiment/lra/old/s4-lra-cifar.yaml b/configs/experiment/lra/old/s4-lra-cifar.yaml new file mode 100644 index 0000000..9eed29a --- /dev/null +++ b/configs/experiment/lra/old/s4-lra-cifar.yaml @@ -0,0 +1,31 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + +model: + dropout: 0.2 + tie_dropout: true + n_layers: 6 + d_model: 512 + prenorm: false + norm: batch + +dataset: + grayscale: true + +loader: + batch_size: 50 + +optimizer: + lr: 0.004 + weight_decay: 0.01 + +scheduler: + patience: 10 + +trainer: + max_epochs: 100 + +train: + seed: 1112 diff --git a/configs/experiment/lra/old/s4-lra-imdb.yaml b/configs/experiment/lra/old/s4-lra-imdb.yaml new file mode 100644 index 0000000..f3b3dfe --- /dev/null +++ b/configs/experiment/lra/old/s4-lra-imdb.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - /pipeline: imdb + - /model: s4 + +model: + dropout: 0.0 + n_layers: 4 + d_model: 128 + prenorm: true + norm: batch + +dataset: + l_max: 2048 + level: char + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + +scheduler: + patience: 10 + +trainer: + max_epochs: 40 + +train: + seed: 1112 diff --git a/configs/experiment/lra/old/s4-lra-listops.yaml b/configs/experiment/lra/old/s4-lra-listops.yaml new file mode 100644 index 0000000..5591f18 --- /dev/null +++ b/configs/experiment/lra/old/s4-lra-listops.yaml @@ -0,0 +1,30 @@ +# @package _global_ +defaults: + - /pipeline: listops + - /model: s4 + +model: + dropout: 0. + n_layers: 6 + d_model: 128 + prenorm: false + norm: batch + +decoder: + mode: pool + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.01 + +scheduler: + patience: 5 + +trainer: + max_epochs: 50 + +train: + seed: 1112 diff --git a/configs/experiment/lra/old/s4-lra-pathfinder.yaml b/configs/experiment/lra/old/s4-lra-pathfinder.yaml new file mode 100644 index 0000000..e62ac37 --- /dev/null +++ b/configs/experiment/lra/old/s4-lra-pathfinder.yaml @@ -0,0 +1,29 @@ +# @package _global_ +defaults: + - /pipeline: pathfinder + - /model: s4 + +model: + dropout: 0.1 + n_layers: 6 + prenorm: true + d_model: 256 + norm: batch + +decoder: + mode: last + +loader: + batch_size: 100 + +optimizer: + lr: 0.004 + +scheduler: + patience: 10 + +trainer: + max_epochs: 200 + +train: + seed: 1112 diff --git a/configs/experiment/lra/old/s4-lra-pathx.yaml b/configs/experiment/lra/old/s4-lra-pathx.yaml new file mode 100644 index 0000000..a38cbc2 --- /dev/null +++ b/configs/experiment/lra/old/s4-lra-pathx.yaml @@ -0,0 +1,26 @@ +# @package _global_ +defaults: + - /pipeline: pathx + - /model: s4 + +model: + dropout: 0. + n_layers: 6 + prenorm: true + d_model: 256 + norm: batch + +loader: + batch_size: 32 + +optimizer: + lr: 0.0005 + +scheduler: + patience: 40 + +trainer: + max_epochs: 100 + +train: + seed: 1112 diff --git a/configs/experiment/lra/s4-lra-aan.yaml b/configs/experiment/lra/s4-lra-aan.yaml new file mode 100644 index 0000000..98b6f26 --- /dev/null +++ b/configs/experiment/lra/s4-lra-aan.yaml @@ -0,0 +1,42 @@ +# @package _global_ +defaults: + - /pipeline: aan + - /model: s4 + - override /scheduler: cosine_warmup + +scheduler: + num_training_steps: 50000 # 20 epochs + num_warmup_steps: 2500 # 1 epoch + +model: + dropout: 0. + n_layers: 6 + prenorm: true + d_model: 256 + norm: batch + layer: + d_state: 64 + lr: + dt: null + A: 0.001 + B: 0.001 + dt_min: 0.001 + dt_max: 0.1 + measure: legs + bidirectional: true + postact: glu + n_ssm: 256 + +loader: + batch_size: 64 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +trainer: + max_epochs: 20 + +train: + seed: 2222 + interval: step diff --git a/configs/experiment/lra/s4-lra-cifar.yaml b/configs/experiment/lra/s4-lra-cifar.yaml new file mode 100644 index 0000000..8d88ecb --- /dev/null +++ b/configs/experiment/lra/s4-lra-cifar.yaml @@ -0,0 +1,44 @@ +# @package _global_ +defaults: + - /pipeline: cifar + - /model: s4 + - override /scheduler: cosine_warmup + +model: + dropout: 0.1 + tie_dropout: true + n_layers: 6 + d_model: 512 + prenorm: false + norm: layer + layer: + d_state: 64 + lr: + dt: null + A: 0.001 + B: 0.001 + dt_min: 0.001 + dt_max: 0.1 + measure: legs + bidirectional: true + postact: glu + n_ssm: 2 + +dataset: + grayscale: true + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +trainer: + max_epochs: 200 + +scheduler: + num_training_steps: 200000 + +train: + seed: 2222 diff --git a/configs/experiment/lra/s4-lra-imdb.yaml b/configs/experiment/lra/s4-lra-imdb.yaml new file mode 100644 index 0000000..ead43b0 --- /dev/null +++ b/configs/experiment/lra/s4-lra-imdb.yaml @@ -0,0 +1,47 @@ +# @package _global_ +defaults: + - /pipeline: imdb + - /model: s4 + - override /scheduler: cosine_warmup + +decoder: + mode: pool + +model: + dropout: 0.0 + n_layers: 6 + d_model: 256 + prenorm: true + norm: batch + layer: + d_state: 64 + lr: + dt: null + A: 0.001 + B: 0.001 + dt_min: 0.001 + dt_max: 0.1 + measure: legs + bidirectional: true + postact: glu + n_ssm: ${..d_model} + +dataset: + l_max: 4096 + level: char + +loader: + batch_size: 16 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +scheduler: + num_training_steps: 50000 + +trainer: + max_epochs: 32 + +train: + seed: 2222 diff --git a/configs/experiment/lra/s4-lra-listops.yaml b/configs/experiment/lra/s4-lra-listops.yaml new file mode 100644 index 0000000..41693eb --- /dev/null +++ b/configs/experiment/lra/s4-lra-listops.yaml @@ -0,0 +1,44 @@ +# @package _global_ +defaults: + - /pipeline: listops + - /model: s4 + - override /scheduler: cosine_warmup + +model: + dropout: 0. + n_layers: 8 + d_model: 128 + prenorm: false + norm: batch + layer: + d_state: 64 + lr: + dt: null + A: 0.001 + B: 0.001 + dt_min: 0.001 + dt_max: 0.1 + measure: legs + bidirectional: true + postact: glu + n_ssm: ${..d_model} + +decoder: + mode: pool + +loader: + batch_size: 50 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +scheduler: + num_training_steps: 80000 + # patience: 5 + +trainer: + max_epochs: 40 + +train: + seed: 2222 diff --git a/configs/experiment/lra/s4-lra-pathfinder.yaml b/configs/experiment/lra/s4-lra-pathfinder.yaml new file mode 100644 index 0000000..44611ea --- /dev/null +++ b/configs/experiment/lra/s4-lra-pathfinder.yaml @@ -0,0 +1,42 @@ +# @package _global_ +defaults: + - /pipeline: pathfinder + - /model: s4 + - override /scheduler: cosine_warmup + +scheduler: + num_training_steps: 500000 # 200 epochs + num_warmup_steps: 2500 # 1 epoch + +model: + dropout: 0.0 + n_layers: 6 + prenorm: true + d_model: 256 + norm: batch + layer: + d_state: 64 + lr: 0.001 + dt_min: 0.001 + dt_max: 0.1 + measure: legs + bidirectional: true + postact: glu + n_ssm: 256 + +decoder: + mode: last + +loader: + batch_size: 64 + +optimizer: + lr: 0.004 + weight_decay: 0.03 + +trainer: + max_epochs: 200 + +train: + seed: 2222 + interval: step diff --git a/configs/experiment/lra/s4-lra-pathx.yaml b/configs/experiment/lra/s4-lra-pathx.yaml new file mode 100644 index 0000000..92695da --- /dev/null +++ b/configs/experiment/lra/s4-lra-pathx.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - /pipeline: pathx + - /model: s4 + - override /scheduler: cosine_warmup + +scheduler: + num_training_steps: 500000 # 50 epochs + num_warmup_steps: 10000 # 1 epoch + +model: + dropout: 0. + n_layers: 6 + prenorm: true + d_model: 256 + norm: batch + layer: + d_state: 64 + lr: 0.0005 + dt_min: 0.0001 + dt_max: 0.1 + measure: legs + bidirectional: true + postact: glu + n_ssm: 256 + +loader: + batch_size: 16 + +optimizer: + lr: 0.0005 + weight_decay: 0.05 + +trainer: + max_epochs: 50 + +train: + seed: 1112 + interval: step diff --git a/configs/experiment/rnn.yaml b/configs/experiment/rnn.yaml new file mode 100644 index 0000000..f26cbb0 --- /dev/null +++ b/configs/experiment/rnn.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Basic rnn experiment +# Override model/layer/cell for different RNN cells, e.g. `python -m train experiment=rnn model/layer/cell=hippo-legs +defaults: + - /pipeline: mnist + - /model: s4 + - override /model/layer: rnn + +# Different default settings for model backbone +model: + prenorm: False + transposed: False + n_layers: 1 + d_model: 256 + residual: N + pool: null + norm: none + dropout: 0.0 + tie_dropout: false + # In the 1 layer case, memory optimization by not returning outputs + track_norms: false + layer: + return_output: false + +# Decode using the end state +decoder: state diff --git a/configs/experiment/sc/convnet-sc.yaml b/configs/experiment/sc/convnet-sc.yaml new file mode 100644 index 0000000..be30c35 --- /dev/null +++ b/configs/experiment/sc/convnet-sc.yaml @@ -0,0 +1,29 @@ +# @package _global_ +defaults: + - /pipeline: sc + - /model: convnet_1d + - override /scheduler: cosine_warmup + +dataset: + length: 16384 + +encoder: + _name_: conv1d + kernel_size: 25 + stride: 1 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +loader: + batch_size: 16 + +trainer: + max_epochs: 40 + +scheduler: + num_training_steps: 200000 + +train: + seed: 2222 diff --git a/configs/experiment/sc/resnet-sc.yaml b/configs/experiment/sc/resnet-sc.yaml new file mode 100644 index 0000000..59b1b7f --- /dev/null +++ b/configs/experiment/sc/resnet-sc.yaml @@ -0,0 +1,28 @@ +# @package _global_ +defaults: + - /pipeline: sc + - /model: nonaka/resnet + - override /scheduler: cosine_warmup + +model: + num_classes: 35 + input_channels: 1 + +encoder: id +decoder: id + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +loader: + batch_size: 16 + +trainer: + max_epochs: 40 + +scheduler: + num_training_steps: 200000 + +train: + seed: 2222 diff --git a/configs/experiment/sc/s4-sc-ablation.yaml b/configs/experiment/sc/s4-sc-ablation.yaml new file mode 100644 index 0000000..6711127 --- /dev/null +++ b/configs/experiment/sc/s4-sc-ablation.yaml @@ -0,0 +1,37 @@ +# @package _global_ +defaults: + - /pipeline: sc + - /model: s4 + - override /scheduler: cosine_warmup + +model: + dropout: 0.0 + n_layers: 4 + prenorm: true + d_model: 128 + norm: batch + layer: + d_state: 64 + bidirectional: False + # resample: true + dt_min: 0.001 + dt_max: 0.1 + postact: null + +decoder: + mode: last + +optimizer: + lr: 0.004 + +loader: + batch_size: 32 + +trainer: + max_epochs: 10 + +scheduler: + num_training_steps: 50000 + +train: + seed: 1111 diff --git a/configs/experiment/sc/s4-sc.yaml b/configs/experiment/sc/s4-sc.yaml new file mode 100644 index 0000000..79cf045 --- /dev/null +++ b/configs/experiment/sc/s4-sc.yaml @@ -0,0 +1,34 @@ +# @package _global_ +# Should get to around 96.5% +defaults: + - /pipeline: sc + - /model: s4 + - override /scheduler: cosine_warmup + +model: + dropout: 0.0 + n_layers: 6 + prenorm: true + d_model: 128 + norm: batch + layer: + d_state: 64 + bidirectional: true + postact: glu + n_ssm: 2 + +optimizer: + lr: 0.01 + weight_decay: 0.05 + +loader: + batch_size: 16 + +trainer: + max_epochs: 40 + +scheduler: + num_training_steps: 200000 + +train: + seed: 2222 diff --git a/configs/experiment/sc/transformer-sc.yaml b/configs/experiment/sc/transformer-sc.yaml new file mode 100644 index 0000000..b5c7227 --- /dev/null +++ b/configs/experiment/sc/transformer-sc.yaml @@ -0,0 +1,26 @@ +# @package _global_ +defaults: + - /pipeline: sc + - /model: transformer + - override /model/layer: performer + +model: + dropout: 0.0 + n_layers: 4 + d_model: 128 + prenorm: true + +optimizer: + lr: 0.001 + +loader: + batch_size: 16 + +trainer: + max_epochs: 200 + +scheduler: + patience: 10 + +train: + seed: 1112 diff --git a/configs/experiment/synthetic/s4-copying.yaml b/configs/experiment/synthetic/s4-copying.yaml new file mode 100644 index 0000000..6dfa2ba --- /dev/null +++ b/configs/experiment/synthetic/s4-copying.yaml @@ -0,0 +1,39 @@ +# @package _global_ +defaults: + - /pipeline: copying + - /model: s4 + +dataset: + l_noise: 0 + l_memorize: 1024 + n_tokens: 64 + n_train: 10000 + n_eval: 1000 + +model: + dropout: 0.0 + n_layers: 4 + d_model: 64 + prenorm: true + layer: + rank: 2 + measure: fourier + deterministic: true + dt_min: ${eval:${model.n_layers}/${dataset.l_memorize}} # 0.0008 + dt_max: ${eval:${model.n_layers}/${dataset.l_memorize}} # 0.001 + d_state: 256 + lr: 0.0001 + # shift: true + +loader: + batch_size: 8 + +optimizer: + lr: 0.001 + weight_decay: 0.00 + +trainer: + max_epochs: 1000 + +train: + seed: 1112 diff --git a/configs/experiment/synthetic/s4-delay.yaml b/configs/experiment/synthetic/s4-delay.yaml new file mode 100644 index 0000000..e77038c --- /dev/null +++ b/configs/experiment/synthetic/s4-delay.yaml @@ -0,0 +1,37 @@ +# @package _global_ +defaults: + - /pipeline: delay + - /model: s4 + +model: + dropout: 0.0 + n_layers: 1 + d_model: 4 + prenorm: true + norm: null # Fully linear model + residual: null + layer: + rank: 1 + measure: legs + deterministic: false + d_state: 1024 + lr: + A: 0.001 + B: 0.001 + dt: 0.001 + dt_min: 0.002 + dt_max: 0.002 + linear: true + +loader: + batch_size: 64 + +optimizer: + lr: 0.001 + weight_decay: 0.0 + +trainer: + max_epochs: 20 + +train: + seed: 1112 diff --git a/configs/experiment/synthetic/s4-reconstruct.yaml b/configs/experiment/synthetic/s4-reconstruct.yaml new file mode 100644 index 0000000..9a3356e --- /dev/null +++ b/configs/experiment/synthetic/s4-reconstruct.yaml @@ -0,0 +1,37 @@ +# @package _global_ +defaults: + - /pipeline: reconstruct + - /model: s4 + +model: + dropout: 0.0 + n_layers: 1 + d_model: 256 + prenorm: true + norm: null + residual: null + layer: + rank: 1 + measure: fourier + deterministic: false + d_state: 256 + lr: + A: 0.0 + B: 0.0 + dt: 0.001 + dt_min: 0.002 + dt_max: 0.002 + linear: 1 + +loader: + batch_size: 64 + +optimizer: + lr: 0.001 + weight_decay: 0.0 + +trainer: + max_epochs: 20 + +train: + seed: 1112 diff --git a/configs/generate.yaml b/configs/generate.yaml new file mode 100644 index 0000000..651efec --- /dev/null +++ b/configs/generate.yaml @@ -0,0 +1,16 @@ +defaults: + - config + + +experiment_path: null # Path to state-spaces experiment folder +checkpoint_path: checkpoints/val/loss.ckpt # Relative path to checkpoint in state-spaces experiment folder. Uses checkpoints/val/loss.ckpt by default. +l_sample: 16000 # Sample length +n_samples: 32 # Number of distinct conditioning samples drawn from dataset +n_reps: 1 # Number of times to replicate each sample +l_prefix: 0 # Prefix length: num steps to use for conditioning +top_p: 1. # Nucleus sampling +temp: 1. # Temperature +split: val # If conditioning, which split of the data to use ['val' | 'test'] +save_dir: sashimi/samples # Save directory. Pass in 'null' (None) to save in Hydra directory to ensure that samples are not overridden +load_data: true # Load the dataset (set to false to disable if not conditioning) +decode: audio # Decoding mode ['audio' | 'text' | None]. The pretrained WikiText-103 model currently does not generate correctly diff --git a/configs/loader/default.yaml b/configs/loader/default.yaml new file mode 100644 index 0000000..738bdec --- /dev/null +++ b/configs/loader/default.yaml @@ -0,0 +1,4 @@ +batch_size: 50 +num_workers: 4 +pin_memory: True +drop_last: True # We set this to true because of the recurrent state mechanism \ No newline at end of file diff --git a/configs/loader/imresolution.yaml b/configs/loader/imresolution.yaml new file mode 100644 index 0000000..7d31943 --- /dev/null +++ b/configs/loader/imresolution.yaml @@ -0,0 +1,10 @@ +batch_size: 50 +num_workers: 4 +pin_memory: True +drop_last: True # We set this to true because of the recurrent state mechanism + +train_resolution: 1 +eval_resolutions: + - 1 +img_size: null +channels_last: true diff --git a/configs/loader/resolution.yaml b/configs/loader/resolution.yaml new file mode 100644 index 0000000..a99c97c --- /dev/null +++ b/configs/loader/resolution.yaml @@ -0,0 +1,8 @@ +batch_size: 50 +num_workers: 4 +pin_memory: True +drop_last: True # We set this to true because of the recurrent state mechanism + +train_resolution: 1 +eval_resolutions: + - 1 diff --git a/configs/loader/tbptt.yaml b/configs/loader/tbptt.yaml new file mode 100644 index 0000000..73572e1 --- /dev/null +++ b/configs/loader/tbptt.yaml @@ -0,0 +1,8 @@ +_name_: tbptt +chunk_len: null # [ None | int ] chunk length for tbptt (used by TBPTTDataLoader) +overlap_len: null # [ None | int ] overlap length for tbptt (used by TBPTTDataLoader) + +batch_size: 50 +num_workers: 4 +pin_memory: True +drop_last: True # We set this to true because of the recurrent state mechanism diff --git a/configs/model/README.md b/configs/model/README.md new file mode 100644 index 0000000..718be43 --- /dev/null +++ b/configs/model/README.md @@ -0,0 +1,27 @@ + +The `model/` configs largely follow the structure of the `src/models/` code folder. + +## Backbones +Top-level configs use the model backbone structure specified by this repository. +These models consist of **backbones** that are composed of repeatable blocks of core **layers**. +The backbones include a simple isotropic residual backbone (in the style of ResNets and Transformers) (`base.yaml`) and variations of UNet structures (`unet.yaml`, `sashimi.yaml`). + +## Layers +Layers configs are defined in `model/layer/`. Each one of these instantiates a `src.models.sequence.base.SequenceModule` which maps an input sequence to output sequence, and can be passed into the various backbones. +Older versions of HiPPO focused on RNNs, and defined a flexible RNN layer in `model/layer/rnn.yaml`. This RNN accepts any RNN cell, with example configs in `model/layer/cell/`. + +## Examples +Some examples of full models are provided which combine a backbone with a choice of inner layer, such as `convnet1d.yaml` (a simple 1D residual convnet), `s4.yaml` (basic isotropic S4 model), and `transformer.yaml` (isotropic Transformer model composed of alternating layers of self-attention and feed-forward network). + +## Other Baselines + +Other baseline models are included that do not necessarily follow this structure. +``` +baseline/ Miscellaneous baselines from the literature +nonaka/ 1-D CNN models ported from the paper [Nonaka, Seita] + "In-depth Benchmarking of Deep Neural Network Architectures for ECG Diagnosis" + (https://github.com/seitalab/dnn_ecg_comparison) +segmentation/ Segmentation models (preliminary) +timm/ Ports of timm ResNet and ConvNext models +vit/ Ports of vit models +``` diff --git a/configs/model/base.yaml b/configs/model/base.yaml index e054f9c..461f0f8 100644 --- a/configs/model/base.yaml +++ b/configs/model/base.yaml @@ -1,15 +1,21 @@ -# _target_: models.sequence.SequenceModel defaults: - layer: s4 + _name_: model -prenorm: False -transposed: False +prenorm: true +transposed: false n_layers: 4 d_model: 256 residual: R pool: - _name_: sample - pool: 1 + _name_: pool + stride: 1 expand: 1 norm: layer -dropout: 0.2 +dropout: 0.0 +tie_dropout: false +track_norms: true # Logs to wandb + +# Optional encoder/decoder, e.g. add positional embeddings or padding masks +encoder: null +decoder: null diff --git a/configs/model/baseline/ckconv.yaml b/configs/model/baseline/ckconv.yaml new file mode 100644 index 0000000..0b51512 --- /dev/null +++ b/configs/model/baseline/ckconv.yaml @@ -0,0 +1,13 @@ +_name_: ckconv +d_model: 30 # taken from Table 8 +num_blocks: 2 # always 2 blocks +kernelnet_hidden_channels: 32 +kernelnet_activation_function: Sine +kernelnet_norm_type: LayerNorm +dim_linear: 1 # 1 / 2 +bias: True # False / True +omega_0: 40.0 # good values in [0, 70.] +dropout: 0.2 +weight_dropout: 0.0 +pool: False # always set to False by CKConv +wd: 0. diff --git a/configs/model/baseline/lipschitzrnn.yaml b/configs/model/baseline/lipschitzrnn.yaml new file mode 100644 index 0000000..70a867e --- /dev/null +++ b/configs/model/baseline/lipschitzrnn.yaml @@ -0,0 +1,11 @@ +_name_: lipschitzrnn +d_model: 128 +eps: 0.03 +beta: .75 +gamma: .001 +gated: False +init_std: 1 +alpha: 1 +model: LipschitzRNN +solver: euler +chunk: 1 diff --git a/configs/model/baseline/lstm.yaml b/configs/model/baseline/lstm.yaml new file mode 100644 index 0000000..c2c1bd6 --- /dev/null +++ b/configs/model/baseline/lstm.yaml @@ -0,0 +1,4 @@ +_name_: lstm +d_model: 256 +d_hidden: ${.d_model} +n_layers: 1 diff --git a/configs/model/baseline/odelstm.yaml b/configs/model/baseline/odelstm.yaml new file mode 100644 index 0000000..7a12f89 --- /dev/null +++ b/configs/model/baseline/odelstm.yaml @@ -0,0 +1,4 @@ +_name_: odelstm +d_model: 64 +solver_type: dopri5 +return_sequences: true diff --git a/configs/model/baseline/resnet2d.yaml b/configs/model/baseline/resnet2d.yaml new file mode 100644 index 0000000..7e742b1 --- /dev/null +++ b/configs/model/baseline/resnet2d.yaml @@ -0,0 +1,2 @@ +_name_: torch/resnet2d +variant: 18 # resnet18 diff --git a/configs/model/baseline/samplernn.yaml b/configs/model/baseline/samplernn.yaml new file mode 100644 index 0000000..d22b49d --- /dev/null +++ b/configs/model/baseline/samplernn.yaml @@ -0,0 +1,13 @@ +_name_: samplernn +frame_sizes: + - 16 + - 4 +n_rnn: 2 +d_hidden: 1024 +bits: ${..dataset.bits} +learn_h0: true +d_model: 256 +weight_norm: true +reproduce: true +quantization: ${..dataset.quantization} +layer: gru diff --git a/configs/model/baseline/stackedrnn_baseline.yaml b/configs/model/baseline/stackedrnn_baseline.yaml new file mode 100644 index 0000000..a606f14 --- /dev/null +++ b/configs/model/baseline/stackedrnn_baseline.yaml @@ -0,0 +1,9 @@ +_name_: stackedrnn_baseline +d_model: 256 +d_hidden: 1024 +n_layers: 2 +learn_h0: false +rnn_type: gru +skip_connections: true +weight_norm: false +dropout: 0.0 \ No newline at end of file diff --git a/configs/model/baseline/unicornn.yaml b/configs/model/baseline/unicornn.yaml new file mode 100644 index 0000000..d4c900c --- /dev/null +++ b/configs/model/baseline/unicornn.yaml @@ -0,0 +1,7 @@ +# _target_: models.baselines.unicornn.UnICORNN +_name_: unicornn +d_model: 128 # hidden size of recurrent net +dt: 0.046 # default of HR from repo +alpha: 10.0 # default of HR from repo +n_layers: 3 # default of HR from repo +drop: 0.1 diff --git a/configs/model/baseline/wavenet.yaml b/configs/model/baseline/wavenet.yaml new file mode 100644 index 0000000..6c6c54e --- /dev/null +++ b/configs/model/baseline/wavenet.yaml @@ -0,0 +1,9 @@ +_name_: wavenet +layers: 10 +blocks: 4 +dilation_channels: 64 +residual_channels: 64 +skip_channels: 1024 +end_channels: 512 +kernel_size: 2 +classes: 256 \ No newline at end of file diff --git a/configs/model/convnet1d.yaml b/configs/model/convnet1d.yaml new file mode 100644 index 0000000..e8a5f89 --- /dev/null +++ b/configs/model/convnet1d.yaml @@ -0,0 +1,11 @@ +defaults: + - base + - override layer: conv1d + +transposed: true +n_repeat: 3 +d_model: 64 +pool: + stride: 4 + expand: 2 +norm: batch diff --git a/configs/model/convnet2d.yaml b/configs/model/convnet2d.yaml new file mode 100644 index 0000000..f86d172 --- /dev/null +++ b/configs/model/convnet2d.yaml @@ -0,0 +1,4 @@ +# _target_: models.sequence.SequenceModel +defaults: + - convnet1d + - override layer: conv2d diff --git a/configs/model/layer/cell/exprnn.yaml b/configs/model/layer/cell/exprnn.yaml new file mode 100644 index 0000000..c7a865a --- /dev/null +++ b/configs/model/layer/cell/exprnn.yaml @@ -0,0 +1,7 @@ +_name_: exprnn # Can change to gru to allow for orthogonal + gating options, i.e. the GORU +d_model: 256 +hidden_activation: modrelu +orthogonal: True +ortho_args: + method: dtriv # 'cayley' | 'exprnn' | 'dtriv' + init: cayley # 'henaff' | 'cayley' diff --git a/configs/model/layer/cell/goru.yaml b/configs/model/layer/cell/goru.yaml new file mode 100644 index 0000000..c2dc160 --- /dev/null +++ b/configs/model/layer/cell/goru.yaml @@ -0,0 +1,12 @@ +# Example of composing our RNN classes to build more advanced models, +# such as the Gated Orthogonal Recurrent Unit +# https://arxiv.org/abs/1706.02761 +_name_: gru +d_model: 256 +hidden_activation: modrelu +gate: G +reset: G +orthogonal: True +ortho_args: + method: dtriv # 'cayley' | 'exprnn' | 'dtriv' + init: cayley # 'henaff' | 'cayley' diff --git a/configs/model/layer/cell/gru.yaml b/configs/model/layer/cell/gru.yaml new file mode 100644 index 0000000..471cd9b --- /dev/null +++ b/configs/model/layer/cell/gru.yaml @@ -0,0 +1,6 @@ +_name_: gru +d_model: 256 +hidden_activation: tanh +gate: G +reset: G +orthogonal: False diff --git a/configs/model/layer/cell/hippo-glagt.yaml b/configs/model/layer/cell/hippo-glagt.yaml new file mode 100644 index 0000000..cb6ef01 --- /dev/null +++ b/configs/model/layer/cell/hippo-glagt.yaml @@ -0,0 +1,10 @@ +# Config for the HiPPO-GLagT (Translated Generalized Laguerre) cell +_name_: lagt +d_model: 256 +memory_size: 1 +memory_order: -1 +dt: 1.0 # Effective at different dt than HiPPO-LegT, explained in HTTYH paper +discretization: bilinear +measure_args: + alpha: 0.0 + beta: 0.01 diff --git a/configs/model/layer/cell/hippo-lagt.yaml b/configs/model/layer/cell/hippo-lagt.yaml new file mode 100644 index 0000000..d084d32 --- /dev/null +++ b/configs/model/layer/cell/hippo-lagt.yaml @@ -0,0 +1,9 @@ +# Config for the HiPPO-LagT (Translated Laguerre) cell +_name_: lagt +d_model: 256 +memory_size: 1 +memory_order: -1 +dt: 1.0 # Effective at different dt than HiPPO-LegT, explained in HTTYH paper +discretization: bilinear +measure_args: + beta: 1.0 diff --git a/configs/model/layer/cell/hippo-legs.yaml b/configs/model/layer/cell/hippo-legs.yaml new file mode 100644 index 0000000..e24fa73 --- /dev/null +++ b/configs/model/layer/cell/hippo-legs.yaml @@ -0,0 +1,6 @@ +# Config for the HiPPO-LegS (Scaled Legendre) cell +_name_: legs +d_model: 256 # hidden size +memory_size: 1 +memory_order: -1 +discretization: bilinear diff --git a/configs/model/layer/cell/hippo-legt.yaml b/configs/model/layer/cell/hippo-legt.yaml new file mode 100644 index 0000000..6c068f3 --- /dev/null +++ b/configs/model/layer/cell/hippo-legt.yaml @@ -0,0 +1,7 @@ +# Config for the HiPPO-LegT (Translated Legendre) cell +_name_: legt +d_model: 256 +memory_size: 1 +memory_order: -1 +dt: 0.01 +discretization: bilinear diff --git a/configs/model/layer/cell/hippo-timestamp.yaml b/configs/model/layer/cell/hippo-timestamp.yaml new file mode 100644 index 0000000..3fab5e6 --- /dev/null +++ b/configs/model/layer/cell/hippo-timestamp.yaml @@ -0,0 +1,6 @@ +# Config for the HiPPO-LegS (Scaled Legendre) cell with timestamped input +_name_: tlsi +d_model: 256 # hidden size +memory_size: 1 +memory_order: -1 +discretization: bilinear diff --git a/configs/model/layer/cell/lmu.yaml b/configs/model/layer/cell/lmu.yaml new file mode 100644 index 0000000..6abc17a --- /dev/null +++ b/configs/model/layer/cell/lmu.yaml @@ -0,0 +1,8 @@ +_name_: lmu +d_model: 256 +memory_size: 1 +memory_order: -1 +dt: 1.0 +theta: 784 +discretization: zoh +gate: N diff --git a/configs/model/layer/cell/rnn.yaml b/configs/model/layer/cell/rnn.yaml new file mode 100644 index 0000000..7ed6800 --- /dev/null +++ b/configs/model/layer/cell/rnn.yaml @@ -0,0 +1,4 @@ +_name_: rnn +d_model: 256 +hidden_activation: tanh +orthogonal: False diff --git a/configs/model/layer/cell/sru.yaml b/configs/model/layer/cell/sru.yaml new file mode 100644 index 0000000..c1bb557 --- /dev/null +++ b/configs/model/layer/cell/sru.yaml @@ -0,0 +1,14 @@ +_name_: sru +d_model: 256 +residual: H + +# Below illustrates creating a full model using the SRURNN layer instead of SRU cell + +# # @package model +# _name_: model +# layer: +# _name_: sru +# d_model: 256 +# feedback: True +# n_layers: 3 +# residual: H diff --git a/configs/model/layer/conv1d.yaml b/configs/model/layer/conv1d.yaml new file mode 100644 index 0000000..d0849a8 --- /dev/null +++ b/configs/model/layer/conv1d.yaml @@ -0,0 +1,12 @@ +_name_: conv1d +d_output: null +kernel_size: 25 +stride: 1 +padding: 12 +dilation: 1 +groups: 1 +bias: true +padding_mode: zeros +activation: gelu +dropout: null +transposed: True diff --git a/configs/model/layer/conv2d.yaml b/configs/model/layer/conv2d.yaml new file mode 100644 index 0000000..551cb07 --- /dev/null +++ b/configs/model/layer/conv2d.yaml @@ -0,0 +1,12 @@ +_name_: conv2d +# d_output: null +depthwise: false +kernel_size: 3 +stride: 1 +# padding: 1 # Determined by kernel_size +dilation: 1 +# groups: 1 # Set automatically by depthwise flag +bias: true +padding_mode: zeros +dropout: null +transposed: true diff --git a/configs/model/layer/ff.yaml b/configs/model/layer/ff.yaml index d7f0c72..a961c5e 100644 --- a/configs/model/layer/ff.yaml +++ b/configs/model/layer/ff.yaml @@ -2,3 +2,5 @@ _name_: ff expand: 4 dropout: null transposed: False +dropout: 0.0 +tie_dropout: ${model.tie_dropout,null} diff --git a/configs/model/layer/rnn.yaml b/configs/model/layer/rnn.yaml index 9c370ba..6a30ccf 100644 --- a/configs/model/layer/rnn.yaml +++ b/configs/model/layer/rnn.yaml @@ -1 +1,5 @@ +defaults: + - cell: rnn + _name_: rnn +return_output: true diff --git a/configs/model/layer/s4.yaml b/configs/model/layer/s4.yaml index e4604cd..ea759d6 100644 --- a/configs/model/layer/s4.yaml +++ b/configs/model/layer/s4.yaml @@ -3,25 +3,22 @@ d_state: 64 channels: 1 bidirectional: false activation: gelu -postact: null +postact: glu initializer: null weight_norm: false hyper_act: null dropout: ${..dropout} # Same as null +tie_dropout: ${oc.select:model.tie_dropout,null} +mode: nplr measure: legs rank: 1 dt_min: 0.001 dt_max: 0.1 -trainable: - dt: true - A: true - P: true - B: true -lr: 0.001 -mode: nplr +lr: + dt: 0.001 + A: 0.001 + B: 0.001 n_ssm: 1 -liquid: 0 -resample: false deterministic: false # Special C init l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to 1 and kernel will automatically resize -verbose: true \ No newline at end of file +verbose: true diff --git a/configs/model/layer/s4d.yaml b/configs/model/layer/s4d.yaml index ecbc434..669314d 100644 --- a/configs/model/layer/s4d.yaml +++ b/configs/model/layer/s4d.yaml @@ -1,12 +1,5 @@ -_name_: s4d -d_state: 64 -channels: 1 -bidirectional: false -activation: gelu -postact: null -dropout: ${..dropout} # Same as null -scaling: inverse -dt_min: 0.001 -dt_max: 0.1 -lr: 0.001 -n_ssm: 1 +defaults: + - s4 + +mode: diag +measure: diag-lin diff --git a/configs/model/layer/s4d_example.yaml b/configs/model/layer/s4d_example.yaml new file mode 100644 index 0000000..f7beb52 --- /dev/null +++ b/configs/model/layer/s4d_example.yaml @@ -0,0 +1,7 @@ +# Config for standalone example s4d file for testing purposes +_name_: s4d +d_state: 64 +dropout: ${..dropout} # Same as null +dt_min: 0.001 +dt_max: 0.1 +lr: 0.001 diff --git a/configs/model/layer/s4ff.yaml b/configs/model/layer/s4ff.yaml index 72f6444..3add389 100644 --- a/configs/model/layer/s4ff.yaml +++ b/configs/model/layer/s4ff.yaml @@ -1,4 +1,5 @@ # TODO there has to be a way to compose this +# [22-02] TODO needs update - _name_: s4 d_state: 64 measure: legs @@ -10,11 +11,7 @@ B: 2 C: 1 dt: 1 - lr: - dt: ${optimizer.lr} # 0.0005 - A: ${.dt} - B: ${.dt} - C: null + lr: 0.0005 cache: False weight_decay: 0.0 weight_norm: False diff --git a/configs/model/layer/s4nd.yaml b/configs/model/layer/s4nd.yaml new file mode 100644 index 0000000..3bac1f5 --- /dev/null +++ b/configs/model/layer/s4nd.yaml @@ -0,0 +1,26 @@ +_name_: s4nd +d_state: 64 +channels: 1 +bidirectional: true +activation: gelu +postact: glu +initializer: null +weight_norm: false +hyper_act: null +trank: 1 +dropout: ${..dropout} # Same as null +tie_dropout: ${..tie_dropout} +measure: legs +rank: 1 +dt_min: 0.001 +dt_max: 0.1 +lr: + dt: 0.001 + A: 0.001 + B: 0.001 +n_ssm: 1 +deterministic: false # Special C init +l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to 1 and kernel will automatically resize +verbose: true +linear: false +# linear_output: True diff --git a/configs/model/layer/s4s4ff.yaml b/configs/model/layer/s4s4ff.yaml index c0da354..4aa08cb 100644 --- a/configs/model/layer/s4s4ff.yaml +++ b/configs/model/layer/s4s4ff.yaml @@ -1,20 +1,16 @@ -# [22-02] TODO needs update - _name_: s4 l_max: ${dataset.__l_max} postact: glu dropout: ${...dropout} # Same as null - lr: 0.0005 - hurwitz: false - tie_state: false + lr: ${optimizer.lr} + n_ssm: 1 - _name_: s4 l_max: ${dataset.__l_max} postact: glu dropout: ${...dropout} # Same as null - lr: 0.0005 - hurwitz: false - tie_state: false + lr: ${optimizer.lr} + n_ssm: 1 - _name_: ff expand: 4 activation: gelu dropout: ${...dropout} # Same as null - # transposed: False # Set by backbone diff --git a/configs/model/layer/sru.yaml b/configs/model/layer/sru.yaml new file mode 100644 index 0000000..f741f9f --- /dev/null +++ b/configs/model/layer/sru.yaml @@ -0,0 +1,10 @@ +# @package model +# Run with python -m train model/layer=sru +_name_: model +layer: + _name_: sru + d_model: 256 + return_output: true + feedback: true +n_layers: 3 +residual: H diff --git a/configs/model/layer/standalone.yaml b/configs/model/layer/standalone.yaml index 27780c4..305e630 100644 --- a/configs/model/layer/standalone.yaml +++ b/configs/model/layer/standalone.yaml @@ -1,15 +1,20 @@ +# Config for the standalone s4.py module for testing purposes +# Should be equivalent to the normal s4.yaml layer _name_: standalone d_state: 64 channels: 1 bidirectional: false activation: gelu -postact: null +postact: glu dropout: ${..dropout} # Same as null +mode: nplr measure: legs dt_min: 0.001 dt_max: 0.1 -trainable: true -lr: 0.001 -n_ssm: 1 +lr: + dt: 0.001 + A: 0.001 + B: 0.001 +n_ssm: null l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to 1 and kernel will automatically resize verbose: true diff --git a/configs/model/layer/transformer.yaml b/configs/model/layer/transformer.yaml index e828ea6..5245e03 100644 --- a/configs/model/layer/transformer.yaml +++ b/configs/model/layer/transformer.yaml @@ -6,7 +6,6 @@ add_zero_attn: False kdim: null vdim: null - batch_first: True - _name_: ff expand: 4 activation: gelu diff --git a/configs/model/layer/vit.yaml b/configs/model/layer/vit.yaml new file mode 100644 index 0000000..ec438a5 --- /dev/null +++ b/configs/model/layer/vit.yaml @@ -0,0 +1,7 @@ +_name_: vit +num_heads: 8 +qkv_bias: False +qk_scale: null +attn_drop: 0.0 +packed_linear: true +linear_cfg: null diff --git a/configs/model/nonaka/inception.yaml b/configs/model/nonaka/inception.yaml new file mode 100644 index 0000000..8dbbfcc --- /dev/null +++ b/configs/model/nonaka/inception.yaml @@ -0,0 +1,14 @@ +_name_: nonaka/inception +num_classes: 2 +input_channels: 8 +kernel_size: 40 +depth: 6 +bottleneck_size: 32 +nb_filters: 32 +use_residual: True +lin_ftrs_head: null +ps_head: 0.5 +bn_final_head: False +bn_head: True +act_head: "relu" +concat_pooling: True diff --git a/configs/model/nonaka/resnet.yaml b/configs/model/nonaka/resnet.yaml new file mode 100644 index 0000000..3e8febf --- /dev/null +++ b/configs/model/nonaka/resnet.yaml @@ -0,0 +1,16 @@ +_name_: nonaka/resnet18 +num_classes: 2 # Set to d_input +input_channels: 3 # Set to d_output +kernel_size: 3 +inplanes: 64 +fix_feature_dim: True +kernel_size_stem : null +stride_stem: 2 +pooling_stem: True +stride: 2 +lin_ftrs_head: null +ps_head: 0.5 +bn_final_head: False +bn_head: True +act_head: relu +concat_pooling: True diff --git a/configs/model/nonaka/xresnet.yaml b/configs/model/nonaka/xresnet.yaml new file mode 100644 index 0000000..1bab930 --- /dev/null +++ b/configs/model/nonaka/xresnet.yaml @@ -0,0 +1,15 @@ +_name_: nonaka/xresnet50 +input_channels: 3 +num_classes: 1000 +stem_szs: [32, 32, 64] +kernel_size: 5 +kernel_size_stem: 5 +widen: 1.0 +sa: False +# act_cls: nn.ReLU +lin_ftrs_head: null +ps_head: 0.5 +bn_final_head: False +bn_head: True +act_head: "relu" +concat_pooling: True diff --git a/configs/model/s4.yaml b/configs/model/s4.yaml index 8fe85fc..c1aa6be 100644 --- a/configs/model/s4.yaml +++ b/configs/model/s4.yaml @@ -1,15 +1,6 @@ -# _target_: models.sequence.SequenceModel defaults: - - layer: s4 -_name_: model -prenorm: False -transposed: True -n_layers: 4 -d_model: 256 -residual: R -pool: - _name_: sample - stride: 1 - expand: 1 -norm: layer -dropout: 0.0 + - base + - override layer: s4 + +transposed: false # Actually faster than "true" +tie_dropout: true diff --git a/configs/model/transformer.yaml b/configs/model/transformer.yaml index 4a32282..fe2b0cb 100644 --- a/configs/model/transformer.yaml +++ b/configs/model/transformer.yaml @@ -1,33 +1,8 @@ -# _target_: models.sequence.SequenceModel -_name_: model -layer: - - _name_: mha - n_heads: 8 - causal: True - dropout: null - bias: True - add_bias_kv: False - add_zero_attn: False - kdim: null - vdim: null - - _name_: ff - expand: 4 - dropout: null - transposed: False -n_layers: 16 -d_model: 512 -residual: R -prenorm: False -pool: - _name_: sample - pool: 1 - expand: 1 -norm: layer -dropout: 0.1 -# init: -# init: normal # Parameter initializer to use -# init_range: 0.1 # Parameters initialized by U(-init_range, init_range) -# init_std: 0.02 # Parameters initialized by N(0, init_std) +# Large Transformer model used as baseline for WikiText-103 +defaults: + - base + - override layer: transformer + encoder: _name_: position - dropout: 0.1 + dropout: ${..dropout} diff --git a/configs/pipeline/adding.yaml b/configs/pipeline/adding.yaml index 0ea0ee3..6eef2ac 100644 --- a/configs/pipeline/adding.yaml +++ b/configs/pipeline/adding.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - /trainer: default - - /loader: torch + - /loader: default - /dataset: adding - /task: regression - /optimizer: adamw diff --git a/configs/pipeline/cifar.yaml b/configs/pipeline/cifar.yaml index cf214d5..302d8d8 100644 --- a/configs/pipeline/cifar.yaml +++ b/configs/pipeline/cifar.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - /trainer: default - - /loader: torch + - /loader: default - /dataset: cifar - /task: multiclass_classification - /optimizer: adamw diff --git a/configs/pipeline/copying.yaml b/configs/pipeline/copying.yaml index e547bf1..fed89d0 100644 --- a/configs/pipeline/copying.yaml +++ b/configs/pipeline/copying.yaml @@ -3,3 +3,5 @@ defaults: - adding - override /dataset: copying - override /task: multiclass_classification + +encoder: embedding diff --git a/configs/pipeline/delay.yaml b/configs/pipeline/delay.yaml new file mode 100644 index 0000000..aea4ad1 --- /dev/null +++ b/configs/pipeline/delay.yaml @@ -0,0 +1,6 @@ +# @package _global_ +defaults: + - adding + - override /dataset: delay + +# encoder: embedding diff --git a/configs/pipeline/ema.yaml b/configs/pipeline/ema.yaml new file mode 100644 index 0000000..6dfdb9a --- /dev/null +++ b/configs/pipeline/ema.yaml @@ -0,0 +1,5 @@ +# @package _global_ +defaults: + - integrator + - override /dataset: ema + - override /scheduler: constant diff --git a/configs/pipeline/imdb.yaml b/configs/pipeline/imdb.yaml index 5598969..ed62bcb 100644 --- a/configs/pipeline/imdb.yaml +++ b/configs/pipeline/imdb.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - /trainer: default - - /loader: torch + - /loader: default - /dataset: imdb - /task: multiclass_classification - /optimizer: adamw diff --git a/configs/pipeline/informer.yaml b/configs/pipeline/informer.yaml index 75073a8..24546b7 100644 --- a/configs/pipeline/informer.yaml +++ b/configs/pipeline/informer.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - /trainer: default - - /loader: torch + - /loader: default - /dataset: etth - /task: regression - /optimizer: adamw diff --git a/configs/pipeline/reconstruct.yaml b/configs/pipeline/reconstruct.yaml new file mode 100644 index 0000000..cee41a8 --- /dev/null +++ b/configs/pipeline/reconstruct.yaml @@ -0,0 +1,6 @@ +# @package _global_ +defaults: + - adding + - override /dataset: reconstruct + +# encoder: embedding diff --git a/configs/pipeline/sc.yaml b/configs/pipeline/sc.yaml index d1fa5fb..7eabca0 100644 --- a/configs/pipeline/sc.yaml +++ b/configs/pipeline/sc.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - /trainer: default - - /loader: torch + - /loader: resolution - /dataset: sc - /task: multiclass_classification - /optimizer: adamw @@ -21,4 +21,4 @@ loader: train_resolution: 1 eval_resolutions: - 1 - - 2 \ No newline at end of file + - 2 diff --git a/configs/pipeline/wt103.yaml b/configs/pipeline/wt103.yaml index 5a1dd15..7571f99 100644 --- a/configs/pipeline/wt103.yaml +++ b/configs/pipeline/wt103.yaml @@ -11,7 +11,6 @@ train: mode: min task: - # _target_: tasks.tasks.AdaptiveLMTask _name_: adaptivelm init_scale: 0.5 # null to get transformer-xl init bias_scale: 1.0 diff --git a/configs/scheduler/cosine.yaml b/configs/scheduler/cosine.yaml index 574a849..9e77632 100644 --- a/configs/scheduler/cosine.yaml +++ b/configs/scheduler/cosine.yaml @@ -1,8 +1,8 @@ # @package _global_ train: - interval: step + interval: epoch scheduler: # _target_: torch.optim.lr_scheduler.CosineAnnealingLR _name_: cosine - T_max: 40000 # Max number of training steps for LR scheduler - eta_min: 0.001 # Min learning rate for cosine scheduler + T_max: 100 # Max number of epochs steps for LR scheduler + eta_min: 1e-6 # Min learning rate for cosine scheduler diff --git a/configs/scheduler/timm_cosine.yaml b/configs/scheduler/timm_cosine.yaml new file mode 100644 index 0000000..b212a4b --- /dev/null +++ b/configs/scheduler/timm_cosine.yaml @@ -0,0 +1,12 @@ +# @package _global_ +train: + interval: epoch + monitor: ??? # must be specified +scheduler: + _name_: timm_cosine + t_initial: 300 + lr_min: 1e-5 + cycle_decay: 0.1 # changed from decay_rate in timm 0.5.4 + warmup_lr_init: 1e-6 + warmup_t: 10 + cycle_limit: 1 \ No newline at end of file diff --git a/configs/task/forecasting.yaml b/configs/task/forecasting.yaml new file mode 100644 index 0000000..0c53562 --- /dev/null +++ b/configs/task/forecasting.yaml @@ -0,0 +1,3 @@ +_name_: forecasting +loss: mse +metrics: null \ No newline at end of file diff --git a/configs/task/lm.yaml b/configs/task/lm.yaml index b4927e8..9fc94c3 100644 --- a/configs/task/lm.yaml +++ b/configs/task/lm.yaml @@ -1,8 +1,5 @@ -_target_: tasks.tasks.GeneralTask -encoder: - _name_: linear -decoder: - _name_: feature -loss: cross_entropy -metrics: - - accuracy +_name_: lm +tied: false +rescale: true +# loss: cross_entropy # Handled by task: cross entropy loss +metrics: bpb # Bits per byte diff --git a/configs/task/multiclass_classification.yaml b/configs/task/multiclass_classification.yaml index 8956c3a..2bcf962 100644 --- a/configs/task/multiclass_classification.yaml +++ b/configs/task/multiclass_classification.yaml @@ -3,4 +3,4 @@ _name_: base loss: cross_entropy metrics: - accuracy -torchmetrics: null \ No newline at end of file +torchmetrics: null diff --git a/configs/task/multilabel_classification.yaml b/configs/task/multilabel_classification.yaml new file mode 100644 index 0000000..2e779e0 --- /dev/null +++ b/configs/task/multilabel_classification.yaml @@ -0,0 +1,9 @@ +# _target_: +_name_: base +loss: binary_cross_entropy +metrics: null +torchmetrics: + - AUROC + - Precision + - Recall + - F1 diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index 4a41ec5..da0b75c 100644 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -1,14 +1,11 @@ -# _target_: pytorch_lightning.Trainer - -gpus: 1 # set `1` to train on GPU, `0` to train on CPU only -accumulate_grad_batches: 1 +gpus: 1 # set `1` to train on GPU, `0` to train on CPU only +accumulate_grad_batches: 1 # Gradient accumulation every n batches max_epochs: 200 -# accelerator: ddp # controlled by train.ddp instead + # accelerator: ddp # Automatically set if gpus > 1 gradient_clip_val: 0.0 log_every_n_steps: 10 -limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run -limit_val_batches: 1.0 # train on full dataset, can be used to toggle quick run -weights_summary: top # Set to 'full' to see every layer +limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run +limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run +weights_summary: top # Set to 'full' to see every layer progress_bar_refresh_rate: 1 -track_grad_norm: -1 # Set to 2 to track norms of gradients -resume_from_checkpoint: null +track_grad_norm: -1 # Set to 2 to track norms of gradients diff --git a/configs/trainer/full.yaml b/configs/trainer/full.yaml index af79276..eff63f0 100644 --- a/configs/trainer/full.yaml +++ b/configs/trainer/full.yaml @@ -4,6 +4,7 @@ checkpoint_callback: True default_root_dir: null gradient_clip_val: 0.0 +gradient_clip_algorithm: "norm" # norm, value process_position: 0 num_nodes: 1 num_processes: 1 @@ -34,7 +35,6 @@ weights_summary: "top" weights_save_path: null num_sanity_val_steps: 2 truncated_bptt_steps: null -resume_from_checkpoint: null profiler: null benchmark: False deterministic: False diff --git a/configs/trainer/lm.yaml b/configs/trainer/lm.yaml index 9985b57..beb29e0 100644 --- a/configs/trainer/lm.yaml +++ b/configs/trainer/lm.yaml @@ -1,5 +1,3 @@ -# _target_: pytorch_lightning.Trainer - accumulate_grad_batches: 1 # accelerator: null # set to 'ddp' for distributed # amp_backend: native # 'native' | 'apex' @@ -11,7 +9,6 @@ precision: 16 progress_bar_refresh_rate: 1 weights_summary: top # Set to 'full' to see every layer track_grad_norm: -1 # Set to 2 to track norms of gradients -# limit_train_batches: 0.9999 # For some reason I get hanging issues on DDP, but removing one batch fixes it... limit_train_batches: 1.0 limit_val_batches: 1.0 # We use the dataloader from Transformer-XL to ensure adjacent minibatches @@ -19,4 +16,3 @@ limit_val_batches: 1.0 # So that dataloader has to deal with DDP, and we don't want PL to handle # that. replace_sampler_ddp: False -resume_from_checkpoint: null diff --git a/example.py b/example.py index c3dd8dd..c820c50 100644 --- a/example.py +++ b/example.py @@ -2,8 +2,8 @@ Train an S4 model on sequential CIFAR10 / sequential MNIST with PyTorch for demonstration purposes. This code borrows heavily from https://github.com/kuangliu/pytorch-cifar. -This file only depends on the standalone S4 layer -available in s4.py at src/models/sequence/ss/standalone. +This file only depends on the standalone S4 layer +available in src/models/s4/ * Train standard sequential CIFAR: python -m example @@ -16,7 +16,7 @@ This backbone is a good starting point for many problems, although some tasks (especially generation) may require using other backbones. -The default CIFAR10 model trained by this file should get +The default CIFAR10 model trained by this file should get 89+% accuracy on the CIFAR10 test set in 80 epochs. Each epoch takes approximately 7m20s on a T4 GPU (will be much faster on V100 / A100). @@ -32,26 +32,38 @@ import os import argparse -from src.models.sequence.ss.standalone.s4 import S4 +from src.models.s4.s4 import S4 +from src.models.s4.s4d import S4D from tqdm.auto import tqdm +# Dropout broke in PyTorch 1.11 +if tuple(map(int, torch.__version__.split('.')[:2])) == (1, 11): + print("WARNING: Dropout is bugged in PyTorch 1.11. Results may be worse.") + dropout_fn = nn.Dropout +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 12): + dropout_fn = nn.Dropout1d +else: + dropout_fn = nn.Dropout2d + + parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') # Optimizer parser.add_argument('--lr', default=0.01, type=float, help='Learning rate') parser.add_argument('--weight_decay', default=0.01, type=float, help='Weight decay') # Scheduler -parser.add_argument('--patience', default=10, type=float, help='Patience for learning rate scheduler') +# parser.add_argument('--patience', default=10, type=float, help='Patience for learning rate scheduler') +parser.add_argument('--epochs', default=100, type=float, help='Training epochs') # Dataset parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10'], type=str, help='Dataset') parser.add_argument('--grayscale', action='store_true', help='Use grayscale CIFAR10') # Dataloader parser.add_argument('--num_workers', default=4, type=int, help='Number of workers to use for dataloader') -parser.add_argument('--batch_size', default=128, type=int, help='Batch size') +parser.add_argument('--batch_size', default=64, type=int, help='Batch size') # Model parser.add_argument('--n_layers', default=4, type=int, help='Number of layers') -parser.add_argument('--d_model', default=512, type=int, help='Model dimension') -parser.add_argument('--dropout', default=0.2, type=float, help='Dropout') +parser.add_argument('--d_model', default=128, type=int, help='Model dimension') +parser.add_argument('--dropout', default=0.1, type=float, help='Dropout') parser.add_argument('--prenorm', action='store_true', help='Prenorm') # General parser.add_argument('--resume', '-r', action='store_true', help='Resume from checkpoint') @@ -75,7 +87,7 @@ def split_train_val(train, val_split): return train, val if args.dataset == 'cifar10': - + if args.grayscale: transform = transforms.Compose([ transforms.Grayscale(), @@ -94,15 +106,15 @@ def split_train_val(train, val_split): transform_train = transform_test = transform trainset = torchvision.datasets.CIFAR10( - root='./data', train=True, download=True, transform=transform_train) + root='./data/cifar/', train=True, download=True, transform=transform_train) trainset, _ = split_train_val(trainset, val_split=0.1) - + valset = torchvision.datasets.CIFAR10( - root='./data', train=True, download=True, transform=transform_test) + root='./data/cifar/', train=True, download=True, transform=transform_test) _, valset = split_train_val(valset, val_split=0.1) testset = torchvision.datasets.CIFAR10( - root='./data', train=False, download=True, transform=transform_test) + root='./data/cifar/', train=False, download=True, transform=transform_test) d_input = 3 if not args.grayscale else 1 d_output = 10 @@ -128,6 +140,7 @@ def split_train_val(train, val_split): d_input = 1 d_output = 10 +else: raise NotImplementedError # Dataloaders trainloader = torch.utils.data.DataLoader( @@ -140,11 +153,11 @@ def split_train_val(train, val_split): class S4Model(nn.Module): def __init__( - self, - d_input, - d_output=10, - d_model=256, - n_layers=4, + self, + d_input, + d_output=10, + d_model=256, + n_layers=4, dropout=0.2, prenorm=False, ): @@ -161,17 +174,10 @@ def __init__( self.dropouts = nn.ModuleList() for _ in range(n_layers): self.s4_layers.append( - S4( - d_model=d_model, - l_max=1024, - bidirectional=True, - postact='glu', - dropout=dropout, - transposed=True, - ) + S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, args.lr)) ) self.norms.append(nn.LayerNorm(d_model)) - self.dropouts.append(nn.Dropout2d(dropout)) + self.dropouts.append(dropout_fn(dropout)) # Linear decoder self.decoder = nn.Linear(d_model, d_output) @@ -181,7 +187,7 @@ def forward(self, x): Input x is shape (B, L, d_input) """ x = self.encoder(x) # (B, L, d_input) -> (B, L, d_model) - + x = x.transpose(-1, -2) # (B, L, d_model) -> (B, d_model, L) for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts): # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L) @@ -190,7 +196,7 @@ def forward(self, x): if self.prenorm: # Prenorm z = norm(z.transpose(-1, -2)).transpose(-1, -2) - + # Apply S4 block: we ignore the state input and output z, _ = layer(z) @@ -217,17 +223,16 @@ def forward(self, x): # Model print('==> Building model..') model = S4Model( - d_input=d_input, - d_output=d_output, - d_model=args.d_model, - n_layers=args.n_layers, + d_input=d_input, + d_output=d_output, + d_model=args.d_model, + n_layers=args.n_layers, dropout=args.dropout, prenorm=args.prenorm, ) model = model.to(device) if device == 'cuda': - model = torch.nn.DataParallel(model) cudnn.benchmark = True if args.resume: @@ -239,34 +244,30 @@ def forward(self, x): best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] -def setup_optimizer(model, lr, weight_decay, patience): +def setup_optimizer(model, lr, weight_decay, epochs): """ S4 requires a specific optimizer setup. - The S4 layer (A, B, C, dt) parameters typically - require a smaller learning rate (typically 0.001), with no weight decay. + The S4 layer (A, B, C, dt) parameters typically + require a smaller learning rate (typically 0.001), with no weight decay. - The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01) + The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01) and weight decay (if desired). """ # All parameters in the model all_parameters = list(model.parameters()) - + # General parameters don't contain the special _optim key params = [p for p in all_parameters if not hasattr(p, "_optim")] # Create an optimizer with the general parameters - optimizer = optim.AdamW( - params, - lr=lr, - weight_decay=weight_decay, - ) + optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay) # Add parameters with special hyperparameters hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")] hps = [ - dict(s) for s in set(frozenset(hp.items()) for hp in hps) + dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps))) ] # Unique dicts for hp in hps: params = [p for p in all_parameters if getattr(p, "_optim", None) == hp] @@ -275,9 +276,10 @@ def setup_optimizer(model, lr, weight_decay, patience): ) # Create a lr scheduler - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2) - - # Print optimizer info + # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) + + # Print optimizer info keys = sorted(set([k for hp in hps for k in hp.keys()])) for i, g in enumerate(optimizer.param_groups): group_hps = {k: g.get(k, None) for k in keys} @@ -290,7 +292,7 @@ def setup_optimizer(model, lr, weight_decay, patience): criterion = nn.CrossEntropyLoss() optimizer, scheduler = setup_optimizer( - model, lr=args.lr, weight_decay=args.weight_decay, patience=args.patience + model, lr=args.lr, weight_decay=args.weight_decay, epochs=args.epochs ) ############################################################################### @@ -318,7 +320,7 @@ def train(): correct += predicted.eq(targets).sum().item() pbar.set_description( - 'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' % + 'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total) ) @@ -342,7 +344,7 @@ def eval(epoch, dataloader, checkpoint=False): correct += predicted.eq(targets).sum().item() pbar.set_description( - 'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' % + 'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total) ) @@ -362,7 +364,7 @@ def eval(epoch, dataloader, checkpoint=False): return acc -pbar = tqdm(range(start_epoch, start_epoch+200)) +pbar = tqdm(range(start_epoch, args.epochs)) for epoch in pbar: if epoch == 0: pbar.set_description('Epoch: %d' % (epoch)) @@ -371,5 +373,6 @@ def eval(epoch, dataloader, checkpoint=False): train() val_acc = eval(epoch, valloader, checkpoint=True) eval(epoch, testloader) - scheduler.step(val_acc) - + scheduler.step() + # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}") + diff --git a/experiments.md b/experiments.md new file mode 100644 index 0000000..fa3dae6 --- /dev/null +++ b/experiments.md @@ -0,0 +1,106 @@ +This README provides configs for various experiments in the S4 papers. + +As documented in the main README, adding `wandb=null` to any command line turns off logging. + +Some of these datasets may require downloading and preparing data, documented in the [src/dataloaders](./src/dataloaders/) subdirectory. + +## Long Range Arena (LRA) + +The latest LRA results are reported in the [HTTYH](https://arxiv.org/abs/2206.12037) paper, which achieves over 86% average. + +``` +python -m train experiment=lra/s4-lra-listops +python -m train experiment=lra/s4-lra-imdb +python -m train experiment=lra/s4-lra-cifar +python -m train experiment=lra/s4-lra-aan +python -m train experiment=lra/s4-lra-pathfinder +python -m train experiment=lra/s4-lra-pathx +``` + +To help reproduce results and sanity check, this table lists approximate final performance, intermediate performance, and timing information. + + +| | listops | imdb | aan | cifar | pathfinder | pathx | +| --- | --- | --- | --- | --- | --- | --- | +| **Final Accuracy** | 59.5 | 86.5 | 91.0 | 88.5 | 94.0 | 96.0 | +| **acc @ epoch** | 50 @ 10 | 80 @ 10 | 80 @ 10 | 80 @ 20 | 90 @ 20 | 92 @ 10 | +| **time / epoch (GPU)** | 15m (T4) | 17m (T4) | 23m (A100) | 2m (A100) | 7m (A100) | 56m (A100) | + +### V1 +The configs for the original version of the S4 paper (ICLR 2022) can be run with the following commands. +``` +python -m train experiment=lra/old/s4-lra-listops +python -m train experiment=lra/old/s4-lra-imdb +python -m train experiment=lra/old/s4-lra-cifar +python -m train experiment=lra/old/s4-lra-aan +python -m train experiment=lra/old/s4-lra-pathfinder +python -m train experiment=lra/old/s4-lra-pathx +``` + +NOTE: These configs are meant for the first version of the S4 model, which is saved in a tag: `git checkout v1` + +## CIFAR-10 + +``` +python -m train experiment=cifar/s4-cifar +``` + +The above command line reproduces our best sequential CIFAR model. +Note that it is possible to get fairly good results with much smaller models. +The small [ablation models](#s4d-ablations) are one example, and the +[example.py](../example.py) script is another example. + +## Speech Commands (SC) + +The latest SC config reported in the S4D paper can be run with +``` +python -m train experiment=sc/s4-sc +``` + +### SC10 +The original S4 paper used a smaller 10-way classification task used in [prior](https://arxiv.org/abs/2005.08926) [work](https://arxiv.org/abs/2102.02611). + +This version can be toggled either with `dataset=sc dataset.all_classes=false` or `dataset=sc10`. + +The original S4 config can be run using V1 of this code using +``` +python -m train experiment=old/s4-sc +``` + +## WikiText-103 + +V3 re-trained the WikiText-103 experiment with the latest model and a larger context size. +The trained checkpoint can be found at [TODO]. +``` +python -m train experiment=lm/s4-wt103 +``` + +The default settings require 8 GPUs with 40GB memory. Modifications can be made by decreasing batch size and accumulating gradients, e.g. add `loader.batch_size=4 trainer.accumulate_grad_batches=2` to the command line. + +Autoregressive generation can be performed with this checkpoint following the instructions in the main [README](README.md#generation) + +## Time Series Forecasting + +The ETTH, ETTM, Weather, and ECL experiments originally from the [Informer]() paper are supported. +Download the [data](https://drive.google.com/file/d/1XqpxE6cthIxKYviSmR703yU45vdQ1oHT/view?usp=sharing) to `./data`, and unzip `informer.zip` inside that folder. + +``` +python -m train experiment=forecasting/s4-informer-{etth,ettm,ecl,weather} +``` + +## S4D Ablations + +The [S4D](https://arxiv.org/abs/2206.11893) paper uses small models on varied tasks to perform extensive ablations. +``` +python -m train experiment=cifar/s4-cifar-ablation +python -m train experiment=bidmc/s4-bidmc-ablation +python -m train experiment=sc/s4-sc-ablation +``` + +## SaShiMi + +Configs for models and baselines from the SaShiMi paper can be found under [configs/audio/](configs/audio/) and run with +``` +python -m train experiment=audio/{sashimi,samplernn,wavenet}-{sc09,youtubemix,beethoven} +``` +More documentation can be found in the [SaShiMi README](sashimi/README.md). diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..bda0b64 --- /dev/null +++ b/generate.py @@ -0,0 +1,202 @@ +import argparse +import os + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +import torchaudio + +import hydra +from omegaconf import OmegaConf +from torch.distributions import Categorical +from tqdm.auto import tqdm + +from src import utils +from src.dataloaders.audio import mu_law_decode +from src.models.baselines.samplernn import SampleRNN +from src.models.baselines.wavenet import WaveNetModel +from src.models.sequence.ss.s4 import S4 +from train import SequenceLightningModule + +@torch.inference_mode() +def generate( + model, + batch, + tau=1.0, + l_prefix=0, + T=None, + debug=False, + top_p=1.0, + benchmark=False, + return_logprobs=False, +): + x, _, *_ = batch # (B, L) + x = x.to('cuda') + T = x.shape[1] if T is None else T + + # Set up the initial state + model._reset_state(batch, device='cuda') + + # First sample + x_t = x[:, 0] + y_all = [] + logprobs = np.zeros(x.shape[0]) + entropy = np.zeros(x.shape[0]) + + if debug: + y_raw = [] + + # Generation loop + for t in tqdm(range(T)): + + # Step through the model with the current sample + y_t = model.step(x_t) + + # Handle special loss functions such as ProjectedAdaptiveSoftmax + if hasattr(model.loss, "compute_logits"): y_t = model.loss.compute_logits(y_t) + + if debug: + y_raw.append(y_t.detach().cpu()) + + # Output distribution + probs = F.softmax(y_t, dim=-1) + + # Optional: nucleus sampling + if top_p < 1.0: + sorted_probs = probs.sort(dim=-1, descending=True) + csum_probs = sorted_probs.values.cumsum(dim=-1) > top_p + csum_probs[..., 1:] = csum_probs[..., :-1].clone() + csum_probs[..., 0] = 0 + indices_to_remove = torch.zeros_like(csum_probs) + indices_to_remove[torch.arange(sorted_probs.indices.shape[0])[:, None].repeat(1, sorted_probs.indices.shape[1]).flatten(), sorted_probs.indices.flatten()] = csum_probs.flatten() + y_t = y_t + indices_to_remove.int() * (-1e20) + + # Sample from the distribution + y_t = Categorical(logits=y_t/tau).sample() + + # Feed back to the model + if t < l_prefix-1: + x_t = x[:, t+1] + else: + x_t = y_t + + # Calculate the log-likelihood + if return_logprobs: + probs = probs.squeeze(1) + if len(y_t.shape) > 1: + logprobs += torch.log(probs[torch.arange(probs.shape[0]), y_t.squeeze(1)]).cpu().numpy() + else: + logprobs += torch.log(probs[torch.arange(probs.shape[0]), y_t]).cpu().numpy() + entropy += -(probs * (probs + 1e-6).log()).sum(dim=-1).cpu().numpy() + + y_all.append(x_t.cpu()) + # y_all.append(y_t.cpu()) + + y_all = torch.stack(y_all, dim=1) + + if isinstance(model.model, WaveNetModel) and not benchmark: + y_all = y_all[model.model.receptive_field:] + + if not return_logprobs: + if debug: + y_raw = torch.stack(y_raw) + return y_all, y_raw + return y_all + else: + assert not debug + return y_all, logprobs, entropy + + +@hydra.main(config_path="configs", config_name="generate.yaml") +def main(config: OmegaConf): + ### See configs/generate.yaml for descriptions of generation flags ### + + # Load train config from existing Hydra experiment + if config.experiment_path is not None: + experiment_config = OmegaConf.load(os.path.join(config.experiment_path, '.hydra', 'config.yaml')) + # config = OmegaConf.merge(config, experiment_config) + config.model = experiment_config.model + config.task = experiment_config.task + config.encoder = experiment_config.encoder + config.decoder = experiment_config.decoder + config.dataset = experiment_config.dataset + config.loader = experiment_config.loader + + # Special override flags + if not config.load_data: + OmegaConf.update(config, "train.disable_dataset", True) + + OmegaConf.update(config, "loader.batch_size", config.n_samples) + + # Create the Lightning Module - same as train.py + + config = utils.train.process_config(config) + utils.train.print_config(config, resolve=True) + + print("Loading model...") + assert torch.cuda.is_available(), 'Use a GPU for generation.' + + if config.train.seed is not None: + pl.seed_everything(config.train.seed, workers=True) + + if not config.experiment_path: + ckpt_path = config.checkpoint_path + else: + ckpt_path = os.path.join(config.experiment_path, config.checkpoint_path) + model = SequenceLightningModule.load_from_checkpoint(ckpt_path, config=config) + model.to('cuda') + + # Setup: required for S4 modules in SaShiMi + for module in model.modules(): + if hasattr(module, 'setup_step'): module.setup_step() + model.eval() + + if config.load_data: + # Get the eval dataloaders + eval_dataloaders = model.val_dataloader() + dl = eval_dataloaders[0] if config.split == 'val' else eval_dataloaders[1] + + # Construct a batch + x, _, *_ = next(iter(dl)) + batch = (x.repeat(config.n_reps, 1), None, None) + else: + assert config.l_prefix == 0, 'Only unconditional generation when data is not loaded.' + batch = (torch.zeros(config.n_samples * config.n_reps, 1).to(torch.long) + 128, None, None) + + # Handle save directory intelligently + if config.save_dir: + save_dir = hydra.utils.to_absolute_path(config.save_dir) + else: + save_dir = os.path.join(os.getcwd(), "samples/") + os.makedirs(save_dir, exist_ok=True) + + # Generate + y, logprobs, _ = generate( + model, # lightning module (SequenceLightningModule from `train.py`) + batch, # pass data to condition the generation + l_prefix=config.l_prefix, # length of conditioning prefix + T=config.l_sample, # length of generated sequence + top_p=config.top_p, # nucleus sampling: always set to 1.0 for SaShiMi experiments + tau=config.temp, # temperature: always set to 1.0 for SaShiMi experiments + return_logprobs=True, # calc exact likelihoods + ) + # Sort based on likelihoods and save + y = y[np.argsort(logprobs.flatten())] + + # Decode quantization + if config.decode == 'audio': + print("Saving samples into:", save_dir) + y = mu_law_decode(y) + for i, d in enumerate(y): + filename = f'{save_dir}/unconditional_{config.dataset._name_}_{config.model._name_}_len_{config.l_sample/16000.:.2f}s_gen_{i+1}.wav' + torchaudio.save(filename, d.unsqueeze(0), 16000) + np.save(f'{save_dir}/unconditional_{config.dataset._name_}_{config.model._name_}_len_{config.l_sample/16000.:.2f}s_logprobs.npy', logprobs) + elif config.decode == 'text': + y = [model.dataset.vocab.get_symbols(_y) for _y in y] + breakpoint() + else: pass + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index c2c38a1..3da2423 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,20 +2,26 @@ numpy scipy pandas sklearn -sktime matplotlib tqdm rich -pytorch-lightning +pytorch-lightning==1.5.10 hydra-core -munch omegaconf wandb einops opt_einsum -cmake -torchtext -datasets -transformers -pytorch-fast-transformers -# pykeops==1.5 # Install to enable fast Cauchy kernel through pykeops +cmake # For pykeops support +pykeops # If there are installation problems with pykeops==2.x, try pykeops==1.5 +transformers # For some schedulers + +# Model specific packges +# pytorch-fast-transformers # for Performer + +# Dataset specific packages +torchtext # LRA +datasets # LRA +# gluonts +# timm +# lightning-bolts +# sktime # BIDMC diff --git a/sashimi/README.md b/sashimi/README.md index 3930d1f..2cc6e8c 100644 --- a/sashimi/README.md +++ b/sashimi/README.md @@ -7,6 +7,7 @@ ## Table of Contents - [Standalone Implementation](#standalone-implementation) +- [SaShiMi+DiffWave](#diffwave) - [Datasets](#datasets) - [Model Training](#model-training) - [Audio Generation](#audio-generation) @@ -33,17 +34,18 @@ Samples of SaShiMi and baseline audio can be found [online](https://hazyresearch.stanford.edu/sashimi-examples). ## Standalone Implementation -We provide a standalone PyTorch implementation of the SaShiMi architecture backbone in `state-spaces/sashimi/sashimi.py`, which you can use in your own code. Note that you'll need to also copy over the standalone S4 layer implementation, which can be found at `state-spaces/src/models/sequence/ss/standalone/s4.py`. +We provide a standalone PyTorch implementation of the SaShiMi architecture backbone in [sashimi/sashimi.py](sashimi.py), which you can use in your own code. Note that you'll need to also copy over the standalone S4 layer implementation, which can be found at [src/models/s4/s4.py](../src/models/s4/). +Note that our experiments do not use this standalone and instead use the modular model construction detailed in [src/models/README.md](../src/models/), so this standalone is less tested; if running experiments from this codebase, it is recommended to use the normal model ([Model Training](#model-training)). You can treat the SaShiMi module as a sequence-to-sequence map taking `(batch, seq, dim)` inputs to `(batch, seq, dim)` outputs i.e. ```python sashimi = Sashimi().cuda() # refer to docstring for arguments x = torch.randn(batch_size, seq_len, dim).cuda() # Run forward -y, _ = sashimi(x) # y.shape == x.shape +y = sashimi(x) # y.shape == x.shape ``` -If you use SaShiMi for autoregressive generation, you can convert it to a recurrent model at inference time and then step it to generate samples one at a time. +If you use SaShiMi for autoregressive generation, you can convert it to a recurrent model at inference time and then step it to generate samples one at a time. See also the main [README](README.md#generation) for the main generation script. ```python with torch.no_grad(): sashimi.eval() @@ -60,14 +62,9 @@ with torch.no_grad(): ys = torch.stack(ys, dim=1) # ys.shape == x.shape ``` -We also include a modified DiffWave SaShiMi backbone for diffusion models. This can be enabled by simply passing in the `diffwave=True` argument to the `Sashimi` constructor (refer to the model docstring for more details). -```python -sashimi = Sashimi(diffwave=True, unet=True).cuda() # we recommend turning on the unet option -x = torch.randn(batch_size, seq_len, dim).cuda() -diffusion_steps = torch.randint(0, max_steps, (batch_size, 1)).cuda() -# Run forward -y, _ = sashimi((x, diffusion_steps)) # y.shape == x.shape -``` +## DiffWave + +The DiffWave and DiffWave+SaShiMi experiments used an alternative pipeline to handle diffusion logic, and is supported in another codebase located here: https://github.com/albertfgu/diffwave-sashimi ## Datasets You can download the Beethoven, YouTubeMix and SC09 datasets from the following links on the Huggingface Hub. Details about the datasets can be found in the README files on the respective dataset pages. @@ -84,23 +81,23 @@ Details about the training-validation-test splits used are also included in the SaShiMi models rely on the same training framework as S4 (see the [README](../README.md) for details). To reproduce our results or train new SaShiMi models, you can use the following commands: ```bash # Train SaShiMi models on YouTubeMix, Beethoven and SC09 -python -m train experiment=sashimi-youtubemix wandb=null -python -m train experiment=sashimi-beethoven wandb=null -python -m train experiment=sashimi-sc09 wandb=null +python -m train experiment=audio/sashimi-youtubemix wandb=null +python -m train experiment=audio/sashimi-beethoven wandb=null +python -m train experiment=audio/sashimi-sc09 wandb=null ``` If you encounter GPU OOM errors on either Beethoven or YouTubeMix, we recommend reducing the sequence length used for training by setting `dataset.sample_len` to a lower value e.g. `dataset.sample_len=120000`. For SC09, we recommend reducing batch size if GPU memory is an issue, by setting `loader.batch_size` to a lower value. We also include implementations of SampleRNN and WaveNet models, which can be trained easily using the following commands: ```bash # Train SampleRNN models on YouTubeMix, Beethoven and SC09 -python -m train experiment=samplernn-youtubemix wandb=null -python -m train experiment=samplernn-beethoven wandb=null -python -m train experiment=samplernn-sc09 wandb=null +python -m train experiment=audio/samplernn-youtubemix wandb=null +python -m train experiment=audio/samplernn-beethoven wandb=null +python -m train experiment=audio/samplernn-sc09 wandb=null # Train WaveNet models on YouTubeMix, Beethoven and SC09 -python -m train experiment=wavenet-youtubemix wandb=null -python -m train experiment=wavenet-beethoven wandb=null -python -m train experiment=wavenet-sc09 wandb=null +python -m train experiment=audio/wavenet-youtubemix wandb=null +python -m train experiment=audio/wavenet-beethoven wandb=null +python -m train experiment=audio/wavenet-sc09 wandb=null ``` Audio generation models are generally slow to train, e.g. YouTubeMix SaShiMi models take up to a week to train on a single V100 GPU. @@ -108,20 +105,23 @@ Audio generation models are generally slow to train, e.g. YouTubeMix SaShiMi mod ## Audio Generation +To generate audio, use the `state-spaces/generation.py` script. +More instructions can be found in the main [README](../README.md#generation). + ### Download Checkpoints We provide checkpoints for SaShiMi, SampleRNN and WaveNet on YouTubeMix and SC09 on the [Huggingface Hub](https://huggingface.co/krandiash/sashimi-release). The checkpoint files are named `checkpoints/_.pt` and are provided for use with our generation script at `state-spaces/sashimi/generation.py`. ### Unconditional Generation -To generate audio, you can use the `state-spaces/sashimi/generation.py` script. First, put the checkpoints you downloaded at `state-spaces/sashimi/checkpoints/`. +First, put the checkpoints you downloaded at `state-spaces/checkpoints/`. -Then, run the following command to generate audio (see the `--help` flag for more details): +Then, run the following command to generate audio ```bash -python -m sashimi.generation --model --dataset --sample_len +python -m generate experiment=- l_sample= load_data=false ``` For example, to generate 32 unconditional samples of 1 second 16kHz audio from the SaShiMi model on YouTubeMix, run the following command: ```bash -python -m sashimi.generation --model sashimi --dataset youtubemix --n_samples 32 --sample_len 16000 +python -m generate experiment=sashimi-youtubemix n_samples=32 l_sample=16000 load_data=false ``` The generated `.wav` files will be saved to `sashimi/samples/`. You can generate audio for all models and datasets in a similar way. @@ -132,7 +132,7 @@ The generated `.wav` files will be saved to `sashimi/samples/`. You can generate ### Conditional Generation You can also generate conditional samples, e.g. to generate 32 samples conditioned on 0.5 seconds of audio from the SaShiMi model on YouTubeMix, run the following command: ```bash -python -m sashimi.generation --model sashimi --dataset youtubemix --n_samples 8 --n_reps 4 --sample_len 16000 --prefix 8000 --load_data +python -m generate experiment=sashimi-youtubemix n_samples=8 n_reps=4 l_sample=16000 l_prefix=8000 ``` The `prefix` flag specifies the number of steps to condition on. The script selects the first `n_samples` examples of the specified `split` (defaults to `val`) of the dataset. `n_reps` specifies how many generated samples will condition on a prefix from a single example (i.e. the total number of generated samples is `n_samples x n_reps`). @@ -462,4 +462,4 @@ To post HITs, the steps are: To access results, you can go to the "Manage" page and you should see "Batches in progress" populated with the HITs you posted. You should be able to download a CSV file of the results for each batch. You can plug this into the Jupyter Notebooks we include under `state-spaces/sashimi/mturk/mos/` to get final results. -> _Note:_ for SC09 HITs, we strongly recommend posting HITs for all models at the same time to reduce the possibility of scoring discrepancies due to different populations of workers. \ No newline at end of file +> _Note:_ for SC09 HITs, we strongly recommend posting HITs for all models at the same time to reduce the possibility of scoring discrepancies due to different populations of workers. diff --git a/sashimi/sashimi.py b/sashimi/sashimi.py index 0d69360..f7b67ed 100644 --- a/sashimi/sashimi.py +++ b/sashimi/sashimi.py @@ -2,89 +2,21 @@ SaShiMi backbone. Use this backbone in your own models. You'll also need to copy over the -standalone S4 layer, which can be found at -`state-spaces/src/models/sequence/ss/standalone/s4.py`. +standalone S4 layer, which can be found at `state-spaces/src/models/s4/` It's Raw! Audio Generation with State-Space Models -Karan Goel, Albert Gu, Chris Donahue, Christopher Re. +Karan Goel, Albert Gu, Chris Donahue, Christopher Re. """ import sys -import warnings sys.path.append('../') -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from src.models.sequence.ss.standalone.s4 import LinearActivation, S4 - - -def swish(x): - return x * torch.sigmoid(x) - - -def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): - """ - Embed a diffusion step $t$ into a higher dimensional space - E.g. the embedding vector in the 128-dimensional space is - [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] - - Taken from https://github.com/philsyn/DiffWave-Vocoder - - Parameters: - diffusion_steps (torch.long tensor, shape=(batchsize, 1)): - diffusion steps for batch data - diffusion_step_embed_dim_in (int, default=128): - dimensionality of the embedding space for discrete diffusion steps - - Returns: - the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): - """ - - assert diffusion_step_embed_dim_in % 2 == 0 - - half_dim = diffusion_step_embed_dim_in // 2 - _embed = np.log(10000) / (half_dim - 1) - _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda() - _embed = diffusion_steps * _embed - diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1) - - return diffusion_step_embed - -class Conv(nn.Module): - """ - Dilated conv layer with kaiming_normal initialization - from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py - """ - def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): - super(Conv, self).__init__() - self.padding = dilation * (kernel_size - 1) // 2 - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) - self.conv = nn.utils.weight_norm(self.conv) - nn.init.kaiming_normal_(self.conv.weight) - - def forward(self, x): - out = self.conv(x) - return out - - -class ZeroConv1d(nn.Module): - """ - Conv1x1 layer with zero initialization - From https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed - """ - def __init__(self, in_channel, out_channel): - super(ZeroConv1d, self).__init__() - self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0) - self.conv.weight.data.zero_() - self.conv.bias.data.zero_() - - def forward(self, x): - out = self.conv(x) - return out +from src.models.s4.s4 import LinearActivation, S4 class DownPool(nn.Module): def __init__(self, d_input, expand, pool): @@ -96,10 +28,9 @@ def __init__(self, d_input, expand, pool): d_input * pool, self.d_output, transposed=True, - weight_norm=True, ) - def forward(self, x, **kwargs): + def forward(self, x): x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.pool) x = self.linear(x) return x, None @@ -125,27 +56,25 @@ def default_state(self, *args, **kwargs): class UpPool(nn.Module): - def __init__(self, d_input, expand, pool, causal=True): + def __init__(self, d_input, expand, pool): super().__init__() self.d_output = d_input // expand self.pool = pool - self.causal = causal self.linear = LinearActivation( d_input, self.d_output * pool, transposed=True, - weight_norm=True, ) - def forward(self, x, **kwargs): + def forward(self, x, skip=None): x = self.linear(x) - - if self.causal: - # Shift to ensure causality - x = F.pad(x[..., :-1], (1, 0)) + x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.pool) + + if skip is not None: + x = x + skip return x, None def step(self, x, state, **kwargs): @@ -184,7 +113,7 @@ def __init__(self, d_model, expand=2, dropout=0.0): super().__init__() input_linear = LinearActivation( - d_model, + d_model, d_model * expand, transposed=True, activation='gelu', @@ -193,7 +122,7 @@ def __init__(self, d_model, expand=2, dropout=0.0): dropout = nn.Dropout2d(dropout) if dropout > 0.0 else nn.Identity() output_linear = LinearActivation( d_model * expand, - d_model, + d_model, transposed=True, activation=None, activate=False, @@ -205,7 +134,7 @@ def __init__(self, d_model, expand=2, dropout=0.0): output_linear, ) - def forward(self, x, **kwargs): + def forward(self, x): return self.ff(x), None def default_state(self, *args, **kwargs): @@ -219,8 +148,8 @@ def step(self, x, state, **kwargs): class ResidualBlock(nn.Module): def __init__( - self, - d_model, + self, + d_model, layer, dropout=0.0, ): @@ -239,15 +168,15 @@ def __init__( self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout2d(dropout) if dropout > 0.0 else nn.Identity() - def forward(self, x, **kwargs): + def forward(self, x): """ Input x is shape (B, d_input, L) """ z = x - + # Prenorm z = self.norm(z.transpose(-1, -2)).transpose(-1, -2) - + # Apply layer: we ignore the state input and output for training z, _ = self.layer(z) @@ -277,159 +206,54 @@ def step(self, x, state, **kwargs): return x, state -class DiffWaveS4Block(nn.Module): - """ - Modified DiffWave block that uses S4. - - Taken from https://github.com/philsyn/DiffWave-Vocoder - """ - def __init__(self, - d_model, - diffusion_step_embed_dim_out=512, - unconditional=False, - mel_upsample=[16, 16], - ): - super().__init__() - self.d_model = d_model - - # the layer-specific fc for diffusion step embedding - self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.d_model) - - self.layer = S4( - d_model, - bidirectional=True, - hurwitz=True, # use the Hurwitz parameterization for stability - tie_state=True, # tie SSM parameters across d_state in the S4 layer - trainable={ - 'dt': True, - 'A': True, - 'P': True, - 'B': True, - }, # train all internal S4 parameters - ) - self.norm = nn.LayerNorm(d_model) - - self.unconditional = unconditional - if not self.unconditional: - # add mel spectrogram upsampler and conditioner conv1x1 layer - self.upsample_conv2d = torch.nn.ModuleList() - for s in mel_upsample: - conv_trans2d = torch.nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s)) - conv_trans2d = torch.nn.utils.weight_norm(conv_trans2d) - torch.nn.init.kaiming_normal_(conv_trans2d.weight) - self.upsample_conv2d.append(conv_trans2d) - self.mel_conv = Conv(80, self.d_model, kernel_size=1) # 80 is mel bands - - def forward(self, x, diffusion_step_embed, mel_spec=None): - y = x - B, C, L = x.shape - assert C == self.d_model - - y = self.norm(y.transpose(-1, -2)).transpose(-1, -2) - - # add in diffusion step embedding - part_t = self.fc_t(diffusion_step_embed) - y = y + part_t.unsqueeze(-1) - - # S4 layer - y, _ = self.layer(y) - - # add mel spectrogram as (local) conditioner - if mel_spec is not None: - assert not self.unconditional - # Upsample spectrogram to size of audio - mel_spec = torch.unsqueeze(mel_spec, dim=1) - mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4) - mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4) - mel_spec = torch.squeeze(mel_spec, dim=1) - - assert(mel_spec.size(2) >= L) - if mel_spec.size(2) > L: - mel_spec = mel_spec[:, :, :L] - - mel_spec = self.mel_conv(mel_spec) - y = y + mel_spec - - # Residual - y = x + y - - return y, None - class Sashimi(nn.Module): def __init__( self, - d_model=64, - n_layers=8, - pool=[4, 4], - expand=2, - ff=2, + d_model=64, + n_layers=8, + pool=[4, 4], + expand=2, + ff=2, bidirectional=False, glu=True, unet=False, - diffwave=False, dropout=0.0, - **kwargs, ): """ - SaShiMi model backbone. + SaShiMi model backbone. Args: d_model: dimension of the model. We generally use 64 for all our experiments. - n_layers: number of (Residual (S4) --> Residual (FF)) blocks at each pooling level. - We use 8 layers for our experiments, although we found that increasing layers even further generally + n_layers: number of (Residual (S4) --> Residual (FF)) blocks at each pooling level. + We use 8 layers for our experiments, although we found that increasing layers even further generally improves performance at the expense of training / inference speed. - pool: pooling factor at each level. Pooling shrinks the sequence length at lower levels. + pool: pooling factor at each level. Pooling shrinks the sequence length at lower levels. We experimented with a pooling factor of 4 with 1 to 4 tiers of pooling and found 2 tiers to be best. It's possible that a different combination of pooling factors and number of tiers may perform better. expand: expansion factor when pooling. Features are expanded (i.e. the model becomes wider) at lower levels of the architecture. We generally found 2 to perform best (among 2, 4). ff: expansion factor for the FF inverted bottleneck. We generally found 2 to perform best (among 2, 4). - bidirectional: use bidirectional S4 layers. Bidirectional layers are suitable for use with non-causal models + bidirectional: use bidirectional S4 layers. Bidirectional layers are suitable for use with non-causal models such as diffusion models like DiffWave. glu: use gated linear unit in the S4 layers. Adds parameters and generally improves performance. - unet: use a unet-like architecture, adding (Residual (S4) --> Residual (FF)) layers before downpooling. + unet: use a unet-like architecture, adding (Residual (S4) --> Residual (FF)) layers before downpooling. All else fixed, this slows down inference (and slightly slows training), but generally improves performance. We use this variant when dropping in SaShiMi into diffusion models, and this should generally be preferred for non-autoregressive models. - diffwave: switch to DiffWave model with SaShiMi backbone. We use this variant for our diffusion - models. Note that S4 is bidirectional by default in this variant, and we recommend switching - on the `unet` argument as well. Additional kwargs for - - `diffusion_step_embed_dim_in` (default 128) - - `diffusion_step_embed_dim_mid` (default 512) - - `diffusion_step_embed_dim_out` (default 512) - - `unconditional` (default False) - - `mel_upsample` (default [16, 16]) - can be passed in to control the SaShiMi diffusion model. dropout: dropout rate. Default to 0.0, since we haven't found settings where SaShiMi overfits. """ super().__init__() self.d_model = H = d_model self.unet = unet - self.diffwave = diffwave - - # Bidirectional S4 layers are always used in DiffWave - bidirectional = bidirectional or diffwave - - if self.diffwave and not self.unet: - warnings.warn("DiffWave is not recommended without UNet. Consider using UNet instead.") def s4_block(dim): layer = S4( - d_model=dim, + d_model=dim, d_state=64, bidirectional=bidirectional, postact='glu' if glu else None, dropout=dropout, transposed=True, - # hurwitz=True, # use the Hurwitz parameterization for stability - # tie_state=True, # tie SSM parameters across d_state in the S4 layer - trainable={ - 'dt': True, - 'A': True, - 'P': True, - 'B': True, - }, # train all internal S4 parameters - ) return ResidualBlock( d_model=dim, @@ -449,38 +273,6 @@ def ff_block(dim): dropout=dropout, ) - if diffwave: - # Setup for DiffWave SaShiMi model - # Borrows code from https://github.com/philsyn/DiffWave-Vocoder - - self.diffusion_step_embed_dim_in = kwargs.get('diffusion_step_embed_dim_in', 128) - self.diffusion_step_embed_dim_mid = kwargs.get('diffusion_step_embed_dim_mid', 512) - self.diffusion_step_embed_dim_out = kwargs.get('diffusion_step_embed_dim_out', 512) - in_channels = 1 - out_channels = 1 - - # Initial conv1x1 with relu - self.init_conv = nn.Sequential(Conv(in_channels, d_model, kernel_size=1), nn.ReLU()) - - # the shared two fc layers for diffusion step embedding - self.fc_t1 = nn.Linear(self.diffusion_step_embed_dim_in, self.diffusion_step_embed_dim_mid) - self.fc_t2 = nn.Linear(self.diffusion_step_embed_dim_mid, self.diffusion_step_embed_dim_out) - - # Final conv1x1 -> relu -> zeroconv1x1 - self.final_conv = nn.Sequential( - Conv(d_model, d_model, kernel_size=1), - nn.ReLU(), - ZeroConv1d(d_model, out_channels), - ) - - def s4_block(dim): - return DiffWaveS4Block( - d_model=dim, - diffusion_step_embed_dim_out=self.diffusion_step_embed_dim_out, - unconditional=kwargs.get('unconditional', False), - mel_upsample=kwargs.get('mel_upsample', [16, 16]), - ) - # Down blocks d_layers = [] for p in pool: @@ -493,26 +285,26 @@ def s4_block(dim): # Add sequence downsampling and feature expanding d_layers.append(DownPool(H, expand, p)) H *= expand - + # Center block c_layers = [] for _ in range(n_layers): c_layers.append(s4_block(H)) if ff > 0: c_layers.append(ff_block(H)) - + # Up blocks u_layers = [] for p in pool[::-1]: block = [] H //= expand - block.append(UpPool(H * expand, expand, p, causal=not bidirectional)) + block.append(UpPool(H * expand, expand, p)) for _ in range(n_layers): block.append(s4_block(H)) if ff > 0: block.append(ff_block(H)) u_layers.append(nn.ModuleList(block)) - + self.d_layers = nn.ModuleList(d_layers) self.c_layers = nn.ModuleList(c_layers) self.u_layers = nn.ModuleList(u_layers) @@ -520,56 +312,34 @@ def s4_block(dim): assert H == d_model - def forward(self, x, state=None, mel_spec=None): + def forward(self, x, state=None): """ input: (batch, length, d_input) output: (batch, length, d_output) """ - if self.diffwave: - audio, diffusion_steps = x - x = audio - # BLD -> BDL - x = x.transpose(1, 2) - - x = self.init_conv(x) - - diffusion_step_embed = calc_diffusion_step_embedding( - diffusion_steps, - self.diffusion_step_embed_dim_in, - ) - diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) - diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) - - # Additional kwargs to pass onto the DiffWaveS4Block - layer_kwargs = dict(diffusion_step_embed=diffusion_step_embed, mel_spec=mel_spec) - else: - # BLD -> BDL - x = x.transpose(1, 2) - - # No additional kwargs to pass onto the S4 & FF blocks - layer_kwargs = dict() + x = x.transpose(1, 2) # Down blocks outputs = [] outputs.append(x) for layer in self.d_layers: - x, _ = layer(x, **layer_kwargs) + x, _ = layer(x) outputs.append(x) # Center block for layer in self.c_layers: - x, _ = layer(x, **layer_kwargs) + x, _ = layer(x) x = x + outputs.pop() # add a skip connection to the last output of the down block # Up blocks for block in self.u_layers: if self.unet: for layer in block: - x, _ = layer(x, **layer_kwargs) + x, _ = layer(x) x = x + outputs.pop() # skip connection else: for layer in block: - x, _ = layer(x, **layer_kwargs) + x, _ = layer(x) if isinstance(layer, UpPool): # Before modeling layer in the block x = x + outputs.pop() @@ -580,9 +350,6 @@ def forward(self, x, state=None, mel_spec=None): x = x.transpose(1, 2) # (batch, length, expand) x = self.norm(x) - if self.diffwave: - x = self.final_conv(x.transpose(1, 2)).transpose(1, 2) - return x, None # required to return a state def default_state(self, *args, **kwargs): @@ -654,21 +421,20 @@ def setup_rnn(self, mode='dense'): Convert the SaShiMi model to a RNN for autoregressive generation. Args: - mode: S4 recurrence mode. Using `diagonal` can speed up generation by 10-20%. - `linear` should be faster theoretically but is slow in practice since it + mode: S4 recurrence mode. Using `diagonal` can speed up generation by 10-20%. + `linear` should be faster theoretically but is slow in practice since it dispatches more operations (could benefit from fused operations). Note that `diagonal` could potentially be unstable if the diagonalization is numerically unstable (although we haven't encountered this case in practice), while `dense` should always be stable. """ assert mode in ['dense', 'diagonal', 'linear'] for module in self.modules(): - if hasattr(module, 'setup_step'): module.setup_step(mode) + if hasattr(module, 'setup_step'): module.setup_step(mode=mode) if __name__ == '__main__': from tqdm.auto import tqdm - # Example: SaShiMi for autoregressive modeling model = Sashimi(n_layers=2).cuda() # Print parameter count print(sum(p.numel() for p in model.parameters())) @@ -679,7 +445,7 @@ def setup_rnn(self, mode='dense'): # Forward in convolutional mode: used for training SaShiMi x = torch.randn(3, 10240, 64).cuda() y, _ = model(x) - + # Setup the SaShiMi RNN model.setup_rnn('diagonal') @@ -689,22 +455,8 @@ def setup_rnn(self, mode='dense'): for i in tqdm(range(10240)): y_, state = model.step(x[:, i], state) ys.append(y_.detach().cpu()) - + ys = torch.stack(ys, dim=1) + breakpoint() print(y.shape, ys.shape) - - - # Example: SaShiMi for diffusion modeling - model = Sashimi(n_layers=2, diffwave=True, unet=True).cuda() - # Print parameter count - print(sum(p.numel() for p in model.parameters())) - - model.eval() - - with torch.no_grad(): - # Forward (only) in convolutional mode - x = torch.randn(3, 10240, 1).cuda() - steps = torch.randint(0, 4, (3, 1)).cuda() - y, _ = model((x, steps)) - print(y.shape) diff --git a/src/callbacks/progressive_resizing.py b/src/callbacks/progressive_resizing.py new file mode 100644 index 0000000..f57c9ed --- /dev/null +++ b/src/callbacks/progressive_resizing.py @@ -0,0 +1,118 @@ +import numpy as np +from pytorch_lightning.callbacks import Callback + +import src.utils as utils +from src.utils import registry + + +class ProgressiveResizing(Callback): + + def __init__(self, stage_params: list): + """ + stage_params is a list of dicts + e.g. stage_params = [ + {'resolution': 4, 'epochs': 50}, # 32 x 32 + {'resolution': 2, 'epochs': 30}, # 64 x 64 + {'resolution': 1, 'epochs': 20}, # 128 x 128 + ] + """ + super().__init__() + assert len(stage_params) > 0, 'No stages specified' + assert all([{'resolution', 'epochs'} <= set(stage.keys()) for stage in stage_params]), \ + 'stage_params must contain keys: resolution and epochs' + + self.stage_params = stage_params + self.stage_epochs_cume = np.cumsum([stage['epochs'] for stage in stage_params]) + + self._current_stage = 0 + + def _verify_stages(self, trainer, model): + # Double-check that stage parameters are correct, otherwise we'll fail in the middle of training + for stage in self.stage_params: + if hasattr(stage, 'scheduler'): + # Verify that we can actually create the scheduler when we need to update it in each stage + scheduler = utils.instantiate(registry.scheduler, {**model.hparams.scheduler, **stage['scheduler']}, trainer.optimizers[0]) + del scheduler + + def on_train_start(self, trainer, model) -> None: + # Verify all the stage parameters are correct + self._verify_stages(trainer, model) + + print(f"Training starts at {trainer.current_epoch}") + if trainer.current_epoch == 0: + # Update the model to the first stage + self._update_to_current_stage(trainer, model) + else: + # Preemption or resumption of progressive resizing + # Update the stage to the current one + self._current_stage = int(np.searchsorted(self.stage_epochs_cume - 1, trainer.current_epoch)) + self._starting_stage = np.any(trainer.current_epoch == self.stage_epochs_cume) + + print("Progressive Resizing: Restarting at Stage {}".format(self._current_stage)) + if self._starting_stage: + self._update_lr_scheduler(trainer, model) + + # Set the dataloader and model + self._update_dataloaders(trainer, model) + self._update_model(trainer, model) + + return super().on_train_start(trainer, model) + + def _update_lr_scheduler(self, trainer, model): + if not hasattr(self.stage_params[self._current_stage], 'scheduler'): + # No scheduler specified, so don't update the current scheduler + return + + assert len(trainer.lr_schedulers) == 1 + # Reinitialize the scheduler + # We don't need to carry over information from the last scheduler e.g. the last_epoch property, + # because that will mess with the new scheduler when we step it + hparams = {**model.hparams.scheduler, **self.stage_params[self._current_stage]['scheduler']} + + # Note that passing in the optimizer below is okay: the scheduler will be reinitialized and doesn't seem to inherit any current lr info from the optimizer + trainer.lr_schedulers[0]['scheduler'] = utils.instantiate(registry.scheduler, hparams, trainer.optimizers[0]) + + print("\tChanged scheduler to {}".format(hparams)) + + def _update_dataloaders(self, trainer, model): + # Set the train resolution and reset the dataloader + model.hparams.loader.train_resolution = self.stage_params[self._current_stage]['resolution'] + trainer.reset_train_dataloader(model) + + print('\tChanged resolution to {}'.format(self.stage_params[self._current_stage]['resolution'])) + + def _update_model(self, trainer, model): + if not hasattr(self.stage_params[self._current_stage], 'bandlimit'): + return + + # Update the bandlimit value for the model: this is a hack to make sure the model is updated + # Iterate over all the modules + for module in model.modules(): + if hasattr(module, 'bandlimit'): + module.bandlimit = self.stage_params[self._current_stage]['bandlimit'] + + print('\tChanged bandlimit to {}'.format(self.stage_params[self._current_stage]['bandlimit'])) + + def _update_to_current_stage(self, trainer, model): + print("Progressive Resizing: Moving to Stage {}".format(self._current_stage)) + # Update the train dataloader, model and scheduler + self._update_dataloaders(trainer, model) + self._update_model(trainer, model) + self._update_lr_scheduler(trainer, model) + + + def on_train_epoch_end(self, trainer, model): + """ + Check to see if new stage is reached for the next epoch, and if so, prepare the new stage by + changing the dataloader. + + (We do next epoch so that the dataloader is prepared before the next epoch) + """ + next_epoch = trainer.current_epoch + 1 + + # Check if stage should be increased + if next_epoch >= self.stage_epochs_cume[self._current_stage] and self._current_stage < len(self.stage_params) - 1: + self._current_stage += 1 + self._update_to_current_stage(trainer, model) + + return super().on_train_epoch_end(trainer, model) diff --git a/src/dataloaders/README.md b/src/dataloaders/README.md new file mode 100644 index 0000000..513e1cb --- /dev/null +++ b/src/dataloaders/README.md @@ -0,0 +1,110 @@ +# Overview + +Basic datasets including MNIST, CIFAR, and Speech Commands will auto-download. Source code for these datamodules are in [basic.py](basic.py). + +By default, data is downloaded to `./data/` by default, where `.` is the top level directory of this repository (e.g. 'state-spaces'). + +- [Data Preparation](#data-preparation) - Instructions for downloading other datasets +- [Adding a Dataset](#adding-a-dataset-wip) - Basic instructions for adding new datasets + +## Advanced Usage + +After downloading and preparing data, the paths can be configured in several ways. + +1. Suppose that it is desired to download all data to a different folder, for example a different disk. +The data path can be configured by setting the environment variable `DATA_PATH`, which defaults to `./data`. + +2. For fine-grained control over the path of a particular dataset, set `dataset.data_dir` in the config. For example, if the LRA ListOps files are located in `/home/lra/listops-1000/` instead of the default `./data/listops/`, +pass in `+dataset.data_dir=/home/lra/listops-1000` on the command line or modify the config file directly. + +3. As a simple workaround, softlinks can be set, e.g. `ln -s /home/lra/listops-1000 ./data/listops` + + +# Data Preparation + +Datasets that must be manually downloaded include [LRA](#long-range-arena-lra), [WikiText-103](#wikitext-103), [BIDMC](#bidmc), and [other audio datasets](#other-audio) used in SaShiMi. + +By default, these should go under `$DATA_PATH/`, which defaults to `./data`. For the remainder of this README, these are used interchangeably. + +## Long Range Arena (LRA) + +LRA can be downloaded from the [GitHub page](https://github.com/google-research/long-range-arena). +These datasets should be organized as follows: +``` +$DATA_PATH/ + pathfinder/ + pathfinder32/ + pathfinder64/ + pathfinder128/ + pathfinder256/ + aan/ + listops/ +``` +The other two datasets in the suite ("Image" or grayscale sequential CIFAR-10; "Text" or char-level IMDB sentiment classification) are both auto-downloaded. + +## Speech Commands (SC) + +The full SC dataset is auto-downloaded into `./data/SpeechCommands/`. +Specific subsets such as the SC10 subset can be toggled in the config or command line. + +For the SC09 audio generation dataset, copy the digit subclasses of the `./data/SpeechCommands` folder into `data/sc09/{zero,one,two,three,four,five,six,seven,eight,nine}`. Also copy the `./data/SpeechCommands/{validation_list,test_list}.txt` files. + +## WikiText-103 + +The WikiText-103 language modeling dataset can be downloaded by the `getdata.sh` script from the [Transformer-XL codebase](https://github.com/kimiyoung/transformer-xl). +By default, the datamodule looks for it under `$DATA_PATH/wt103`. + +## BIDMC + +See [prepare/bidmc/README.md](prepare/bidmc/README.md) + +## Other Audio + +Instructions for other audio datasets used by the SaShiMi paper, including Beethoven and YoutubeMix, +can be found in the [SaShiMi README](../../sashimi/). + +# Adding a Dataset [WIP] +Datasets generally consist of two components. + +1. The first is the `torch.utils.data.Dataset` class which defines the raw data, or (data, target) pairs. + +2. The second is a [SequenceDataset](src/dataloaders/base.py) class, which defines how to set up the dataset as well as the dataloaders. This class is very similar to PyTorch Lightning's `LightningDataModule` and satisfies an interface described below. + +Datasets are sometimes defined in the [datasets/](./datasets/) subfolder, while Datamodules are all defined in the top-level files in this folder and imported by [__init__.py](./__init__.py). + +Basic examples of datamodules are provided [here](./basic.py). + +## SequenceDataset [WIP] + +TODO: +- Add documentation for adding a new dataset +- Restructure folder so that each dataset is in its own file +- Use Hydra to instantiate datamodules + + diff --git a/src/dataloaders/__init__.py b/src/dataloaders/__init__.py index 642e052..6eb3cc5 100644 --- a/src/dataloaders/__init__.py +++ b/src/dataloaders/__init__.py @@ -1,2 +1,2 @@ -from .datasets import SequenceDataset -from . import lm, et \ No newline at end of file +from . import audio, basic, et, lm, lra, synthetic, ts +from .base import SequenceDataset diff --git a/src/dataloaders/audio.py b/src/dataloaders/audio.py index a394c61..83d811f 100644 --- a/src/dataloaders/audio.py +++ b/src/dataloaders/audio.py @@ -1,11 +1,16 @@ -import torch -import torchaudio -import numpy as np +"""Audio datasets and utilities.""" import os - from os import listdir from os.path import join +import torch +import torchaudio +from torch import nn +from torch.nn import functional as F + +from src.dataloaders.base import default_data_path, SequenceDataset, deprecated + + def minmax_scale(tensor, range_min=0, range_max=1): """ Min-max scaling to [0, 1]. @@ -21,7 +26,7 @@ def quantize(samples, bits=8, epsilon=0.01): q_levels = 1 << bits samples *= q_levels - epsilon samples += epsilon / 2 - return samples.long() + return samples.long() def dequantize(samples, bits=8): """ @@ -90,7 +95,7 @@ class AbstractAudioDataset(torch.utils.data.Dataset): def __init__( self, bits=8, - sample_len=None, + sample_len=None, quantization='linear', return_type='autoregressive', drop_last=True, @@ -124,7 +129,7 @@ def __init__( def setup(self): return NotImplementedError("Must assign a list of filepaths to self.file_names.") - + def __getitem__(self, index): # Load signal if self.sample_len is not None: @@ -145,7 +150,7 @@ def __getitem__(self, index): # Transpose the signal to get (L, 1) seq = seq.transpose(0, 1) - + # Unsqueeze to (1, L, 1) seq = seq.unsqueeze(0) @@ -167,7 +172,7 @@ def __getitem__(self, index): if self.context_len is not None: y = y[self.context_len:] # Trim the signal if self.pad_len is not None: - x = torch.cat((torch.zeros(self.pad_len, dtype=torch.long) + self.zero, x)) # Pad the signal + x = torch.cat((torch.zeros(self.pad_len, dtype=self.qtype) + self.zero, x)) # Pad the signal return x, y elif self.return_type is None: return qseq @@ -184,12 +189,12 @@ def create_examples(self, sample_len: int): ] if sample_len is not None: - # Reorganize files into a flat list of (file_name, start_frame) pairs + # Reorganize files into a flat list of (file_name, start_frame) pairs # so that consecutive items are separated by sample_len self.examples = [] for file_name, metadata in zip(self.file_names, self.metadata): # Update the sample_len if resampling to target_sr is required - # This is because the resampling will change the length of the signal + # This is because the resampling will change the length of the signal # so we need to adjust the sample_len accordingly (e.g. if downsampling # the sample_len will need to be increased) sample_len_i = sample_len @@ -199,9 +204,9 @@ def create_examples(self, sample_len: int): margin = metadata.num_frames % sample_len_i for start_frame in range(0, metadata.num_frames - margin, sample_len_i): self.examples.append((file_name, start_frame, sample_len_i)) - + if margin > 0 and not self.drop_last: - # Last (leftover) example is shorter than sample_len, and equal to the margin + # Last (leftover) example is shorter than sample_len, and equal to the margin # (must be padded in collate_fn) self.examples.append((file_name, metadata.num_frames - margin, margin)) else: @@ -211,22 +216,27 @@ def create_quantizer(self, quantization: str): if quantization == 'linear': self.quantizer = linear_encode self.dequantizer = linear_decode + self.qtype = torch.long elif quantization == 'mu-law': self.quantizer = mu_law_encode self.dequantizer = mu_law_decode + self.qtype = torch.long + elif quantization is None: + self.quantizer = lambda x, bits: x + self.dequantizer = lambda x, bits: x + self.qtype = torch.float else: raise ValueError('Invalid quantization type') - class QuantizedAudioDataset(AbstractAudioDataset): """ Adapted from https://github.com/deepsound-project/samplernn-pytorch/blob/master/dataset.py """ def __init__( - self, - path, - bits=8, + self, + path, + bits=8, ratio_min=0, ratio_max=1, sample_len=None, @@ -262,6 +272,105 @@ def setup(self): int(self.ratio_min * len(file_names)) : int(self.ratio_max * len(file_names)) ] +class QuantizedAutoregressiveAudio(SequenceDataset): + _name_ = 'qautoaudio' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'path': None, + 'bits': 8, + 'sample_len': None, + 'train_percentage': 0.88, + 'quantization': 'linear', + 'drop_last': False, + 'context_len': None, + 'pad_len': None, + } + + def setup(self): + from src.dataloaders.audio import QuantizedAudioDataset + assert self.path is not None or self.data_dir is not None, "Pass a path to a folder of audio: either `data_dir` for full directory or `path` for relative path." + if self.data_dir is None: + self.data_dir = default_data_path / self.path + + self.dataset_train = QuantizedAudioDataset( + path=self.data_dir, + bits=self.bits, + ratio_min=0, + ratio_max=self.train_percentage, + sample_len=self.sample_len, + quantization=self.quantization, + drop_last=self.drop_last, + context_len=self.context_len, + pad_len=self.pad_len, + ) + + self.dataset_val = QuantizedAudioDataset( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage, + ratio_max=self.train_percentage + (1 - self.train_percentage) / 2, + sample_len=self.sample_len, + quantization=self.quantization, + drop_last=self.drop_last, + context_len=self.context_len, + pad_len=self.pad_len, + ) + + self.dataset_test = QuantizedAudioDataset( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage + (1 - self.train_percentage) / 2, + ratio_max=1, + sample_len=self.sample_len, + quantization=self.quantization, + drop_last=self.drop_last, + context_len=self.context_len, + pad_len=self.pad_len, + ) + + def collate_fn(batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + if self.pad_len is None: + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + else: + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len + self.pad_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=-100, # pad with -100 to ignore these locations in cross-entropy loss + batch_first=True, + ) + return x, y, {"lengths": lengths} + + if not self.drop_last: + self._collate_fn = collate_fn # TODO not tested + class SpeechCommands09(AbstractAudioDataset): CLASSES = [ @@ -280,9 +389,9 @@ class SpeechCommands09(AbstractAudioDataset): CLASS_TO_IDX = dict(zip(CLASSES, range(len(CLASSES)))) def __init__( - self, - path, - bits=8, + self, + path, + bits=8, split='train', sample_len=16000, quantization='linear', # [linear, mu-law] @@ -310,7 +419,7 @@ def __init__( def setup(self): with open(join(self.path, 'validation_list.txt')) as f: validation_files = set([line.rstrip() for line in f.readlines()]) - + with open(join(self.path, 'testing_list.txt')) as f: test_files = set([line.rstrip() for line in f.readlines()]) @@ -322,12 +431,12 @@ def setup(self): for file_name in listdir(join(self.path, class_name)) if file_name.endswith('.wav') ] - + # Keep files based on the split if self.split == 'train': self.file_names = [ - join(self.path, class_name, file_name) - for class_name, file_name in self.file_names + join(self.path, class_name, file_name) + for class_name, file_name in self.file_names if join(class_name, file_name) not in validation_files and join(class_name, file_name) not in test_files ] @@ -350,3 +459,599 @@ def __getitem__(self, index): if self.dequantize: x = self.dequantizer(x).unsqueeze(1) return x, y, *z + +class SpeechCommands09Autoregressive(SequenceDataset): + _name_ = 'sc09' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'quantization': 'mu-law', + 'dequantize': False, + 'pad_len': None, + } + + def setup(self): + from src.dataloaders.audio import SpeechCommands09 + self.data_dir = self.data_dir or default_data_path / self._name_ + + self.dataset_train = SpeechCommands09( + path=self.data_dir, + bits=self.bits, + split='train', + quantization=self.quantization, + dequantize=self.dequantize, + pad_len=self.pad_len, + ) + + self.dataset_val = SpeechCommands09( + path=self.data_dir, + bits=self.bits, + split='validation', + quantization=self.quantization, + dequantize=self.dequantize, + pad_len=self.pad_len, + ) + + self.dataset_test = SpeechCommands09( + path=self.data_dir, + bits=self.bits, + split='test', + quantization=self.quantization, + dequantize=self.dequantize, + pad_len=self.pad_len, + ) + + self.sample_len = self.dataset_train.sample_len + + def _collate_fn(self, batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + if self.pad_len is None: + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + else: + pad_length = 0 # int(self.sample_len + self.pad_len - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero if not self.dequantize else 0., + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero if not self.dequantize else 0.) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=-100, # pad with -100 to ignore these locations in cross-entropy loss + batch_first=True, + ) + return x, y, {"lengths": lengths} + +class MaestroDataset(AbstractAudioDataset): + + YEARS = [2004, 2006, 2008, 2009, 2011, 2013, 2014, 2015, 2017, 2018] + SPLITS = ['train', 'validation', 'test'] + + def __init__( + self, + path, + bits=8, + split='train', + sample_len=None, + quantization='linear', + return_type='autoregressive', + drop_last=False, + target_sr=16000, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=return_type, + split=split, + path=path, + drop_last=drop_last, + target_sr=target_sr, + ) + + def setup(self): + import pandas as pd + from natsort import natsorted + + self.path = str(self.path) + + # Pull out examples in the specified split + df = pd.read_csv(self.path + '/maestro-v3.0.0.csv') + df = df[df['split'] == self.split] + + file_names = [] + for filename in df['audio_filename'].values: + filepath = os.path.join(self.path, filename) + assert os.path.exists(filepath) + file_names.append(filepath) + self.file_names = natsorted(file_names) + +class MaestroAutoregressive(SequenceDataset): + _name_ = 'maestro' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'sample_len': None, + 'quantization': 'mu-law', + } + + def setup(self): + from src.dataloaders.audio import MaestroDataset + self.data_dir = self.data_dir or default_data_path / self._name_ / 'maestro-v3.0.0' + + self.dataset_train = MaestroDataset( + path=self.data_dir, + bits=self.bits, + split='train', + sample_len=self.sample_len, + quantization=self.quantization, + ) + + self.dataset_val = MaestroDataset( + path=self.data_dir, + bits=self.bits, + split='validation', + sample_len=self.sample_len, + quantization=self.quantization, + ) + + self.dataset_test = MaestroDataset( + path=self.data_dir, + bits=self.bits, + split='test', + sample_len=self.sample_len, + quantization=self.quantization, + ) + + def _collate_fn(self, batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + pad_length = int(min(max(1024, 2**max_length.log2().ceil()), self.sample_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + return x, y, {"lengths": lengths} + +class LJSpeech(QuantizedAudioDataset): + + def __init__( + self, + path, + bits=8, + ratio_min=0, + ratio_max=1, + sample_len=None, + quantization='linear', # [linear, mu-law] + return_type='autoregressive', # [autoregressive, None] + drop_last=False, + target_sr=None, + use_text=False, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=return_type, + drop_last=drop_last, + target_sr=target_sr, + path=path, + ratio_min=ratio_min, + ratio_max=ratio_max, + use_text=use_text, + ) + + def setup(self): + import pandas as pd + from sklearn.preprocessing import LabelEncoder + super().setup() + + self.vocab_size = None + if self.use_text: + self.transcripts = {} + with open(str(self.path.parents[0] / 'metadata.csv'), 'r') as f: + for line in f: + index, raw_transcript, normalized_transcript = line.rstrip('\n').split("|") + self.transcripts[index] = normalized_transcript + # df = pd.read_csv(self.path.parents[0] / 'metadata.csv', sep="|", header=None) + # self.transcripts = dict(zip(df[0], df[2])) # use normalized transcripts + + self.tok_transcripts = {} + self.vocab = set() + for file_name in self.file_names: + # Very simple tokenization, character by character + # Capitalization is ignored for simplicity + file_name = file_name.split('/')[-1].split('.')[0] + self.tok_transcripts[file_name] = list(self.transcripts[file_name].lower()) + self.vocab.update(self.tok_transcripts[file_name]) + + # Fit a label encoder mapping characters to numbers + self.label_encoder = LabelEncoder() + self.label_encoder.fit(list(self.vocab)) + # add a token for padding, no additional token for UNK (our dev/test set contain no unseen characters) + self.vocab_size = len(self.vocab) + 1 + + # Finalize the tokenized transcripts + for file_name in self.file_names: + file_name = file_name.split('/')[-1].split('.')[0] + self.tok_transcripts[file_name] = torch.tensor(self.label_encoder.transform(self.tok_transcripts[file_name])) + + + def __getitem__(self, index): + item = super().__getitem__(index) + if self.use_text: + file_name, _, _ = self.examples[index] + tok_transcript = self.tok_transcripts[file_name.split('/')[-1].split('.')[0]] + return *item, tok_transcript + return item + +class LJSpeechAutoregressive(SequenceDataset): + _name_ = 'ljspeech' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 1 << self.bits + + @property + def l_output(self): + return self.sample_len + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'sample_len': None, + 'quantization': 'mu-law', + 'train_percentage': 0.88, + 'use_text': False, + } + + def setup(self): + from src.dataloaders.audio import LJSpeech + self.data_dir = self.data_dir or default_data_path / self._name_ / 'LJSpeech-1.1' / 'wavs' + + self.dataset_train = LJSpeech( + path=self.data_dir, + bits=self.bits, + ratio_min=0, + ratio_max=self.train_percentage, + sample_len=self.sample_len, + quantization=self.quantization, + target_sr=16000, + use_text=self.use_text, + ) + + self.dataset_val = LJSpeech( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage, + ratio_max=self.train_percentage + (1 - self.train_percentage) / 2, + sample_len=self.sample_len, + quantization=self.quantization, + target_sr=16000, + use_text=self.use_text, + ) + + self.dataset_test = LJSpeech( + path=self.data_dir, + bits=self.bits, + ratio_min=self.train_percentage + (1 - self.train_percentage) / 2, + ratio_max=1, + sample_len=self.sample_len, + quantization=self.quantization, + target_sr=16000, + use_text=self.use_text, + ) + + self.vocab_size = self.dataset_train.vocab_size + + def _collate_fn(self, batch): + x, y, *z = zip(*batch) + + if self.use_text: + tokens = z[0] + text_lengths = torch.tensor([len(e) for e in tokens]) + tokens = nn.utils.rnn.pad_sequence( + tokens, + padding_value=self.vocab_size - 1, + batch_first=True, + ) + else: + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=self.dataset_train.zero) + y = nn.utils.rnn.pad_sequence( + y, + padding_value=-100, # pad with -100 to ignore these locations in cross-entropy loss + batch_first=True, + ) + if self.use_text: + return x, y, {"lengths": lengths, "tokens": tokens, "text_lengths": text_lengths} + else: + return x, y, {"lengths": lengths} + +class _SpeechCommands09Classification(SpeechCommands09): + + def __init__( + self, + path, + bits=8, + split='train', + sample_len=16000, + quantization='linear', # [linear, mu-law] + drop_last=False, + target_sr=None, + **kwargs, + ): + super().__init__( + bits=bits, + sample_len=sample_len, + quantization=quantization, + return_type=None, + split=split, + drop_last=drop_last, + target_sr=target_sr, + path=path, + **kwargs, + ) + + def __getitem__(self, index): + x = super().__getitem__(index) + x = mu_law_decode(x) + y = torch.tensor(self.CLASS_TO_IDX[self.file_names[index].split("/")[-2]]) + return x, y + +class SpeechCommands09Classification(SequenceDataset): + _name_ = 'sc09cls' + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 10 + + @property + def l_output(self): + return 0 + + @property + def n_tokens(self): + return 1 << self.bits + + @property + def init_defaults(self): + return { + 'bits': 8, + 'quantization': 'mu-law', + } + + def setup(self): + from src.dataloaders.audio import _SpeechCommands09Classification + self.data_dir = self.data_dir or default_data_path / 'sc09' + + self.dataset_train = _SpeechCommands09Classification( + path=self.data_dir, + bits=self.bits, + split='train', + quantization=self.quantization, + ) + + self.dataset_val = _SpeechCommands09Classification( + path=self.data_dir, + bits=self.bits, + split='validation', + quantization=self.quantization, + ) + + self.dataset_test = _SpeechCommands09Classification( + path=self.data_dir, + bits=self.bits, + split='test', + quantization=self.quantization, + ) + + self.sample_len = self.dataset_train.sample_len + + def collate_fn(self, batch): + x, y, *z = zip(*batch) + assert len(z) == 0 + lengths = torch.tensor([len(e) for e in x]) + max_length = lengths.max() + pad_length = int(min(2**max_length.log2().ceil(), self.sample_len) - max_length) + x = nn.utils.rnn.pad_sequence( + x, + padding_value=self.dataset_train.zero, + batch_first=True, + ) + x = F.pad(x, (0, pad_length), value=0.)#self.dataset_train.zero) + y = torch.tensor(y) + return x, y, {"lengths": lengths} + +@deprecated +class SpeechCommandsGeneration(SequenceDataset): + _name_ = "scg" + + init_defaults = { + "mfcc": False, + "dropped_rate": 0.0, + "length": 16000, + "all_classes": False, + "discrete_input": False, + } + + @property + def n_tokens(self): + return 256 if self.discrete_input else None + + def init(self): + if self.mfcc: + self.d_input = 20 + self.L = 161 + else: + self.d_input = 1 + self.L = self.length + + if self.dropped_rate > 0.0: + self.d_input += 1 + + self.d_output = 256 + self.l_output = self.length + + def setup(self): + from src.dataloaders.datasets.sc import _SpeechCommandsGeneration + + # TODO refactor with data_dir argument + self.dataset_train = _SpeechCommandsGeneration( + partition="train", + length=self.length, # self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=default_data_path, + all_classes=self.all_classes, + discrete_input=self.discrete_input, + ) + + self.dataset_val = _SpeechCommandsGeneration( + partition="val", + length=self.length, # self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=default_data_path, + all_classes=self.all_classes, + discrete_input=self.discrete_input, + ) + + self.dataset_test = _SpeechCommandsGeneration( + partition="test", + length=self.length, # self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=default_data_path, + all_classes=self.all_classes, + discrete_input=self.discrete_input, + ) + + @classmethod + def _return_callback(cls, return_value, *args, **kwargs): + x, y, *z = return_value + return x, y.long(), *z + +@deprecated +class Music(SequenceDataset): + _name_ = "music" + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return 256 + + @property + def l_output(self): + return self.sample_rate * self.sample_len + + @property + def n_tokens(self): + return 256 if self.discrete_input else None + + @property + def init_defaults(self): + return { + "sample_len": 1, + "sample_rate": 16000, + "train_percentage": 0.88, + "discrete_input": False, + } + + def init(self): + return + + def setup(self): + from src.dataloaders.music import _Music + + self.music_class = _Music( + path=default_data_path, + sample_len=self.sample_len, # In seconds + sample_rate=self.sample_rate, + train_percentage=self.train_percentage, # Use settings from SampleRNN paper + discrete_input=self.discrete_input, + ) + + self.dataset_train = self.music_class.get_data("train") + self.dataset_test = self.music_class.get_data("test") + self.dataset_val = self.music_class.get_data("val") + + @classmethod + def _return_callback(cls, return_value, *args, **kwargs): + x, y, *z = return_value + return x, y.long(), *z diff --git a/src/dataloaders/base.py b/src/dataloaders/base.py new file mode 100644 index 0000000..ce73a64 --- /dev/null +++ b/src/dataloaders/base.py @@ -0,0 +1,348 @@ +""" Datasets for core experimental results """ + +import os +import pickle +from functools import partial +from pathlib import Path + +import numpy as np +import torch +import torchaudio.functional as TF +import torchvision +from einops import rearrange +from einops.layers.torch import Rearrange +from src.utils import is_list, permutations +from torch.nn import functional as F + +def deprecated(cls_or_func): + def _deprecated(*args, **kwargs): + print(f"{cls_or_func} is deprecated") + return cls_or_func(*args, **kwargs) + return _deprecated + +# Default data path is environment variable or hippo/data +if (default_data_path := os.getenv("DATA_PATH")) is None: + default_data_path = Path(__file__).parent.parent.parent.absolute() + default_data_path = default_data_path / "data" +else: + default_data_path = Path(default_data_path).absolute() + +class DefaultCollateMixin: + """Controls collating in the DataLoader + + The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor. + """ + + @classmethod + def _collate_callback(cls, x, *args, **kwargs): + """ + Modify the behavior of the default _collate method. + """ + return x + + _collate_arg_names = [] + + @classmethod + def _return_callback(cls, return_value, *args, **kwargs): + """ + Modify the return value of the collate_fn. + Assign a name to each element of the returned tuple beyond the (x, y) pairs + See InformerSequenceDataset for an example of this being used + """ + x, y, *z = return_value + assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" + return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} + + @classmethod + def _collate(cls, batch, *args, **kwargs): + # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py + elem = batch[0] + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + x = torch.stack(batch, dim=0, out=out) + + # Insert custom functionality into the collate_fn + x = cls._collate_callback(x, *args, **kwargs) + + return x + else: + return torch.tensor(batch) + + @classmethod + def _collate_fn(cls, batch, *args, **kwargs): + """ + Default collate function. + Generally accessed by the dataloader() methods to pass into torch DataLoader + + Arguments: + batch: list of (x, y) pairs + args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback + """ + x, y, *z = zip(*batch) + + x = cls._collate(x, *args, **kwargs) + y = cls._collate(y) + z = [cls._collate(z_) for z_ in z] + + return_value = (x, y, *z) + return cls._return_callback(return_value, *args, **kwargs) + + # List of loader arguments to pass into collate_fn + collate_args = [] + + def _dataloader(self, dataset, **loader_args): + collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} + loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} + loader_cls = loader_registry[loader_args.pop("_name_", None)] + return loader_cls( + dataset=dataset, + collate_fn=partial(self._collate_fn, **collate_args), + **loader_args, + ) + + +class SequenceResolutionCollateMixin(DefaultCollateMixin): + """self.collate_fn(resolution) produces a collate function that subsamples elements of the sequence""" + + @classmethod + def _collate_callback(cls, x, resolution=None): + if resolution is None: + pass + elif is_list(resolution): # Resize to first resolution, then apply resampling technique + # Sample to first resolution + x = x.squeeze(-1) # (B, L) + L = x.size(1) + x = x[:, ::resolution[0]] # assume length is first axis after batch + _L = L // resolution[0] + for r in resolution[1:]: + x = TF.resample(x, _L, L//r) + _L = L // r + x = x.unsqueeze(-1) # (B, L, 1) + else: + # Assume x is (B, L_0, L_1, ..., L_k, C) for x.ndim > 2 and (B, L) for x.ndim = 2 + assert x.ndim >= 2 + n_resaxes = max(1, x.ndim - 2) # [AG 22/07/02] this line looks suspicious... are there cases with 2 axes? + # rearrange: b (l_0 res_0) (l_1 res_1) ... (l_k res_k) ... -> res_0 res_1 .. res_k b l_0 l_1 ... + lhs = "b " + " ".join([f"(l{i} res{i})" for i in range(n_resaxes)]) + " ..." + rhs = " ".join([f"res{i}" for i in range(n_resaxes)]) + " b " + " ".join([f"l{i}" for i in range(n_resaxes)]) + " ..." + x = rearrange(x, lhs + " -> " + rhs, **{f'res{i}': resolution for i in range(n_resaxes)}) + x = x[tuple([0] * n_resaxes)] + + return x + + @classmethod + def _return_callback(cls, return_value, resolution=None): + return *return_value, {"rate": resolution} + + + collate_args = ['resolution'] + +class ImageResolutionCollateMixin(SequenceResolutionCollateMixin): + """self.collate_fn(resolution, img_size) produces a collate function that resizes inputs to size img_size/resolution""" + + _interpolation = torchvision.transforms.InterpolationMode.BILINEAR + _antialias = True + + @classmethod + def _collate_callback(cls, x, resolution=None, img_size=None, channels_last=True): + if x.ndim < 4: + return super()._collate_callback(x, resolution=resolution) + if img_size is None: + x = super()._collate_callback(x, resolution=resolution) + else: + x = rearrange(x, 'b ... c -> b c ...') if channels_last else x + _size = round(img_size/resolution) + x = torchvision.transforms.functional.resize( + x, + size=[_size, _size], + interpolation=cls._interpolation, + antialias=cls._antialias, + ) + x = rearrange(x, 'b c ... -> b ... c') if channels_last else x + return x + + @classmethod + def _return_callback(cls, return_value, resolution=None, img_size=None, channels_last=True): + return *return_value, {"rate": resolution} + + collate_args = ['resolution', 'img_size', 'channels_last'] + +class TBPTTDataLoader(torch.utils.data.DataLoader): + """ + Adapted from https://github.com/deepsound-project/samplernn-pytorch + """ + + def __init__( + self, + dataset, + batch_size, + chunk_len, + overlap_len, + *args, + **kwargs + ): + super().__init__(dataset, batch_size, *args, **kwargs) + assert chunk_len is not None and overlap_len is not None, "TBPTTDataLoader: chunk_len and overlap_len must be specified." + + # Zero padding value, given by the dataset + self.zero = dataset.zero if hasattr(dataset, "zero") else 0 + + # Size of the chunks to be fed into the model + self.chunk_len = chunk_len + + # Keep `overlap_len` from the previous chunk (e.g. SampleRNN requires this) + self.overlap_len = overlap_len + + def __iter__(self): + for batch in super().__iter__(): + x, y, z = batch # (B, L) (B, L, 1) {'lengths': (B,)} + + # Pad with self.overlap_len - 1 zeros + pad = lambda x, val: torch.cat([x.new_zeros((x.shape[0], self.overlap_len - 1, *x.shape[2:])) + val, x], dim=1) + x = pad(x, self.zero) + y = pad(y, 0) + z = { k: pad(v, 0) for k, v in z.items() if v.ndim > 1 } + _, seq_len, *_ = x.shape + + reset = True + + for seq_begin in list(range(self.overlap_len - 1, seq_len, self.chunk_len))[:-1]: + from_index = seq_begin - self.overlap_len + 1 + to_index = seq_begin + self.chunk_len + # TODO: check this + # Ensure divisible by overlap_len + if self.overlap_len > 0: + to_index = min(to_index, seq_len - ((seq_len - self.overlap_len + 1) % self.overlap_len)) + + x_chunk = x[:, from_index:to_index] + if len(y.shape) == 3: + y_chunk = y[:, seq_begin:to_index] + else: + y_chunk = y + z_chunk = {k: v[:, from_index:to_index] for k, v in z.items() if len(v.shape) > 1} + + yield (x_chunk, y_chunk, {**z_chunk, "reset": reset}) + + reset = False + + def __len__(self): + raise NotImplementedError() + + +# class SequenceDataset(LightningDataModule): +# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just provide our own class with the same core methods as LightningDataModule (e.g. setup) +class SequenceDataset(DefaultCollateMixin): + registry = {} + _name_ = NotImplementedError("Dataset must have shorthand name") + + # Since subclasses do not specify __init__ which is instead handled by this class + # Subclasses can provide a list of default arguments which are automatically registered as attributes + # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class such as the _name_ and d_input/d_output + @property + def init_defaults(self): + return {} + + # https://www.python.org/dev/peps/pep-0487/#subclass-registration + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.registry[cls._name_] = cls + + def __init__(self, _name_, data_dir=None, **dataset_cfg): + assert _name_ == self._name_ + self.data_dir = Path(data_dir).absolute() if data_dir is not None else None + + # Add all arguments to self + init_args = self.init_defaults.copy() + init_args.update(dataset_cfg) + for k, v in init_args.items(): + setattr(self, k, v) + + # The train, val, test datasets must be set by `setup()` + self.dataset_train = self.dataset_val = self.dataset_test = None + + self.init() + + def init(self): + """Hook called at end of __init__, override this instead of __init__""" + pass + + def setup(self): + """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" + raise NotImplementedError + + def split_train_val(self, val_split): + """ + Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. + """ + train_len = int(len(self.dataset_train) * (1.0 - val_split)) + self.dataset_train, self.dataset_val = torch.utils.data.random_split( + self.dataset_train, + (train_len, len(self.dataset_train) - train_len), + generator=torch.Generator().manual_seed( + getattr(self, "seed", 42) + ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us + ) + + def train_dataloader(self, **kwargs): + return self._train_dataloader(self.dataset_train, **kwargs) + + def _train_dataloader(self, dataset, **kwargs): + if dataset is None: return + kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler + return self._dataloader(dataset, **kwargs) + + def val_dataloader(self, **kwargs): + return self._eval_dataloader(self.dataset_val, **kwargs) + + def test_dataloader(self, **kwargs): + return self._eval_dataloader(self.dataset_test, **kwargs) + + def _eval_dataloader(self, dataset, **kwargs): + if dataset is None: return + # Note that shuffle=False by default + return self._dataloader(dataset, **kwargs) + + def __str__(self): + return self._name_ + +class ResolutionSequenceDataset(SequenceDataset, SequenceResolutionCollateMixin): + + def _train_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): + if train_resolution is None: train_resolution = [1] + if not is_list(train_resolution): train_resolution = [train_resolution] + assert len(train_resolution) == 1, "Only one train resolution supported for now." + return super()._train_dataloader(dataset, resolution=train_resolution[0], **kwargs) + + def _eval_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): + if dataset is None: return + if eval_resolutions is None: eval_resolutions = [1] + if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions] + + dataloaders = [] + for resolution in eval_resolutions: + dataloaders.append(super()._eval_dataloader(dataset, resolution=resolution, **kwargs)) + + return ( + { + None if res == 1 else str(res): dl + for res, dl in zip(eval_resolutions, dataloaders) + } + if dataloaders is not None else None + ) + +class ImageResolutionSequenceDataset(ResolutionSequenceDataset, ImageResolutionCollateMixin): + pass + + + +# Registry for dataloader class +loader_registry = { + "tbptt": TBPTTDataLoader, + None: torch.utils.data.DataLoader, # default case +} diff --git a/src/dataloaders/basic.py b/src/dataloaders/basic.py new file mode 100644 index 0000000..187a733 --- /dev/null +++ b/src/dataloaders/basic.py @@ -0,0 +1,271 @@ +"""Implementation of basic benchmark datasets used in S4 experiments: MNIST, CIFAR10 and Speech Commands.""" +import numpy as np +import torch +import torchvision +from einops.layers.torch import Rearrange +from src.utils import permutations + +from src.dataloaders.base import default_data_path, ImageResolutionSequenceDataset, ResolutionSequenceDataset, SequenceDataset + + +class MNIST(SequenceDataset): + _name_ = "mnist" + d_input = 1 + d_output = 10 + l_output = 0 + L = 784 + + @property + def init_defaults(self): + return { + "permute": True, + "val_split": 0.1, + "seed": 42, # For train/val split + } + + def setup(self): + self.data_dir = self.data_dir or default_data_path / self._name_ + + transform_list = [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Lambda(lambda x: x.view(self.d_input, self.L).t()), + ] # (L, d_input) + if self.permute: + # below is another permutation that other works have used + # permute = np.random.RandomState(92916) + # permutation = torch.LongTensor(permute.permutation(784)) + permutation = permutations.bitreversal_permutation(self.L) + transform_list.append( + torchvision.transforms.Lambda(lambda x: x[permutation]) + ) + # TODO does MNIST need normalization? + # torchvision.transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs + transform = torchvision.transforms.Compose(transform_list) + self.dataset_train = torchvision.datasets.MNIST( + self.data_dir, + train=True, + download=True, + transform=transform, + ) + self.dataset_test = torchvision.datasets.MNIST( + self.data_dir, + train=False, + transform=transform, + ) + self.split_train_val(self.val_split) + + def __str__(self): + return f"{'p' if self.permute else 's'}{self._name_}" + + +class CIFAR10(ImageResolutionSequenceDataset): + _name_ = "cifar" + d_output = 10 + l_output = 0 + + @property + def init_defaults(self): + return { + "permute": None, + "grayscale": False, + "tokenize": False, # if grayscale, tokenize into discrete byte inputs + "augment": False, + "cutout": False, + "rescale": None, + "random_erasing": False, + "val_split": 0.1, + "seed": 42, # For validation split + } + + @property + def d_input(self): + if self.grayscale: + if self.tokenize: + return 256 + else: + return 1 + else: + assert not self.tokenize + return 3 + + def setup(self): + img_size = 32 + if self.rescale: + img_size //= self.rescale + + if self.grayscale: + preprocessors = [ + torchvision.transforms.Grayscale(), + torchvision.transforms.ToTensor(), + ] + permutations_list = [ + torchvision.transforms.Lambda( + lambda x: x.view(1, img_size * img_size).t() + ) # (L, d_input) + ] + + if self.tokenize: + preprocessors.append( + torchvision.transforms.Lambda(lambda x: (x * 255).long()) + ) + permutations_list.append(Rearrange("l 1 -> l")) + else: + preprocessors.append( + torchvision.transforms.Normalize( + mean=122.6 / 255.0, std=61.0 / 255.0 + ) + ) + else: + preprocessors = [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + permutations_list = [ + torchvision.transforms.Lambda( + Rearrange("z h w -> (h w) z", z=3, h=img_size, w=img_size) + ) # (L, d_input) + ] + + # Permutations and reshaping + if self.permute == "br": + permutation = permutations.bitreversal_permutation(img_size * img_size) + print("bit reversal", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "snake": + permutation = permutations.snake_permutation(img_size, img_size) + print("snake", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "hilbert": + permutation = permutations.hilbert_permutation(img_size) + print("hilbert", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "transpose": + permutation = permutations.transpose_permutation(img_size, img_size) + transform = torchvision.transforms.Lambda( + lambda x: torch.cat([x, x[permutation]], dim=-1) + ) + permutations_list.append(transform) + elif self.permute == "2d": # h, w, c + permutation = torchvision.transforms.Lambda( + Rearrange("(h w) c -> h w c", h=img_size, w=img_size) + ) + permutations_list.append(permutation) + elif self.permute == "2d_transpose": # c, h, w + permutation = torchvision.transforms.Lambda( + Rearrange("(h w) c -> c h w", h=img_size, w=img_size) + ) + permutations_list.append(permutation) + + # Augmentation + if self.augment: + augmentations = [ + torchvision.transforms.RandomCrop( + img_size, padding=4, padding_mode="symmetric" + ), + torchvision.transforms.RandomHorizontalFlip(), + ] + + post_augmentations = [] + if self.cutout: + post_augmentations.append(Cutout(1, img_size // 2)) + pass + if self.random_erasing: + # augmentations.append(RandomErasing()) + pass + else: + augmentations, post_augmentations = [], [] + transforms_train = ( + augmentations + preprocessors + post_augmentations + permutations_list + ) + transforms_eval = preprocessors + permutations_list + + transform_train = torchvision.transforms.Compose(transforms_train) + transform_eval = torchvision.transforms.Compose(transforms_eval) + self.dataset_train = torchvision.datasets.CIFAR10( + f"{default_data_path}/{self._name_}", + train=True, + download=True, + transform=transform_train, + ) + self.dataset_test = torchvision.datasets.CIFAR10( + f"{default_data_path}/{self._name_}", train=False, transform=transform_eval + ) + + if self.rescale: + print(f"Resizing all images to {img_size} x {img_size}.") + self.dataset_train.data = self.dataset_train.data.reshape((self.dataset_train.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) + self.dataset_test.data = self.dataset_test.data.reshape((self.dataset_test.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) + + self.split_train_val(self.val_split) + + def __str__(self): + return f"{'p' if self.permute else 's'}{self._name_}" + +class SpeechCommands(ResolutionSequenceDataset): + _name_ = "sc" + + @property + def init_defaults(self): + return { + "mfcc": False, + "dropped_rate": 0.0, + "length": 16000, + "all_classes": False, + } + + @property + def d_input(self): + _d_input = 20 if self.mfcc else 1 + _d_input += 1 if self.dropped_rate > 0.0 else 0 + return _d_input + + @property + def d_output(self): + return 10 if not self.all_classes else 35 + + @property + def l_output(self): + return 0 + + @property + def L(self): + return 161 if self.mfcc else self.length + + + def setup(self): + self.data_dir = self.data_dir or default_data_path # TODO make same logic as other classes + + from src.dataloaders.datasets.sc import _SpeechCommands + + # TODO refactor with data_dir argument + self.dataset_train = _SpeechCommands( + partition="train", + length=self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) + + self.dataset_val = _SpeechCommands( + partition="val", + length=self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) + + self.dataset_test = _SpeechCommands( + partition="test", + length=self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) diff --git a/src/dataloaders/datasets/adding.py b/src/dataloaders/datasets/adding.py new file mode 100644 index 0000000..d2e59aa --- /dev/null +++ b/src/dataloaders/datasets/adding.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def torch_adding_data(L, batch_shape=()): + assert L >= 2 + mid = L//2 + idx0 = torch.randint(low=0, high=mid, size=batch_shape) + idx1 = torch.randint(low=0, high=L-mid, size=batch_shape) + + idx = torch.cat((F.one_hot(idx0, mid), F.one_hot(idx1, L-mid)), dim=-1).float() # (batch_shape, L) + unif = torch.empty(batch_shape+(L,)) + unif.uniform_(0., 1.) + + x = torch.stack((unif, idx), dim=-1) # (batch_shape, L, 2) + y = torch.sum(unif*idx, dim=-1, keepdim=True) # (batch_shape, 1) + + return x, y + +def adding_static_dataset(L, samples): + all_x, all_y = torch_adding_data(L, batch_shape=(samples,)) + print("Constructing Adding dataset of shape", all_x.shape) + ds = torch.utils.data.TensorDataset(all_x, all_y) + return ds diff --git a/src/dataloaders/datasets/celeba.py b/src/dataloaders/datasets/celeba.py new file mode 100644 index 0000000..8eea29a --- /dev/null +++ b/src/dataloaders/datasets/celeba.py @@ -0,0 +1,166 @@ +from functools import partial +import torch +import os +import PIL +from typing import Any, Callable, List, Optional, Union, Tuple +from torchvision.datasets import VisionDataset +try: + import gdown + DOWNLOAD = True +except ImportError: + DOWNLOAD = False +import numpy as np + +class _CelebA(VisionDataset): + """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + split (string): One of {'train', 'valid', 'test', 'all'}. + Accordingly dataset is selected. + target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, + or ``landmarks``. Can also be a list to output a tuple with all specified target types. + The targets represent: + + - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes + - ``identity`` (int): label for each person (data points with the same identity are the same person) + - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) + - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, + righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) + + Defaults to ``attr``. If empty, ``None`` will be returned as target. + + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + base_folder = "celeba" + file_list = [ + # File ID MD5 Hash Filename + ("1cNIac61PSA_LqDFYFUeyaQYekYPc75NH", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), + ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), + ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), + ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), + ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), + ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), + ] + + def __init__( + self, + root: str, + task: str = None, + split: str = "train", + target_type: Union[List[str], str] = "attr", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + import pandas + super(_CelebA, self).__init__(root, transform=transform, + target_transform=target_transform) + self.split = split + if isinstance(target_type, list): + self.target_type = target_type + else: + self.target_type = [target_type] + + if not self.target_type and self.target_transform is not None: + raise RuntimeError('target_transform is specified but target_type is empty') + + if download: + self.download() + + split_map = { + "train": 0, + "valid": 1, + "test": 2, + "all": None, + "hq": None, + } + split_ = split_map[split] + + if split == 'hq': + fn = partial(os.path.join, self.root) + else: + fn = partial(os.path.join, self.root, self.base_folder) + + splits = pandas.read_csv(fn("list_eval_partition.csv"), header=0, index_col=0) + attr = pandas.read_csv(fn("list_attr_celeba.csv"), header=0, index_col=0) + mask = slice(None) if split_ is None else (splits['partition'] == split_) + + if split == 'hq': + filenames = os.listdir(fn('train')) + os.listdir(fn('val')) + self.filename = [fn('train', f) for f in os.listdir(fn('train'))] + [fn('val', f) for f in os.listdir(fn('val'))] + self.attr = torch.as_tensor(attr.loc[filenames].values) + else: + self.filename = splits[mask].index.values + self.attr = torch.as_tensor(attr[mask].values) + + self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} + self.attr_names = list(attr.columns) + + self.task = task + if task: + self.task_idx = int(np.where(np.array(self.attr_names) == task)[0]) + + def download(self) -> None: + import zipfile + if not DOWNLOAD: + raise ImportError("Must install gdown.") + + if os.path.exists(os.path.join(self.root, self.base_folder, 'img_align_celeba')): + print('Files already downloaded and verified') + return + + for (file_id, md5, filename) in self.file_list: + gdown.download(f'https://drive.google.com/uc?id={file_id}', os.path.join(self.root, self.base_folder, filename), quiet=False) + + with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: + f.extractall(os.path.join(self.root, self.base_folder)) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + if self.split == 'hq': + X = PIL.Image.open(self.filename[index]) + else: + X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) + + target: Any = [] + for t in self.target_type: + if t == "attr": + target.append(self.attr[index, :]) + elif t == "identity": + target.append(self.identity[index, 0]) + elif t == "bbox": + target.append(self.bbox[index, :]) + elif t == "landmarks": + target.append(self.landmarks_align[index, :]) + else: + # TODO: refactor with utils.verify_str_arg + raise ValueError("Target type \"{}\" is not recognized.".format(t)) + + if self.transform is not None: + X = self.transform(X) + + if target: + target = tuple(target) if len(target) > 1 else target[0] + + if self.target_transform is not None: + target = self.target_transform(target) + else: + target = None + + if self.task: + return X, torch.eye(2, dtype=int)[target[self.task_idx]] + return X, target # torch.eye(2, dtype=int)[target] + + def __len__(self) -> int: + return len(self.attr) + + def extra_repr(self) -> str: + lines = ["Target type: {target_type}", "Split: {split}"] + return '\n'.join(lines).format(**self.__dict__) diff --git a/src/dataloaders/datasets/copying.py b/src/dataloaders/datasets/copying.py new file mode 100644 index 0000000..b3b7ba6 --- /dev/null +++ b/src/dataloaders/datasets/copying.py @@ -0,0 +1,90 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +def np_copying_data(L, M, A, batch_shape=()): + seq = np.random.randint(low=1, high=A-1, size=batch_shape+(M,)) + zeros_x = np.zeros(batch_shape+(L,)) + markers = (A-1) * np.ones(batch_shape+(M,)) + zeros_y = np.zeros(batch_shape+(M+L,)) + + x_ = np.concatenate([seq, zeros_x, markers], axis=-1) + y_ = np.concatenate([zeros_y, seq], axis=-1) + x = F.one_hot(torch.tensor(x_, dtype=torch.int64), A).float() + y = torch.tensor(y_, dtype=torch.int64) + return x, y + +def torch_copying_data(L, M, A, variable=False, batch_shape=(), one_hot=False): + tokens = torch.randint(low=1, high=A-1, size=batch_shape+(M,)) + if variable: + total_batch = int(np.prod(batch_shape)) + inds = torch.stack([ + torch.randperm(L+M)[:M] + for _ in range(total_batch) + ], 0) + inds = inds.reshape(batch_shape+(M,)) + inds, _ = inds.sort() + else: + inds = torch.arange(M).repeat(batch_shape+(1,)) + zeros_x = torch.zeros(batch_shape+(M+L,), dtype=torch.long) + zeros_x.scatter_(-1, inds, tokens) + markers = (A-1) * torch.ones(batch_shape+(M,), dtype=torch.long) + + x_ = torch.cat([zeros_x, markers], dim=-1) + y_ = torch.cat([tokens], dim=-1) + if one_hot: x = F.one_hot(x_, A).float() + else: x = x_ + y = y_ + return x, y + +def torch_copying_lag_data(L, M, A, batch_shape=()): + x = torch.randint(low=1, high=A-1, size=batch_shape+(L,)) + y = F.pad(x, (M, 0))[..., :L] + return x, y + +class CopyingTrainDataset(torch.utils.data.Dataset): + def __init__(self, L, M, A, samples, lag=False, variable=False, one_hot=False): + """ + L: number of noise tokens + M: number of memorization tokens + A: size of dictionary + """ + super().__init__() + self.L = L + self.M = M + self.A = A + self.samples = samples + self.variable = variable + self.one_hot = one_hot + self.lag = lag + + def __getitem__(self, idx): + assert 0 <= idx < self.samples + if self.lag: + x, y = torch_copying_lag_data(self.L, self.M, self.A) + else: + x, y = torch_copying_data(self.L, self.M, self.A, variable=self.variable, one_hot=self.one_hot) + return x, y + + def __len__(self): + return self.samples + + +class CopyingEvalDataset(torch.utils.data.TensorDataset): + def __init__(self, L, M, A, samples, lag=None, variable=False, one_hot=False): + self.L = L + self.M = M + self.A = A + self.samples = samples + if lag: + all_x, all_y = torch_copying_lag_data(self.L, self.M, self.A, batch_shape=(self.samples,)) + else: + all_x, all_y = torch_copying_data(self.L, self.M, self.A, batch_shape=(self.samples,), variable=variable, one_hot=one_hot) + super().__init__(all_x, all_y) + +def copying_static_dataset(L, M, A, variable, samples): + all_x, all_y = torch_copying_data(L, M, A, variable, batch_shape=(samples,)) + print("Constructing Copying dataset of shape", all_x.shape) + ds = torch.utils.data.TensorDataset(all_x, all_y) + return ds diff --git a/src/dataloaders/datasets/delay.py b/src/dataloaders/datasets/delay.py new file mode 100644 index 0000000..8bf64ef --- /dev/null +++ b/src/dataloaders/datasets/delay.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.dataloaders.utils.signal import whitesignal + + +class DelayTrainDataset(torch.utils.data.Dataset): + def __init__(self, samples, l_seq=1024, n_lag=1, l_lag=None, dt=1e-3, freq=1.0): + """ + """ + super().__init__() + self.L = l_seq + self.dt = dt + self.freq = freq + self.samples = samples + self.l_lag = l_lag or l_seq // n_lag + self.n_lag = n_lag + + def __getitem__(self, idx): + assert 0 <= idx < self.samples + x = torch.FloatTensor(whitesignal(self.L*self.dt, self.dt, self.freq)) # (l_seq) + y = torch.stack([ + F.pad(x[:self.L-i*self.l_lag], (i*self.l_lag, 0)) + for i in range(1, self.n_lag+1) + ], dim=-1) # (l_seq, n_lag) + x = x.unsqueeze(-1) + return x, y + + def __len__(self): + return self.samples + + +class DelayEvalDataset(torch.utils.data.TensorDataset): + def __init__(self, samples, l_seq=1024, n_lag=1, l_lag=None, dt=1e-3, freq=1.0): + self.L = l_seq + self.dt = dt + self.freq = freq + self.samples = samples + self.l_lag = l_lag or l_seq // n_lag + self.n_lag = n_lag + + X = torch.FloatTensor(whitesignal(self.L*self.dt, self.dt, self.freq, batch_shape=(self.samples,))) # (samples, l_seq, 1) + Y = torch.stack([ + F.pad(X[:, :self.L-i*self.l_lag], (i*self.l_lag, 0)) # manually subtract from self.L otherwise error in i=0 case + for i in range(1, self.n_lag+1) + ], dim=-1) # (batch, l_seq, n_lag) + X = X.unsqueeze(-1) # (batch, l_seq, 1) + + super().__init__(X, Y) diff --git a/src/dataloaders/datasets/music.py b/src/dataloaders/datasets/music.py new file mode 100644 index 0000000..4ab49cb --- /dev/null +++ b/src/dataloaders/datasets/music.py @@ -0,0 +1,245 @@ +""" +RNN Vocal Generation Model + +Blizzard, Music, and Huckleberry Finn data feeders. +""" + +import numpy as np +#import scikits.audiolab + +import random +import time +import os +import glob + +import torch +import sklearn +from scipy.io import wavfile + +def normalize01(data): + """To range [0., 1.]""" + data -= np.min(data) + data /= np.max(data) + return data + +def mu_law_encode(audio, bits=8): + """ + Perform mu-law companding transformation. + """ + mu = torch.tensor(2**bits - 1) + + # Audio must be min-max scaled between -1 and 1 + audio = 2 * minmax_scale(audio) - 1 + + # Perform mu-law companding transformation. + numerator = torch.log1p(mu * torch.abs(audio + 1e-8)) + denominator = torch.log1p(mu) + encoded = torch.sign(audio) * (numerator / denominator) + + # Quantize signal to the specified number of levels. + return ((encoded + 1) / 2 * mu + 0.5).long() + +def mu_law_decode(encoded, bits=8): + """ + Perform inverse mu-law transformation. + """ + mu = 2**bits - 1 + # Invert the quantization + x = (encoded.float() / mu) * 2 - 1 + + # Invert the mu-law transformation + x = torch.sign(x) * ((1 + mu)**(torch.abs(x)) - 1) / mu + return x + +def minmax_scale(tensor): + min_val = torch.amin(tensor, dim=(1, 2), keepdim=True) + max_val = torch.amax(tensor, dim=(1, 2), keepdim=True) + return (tensor - min_val) / (max_val - min_val + 1e-6) + +EPSILON = 1e-2 + +def linear_quantize(samples, q_levels): + samples = samples.clone() + # samples -= samples.min(dim=-2)[0].unsqueeze(1).expand_as(samples) + # samples /= samples.max(dim=-2)[0].unsqueeze(1).expand_as(samples) + samples = minmax_scale(samples) + samples *= q_levels - EPSILON + samples += EPSILON / 2 + return samples.long() + +def linear_dequantize(samples, q_levels): + return samples.float() / (q_levels / 2) - 1 + +def q_zero(q_levels): + return q_levels // 2 + +ITEM_LIST = [ + "BeethovenPianoSonataNo.1", + "BeethovenPianoSonataNo.2", + "BeethovenPianoSonataNo.3", + "BeethovenPianoSonataNo.4", + "BeethovenPianoSonataNo.5", + "BeethovenPianoSonataNo.6", + "BeethovenPianoSonataNo.7", + "BeethovenPianoSonataNo.8", + "BeethovenPianoSonataNo.9", + "BeethovenPianoSonataNo.10", + "BeethovenPianoSonataNo.11", + "BeethovenPianoSonataNo.12", + "BeethovenPianoSonata13", + "BeethovenPianoSonataNo.14moonlight", + "BeethovenPianoSonata15", + "BeethovenPianoSonata16", + "BeethovenPianoSonata17", + "BeethovenPianoSonataNo.18", + "BeethovenPianoSonataNo.19", + "BeethovenPianoSonataNo.20", + "BeethovenPianoSonataNo.21Waldstein", + "BeethovenPianoSonata22", + "BeethovenPianoSonataNo.23", + "BeethovenPianoSonataNo.24", + "BeethovenPianoSonataNo.25", + "BeethovenPianoSonataNo.26", + "BeethovenPianoSonataNo.27", + "BeethovenPianoSonataNo.28", + "BeethovenPianoSonataNo.29", + "BeethovenPianoSonataNo.30", + "BeethovenPianoSonataNo.31", + "BeethovenPianoSonataNo.32", +] + +def download_all_data(path): + print('Downloading data to ' + path) + if not os.path.exists(path): + os.system('mkdir ' + path) + for item in ITEM_LIST: + os.system("wget -r -H -nc -nH --cut-dir=1 -A .ogg -R *_vbr.mp3 -e robots=off -P " + path + " -l1 'http://archive.org/download/" + item + "'") + os.system("mv " + os.path.join(path, item, '*.ogg') + " " + path) + os.system("rm -rf " + os.path.join(path, item)) + for f in os.listdir(path): + filepath = os.path.join(path, f) + os.system("ffmpeg -y -i " + filepath + " -ar 16000 -ac 1 " + filepath[:-4] + ".wav") + os.system("rm " + filepath) + print('Data download done') + +class _Music(): + def __init__( + self, + path, + sample_len = 1, # in seconds + sample_rate = 16000, + train_percentage = 0.9, + discrete_input=False, + samplernn_proc=True, + ): + self.sample_len = sample_len + self.sample_rate = sample_rate + self.discrete_input = discrete_input + self.samplernn_proc = samplernn_proc + + self.music_data_path = os.path.join(path, 'music_data') + if not os.path.exists(self.music_data_path): + download_all_data(self.music_data_path) + + self.all_data = self.get_all_data() + self.tensor = self.build_slices(self.all_data) + self.train, self.val, self.test = self.split_data(self.tensor, train_percentage) + self.train_X, self.val_X, self.test_X, self.train_y, self.val_y, self.test_y = self.make_x_y(self.train, self.val, self.test) + + + def get_all_data(self): + from librosa.core import load + # TODO: There are going to be boundary errors here! + all_data = np.array([]) + for f in os.listdir(self.music_data_path): + # sr, data = wavfile.read(os.path.join(self.music_data_path, f)) + data, _ = load(os.path.join(self.music_data_path, f), sr=None, mono=True) + # assert(sr == self.sample_rate) + all_data = np.append(all_data, data) + + # # if not self.samplernn_proc: + # # Convert all data to range [-1, 1] + # all_data = all_data.astype('float64') + # all_data = normalize01(all_data) + # all_data = 2. * all_data - 1. + + return all_data + + def build_slices(self, data): + num_samples_per_slice = self.sample_rate * self.sample_len + + truncated_len = len(data) - len(data) % num_samples_per_slice + + return torch.tensor(data[:truncated_len].reshape(-1, num_samples_per_slice), dtype=torch.float32) + + # tensor = torch.zeros([len(data) // num_samples_per_slice, num_samples_per_slice], dtype=torch.float32) + # for i in range(len(data) // num_samples_per_slice): + # tensor[i] = torch.tensor(data[i * num_samples_per_slice : (i + 1) * num_samples_per_slice]) + # return tensor + + def split_data(self, tensor, train_percentage): + train, test = sklearn.model_selection.train_test_split( + tensor, + train_size=train_percentage, + random_state=0, + shuffle=True + ) + val, test = sklearn.model_selection.train_test_split( + test, + train_size=0.5, + random_state=0, + shuffle=True + ) + train = torch.swapaxes(train.unsqueeze(1).squeeze(-1), 1, 2) + val = torch.swapaxes(val.unsqueeze(1).squeeze(-1), 1, 2) + test = torch.swapaxes(test.unsqueeze(1).squeeze(-1), 1, 2) + return train, val, test + + def make_x_y(self, train, val, test): + + if not self.samplernn_proc: + train_y, val_y, test_y = mu_law_encode(train), mu_law_encode(val), mu_law_encode(test) + if not self.discrete_input: + train_X, val_X, test_X = torch.roll(mu_law_decode(train_y), 1, 1), torch.roll(mu_law_decode(val_y), 1, 1), torch.roll(mu_law_decode(test_y), 1, 1) + train_X[:, 0, :], val_X[:, 0, :], test_X[:, 0, :] = 0, 0, 0 + else: + train_X, val_X, test_X = torch.roll(train_y, 1, 1), torch.roll(val_y, 1, 1), torch.roll(test_y, 1, 1) + train_X[:, 0, :], val_X[:, 0, :], test_X[:, 0, :] = 128, 128, 128 + else: + train_y, val_y, test_y = linear_quantize(train, 256), linear_quantize(val, 256), linear_quantize(test, 256) + # train_y, val_y, test_y = mu_law_encode(train), mu_law_encode(val), mu_law_encode(test) + if not self.discrete_input: + raise NotImplementedError + else: + train_X, val_X, test_X = torch.roll(train_y, 1, 1), torch.roll(val_y, 1, 1), torch.roll(test_y, 1, 1) + train_X[:, 0, :], val_X[:, 0, :], test_X[:, 0, :] = 128, 128, 128 + + return train_X, val_X, test_X, train_y, val_y, test_y + + def get_data(self, partition): + if partition == 'train': + return MusicTensorDataset(self.train_X, self.train_y) + elif partition == 'val': + return MusicTensorDataset(self.val_X, self.val_y) + elif partition == 'test': + return MusicTensorDataset(self.test_X, self.test_y) + +class MusicTensorDataset(torch.utils.data.TensorDataset): + + def __getitem__(self, index): + data = self.tensors[0][index] + target = self.tensors[1][index] + if data.dtype == torch.float32: + return data, target + else: + return data.squeeze(-1), target + # Rejection sampling to remove "bad samples" that are essentially constant audio + # if data.dtype == torch.float32: + # if torch.std(data[1:]) < 1e-5: + # return self.__getitem__(np.random.randint(0, len(self.tensors[0]))) + # return data, target + # else: + # if (data[1:] - data[1]).abs().sum() < 1e-5: + # return self.__getitem__(np.random.randint(0, len(self.tensors[0]))) + # return data.squeeze(-1), target + diff --git a/src/dataloaders/datasets/reconstruct.py b/src/dataloaders/datasets/reconstruct.py new file mode 100644 index 0000000..4b95c2a --- /dev/null +++ b/src/dataloaders/datasets/reconstruct.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from src.dataloaders.utils.signal import whitesignal + + +class ReconstructTrainDataset(torch.utils.data.Dataset): + def __init__(self, samples, l_seq=1024, l_mem=1024, dt=1e-3, freq=1.0, seed=0): + """ + """ + super().__init__() + self.L = l_seq + self.l_mem = l_mem + self.dt = dt + self.freq = freq + self.samples = samples + + def __getitem__(self, idx): + assert 0 <= idx < self.samples + x = torch.FloatTensor(whitesignal(self.L*self.dt, self.dt, self.freq)) + x = x.unsqueeze(-1) + y = x[-self.l_mem:, 0] + return x, y + + def __len__(self): + return self.samples + + +class ReconstructEvalDataset(torch.utils.data.TensorDataset): + def __init__(self, samples, l_seq=1024, l_mem=1024, dt=1e-3, freq=1.0, seed=0): + self.L = l_seq + self.l_mem = l_mem + self.dt = dt + self.freq = freq + self.samples = samples + + X = [] + X = torch.FloatTensor(whitesignal(self.L*self.dt, self.dt, self.freq, batch_shape=(self.samples,))) + X = X[..., None] + Y = X[:, -self.l_mem:, 0] + + super().__init__(X, Y) diff --git a/src/dataloaders/datasets/sc.py b/src/dataloaders/datasets/sc.py new file mode 100644 index 0000000..73a598d --- /dev/null +++ b/src/dataloaders/datasets/sc.py @@ -0,0 +1,524 @@ +""" +Adapted from https://github.com/dwromero/ckconv/blob/dc84dceb490cab2f2ddf609c380083367af21890/datasets/speech_commands.py +which is +adapted from https://github.com/patrick-kidger/NeuralCDE/blob/758d3a7134e3a691013e5cc6b7f68f277e9e6b69/experiments/datasets/speech_commands.py +""" +import os +import pathlib +import tarfile +import urllib.request + +import sklearn.model_selection +import torch +import torch.nn.functional as F +import torchaudio + + +def pad(channel, maxlen): + channel = torch.tensor(channel) + out = torch.full((maxlen,), channel[-1]) + out[: channel.size(0)] = channel + return out + + +def subsample(X, y, subsample_rate): + if subsample_rate != 1: + X = X[:, ::subsample_rate, :] + return X, y + + +def save_data(dir, **tensors): + for tensor_name, tensor_value in tensors.items(): + torch.save(tensor_value, str(dir / tensor_name) + ".pt") + + +def load_data(dir): + tensors = {} + for filename in os.listdir(dir): + if filename.endswith(".pt"): + tensor_name = filename.split(".")[0] + tensor_value = torch.load(str(dir / filename)) + tensors[tensor_name] = tensor_value + return tensors + + +def normalise_data(X, y): + train_X, _, _ = split_data(X, y) + out = [] + for Xi, train_Xi in zip(X.unbind(dim=-1), train_X.unbind(dim=-1)): + train_Xi_nonan = train_Xi.masked_select(~torch.isnan(train_Xi)) + mean = train_Xi_nonan.mean() # compute statistics using only training data. + std = train_Xi_nonan.std() + out.append((Xi - mean) / (std + 1e-5)) + out = torch.stack(out, dim=-1) + return out + +def normalize_all_data(X_train, X_val, X_test): + + for i in range(X_train.shape[-1]): + mean = X_train[:, :, i].mean() + std = X_train[:, :, i].std() + X_train[:, :, i] = (X_train[:, :, i] - mean) / (std + 1e-5) + X_val[:, :, i] = (X_val[:, :, i] - mean) / (std + 1e-5) + X_test[:, :, i] = (X_test[:, :, i] - mean) / (std + 1e-5) + + return X_train, X_val, X_test + +def minmax_scale(tensor): + min_val = torch.amin(tensor, dim=(1, 2), keepdim=True) + max_val = torch.amax(tensor, dim=(1, 2), keepdim=True) + return (tensor - min_val) / (max_val - min_val) + +def mu_law_encode(audio, bits=8): + """ + Perform mu-law companding transformation. + """ + mu = torch.tensor(2**bits - 1) + + # Audio must be min-max scaled between -1 and 1 + audio = 2 * minmax_scale(audio) - 1 + + # Perform mu-law companding transformation. + numerator = torch.log1p(mu * torch.abs(audio)) + denominator = torch.log1p(mu) + encoded = torch.sign(audio) * (numerator / denominator) + + # Quantize signal to the specified number of levels. + return ((encoded + 1) / 2 * mu + 0.5).to(torch.int32) + +def mu_law_decode(encoded, bits=8): + """ + Perform inverse mu-law transformation. + """ + mu = 2**bits - 1 + # Invert the quantization + x = (encoded / mu) * 2 - 1 + + # Invert the mu-law transformation + x = torch.sign(x) * ((1 + mu)**(torch.abs(x)) - 1) / mu + return x + +def split_data(tensor, stratify): + # 0.7/0.15/0.15 train/val/test split + ( + train_tensor, + testval_tensor, + train_stratify, + testval_stratify, + ) = sklearn.model_selection.train_test_split( + tensor, + stratify, + train_size=0.7, + random_state=0, + shuffle=True, + stratify=stratify, + ) + + val_tensor, test_tensor = sklearn.model_selection.train_test_split( + testval_tensor, + train_size=0.5, + random_state=1, + shuffle=True, + stratify=testval_stratify, + ) + return train_tensor, val_tensor, test_tensor + + +class _SpeechCommands(torch.utils.data.TensorDataset): + + SUBSET_CLASSES = [ + "yes", + "no", + "up", + "down", + "left", + "right", + "on", + "off", + "stop", + "go", + ] + ALL_CLASSES = [ + "bed", + "cat", + "down", + "five", + "forward", + "go", + "house", + "left", + "marvin", + "no", + "on", + "right", + "sheila", + "tree", + "up", + "visual", + "yes", + "backward", + "bird", + "dog", + "eight", + "follow", + "four", + "happy", + "learn", + "nine", + "off", + "one", + "seven", + "six", + "stop", + "three", + "two", + "wow", + "zero", + ] + + def __init__( + self, + partition: str, # `train`, `val`, `test` + length: int, # sequence length + mfcc: bool, # whether to use MFCC features (`True`) or raw features + sr: int, # subsampling rate: default should be 1 (no subsampling); keeps every kth sample + dropped_rate: float, # rate at which samples are dropped, lies in [0, 100.] + path: str, + all_classes: bool = False, + gen: bool = False, # whether we are doing speech generation + discrete_input: bool = False, # whether we are using discrete inputs + ): + self.dropped_rate = dropped_rate + self.all_classes = all_classes + self.gen = gen + self.discrete_input = discrete_input + + self.root = pathlib.Path(path) # pathlib.Path("./data") + base_loc = self.root / "SpeechCommands" / "processed_data" + + + if mfcc: + data_loc = base_loc / "mfcc" + elif gen: + data_loc = base_loc / "gen" + else: + data_loc = base_loc / "raw" + + if self.dropped_rate != 0: + data_loc = pathlib.Path( + str(data_loc) + "_dropped{}".format(self.dropped_rate) + ) + + if self.all_classes: + data_loc = pathlib.Path(str(data_loc) + "_all_classes") + + if self.discrete_input: + data_loc = pathlib.Path(str(data_loc) + "_discrete") + + if os.path.exists(data_loc): + pass + else: + self.download() + if not self.all_classes: + train_X, val_X, test_X, train_y, val_y, test_y = self._process_data(mfcc) + else: + train_X, val_X, test_X, train_y, val_y, test_y = self._process_all(mfcc) + + if not os.path.exists(base_loc): + os.mkdir(base_loc) + if not os.path.exists(data_loc): + os.mkdir(data_loc) + save_data( + data_loc, + train_X=train_X, + val_X=val_X, + test_X=test_X, + train_y=train_y, + val_y=val_y, + test_y=test_y, + ) + + X, y = self.load_data(data_loc, partition) # (batch, length, 1) + if self.gen: y = y.transpose(1, 2) + + if not mfcc and not self.gen: + X = F.pad(X, (0, 0, 0, length-16000)) + + # Subsample + if not mfcc: + X, y = subsample(X, y, sr) + + if self.discrete_input: + X = X.long().squeeze() + + super(_SpeechCommands, self).__init__(X, y) + + def download(self): + root = self.root + base_loc = root / "SpeechCommands" + loc = base_loc / "speech_commands.tar.gz" + if os.path.exists(loc): + return + if not os.path.exists(root): + os.mkdir(root) + if not os.path.exists(base_loc): + os.mkdir(base_loc) + urllib.request.urlretrieve( + "http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz", loc + ) # TODO: Add progress bar + with tarfile.open(loc, "r") as f: + f.extractall(base_loc) + + def _process_all(self, mfcc): + assert self.dropped_rate == 0, "Dropped rate must be 0 for all classes" + base_loc = self.root / "SpeechCommands" + + with open(base_loc / "validation_list.txt", "r") as f: + validation_list = set([line.rstrip() for line in f]) + + with open(base_loc / "testing_list.txt", "r") as f: + testing_list = set([line.rstrip() for line in f]) + + train_X, val_X, test_X = [], [], [] + train_y, val_y, test_y = [], [], [] + + batch_index = 0 + y_index = 0 + for foldername in self.ALL_CLASSES: + print(foldername) + loc = base_loc / foldername + for filename in os.listdir(loc): + audio, _ = torchaudio.load( + loc / filename, channels_first=False, + ) + audio = ( + audio / 2 ** 15 + ) + # Pad: A few samples are shorter than the full length + audio = F.pad(audio, (0, 0, 0, 16000 - audio.shape[0])) + + if str(foldername + '/' + filename) in validation_list: + val_X.append(audio) + val_y.append(y_index) + elif str(foldername + '/' + filename) in testing_list: + test_X.append(audio) + test_y.append(y_index) + else: + train_X.append(audio) + train_y.append(y_index) + + batch_index += 1 + y_index += 1 + # print("Full data: {} samples".format(len(X))) + train_X = torch.stack(train_X) + val_X = torch.stack(val_X) + test_X = torch.stack(test_X) + train_y = torch.tensor(train_y, dtype=torch.long) + val_y = torch.tensor(val_y, dtype=torch.long) + test_y = torch.tensor(test_y, dtype=torch.long) + + # If MFCC, then we compute these coefficients. + if mfcc: + train_X = torchaudio.transforms.MFCC( + log_mels=True, n_mfcc=20, melkwargs=dict(n_fft=200, n_mels=64) + )(train_X.squeeze(-1)).detach() + + val_X = torchaudio.transforms.MFCC( + log_mels=True, n_mfcc=20, melkwargs=dict(n_fft=200, n_mels=64) + )(val_X.squeeze(-1)).detach() + + test_X = torchaudio.transforms.MFCC( + log_mels=True, n_mfcc=20, melkwargs=dict(n_fft=200, n_mels=64) + )(test_X.squeeze(-1)).detach() + # X is of shape (batch, channels=20, length=161) + else: + train_X = train_X.unsqueeze(1).squeeze(-1) + val_X = val_X.unsqueeze(1).squeeze(-1) + test_X = test_X.unsqueeze(1).squeeze(-1) + # X is of shape (batch, channels=1, length=16000) + + # Normalize data + if mfcc: + train_X, val_X, test_X = normalize_all_data(train_X.transpose(1, 2), val_X.transpose(1, 2), test_X.transpose(1, 2)) + train_X = train_X.transpose(1, 2) + val_X = val_X.transpose(1, 2) + test_X = test_X.transpose(1, 2) + else: + train_X, val_X, test_X = normalize_all_data(train_X, val_X, test_X) + + # Print the shape of all tensors in one line + print( + "Train: {}, Val: {}, Test: {}".format( + train_X.shape, val_X.shape, test_X.shape + ) + ) + + return ( + train_X, + val_X, + test_X, + train_y, + val_y, + test_y, + ) + + + def _process_data(self, mfcc): + base_loc = self.root / "SpeechCommands" + if self.gen: + X = torch.empty(35628, 16000, 1) + y = torch.empty(35628, dtype=torch.long) + else: + X = torch.empty(34975, 16000, 1) + y = torch.empty(34975, dtype=torch.long) + + batch_index = 0 + y_index = 0 + for foldername in self.SUBSET_CLASSES: + loc = base_loc / foldername + for filename in os.listdir(loc): + audio, _ = torchaudio.load( + loc / filename, channels_first=False, + ) + # audio, _ = torchaudio.load_wav( + # loc / filename, channels_first=False, normalization=False + # ) # for forward compatbility if they fix it + audio = ( + audio / 2 ** 15 + ) # Normalization argument doesn't seem to work so we do it manually. + + # A few samples are shorter than the full length; for simplicity we discard them. + if len(audio) != 16000: + continue + + X[batch_index] = audio + y[batch_index] = y_index + batch_index += 1 + y_index += 1 + if self.gen: + assert batch_index == 35628, "batch_index is {}".format(batch_index) + else: + assert batch_index == 34975, "batch_index is {}".format(batch_index) + + # If MFCC, then we compute these coefficients. + if mfcc: + X = torchaudio.transforms.MFCC( + log_mels=True, n_mfcc=20, melkwargs=dict(n_fft=200, n_mels=64) + )(X.squeeze(-1)).detach() + # X is of shape (batch=34975, channels=20, length=161) + else: + X = X.unsqueeze(1).squeeze(-1) + # X is of shape (batch=34975, channels=1, length=16000) + + # If dropped is different than zero, randomly drop that quantity of data from the dataset. + if self.dropped_rate != 0: + generator = torch.Generator().manual_seed(56789) + X_removed = [] + for Xi in X: + removed_points = ( + torch.randperm(X.shape[-1], generator=generator)[ + : int(X.shape[-1] * float(self.dropped_rate) / 100.0) + ] + .sort() + .values + ) + Xi_removed = Xi.clone() + Xi_removed[:, removed_points] = float("nan") + X_removed.append(Xi_removed) + X = torch.stack(X_removed, dim=0) + + # Normalize data + if mfcc: + X = normalise_data(X.transpose(1, 2), y).transpose(1, 2) + else: + X = normalise_data(X, y) + + # Once the data is normalized append times and mask values if required. + if self.dropped_rate != 0: + # Get mask of possitions that are deleted + mask_exists = (~torch.isnan(X[:, :1, :])).float() + X = torch.where(~torch.isnan(X), X, torch.Tensor([0.0])) + X = torch.cat([X, mask_exists], dim=1) + + train_X, val_X, test_X = split_data(X, y) + train_y, val_y, test_y = split_data(y, y) + + if self.gen: + train_y, val_y, test_y = train_X, val_X, test_X + train_y, val_y, test_y = mu_law_encode(train_y), mu_law_encode(val_y), mu_law_encode(test_y) + # train_X, val_X, test_X = train_X[..., :-1], val_X[..., :-1], test_X[..., :-1] + # # Prepend zero to train_X, val_X, test_X + # train_X = torch.cat([torch.zeros(train_X.shape[0], 1, train_X.shape[2]), train_X], dim=1) + + # train_X, val_X, test_X = torch.roll(train_X, 1, 2), torch.roll(val_X, 1, 2), torch.roll(test_X, 1, 2) + if not self.discrete_input: + train_X, val_X, test_X = torch.roll(mu_law_decode(train_y), 1, 2), torch.roll(mu_law_decode(val_y), 1, 2), torch.roll(mu_law_decode(test_y), 1, 2) + else: + train_X, val_X, test_X = torch.roll(train_y, 1, 2), torch.roll(val_y, 1, 2), torch.roll(test_y, 1, 2) + train_X[..., 0], val_X[..., 0], test_X[..., 0] = 0, 0, 0 + + assert(train_y.shape == train_X.shape) + + return ( + train_X, + val_X, + test_X, + train_y, + val_y, + test_y, + ) + + @staticmethod + def load_data(data_loc, partition): + + tensors = load_data(data_loc) + if partition == "train": + X = tensors["train_X"] + y = tensors["train_y"] + elif partition == "val": + X = tensors["val_X"] + y = tensors["val_y"] + elif partition == "test": + X = tensors["test_X"] + y = tensors["test_y"] + else: + raise NotImplementedError("the set {} is not implemented.".format(set)) + + return X.transpose(1, 2), y + +class _SpeechCommandsGeneration(_SpeechCommands): + SUBSET_CLASSES = [ + "zero", + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + ] + + def __init__( + self, + partition: str, # `train`, `val`, `test` + length: int, # sequence length + mfcc: bool, # whether to use MFCC features (`True`) or raw features + sr: int, # subsampling rate: default should be 1 (no subsampling); keeps every kth sample + dropped_rate: float, # rate at which samples are dropped, lies in [0, 100.] + path: str, + all_classes: bool = False, + discrete_input: bool = False, + ): + super(_SpeechCommandsGeneration, self).__init__( + partition = partition, + length = length, + mfcc = mfcc, + sr = sr, + dropped_rate = dropped_rate, + path = path, + all_classes = all_classes, + gen = True, + discrete_input = discrete_input, + ) diff --git a/src/dataloaders/et.py b/src/dataloaders/et.py index 6db2e01..47444eb 100644 --- a/src/dataloaders/et.py +++ b/src/dataloaders/et.py @@ -17,7 +17,7 @@ import warnings warnings.filterwarnings("ignore") -from src.dataloaders.datasets import SequenceDataset, default_data_path +from src.dataloaders.base import SequenceDataset, default_data_path class TimeFeature: @@ -266,6 +266,7 @@ def __init__( self.cols = cols self.eval_stamp = eval_stamp self.eval_mask = eval_mask + self.forecast_horizon = self.pred_len self.root_path = root_path self.data_path = data_path @@ -485,13 +486,7 @@ def l_output(self): def _get_data_filename(self, variant): return self.variants[variant] - @staticmethod - def collate_fn(batch, resolution): - x, y, *z = zip(*batch) - x = torch.stack(x, dim=0)[:, ::resolution] - y = torch.stack(y, dim=0) - z = [torch.stack(e, dim=0)[:, ::resolution] for e in z] - return x, y, *z + _collate_arg_names = ["mark", "mask"] # Names of the two extra tensors that the InformerDataset returns def setup(self): self.data_dir = self.data_dir or default_data_path / 'informer' / self._name_ diff --git a/src/dataloaders/lm.py b/src/dataloaders/lm.py index c8198f4..38c55da 100644 --- a/src/dataloaders/lm.py +++ b/src/dataloaders/lm.py @@ -34,12 +34,10 @@ log = src.utils.train.get_logger(__name__) -from src.dataloaders.datasets import SequenceDataset, default_data_path -from src.dataloaders.vocabulary import OpenAIVocab, Vocab +from src.dataloaders.base import SequenceDataset, default_data_path +from src.dataloaders.utils.vocabulary import OpenAIVocab, Vocab import src.utils as utils -# from tasks.legacy.tasks import LMPerplexity, LMBPC -# TODO: create a package so we don't have to mess with sys.path? project_root = Path(__file__).parent.parent.absolute() data_path = Path(__file__).absolute().parent / 'data' @@ -54,10 +52,6 @@ def __init__( batch_size, l_max, batch_first=True, - # device="cpu", - # mem_len=None, - # ext_len=None, - # warmup=True, n_context=1, n_epoch_double=0, pad_last=False, @@ -73,16 +67,11 @@ def __init__( self.batch_size = batch_size self.l_max = l_max self.batch_first = batch_first - # self.ext_len = ext_len if ext_len is not None else 0 - # self.mem_len = mem_len - # self.warmup = warmup self.pad_last = pad_last self.roll_seed = roll_seed self.n_context = n_context self.n_epoch_double = n_epoch_double - # self.device = device - # self.last_iter = None # AG: this isn't in original repo and doesn't appear to be used self.epoch = -1 # DDP @@ -127,18 +116,15 @@ def roll(self, seed): row = torch.cat((row[shift:], row[:shift])) self.data[i, :] = row - def get_batch(self, i, l_max=None): + def get_batch(self, i): """ Get batch starting at token index i """ - # if l_max is None: l_max = self.l_max - # seq_len = min(l_max, self.data.size(0) - 1 - i) end_idx = min(i + self.l_inc, self.data.size(-1)-1) - # beg_idx = max(0, i - self.ext_len) - beg_idx = max(0, end_idx - self.l_max) + beg_idx = max(0, i + self.l_inc - self.l_max) seq_len = end_idx - i - data = self.data[..., beg_idx:end_idx] # .to(self.device, non_blocking=True) - target = self.data[..., i+1 : end_idx+1] # .to( self.device, non_blocking=True) + data = self.data[..., beg_idx:end_idx] + target = self.data[..., i+1 : end_idx+1] if self.pad_last and seq_len < self.l_inc: data = F.pad(data, (0, self.l_inc - seq_len)) # (batch_size, l_inc) @@ -149,19 +135,16 @@ def get_batch(self, i, l_max=None): data = data.transpose(0, 1).contiguous() # (n_batch, l_sequence) target = target.transpose(0, 1).contiguous() - # [21-09-19] Unsqueeze the last dimension so that shape is always (n_batch, l_seq, d_input) - data = data - target = target - return data, target, seq_len + return data, target, {"l_output": seq_len} # Return length of desired output - def get_fixlen_iter(self, start=0): # AG: Don't see start ever used? + def get_fixlen_iter(self, start=0): if start != 0: start += self.l_max for i in range(start, self.data.size(-1) - 1, self.l_inc): self.last_iter = i yield self.get_batch(i) - def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): + def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): # NOTE: NOT TESTED l_max = self.l_max + max_deviation * std i = start while True: @@ -191,6 +174,7 @@ def __len__(self): class LMShuffledIterator(object): + # NOTE: Not tested def __init__( self, data, batch_size, l_max, device="cpu", ext_len=None, shuffle=False ): @@ -279,6 +263,7 @@ def __iter__(self): class LMMultiFileIterator(LMShuffledIterator): + # NOTE: Not tested def __init__( self, paths, @@ -319,7 +304,6 @@ def __iter__(self): yield batch -# class WikiText2(LightningDataModule): class WikiText2(SequenceDataset): _name_ = "wt2" @@ -327,88 +311,18 @@ class WikiText2(SequenceDataset): vocab_kwargs = {"special": [""], "lower_case": False} encode_kwargs = {"ordered": True} - # Embedding arguments (adaptive softmax / word embeddings) - # default_task = { - # 'adaptive': False, - # 'div_val': 1, - # 'cutoffs': [], - # 'tie_weights': False, - # 'tie_projs': [False], - # } - @property - def default_task(self): - return { - '_target_': 'tasks.tasks.LMTask', - 'tied': False, - 'rescale': True, - # init_cfg, - 'metrics': ['ppl'], - 'init_cfg': { - 'init': 'normal', # Parameter initializer to use - 'init_range': 0.01, # Parameters initialized by U(-init_range, init_range) - 'init_std': 0.02, # Parameters initialized by N(0, init_std) - 'proj_init_std': 0.01, # Separate std for projection params - } - } - - # Task class / constructor - # task_cls = LMPerplexity - - # @property - # def l_output(self): - # return self.l_max - init_defaults = { # Dataset arguments 'l_max': 512, 'bpe': False, 'roll_seed': 42, 'test_split': True, - # Task / Embedding arguments - # 'task': None, } @property def n_tokens(self): return len(self.vocab) - # def __init__( - # self, - # data_dir, - # d_embed, init_cfg, task=None, # Task / Embedding arguments - # bpe=False, - # l_max=None, - # pad_last=False, - # roll_seed=42, - # eval={ - # 'l_max': None, - # 'pad_last': False, - # 'roll_seed': None, - # }, - # **kwargs, - # # TODO kwargs is here to absorb things like 'num_workers' and 'pin_memory' which should really be part of every dataset - # ): - # super().__init__() - # if data_dir is None: self.data_dir = Path(data_dir) / self._name_ - # # self.d_embed = d_embed - # # self.init_cfg = init_cfg - # if bpe: - # self.vocab = OpenAIVocab() - # else: - # self.vocab = Vocab(**self.vocab_kwargs) - - # # Loader arguments - # assert l_max is not None - # self.l_max = l_max - # self.pad_last = pad_last - # self.roll_seed = roll_seed - - # self.eval = DictConfig(eval) - # if self.eval.l_max is None: self.eval.l_max = self.l_max - - # if task is not None: - # self.task.update(task) - def prepare_data(self): # [21-09-23] probably broken if not self.data_dir.exists(): @@ -423,8 +337,6 @@ def prepare_data(self): def setup(self, stage=None): # [21-09-10 AG]: TODO shouldn't this tokenization happen in the prepare_data? since we're caching it it doesn't really matter, but still if self.data_dir is None: self.data_dir = default_data_path / self._name_ - # self.d_embed = d_embed - # self.init_cfg = init_cfg if self.bpe: self.vocab = OpenAIVocab() else: @@ -452,8 +364,6 @@ def setup(self, stage=None): # [21-09-10 AG]: TODO shouldn't this tokenization h # Define task print("Vocab size:", len(self.vocab)) - # self.task = self.task_cls(len(self.vocab), self.d_embed, init_cfg=self.init_cfg, **self.task_args) - # self.d_input = self.d_output = self.d_embed def _vocab_count(self): self.vocab.count_file(self.data_dir / "train.txt") @@ -511,15 +421,6 @@ def _eval_dataloader(self, dataset, eval=None, **loader_args): def val_dataloader(self, **kwargs): return self._eval_dataloader(self.valid, **kwargs) - # return LMOrderedIterator( - # self.valid, - # batch_size, - # **self.eval, - # ) - # for k in train_args: - # if eval_args.get(k, None) is None: - # eval_args[k] = v - # return LMOrderedIterator(self.valid, **eval_args) def test_dataloader(self, **kwargs): return self._eval_dataloader(self.test, **kwargs) @@ -528,23 +429,6 @@ def test_dataloader(self, **kwargs): class WikiText103(WikiText2): _name_ = "wt103" - @property - def default_task(self): - return { - # 'adaptive': True, - '_target_': 'tasks.tasks.AdaptiveLMTask', - 'div_val': 1, - 'cutoffs': [19997, 39997, 199997], - 'tie_weights': True, - 'tie_projs': [False] + [True, True, True], # * len(cutoffs), - 'init_cfg': { - 'init': 'normal', # Parameter initializer to use - 'init_range': 0.01, # Parameters initialized by U(-init_range, init_range) - 'init_std': 0.02, # Parameters initialized by N(0, init_std) - 'proj_init_std': 0.01, # Separate std for projection params - } - } - def _vocab_count(self): print(self.data_dir) self.vocab.count_file(self.data_dir / "train.txt") @@ -555,35 +439,16 @@ class PennTreeBank(WikiText2): _name_ = "ptb" vocab_kwargs = {"special": [""], "lower_case": True} - # task_cls = LMBPC - class EnWik8(WikiText2): _name_ = "enwik8" vocab_kwargs = {} encode_kwargs = {"ordered": True, "add_eos": False} - # task_cls = LMBPC - @property - def default_task(self): - return { - '_target_': 'tasks.tasks.LMTask', - 'tied': False, - 'rescale': True, - # init_cfg, - 'metrics': ['ppl'], - # 'init_cfg': { - # 'init': 'normal', # Parameter initializer to use - # 'init_range': 0.01, # Parameters initialized by U(-init_range, init_range) - # 'init_std': 0.02, # Parameters initialized by N(0, init_std) - # 'proj_init_std': 0.01, # Separate std for projection params - # } - } class Text8(EnWik8): _name_ = "text8" - # task_cls = LMBPC class LM1B(WikiText2): @@ -640,158 +505,3 @@ def val_dataloader(self, *args, **kwargs): def test_dataloader(self, *args, **kwargs): return LMShuffledIterator(self.test, *args, **kwargs) - - - -class Corpus(object): - # AG: only used in get_lm_corpus which is only called in the unit test - def __init__(self, path, dataset, vocab, *args, **kwargs): - self.dataset = dataset - if vocab == "word": - self.vocab = Vocab(*args, **kwargs) - elif vocab == "bpe": - self.vocab = OpenAIVocab() - else: - raise RuntimeError("Unsupported vocab") - - if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: - self.vocab.count_file(os.path.join(path, "train.txt")) - self.vocab.count_file(os.path.join(path, "valid.txt")) - self.vocab.count_file(os.path.join(path, "test.txt")) - elif self.dataset == "wt103": - self.vocab.count_file(os.path.join(path, "train.txt")) - elif self.dataset == "lm1b": - train_path_pattern = os.path.join( - path, - "1-billion-word-language-modeling-benchmark-r13output", - "training-monolingual.tokenized.shuffled", - "news.en-*", - ) - train_paths = glob.glob(train_path_pattern) - # the vocab will load from file when build_vocab() is called - - self.vocab.build_vocab() - - if self.dataset in ["ptb", "wt2", "wt103"]: - self.train = self.vocab.encode_file( - os.path.join(path, "train.txt"), ordered=True - ) - self.valid = self.vocab.encode_file( - os.path.join(path, "valid.txt"), ordered=True - ) - self.test = self.vocab.encode_file( - os.path.join(path, "test.txt"), ordered=True - ) - elif self.dataset in ["enwik8", "text8"]: - self.train = self.vocab.encode_file( - os.path.join(path, "train.txt"), ordered=True, add_eos=False - ) - self.valid = self.vocab.encode_file( - os.path.join(path, "valid.txt"), ordered=True, add_eos=False - ) - self.test = self.vocab.encode_file( - os.path.join(path, "test.txt"), ordered=True, add_eos=False - ) - elif self.dataset == "lm1b": - self.train = train_paths - self.valid = self.vocab.encode_file( - os.path.join(path, "valid.txt"), - ordered=False, - add_double_eos=True, - ) - self.test = self.vocab.encode_file( - os.path.join(path, "test.txt"), - ordered=False, - add_double_eos=True, - ) - - def get_iterator(self, split, *args, **kwargs): - if split == "train": - if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: - data_iter = LMOrderedIterator(self.train, *args, **kwargs) - elif self.dataset == "lm1b": - kwargs["shuffle"] = True - data_iter = LMMultiFileIterator( - self.train, self.vocab, *args, **kwargs - ) - elif split in ["valid", "test"]: - data = self.valid if split == "valid" else self.test - if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: - data_iter = LMOrderedIterator(data, *args, **kwargs) - elif self.dataset == "lm1b": - data_iter = LMShuffledIterator(data, *args, **kwargs) - - return data_iter - -def get_lm_corpus(data_dir, name, vocab): - if vocab == "word": - fn = os.path.join(data_dir, "cache.pt") - elif vocab == "bpe": - fn = os.path.join(data_dir, "cache.pt.bpe") - else: - raise RuntimeError("Unsupported vocab") - - if os.path.exists(fn): - logging.info("Loading cached dataset...") - corpus = torch.load(fn) - else: - logging.info("Producing dataset {}...".format(name)) - kwargs = {} - if name in ["wt103", "wt2"]: - kwargs["special"] = [""] - kwargs["lower_case"] = False - elif name == "ptb": - kwargs["special"] = [""] - kwargs["lower_case"] = True - elif name == "lm1b": - kwargs["special"] = [] - kwargs["lower_case"] = False - kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt") - elif name in ["enwik8", "text8"]: - pass - - corpus = Corpus(data_dir, name, vocab, **kwargs) - # with distributed.sync_workers() as rank: - # if rank == 0: - # torch.save(corpus, fn) - - return corpus - - -def tokenize_raw(text, lang="en"): - # AG: Not used? - import sacremoses - - mt = sacremoses.MosesTokenizer(lang) - text = mt.tokenize(text, return_str=True) - text = re.sub(r""", '"', text) - text = re.sub(r"'", "'", text) - text = re.sub(r"(\d)\.(\d)", r"\1 @.@ \2", text) - text = re.sub(r"(\d),(\d)", r"\1 @,@ \2", text) - text = re.sub(r"(\w)-(\w)", r"\1 @-@ \2", text) - return text - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="unit test") - parser.add_argument( - "--datadir", - type=str, - default="../data/text8", - help="location of the data corpus", - ) - parser.add_argument( - "--dataset", - type=str, - default="text8", - choices=["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"], - help="dataset name", - ) - args = parser.parse_args() - - logging.basicConfig(level=logging.INFO) - - corpus = get_lm_corpus(args.datadir, args.dataset, vocab="word") - logging.info("Vocab size : {}".format(len(corpus.vocab.idx2sym))) diff --git a/src/dataloaders/lra.py b/src/dataloaders/lra.py new file mode 100644 index 0000000..58accdb --- /dev/null +++ b/src/dataloaders/lra.py @@ -0,0 +1,687 @@ +"""Long Range Arena datasets""" +import io +import logging +import os +import pickle +from pathlib import Path + +import torch +from torch import nn +import torch.nn.functional as F +import torchtext +import torchvision +from einops.layers.torch import Rearrange, Reduce +from PIL import Image # Only used for Pathfinder +from datasets import DatasetDict, Value, load_dataset + +from src.dataloaders.base import default_data_path, SequenceDataset, ImageResolutionSequenceDataset + + +class IMDB(SequenceDataset): + _name_ = "imdb" + d_output = 2 + l_output = 0 + + @property + def init_defaults(self): + return { + "l_max": 4096, + "level": "char", + "min_freq": 15, + "seed": 42, + "val_split": 0.0, + "append_bos": False, + "append_eos": True, + # 'max_vocab': 135, + "n_workers": 4, # Only used for tokenizing dataset before caching + } + + @property + def n_tokens(self): + return len(self.vocab) + + def prepare_data(self): + if self.cache_dir is None: # Just download the dataset + load_dataset(self._name_, cache_dir=self.data_dir) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + """If cache_dir is not None, we'll cache the processed dataset there.""" + self.data_dir = self.data_dir or default_data_path / self._name_ + self.cache_dir = self.data_dir / "cache" + assert self.level in [ + "word", + "char", + ], f"level {self.level} not supported" + + if stage == "test" and hasattr(self, "dataset_test"): + return + dataset, self.tokenizer, self.vocab = self.process_dataset() + print( + f"IMDB {self.level} level | min_freq {self.min_freq} | vocab size {len(self.vocab)}" + ) + dataset.set_format(type="torch", columns=["input_ids", "label"]) + + # Create all splits + dataset_train, self.dataset_test = dataset["train"], dataset["test"] + if self.val_split == 0.0: + # Use test set as val set, as done in the LRA paper + self.dataset_train, self.dataset_val = dataset_train, None + else: + train_val = dataset_train.train_test_split( + test_size=self.val_split, seed=self.seed + ) + self.dataset_train, self.dataset_val = ( + train_val["train"], + train_val["test"], + ) + + def _collate_fn(self, batch): + xs, ys = zip(*[(data["input_ids"], data["label"]) for data in batch]) + lengths = torch.tensor([len(x) for x in xs]) + xs = nn.utils.rnn.pad_sequence( + xs, padding_value=self.vocab[""], batch_first=True + ) + ys = torch.tensor(ys) + return xs, ys, {"lengths": lengths} + + # self._collate_fn = collate_batch + + def process_dataset(self): + cache_dir = ( + None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + ) + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + dataset = load_dataset(self._name_, cache_dir=self.data_dir) + dataset = DatasetDict(train=dataset["train"], test=dataset["test"]) + if self.level == "word": + tokenizer = torchtext.data.utils.get_tokenizer( + "spacy", language="en_core_web_sm" + ) + else: # self.level == 'char' + tokenizer = list # Just convert a string to a list of chars + # Account for and tokens + l_max = self.l_max - int(self.append_bos) - int(self.append_eos) + tokenize = lambda example: {"tokens": tokenizer(example["text"])[:l_max]} + dataset = dataset.map( + tokenize, + remove_columns=["text"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + vocab = torchtext.vocab.build_vocab_from_iterator( + dataset["train"]["tokens"], + min_freq=self.min_freq, + specials=( + ["", ""] + + ([""] if self.append_bos else []) + + ([""] if self.append_eos else []) + ), + ) + vocab.set_default_index(vocab[""]) + + numericalize = lambda example: { + "input_ids": vocab( + ([""] if self.append_bos else []) + + example["tokens"] + + ([""] if self.append_eos else []) + ) + } + dataset = dataset.map( + numericalize, + remove_columns=["tokens"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + + if cache_dir is not None: + self._save_to_cache(dataset, tokenizer, vocab, cache_dir) + return dataset, tokenizer, vocab + + def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): + cache_dir = self.cache_dir / self._cache_dir_name + logger = logging.getLogger(__name__) + logger.info(f"Saving to cache at {str(cache_dir)}") + dataset.save_to_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "wb") as f: + pickle.dump(tokenizer, f) + with open(cache_dir / "vocab.pkl", "wb") as f: + pickle.dump(vocab, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger = logging.getLogger(__name__) + logger.info(f"Load from cache at {str(cache_dir)}") + dataset = DatasetDict.load_from_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + with open(cache_dir / "vocab.pkl", "rb") as f: + vocab = pickle.load(f) + return dataset, tokenizer, vocab + + @property + def _cache_dir_name(self): + return f"l_max-{self.l_max}-level-{self.level}-min_freq-{self.min_freq}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" + +class TabularDataset(torch.utils.data.Dataset): + def __init__( + self, + path, + format, + col_idx=None, + skip_header=False, + csv_reader_params=None, + ): + """ + col_idx: the indices of the columns. + """ + if csv_reader_params is None: + csv_reader_params = {} + format = format.lower() + assert format in ["tsv", "csv"] + with io.open(os.path.expanduser(path), encoding="utf8") as f: + if format == "csv": + reader = torchtext.utils.unicode_csv_reader(f, **csv_reader_params) + elif format == "tsv": + reader = torchtext.utils.unicode_csv_reader( + f, delimiter="\t", **csv_reader_params + ) + else: + reader = f + if skip_header: + next(reader) + self._data = [ + line if col_idx is None else [line[c] for c in col_idx] + for line in reader + ] + + def __len__(self): + return len(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + +# LRA tokenizer renames ']' to 'X' and delete parentheses as their tokenizer removes +# non-alphanumeric characters. +# https://github.com/google-research/long-range-arena/blob/264227cbf9591e39dd596d2dc935297a2070bdfe/lra_benchmarks/listops/input_pipeline.py#L46 +def listops_tokenizer(s): + return s.translate({ord("]"): ord("X"), ord("("): None, ord(")"): None}).split() + + +class ListOps(SequenceDataset): + _name_ = "listops" + d_output = 10 + l_output = 0 + + @property + def init_defaults(self): + return { + "l_max": 2048, + "append_bos": False, + "append_eos": True, + # 'max_vocab': 20, # Actual size 18 + "n_workers": 4, # Only used for tokenizing dataset + } + + @property + def n_tokens(self): + return len(self.vocab) + + @property + def _cache_dir_name(self): + return f"l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" + + def init(self): + if self.data_dir is None: + self.data_dir = default_data_path / self._name_ + self.cache_dir = self.data_dir / self._cache_dir_name + + def prepare_data(self): + if self.cache_dir is None: + for split in ["train", "val", "test"]: + split_path = self.data_dir / f"basic_{split}.tsv" + if not split_path.is_file(): + raise FileNotFoundError( + f""" + File {str(split_path)} not found. + To get the dataset, download lra_release.gz from + https://github.com/google-research/long-range-arena, + then unzip it with tar -xvf lra_release.gz. + Then point data_dir to the listops-1000 directory. + """ + ) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + if stage == "test" and hasattr(self, "dataset_test"): + return + dataset, self.tokenizer, self.vocab = self.process_dataset() + self.vocab_size = len(self.vocab) + dataset.set_format(type="torch", columns=["input_ids", "Target"]) + self.dataset_train, self.dataset_val, self.dataset_test = ( + dataset["train"], + dataset["val"], + dataset["test"], + ) + + def collate_batch(batch): + xs, ys = zip(*[(data["input_ids"], data["Target"]) for data in batch]) + lengths = torch.tensor([len(x) for x in xs]) + xs = nn.utils.rnn.pad_sequence( + xs, padding_value=self.vocab[""], batch_first=True + ) + ys = torch.tensor(ys) + return xs, ys, {"lengths": lengths} + + self._collate_fn = collate_batch + + def process_dataset(self): + cache_dir = ( + None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + ) + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + dataset = load_dataset( + "csv", + data_files={ + "train": str(self.data_dir / "basic_train.tsv"), + "val": str(self.data_dir / "basic_val.tsv"), + "test": str(self.data_dir / "basic_test.tsv"), + }, + delimiter="\t", + keep_in_memory=True, + ) + + tokenizer = listops_tokenizer + + # Account for and tokens + l_max = self.l_max - int(self.append_bos) - int(self.append_eos) + tokenize = lambda example: {"tokens": tokenizer(example["Source"])[:l_max]} + dataset = dataset.map( + tokenize, + remove_columns=["Source"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + vocab = torchtext.vocab.build_vocab_from_iterator( + dataset["train"]["tokens"], + specials=( + ["", ""] + + ([""] if self.append_bos else []) + + ([""] if self.append_eos else []) + ), + ) + vocab.set_default_index(vocab[""]) + + numericalize = lambda example: { + "input_ids": vocab( + ([""] if self.append_bos else []) + + example["tokens"] + + ([""] if self.append_eos else []) + ) + } + dataset = dataset.map( + numericalize, + remove_columns=["tokens"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + + if cache_dir is not None: + self._save_to_cache(dataset, tokenizer, vocab, cache_dir) + return dataset, tokenizer, vocab + + def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): + cache_dir = self.cache_dir / self._cache_dir_name + logger = logging.getLogger(__name__) + logger.info(f"Saving to cache at {str(cache_dir)}") + dataset.save_to_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "wb") as f: + pickle.dump(tokenizer, f) + with open(cache_dir / "vocab.pkl", "wb") as f: + pickle.dump(vocab, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger = logging.getLogger(__name__) + logger.info(f"Load from cache at {str(cache_dir)}") + dataset = DatasetDict.load_from_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + with open(cache_dir / "vocab.pkl", "rb") as f: + vocab = pickle.load(f) + return dataset, tokenizer, vocab + +class PathFinderDataset(torch.utils.data.Dataset): + """Path Finder dataset.""" + + # There's an empty file in the dataset + blacklist = {"pathfinder32/curv_baseline/imgs/0/sample_172.png"} + + def __init__(self, data_dir, transform=None): + """ + Args: + data_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.data_dir = Path(data_dir).expanduser() + assert self.data_dir.is_dir(), f"data_dir {str(self.data_dir)} does not exist" + self.transform = transform + samples = [] + # for diff_level in ['curv_baseline', 'curv_contour_length_9', 'curv_contour_length_14']: + for diff_level in ["curv_contour_length_14"]: + path_list = sorted( + list((self.data_dir / diff_level / "metadata").glob("*.npy")), + key=lambda path: int(path.stem), + ) + assert path_list, "No metadata found" + for metadata_file in path_list: + with open(metadata_file, "r") as f: + for metadata in f.read().splitlines(): + metadata = metadata.split() + image_path = Path(diff_level) / metadata[0] / metadata[1] + if ( + str(Path(self.data_dir.stem) / image_path) + not in self.blacklist + ): + label = int(metadata[3]) + samples.append((image_path, label)) + self.samples = samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, target = self.samples[idx] + # https://github.com/pytorch/vision/blob/9b29f3f22783112406d9c1a6db47165a297c3942/torchvision/datasets/folder.py#L247 + with open(self.data_dir / path, "rb") as f: + sample = Image.open(f).convert("L") # Open in grayscale + if self.transform is not None: + sample = self.transform(sample) + return sample, target + +class PathFinder(ImageResolutionSequenceDataset): + _name_ = "pathfinder" + d_input = 1 + d_output = 2 + l_output = 0 + + @property + def n_tokens(self): + if self.tokenize: + return 256 + + @property + def init_defaults(self): + return { + "resolution": 32, + "sequential": True, + "tokenize": False, + "pool": 1, + "val_split": 0.1, + "test_split": 0.1, + "seed": 42, # Controls the train/val/test split + } + + def default_transforms(self): + transform_list = [torchvision.transforms.ToTensor()] + if self.pool > 1: + transform_list.append( + Reduce( + "1 (h h2) (w w2) -> 1 h w", + "mean", + h2=self.pool, + w2=self.pool, + ) + ) + if self.tokenize: + transform_list.append( + torchvision.transforms.Lambda(lambda x: (x * 255).long()) + ) + else: + transform_list.append(torchvision.transforms.Normalize(mean=0.5, std=0.5)) + if self.sequential: + # If tokenize, it makes more sense to get rid of the channel dimension + transform_list.append( + Rearrange("1 h w -> (h w)") + if self.tokenize + else Rearrange("1 h w -> (h w) 1") + ) + else: + transform_list.append(Rearrange("1 h w -> h w 1")) + return torchvision.transforms.Compose(transform_list) + + def prepare_data(self): + if not self.data_dir.is_dir(): + raise FileNotFoundError( + f""" + Directory {str(self.data_dir)} not found. + To get the dataset, download lra_release.gz from + https://github.com/google-research/long-range-arena, + then unzip it with tar -xvf lra_release.gz. + Then point data_dir to the pathfinderX directory, where X is either 32, 64, 128, or 256. + """ + ) + + def setup(self, stage=None): + if self.data_dir is None: + self.data_dir = ( + default_data_path / self._name_ / f"pathfinder{self.resolution}" + ) + + if stage == "test" and hasattr(self, "dataset_test"): + return + # [2021-08-18] TD: I ran into RuntimeError: Too many open files. + # https://github.com/pytorch/pytorch/issues/11201 + torch.multiprocessing.set_sharing_strategy("file_system") + dataset = PathFinderDataset(self.data_dir, transform=self.default_transforms()) + len_dataset = len(dataset) + val_len = int(self.val_split * len_dataset) + test_len = int(self.test_split * len_dataset) + train_len = len_dataset - val_len - test_len + ( + self.dataset_train, + self.dataset_val, + self.dataset_test, + ) = torch.utils.data.random_split( + dataset, + [train_len, val_len, test_len], + generator=torch.Generator().manual_seed(self.seed), + ) + +class AAN(SequenceDataset): + _name_ = "aan" + d_output = 2 # Use accuracy instead of binary_accuracy + l_output = 0 + + @property + def n_tokens(self): + return len(self.vocab) + + @property + def init_defaults(self): + return { + "l_max": 4000, + # 'max_vocab': 100, # Full size 98 + "append_bos": False, + "append_eos": True, + "n_workers": 4, # For tokenizing only + } + + @property + def _cache_dir_name(self): + return f"l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" + + def init(self): + if self.data_dir is None: + self.data_dir = default_data_path / self._name_ + self.cache_dir = self.data_dir / self._cache_dir_name + + def prepare_data(self): + if self.cache_dir is None: + for split in ["train", "eval", "test"]: + split_path = self.data_dir / f"new_aan_pairs.{split}.tsv" + if not split_path.is_file(): + raise FileNotFoundError( + f""" + File {str(split_path)} not found. + To get the dataset, download lra_release.gz from + https://github.com/google-research/long-range-arena, + then unzip it with tar -xvf lra_release.gz. + Then point data_dir to the tsv_data directory. + """ + ) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + if stage == "test" and hasattr(self, "dataset_test"): + return + + # [2021-08-18] TD: I ran into RuntimeError: Too many open files. + # https://github.com/pytorch/pytorch/issues/11201 + torch.multiprocessing.set_sharing_strategy("file_system") + + dataset, self.tokenizer, self.vocab = self.process_dataset() + # self.vocab_size = len(self.vocab) + print("AAN vocab size:", len(self.vocab)) + + dataset.set_format(type="torch", columns=["input_ids1", "input_ids2", "label"]) + self.dataset_train, self.dataset_val, self.dataset_test = ( + dataset["train"], + dataset["val"], + dataset["test"], + ) + + def collate_batch(batch): + xs1, xs2, ys = zip( + *[ + (data["input_ids1"], data["input_ids2"], data["label"]) + for data in batch + ] + ) + lengths1 = torch.tensor([len(x) for x in xs1]) + lengths2 = torch.tensor([len(x) for x in xs2]) + xs1 = nn.utils.rnn.pad_sequence( + xs1, padding_value=self.vocab[""], batch_first=True + ) + xs2 = nn.utils.rnn.pad_sequence( + xs2, padding_value=self.vocab[""], batch_first=True + ) + # Pad both to same length + # Shape (batch, length) + L = max(xs1.size(1), xs2.size(1)) + xs1 = F.pad(xs1, (0, L-xs1.size(1)), value=self.vocab[""]) + xs2 = F.pad(xs2, (0, L-xs2.size(1)), value=self.vocab[""]) + ys = torch.tensor(ys) + # return xs1, xs2, ys, lengths1, lengths2 + + # Concatenate two batches + xs = torch.cat([xs1, xs2], dim=0) + lengths = torch.cat([lengths1, lengths2], dim=0) + return xs, ys, {"lengths": lengths} + + self._collate_fn = collate_batch + + def process_dataset(self): + cache_dir = ( + None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + ) + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + dataset = load_dataset( + "csv", + data_files={ + "train": str(self.data_dir / "new_aan_pairs.train.tsv"), + "val": str(self.data_dir / "new_aan_pairs.eval.tsv"), + "test": str(self.data_dir / "new_aan_pairs.test.tsv"), + }, + delimiter="\t", + column_names=["label", "input1_id", "input2_id", "text1", "text2"], + keep_in_memory=True, + ) # True) + dataset = dataset.remove_columns(["input1_id", "input2_id"]) + new_features = dataset["train"].features.copy() + new_features["label"] = Value("int32") + dataset = dataset.cast(new_features) + + tokenizer = list # Just convert a string to a list of chars + # Account for and tokens + l_max = self.l_max - int(self.append_bos) - int(self.append_eos) + tokenize = lambda example: { + "tokens1": tokenizer(example["text1"])[:l_max], + "tokens2": tokenizer(example["text2"])[:l_max], + } + dataset = dataset.map( + tokenize, + remove_columns=["text1", "text2"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + vocab = torchtext.vocab.build_vocab_from_iterator( + dataset["train"]["tokens1"] + dataset["train"]["tokens2"], + specials=( + ["", ""] + + ([""] if self.append_bos else []) + + ([""] if self.append_eos else []) + ), + ) + vocab.set_default_index(vocab[""]) + + encode = lambda text: vocab( + ([""] if self.append_bos else []) + + text + + ([""] if self.append_eos else []) + ) + numericalize = lambda example: { + "input_ids1": encode(example["tokens1"]), + "input_ids2": encode(example["tokens2"]), + } + dataset = dataset.map( + numericalize, + remove_columns=["tokens1", "tokens2"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + + if cache_dir is not None: + self._save_to_cache(dataset, tokenizer, vocab, cache_dir) + return dataset, tokenizer, vocab + + def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): + cache_dir = self.cache_dir / self._cache_dir_name + logger = logging.getLogger(__name__) + logger.info(f"Saving to cache at {str(cache_dir)}") + dataset.save_to_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "wb") as f: + pickle.dump(tokenizer, f) + with open(cache_dir / "vocab.pkl", "wb") as f: + pickle.dump(vocab, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger = logging.getLogger(__name__) + logger.info(f"Load from cache at {str(cache_dir)}") + dataset = DatasetDict.load_from_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + with open(cache_dir / "vocab.pkl", "rb") as f: + vocab = pickle.load(f) + return dataset, tokenizer, vocab diff --git a/src/dataloaders/prepare/bidmc/README.md b/src/dataloaders/prepare/bidmc/README.md index 4438852..48a542b 100644 --- a/src/dataloaders/prepare/bidmc/README.md +++ b/src/dataloaders/prepare/bidmc/README.md @@ -10,6 +10,7 @@ Blood Oxygen Saturation: https://zenodo.org/record/4001464/files/BIDMC32SpO2_TRAIN.ts https://zenodo.org/record/4001464/files/BIDMC32SpO2_TEST.ts -0. Working directory `datasets/healthcare` -1. Download the above datasets into `data/RR`, `data/HR`, `data/SpO2` -2. Run `python process_data.py` +1. Create folder `data/bidmc` (relative to repo base) +2. Download the above datasets into `data/bidmc/RR`, `data/bidmc/HR`, `data/bidmc/SpO2` +3. Copy processing scripts `cp src/dataloaders/prepare/bidmc/{process_data.py,data_loader.py} data/bidmc` +4. Run script `cd data/bidmc && python process_data.py` diff --git a/src/dataloaders/prepare/bidmc/process_data.py b/src/dataloaders/prepare/bidmc/process_data.py index 43b2f55..4862f29 100644 --- a/src/dataloaders/prepare/bidmc/process_data.py +++ b/src/dataloaders/prepare/bidmc/process_data.py @@ -4,11 +4,10 @@ from sklearn.model_selection import train_test_split import sktime -from sktime.utils.data_io import load_from_tsfile_to_dataframe +from sktime.datasets import load_from_tsfile_to_dataframe import data_loader as data DATA_PATH = "data/" -# DATASET = 'RR' def split_data( @@ -39,12 +38,12 @@ def _to_numpy(X): def process_data(DATASET, shuffle=True, seed=0): X_train_orig, y_train_orig = data.load_from_tsfile_to_dataframe( - os.path.join(f"{DATA_PATH}/{DATASET}/BIDMC32{DATASET}_TRAIN.ts"), + os.path.join(f"{DATASET}/BIDMC32{DATASET}_TRAIN.ts"), replace_missing_vals_with="NaN", ) X_test_orig, y_test_orig = data.load_from_tsfile_to_dataframe( - os.path.join(f"{DATA_PATH}/{DATASET}/BIDMC32{DATASET}_TEST.ts"), + os.path.join(f"{DATASET}/BIDMC32{DATASET}_TEST.ts"), replace_missing_vals_with="NaN", ) @@ -53,7 +52,7 @@ def process_data(DATASET, shuffle=True, seed=0): ) split = "reshuffle" if shuffle else "original" - data_dir = os.path.join(DATA_PATH, DATASET, split) + data_dir = os.path.join(DATASET, split) os.makedirs(data_dir, exist_ok=True) np.save(os.path.join(data_dir, "trainx.npy"), _to_numpy(X_train)) np.save(os.path.join(data_dir, "trainy.npy"), y_train) @@ -63,7 +62,7 @@ def process_data(DATASET, shuffle=True, seed=0): np.save(os.path.join(data_dir, "testy.npy"), y_test) for f in ["trainx", "trainy", "validx", "validy", "testx", "testy"]: - df = np.load(f"{DATA_PATH}/{DATASET}/{split}/{f}.npy") + df = np.load(f"{DATASET}/{split}/{f}.npy") print(f, df.shape, df.dtype) diff --git a/src/dataloaders/synthetic.py b/src/dataloaders/synthetic.py new file mode 100644 index 0000000..157b5f0 --- /dev/null +++ b/src/dataloaders/synthetic.py @@ -0,0 +1,209 @@ +"""Synthetic datasets""" + +import numpy as np +import torch +import torchvision +from einops.layers.torch import Rearrange +from src.utils import permutations + +from src.dataloaders.base import SequenceDataset + + +class Copying(SequenceDataset): + _name_ = "copying" + + @property + def init_defaults(self): + return { + "l_noise": 100, # number of padding tokens + "l_memorize": 10, # number of tokens to memorize + "n_tokens": 10, # alphabet size + "lag": False, + "variable": False, # Randomly distribute memorization tokens throughout sequence instead of frontloading them + "one_hot": False, + "static": False, # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch + "n_train": 10000, + "n_eval": 1000, + } + + @property + def d_input(self): + return self.n_tokens + + @property + def d_output(self): + return self.n_tokens + + @property + def l_output(self): + return self.l_noise if self.lag else self.l_memorize + + def setup(self): + from .datasets.copying import CopyingEvalDataset, CopyingTrainDataset + + if self.static: train_cls = CopyingEvalDataset + else: train_cls = CopyingTrainDataset + + self.dataset_train = train_cls( + self.l_noise, + self.l_memorize, + self.n_tokens, + samples=self.n_train, + lag=self.lag, + variable=self.variable, + one_hot=self.one_hot, + ) + self.dataset_val = CopyingEvalDataset( + self.l_noise, + self.l_memorize, + self.n_tokens, + samples=self.n_eval, + lag=self.lag, + variable=self.variable, + one_hot=self.one_hot, + ) + self.dataset_test = None + + + def __str__(self): + return f"{self._name_}{self.l_noise}{'v' if self.variable else ''}" + + +class Adding(SequenceDataset): + _name_ = "adding" + d_input = 2 + d_output = 1 + l_output = 0 + + @property + def init_defaults(self): + return { + "l_max": 1000, + "n_samples": 50000, + "val_split": 0.1, + } + + def setup(self): + from .datasets.adding import adding_static_dataset + + self.dataset_train = adding_static_dataset(self.l_max, self.n_samples) + self.dataset_test = None + self.split_train_val(self.val_split) + + def __str__(self): + return f"{self._name_}{self.l_max}" + + +class Reconstruct(SequenceDataset): + _name_ = "reconstruct" + + @property + def init_defaults(self): + return { + "l_seq": 1024, # length of total sequence + "l_mem": 512, # length to reconstruct + "dt": 0.001, + "freq": 1.0, + "seed": 0, + "static": False, # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch + "n_train": 10000, + "n_eval": 1000, + } + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + return self.l_mem + + @property + def l_output(self): + return 0 + + def setup(self): + from .datasets.reconstruct import ReconstructEvalDataset, ReconstructTrainDataset + + if self.static: train_cls = ReconstructEvalDataset + else: train_cls = ReconstructTrainDataset + + self.dataset_train = train_cls( + samples=self.n_train, + l_seq=self.l_seq, + l_mem=self.l_mem, + dt=self.dt, + freq=self.freq, + seed=self.seed, + ) + self.dataset_val = ReconstructEvalDataset( + samples=self.n_eval, + l_seq=self.l_seq, + l_mem=self.l_mem, + dt=self.dt, + freq=self.freq, + seed=self.seed, + ) + self.dataset_test = None + + def __str__(self): + raise NotImplementedError + + +class Delay(SequenceDataset): + _name_ = "delay" + + @property + def init_defaults(self): + return { + "l_seq": 1024, # length of total sequence + "n_lag": 1, # length to reconstruct + "l_lag": None, # length to reconstruct + "dt": 0.001, + "freq": 100.0, + "static": False, # Use a static dataset of size n_train, otherwise always use random data with n_train per epoch + "n_train": 10000, + "n_eval": 1000, + } + + @property + def d_input(self): + return 1 + + @property + def d_output(self): + # NOTE: To reproduce numbers from HTTYH paper, set this equal to 4. There was a bug in the implementation at the time + return self.n_lag + + @property + def l_output(self): + return self.l_seq + + def setup(self): + from .datasets.delay import DelayEvalDataset, DelayTrainDataset + + if self.static: train_cls = DelayEvalDataset + else: train_cls = DelayTrainDataset + + self.dataset_train = train_cls( + samples=self.n_train, + l_seq=self.l_seq, + n_lag=self.n_lag, + l_lag=self.l_lag, + dt=self.dt, + freq=self.freq, + ) + self.dataset_val = DelayEvalDataset( + samples=self.n_eval, + l_seq=self.l_seq, + n_lag=self.n_lag, + l_lag=self.l_lag, + dt=self.dt, + freq=self.freq, + ) + self.dataset_test = None + + + def __str__(self): + return f"{self._name_}{self.l_noise}{'v' if self.variable else ''}" + diff --git a/src/dataloaders/ts.py b/src/dataloaders/ts.py new file mode 100644 index 0000000..4384306 --- /dev/null +++ b/src/dataloaders/ts.py @@ -0,0 +1,66 @@ +"""Time series datasets, especially for medical time series""" + + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from src.dataloaders.base import default_data_path, SequenceDataset, deprecated + +class BIDMC(SequenceDataset): + """BIDMC datasets for Respiratory Rate / Heart Rate / Oxygen Saturation regression""" + + _name_ = "bidmc" + d_input = 2 + + @property + def d_output(self): + return 2 if self.prediction else 1 + + @property + def l_output(self): + return 4000 if self.prediction else 0 + + @property + def init_defaults(self): + return { + "target": "RR", # 'RR' | 'HR' | 'SpO2' + "prediction": False, + "reshuffle": True, + } + + def setup(self): + self.data_dir = self.data_dir or default_data_path / self._name_ + + split = "reshuffle" if self.reshuffle else "original" + # X: (dataset_size, length, d_input) + # y: (dataset_size) + X_train = np.load(self.data_dir / self.target / split / "trainx.npy") + y_train = np.load(self.data_dir / self.target / split / "trainy.npy") + X_val = np.load(self.data_dir / self.target / split / "validx.npy") + y_val = np.load(self.data_dir / self.target / split / "validy.npy") + X_test = np.load(self.data_dir / self.target / split / "testx.npy") + y_test = np.load(self.data_dir / self.target / split / "testy.npy") + + if self.prediction: + y_train = np.pad(X_train[:, 1:, :], ((0, 0), (0, 1), (0, 0))) + y_val = np.pad(X_val[:, 1:, :], ((0, 0), (0, 1), (0, 0))) + y_test = np.pad(X_test[:, 1:, :], ((0, 0), (0, 1), (0, 0))) + + self.dataset_train = torch.utils.data.TensorDataset( + torch.FloatTensor(X_train), torch.FloatTensor(y_train) + ) + + self.dataset_val = torch.utils.data.TensorDataset( + torch.FloatTensor(X_val), torch.FloatTensor(y_val) + ) + + self.dataset_test = torch.utils.data.TensorDataset( + torch.FloatTensor(X_test), torch.FloatTensor(y_test) + ) + + def __str__(self): + split = "reshuffle" if self.reshuffle else "original" + return f"BIDMC{self.target}_{split}" + diff --git a/src/dataloaders/utils/cifar_augmentations.py b/src/dataloaders/utils/cifar_augmentations.py new file mode 100644 index 0000000..3c063ed --- /dev/null +++ b/src/dataloaders/utils/cifar_augmentations.py @@ -0,0 +1,138 @@ +""" +Borrowed from https://github.com/hysts/pytorch_image_classification/tree/9ff4248905850c68aa9c09c17914307eb81769e7/pytorch_image_classification/transforms +""" +import torch +import numpy as np +import PIL +import PIL.Image +from PIL.Image import Image + + +class NpNormalize: + def __init__(self, mean: np.ndarray, std: np.ndarray): + self.mean = np.array(mean) + self.std = np.array(std) + + def __call__(self, image: PIL.Image.Image) -> np.ndarray: + image = np.asarray(image).astype(np.float32) / 255. + image = (image - self.mean) / self.std + return image + + +class Cutout(object): + """Randomly mask out one or more patches from an image. + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + """ + Args: + img (Tensor): Tensor image of size (C, H, W). + Returns: + Tensor: Image with n_holes of dimension length x length cut out of it. + """ + h = img.size(1) + w = img.size(2) + + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + + +# +# class Cutout: +# def __init__(self, p=1.0, mask_size=16, cutout_inside=False, mask_color=0): +# # https://github.com/hysts/pytorch_image_classification/blob/9ff4248905850c68aa9c09c17914307eb81769e7/configs/augmentations/cifar/cutout.yaml +# self.p = p +# self.mask_size = mask_size +# self.cutout_inside = cutout_inside +# self.mask_color = mask_color +# +# self.mask_size_half = self.mask_size // 2 +# self.offset = 1 if self.mask_size % 2 == 0 else 0 +# +# def __call__(self, image: np.ndarray) -> np.ndarray: +# image = np.asarray(image).copy() +# +# if np.random.random() > self.p: +# return image +# +# h, w = image.shape[:2] +# +# if self.cutout_inside: +# cxmin = self.mask_size_half +# cxmax = w + self.offset - self.mask_size_half +# cymin = self.mask_size_half +# cymax = h + self.offset - self.mask_size_half +# else: +# cxmin, cxmax = 0, w + self.offset +# cymin, cymax = 0, h + self.offset +# +# cx = np.random.randint(cxmin, cxmax) +# cy = np.random.randint(cymin, cymax) +# xmin = cx - self.mask_size_half +# ymin = cy - self.mask_size_half +# xmax = xmin + self.mask_size +# ymax = ymin + self.mask_size +# xmin = max(0, xmin) +# ymin = max(0, ymin) +# xmax = min(w, xmax) +# ymax = min(h, ymax) +# image[ymin:ymax, xmin:xmax] = self.mask_color +# return image + + +class RandomErasing: + def __init__(self, p=0.5, max_attempt=20, sl=0.02, sh=0.4, rl=0.3, rh=1. / 0.3): + # https://github.com/hysts/pytorch_image_classification/blob/9ff4248905850c68aa9c09c17914307eb81769e7/configs/augmentations/cifar/random_erasing.yaml + self.p = 0.5 + self.max_attempt = 20 + self.sl, self.sh = 0.02, 0.4 + self.rl = 0.3 + self.rh = 1. / 0.3 + + def __call__(self, image: np.ndarray) -> np.ndarray: + image = np.asarray(image).copy() + + if np.random.random() > self.p: + return image + + h, w = image.shape[:2] + image_area = h * w + + for _ in range(self.max_attempt): + mask_area = np.random.uniform(self.sl, self.sh) * image_area + aspect_ratio = np.random.uniform(self.rl, self.rh) + mask_h = int(np.sqrt(mask_area * aspect_ratio)) + mask_w = int(np.sqrt(mask_area / aspect_ratio)) + + if mask_w < w and mask_h < h: + x0 = np.random.randint(0, w - mask_w) + y0 = np.random.randint(0, h - mask_h) + x1 = x0 + mask_w + y1 = y0 + mask_h + image[y0:y1, x0:x1] = np.random.uniform(0, 1) + break + + return image diff --git a/src/dataloaders/utils/signal.py b/src/dataloaders/utils/signal.py new file mode 100644 index 0000000..effe0b5 --- /dev/null +++ b/src/dataloaders/utils/signal.py @@ -0,0 +1,32 @@ +import numpy as np + +def whitesignal(period, dt, freq, rms=0.5, batch_shape=()): + """ + Produces output signal of length period / dt, band-limited to frequency freq + Output shape (*batch_shape, period/dt) + Adapted from the nengo library + """ + + if freq is not None and freq < 1. / period: + raise ValueError(f"Make ``{freq=} >= 1. / {period=}`` to produce a non-zero signal",) + + nyquist_cutoff = 0.5 / dt + if freq > nyquist_cutoff: + raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})") + + n_coefficients = int(np.ceil(period / dt / 2.)) + shape = batch_shape + (n_coefficients + 1,) + sigma = rms * np.sqrt(0.5) + coefficients = 1j * np.random.normal(0., sigma, size=shape) + coefficients[..., -1] = 0. + coefficients += np.random.normal(0., sigma, size=shape) + coefficients[..., 0] = 0. + + set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq + coefficients *= (1-set_to_zero) + power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients) + if power_correction > 0.: coefficients /= power_correction + coefficients *= np.sqrt(2 * n_coefficients) + signal = np.fft.irfft(coefficients, axis=-1) + signal = signal - signal[..., :1] # Start from 0 + return signal diff --git a/src/dataloaders/utils/timm_mixup.py b/src/dataloaders/utils/timm_mixup.py new file mode 100644 index 0000000..333a9c6 --- /dev/null +++ b/src/dataloaders/utils/timm_mixup.py @@ -0,0 +1,22 @@ +import torch + +from timm.data import Mixup +from timm.data.mixup import mixup_target + + +class TimmMixup(Mixup): + """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. + """ + def __call__(self, x, target, *args): + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + # We move the assert from the beginning of the function to here + assert len(x) % 2 == 0, 'Batch size should be even when using this' + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + # Another change is to set the right device here + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, + device=target.device) + return x, target, *args \ No newline at end of file diff --git a/src/dataloaders/utils/video_loader.py b/src/dataloaders/utils/video_loader.py new file mode 100644 index 0000000..03a04e4 --- /dev/null +++ b/src/dataloaders/utils/video_loader.py @@ -0,0 +1,52 @@ +import cv2 +# from skvideo.io import VideoCapture +# import skvideo.io +import torch +import torch.utils.data as data +from torchvision.datasets.folder import DatasetFolder +from PIL import Image + +import os +import os.path +import sys + +""" +Custom video dataloader for imagenet-p dataset, which comes in .mp4 + +""" + + +class VideoFolder(DatasetFolder): + + def __init__(self, root, transform=None, target_transform=None, loader=None): + super(VideoFolder, self).__init__( + root, loader, ['.mp4'], transform=transform, target_transform=target_transform) + + self.vids = self.samples + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + + # cap = VideoCapture(path) + cap = cv2.VideoCapture(path) + + frames = [] + + while True: + # Capture frame-by-frame + ret, frame = cap.read() + + if not ret: break + + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(self.transform(Image.fromarray(frame)).unsqueeze(0)) + + cap.release() + + return torch.cat(frames, 0), target \ No newline at end of file diff --git a/src/dataloaders/utils/vocabulary.py b/src/dataloaders/utils/vocabulary.py new file mode 100644 index 0000000..b2fa7bf --- /dev/null +++ b/src/dataloaders/utils/vocabulary.py @@ -0,0 +1,237 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +from collections import Counter +from collections import OrderedDict + +import torch + +import src.utils as utils + + +class Vocab(object): + def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, + delimiter=None, vocab_file=None): + self.counter = Counter() + self.special = special + self.min_freq = min_freq + self.max_size = max_size + self.lower_case = lower_case + self.delimiter = delimiter + self.vocab_file = vocab_file + + def tokenize(self, line, add_eos=False, add_double_eos=False): + line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() + + # empty delimiter '' will evaluate False + if self.delimiter == '': + symbols = line + else: + symbols = line.split(self.delimiter) + + if add_double_eos: # lm1b + return [''] + symbols + [''] + elif add_eos: + return symbols + [''] + else: + return symbols + + def count_file(self, path, verbose=False, add_eos=False): + if verbose: + print('counting file {} ...'.format(path)) + assert os.path.exists(path) + + sents = [] + with open(path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + symbols = self.tokenize(line, add_eos=add_eos) + self.counter.update(symbols) + sents.append(symbols) + + return sents + + def count_sents(self, sents, verbose=False): + """ + sents : a list of sentences, each a list of tokenized symbols + """ + if verbose: + print('counting {} sents ...'.format(len(sents))) + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + self.counter.update(symbols) + + def _build_from_file(self, vocab_file): + self.idx2sym = [] + self.sym2idx = OrderedDict() + + with open(vocab_file, 'r', encoding='utf-8') as f: + for line in f: + symb = line.strip().split()[0] + self.add_symbol(symb) + self.unk_idx = self.sym2idx[''] + + def build_vocab(self): + if self.vocab_file: + print('building vocab from {}'.format(self.vocab_file)) + self._build_from_file(self.vocab_file) + print('final vocab size {}'.format(len(self))) + else: + print('building vocab with min_freq={}, max_size={}'.format( + self.min_freq, self.max_size)) + self.idx2sym = [] + self.sym2idx = OrderedDict() + + for sym in self.special: + self.add_special(sym) + + for sym, cnt in self.counter.most_common(self.max_size): + if cnt < self.min_freq: + break + self.add_symbol(sym) + + print('final vocab size {} from {} unique tokens'.format( + len(self), len(self.counter))) + + def encode_file(self, path, ordered=False, verbose=False, add_eos=True, + add_double_eos=False): + if verbose: + print('encoding file {} ...'.format(path)) + assert os.path.exists(path) + encoded = [] + with open(path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + symbols = self.tokenize(line, add_eos=add_eos, + add_double_eos=add_double_eos) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def encode_sents(self, sents, ordered=False, verbose=False): + if verbose: + print('encoding {} sents ...'.format(len(sents))) + encoded = [] + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def add_special(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) + + def add_symbol(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + + def get_sym(self, idx): + assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) + return self.idx2sym[idx] + + def get_idx(self, sym): + if sym in self.sym2idx: + return self.sym2idx[sym] + else: + # print('encounter unk {}'.format(sym)) + assert '' not in sym + assert hasattr(self, 'unk_idx') + return self.sym2idx.get(sym, self.unk_idx) + + def get_symbols(self, indices): + return [self.get_sym(idx) for idx in indices] + + def get_indices(self, symbols): + return [self.get_idx(sym) for sym in symbols] + + def convert_to_tensor(self, symbols): + return torch.LongTensor(self.get_indices(symbols)) + + def convert_to_sent(self, indices, exclude=None): + if exclude is None: + return ' '.join([self.get_sym(idx) for idx in indices]) + else: + return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) + + def __len__(self): + return len(self.idx2sym) + + +# Class OpenAIVocab has been adapted from +# https://github.com/cybertronai/transformer-xl/blob/master/utils/vocabulary.py +class OpenAIVocab(Vocab): + def __init__(self, max_size=None, vocab_file=None): + from transformers import GPT2Tokenizer + self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + self.EOT = self.tokenizer.encoder['<|endoftext|>'] + self.max_size = max_size + self.vocab_file = vocab_file + + pad = 8 + vocab_size = len(self.tokenizer) + padded_vocab_size = (vocab_size + pad - 1) // pad * pad + for i in range(0, padded_vocab_size - vocab_size): + token = f'madeupword{i:09d}' + self.tokenizer.add_tokens([token]) + + def __len__(self): + return len(self.tokenizer) + + def count_file(self, path, verbose=False, add_eos=False): + # TODO: train from scratch, respect self.max_size + pass + + def build_vocab(self): + pass + + def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False) -> torch.LongTensor: + cached = path + '.bpe' + if os.path.exists(cached): + return torch.load(cached) + print(f'encoding file {path} ...') + assert os.path.exists(path), f"{path} doesn't exist" + + with open(path, encoding='utf-8') as f: + # Suppress warnings about length. + with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull): + out = torch.LongTensor(self.tokenizer.encode(f.read()) + [self.EOT]) + with utils.distributed.sync_workers() as rank: + if rank == 0: + torch.save(out, cached) + return out + + def tokenize(self, line, add_eos=False, add_double_eos=False): + return self.tokenizer.encode(line) + + def convert_to_tensor(self, symbols): + return torch.LongTensor(symbols) diff --git a/src/models/README.md b/src/models/README.md new file mode 100644 index 0000000..97e87ec --- /dev/null +++ b/src/models/README.md @@ -0,0 +1,82 @@ +This repository provides a modular and flexible implementation of general deep sequence models. + +``` +baselines/ Ported baseline models +functional/ Mathematical utilities +hippo/ Utilities for defining HiPPO operators +nn/ Standalone neural network components (nn.Module) +s4/ Standalone S4 modules +sequence/ Modular sequence model interface +``` + +# HiPPO + +HiPPO is the mathematical framework upon which the papers HiPPO, LSSL, and S4 are built on. +HiPPO operators are defined in [hippo/hippo.py](hippo/hippo.py). +Function reconstruction experiments and visualizations are presented in [hippo/visualizations.py](hippo/visualizations.py). + +# S4 + +Standalone implementations of S4 can be found inside [s4/](s4/) (see the README for usage). + +# Modular Sequence Model Interface + +This README provides a basic overview of the model source code. +It is recommended to see the [config README](../../configs/model/README.md) for running experiments with these models. + + +## SequenceModule +The SequenceModule class ([sequence/base.py](sequence/base.py)) is the abstract interface that all sequence models adhere to. +In this codebase, sequence models are defined as a sequence-to-sequence map of shape `(batch size, sequence length, model dimension)` to `(batch size, sequence length, output dimension)`. + +The SequenceModule comes with other methods such as `step` which is meant for autoregressive settings, and logic to carry optional hidden states (for stateful models such as RNNs or S4). + +To add a new model to this codebase, subclass `SequenceModule` and implement the required methods. + +## SequenceModel +The `SequenceModel` class ([sequence/model.py](sequence/model.py)) is the main backbone with configurable options for residual function, normalization placement, etc. + +SequenceModel accepts a black box config for a layer. Compatible layers are SequenceModules (i.e. composable sequence transformations) found under `sequence/`. + +## Example Layers + +### S4 + +The S4 module is found at [sequence/ss/s4.py](sequence/ss/s4.py). + +Standalone versions are in the folder [s4/](s4/). + +### LSSL + +The LSSL is the predecessor of S4. It is currently not recommended for use, but the model can be found at [sequence/ss/lssl.py](sequence/ss/lssl.py). + +It can be run by adding `model/layer=lssl` to the command line, or `model/layer=lssl model.layer.learn=0` for the LSSL-fixed model which does not train $A, B, \Delta$. + +### RNNs + +This codebase also contains a modular implementation of many RNN cells. +These include HiPPO-RNN cells from the original [HiPPO paper](https://arxiv.org/abs/2008.07669). + +Some examples include `model=rnn/hippo-legs` and `model=rnn/hippo-legt` for HiPPO variants from the original [paper](https://arxiv.org/abs/2008.07669), or `model=rnn/gru` for a GRU reimplementation, etc. + +An exception is `model=lstm` to use the PyTorch LSTM. + +Example command (reproducing the Permuted MNIST number from the HiPPO paper, which was SotA at the time): +``` +python train.py pipeline=mnist model=rnn/hippo-legs model.cell_args.hidden_size=512 train.epochs=50 train.batch_size=100 train.lr=0.001 +``` + +# Baselines +Other sequence models are easily incorporated into this repository, +and several other baselines have been ported. + +These include CNNs such as [CKConv](https://arxiv.org/abs/2102.02611) and continuous-time/RNN models such as [UnICORNN](https://arxiv.org/abs/2102.02611) and [LipschitzRNN](https://arxiv.org/abs/2006.12070). + +Models and datasets can be flexibly interchanged. +Examples: +``` +python -m train pipeline=cifar model=ckconv +python -m train pipeline=mnist model=lipschitzrnn +``` + + diff --git a/src/models/baselines/gru.py b/src/models/baselines/gru.py new file mode 100644 index 0000000..fe977cf --- /dev/null +++ b/src/models/baselines/gru.py @@ -0,0 +1,57 @@ +""" Wrapper around nn.GRU to make it compatible with our RNN interface. Similar to lstm.TorchLSTM """ + +import torch +from torch import nn +from src.models.sequence import SequenceModule, TransposedModule +from einops import rearrange +import src.models.nn.utils as U + +@TransposedModule +class TorchGRU(nn.GRU, SequenceModule): + """ Wrapper around nn.GRU to make it compatible with our RNN interface """ + + def __init__(self, d_model, d_hidden, n_layers=1, learn_h0=False, **kwargs): + # Rename input_size, hidden_size to d_input, d_model + # Set batch_first as default as per this codebase's convention + self.d_model = d_model + self.d_hidden = d_hidden + self.n_layers = n_layers + self.learn_h0 = learn_h0 + super().__init__(d_model, d_hidden, num_layers=n_layers, batch_first=True, **kwargs) + + self.num_directions = 2 if self.bidirectional else 1 + + if self.learn_h0: + self.h0 = nn.Parameter(torch.zeros(self.num_layers * self.num_directions, 1, self.hidden_size)) + + def step(self, x, state): + raise NotImplementedError + + def default_state(self, *batch_shape, device=None): + """ + Snippet from nn.LSTM source + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM + """ + if not self.learn_h0: + h_zeros = torch.zeros(self.num_layers * self.num_directions, + *batch_shape, self.hidden_size, + dtype=torch.float, device=device) + else: + h_zeros = self.h0.expand(self.num_layers * self.num_directions, *batch_shape, self.hidden_size) + + return h_zeros + + @property + def d_state(self): + return self.n_layers * self.d_hidden + + @property + def d_output(self): + return self.d_hidden + + @property + def state_to_tensor(self): + if self.n_layers == 1: + return lambda state: state[0] + else: + return lambda state: rearrange(state[0], 'd b h -> b (d h)') diff --git a/src/models/baselines/lipschitzrnn.py b/src/models/baselines/lipschitzrnn.py index 7317394..6f040e3 100644 --- a/src/models/baselines/lipschitzrnn.py +++ b/src/models/baselines/lipschitzrnn.py @@ -1,4 +1,7 @@ -""" LipschitzRNN https://github.com/erichson/LipschitzRNN """ +"""Adapted from LipschitzRNN https://github.com/erichson/LipschitzRNN. + +Original code left as comments +""" import numpy as np import torch @@ -9,10 +12,6 @@ from copy import deepcopy -# from tools import * - -# import torchdiffeq -#from torchdiffeq import odeint_adjoint as odeint from torchdiffeq import odeint as odeint def gaussian_init_(n_units, std=1): @@ -208,4 +207,4 @@ def forward(self, x, *args, **kwargs): return h.unsqueeze(1), None - + diff --git a/src/models/baselines/lstm.py b/src/models/baselines/lstm.py new file mode 100644 index 0000000..70fe4b5 --- /dev/null +++ b/src/models/baselines/lstm.py @@ -0,0 +1,62 @@ +""" Wrapper around nn.LSTM to make it compatible with our RNN interface """ + +import torch +from torch import nn +from src.models.sequence import SequenceModule, TransposedModule +from einops import rearrange +import src.models.nn.utils as U + +@TransposedModule +class TorchLSTM(nn.LSTM, SequenceModule): + """ Wrapper around nn.LSTM to make it compatible with our RNN interface """ + + def __init__(self, d_model, d_hidden, n_layers=1, learn_h0=False, **kwargs): + # Rename input_size, hidden_size to d_input, d_model + # Set batch_first as default as per this codebase's convention + self.d_model = d_model + self.d_hidden = d_hidden + self.n_layers = n_layers + self.learn_h0 = learn_h0 + super().__init__(d_model, d_hidden, num_layers=n_layers, batch_first=True, **kwargs) + + self.num_directions = 2 if self.bidirectional else 1 + self.real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size + + if learn_h0: + self.h0 = nn.Parameter(torch.zeros(self.num_layers * self.num_directions, 1, self.real_hidden_size)) + self.c0 = nn.Parameter(torch.zeros(self.num_layers * self.num_directions, 1, self.hidden_size)) + + def step(self, x, state): + raise NotImplementedError("Needs to be implemented.") + + def default_state(self, *batch_shape, device=None): + """ + Snippet from nn.LSTM source + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM + """ + if not self.learn_h0: + h_zeros = torch.zeros(self.num_layers * self.num_directions, + *batch_shape, self.real_hidden_size, + dtype=torch.float, device=device) + c_zeros = torch.zeros(self.num_layers * self.num_directions, + *batch_shape, self.hidden_size, + dtype=torch.float, device=device) + else: + h_zeros = self.h0.expand(self.num_layers * self.num_directions, *batch_shape, self.real_hidden_size) + c_zeros = self.c0.expand(self.num_layers * self.num_directions, *batch_shape, self.hidden_size) + return (h_zeros, c_zeros) + + @property + def d_state(self): + return self.n_layers * self.d_model + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + if self.n_layers == 1: + return lambda state: state[0] + else: + return lambda state: rearrange(state[0], 'd b h -> b (d h)') diff --git a/src/models/baselines/nonaka/LICENSE b/src/models/baselines/nonaka/LICENSE new file mode 100644 index 0000000..81af939 --- /dev/null +++ b/src/models/baselines/nonaka/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + public_ptbxl + Copyright (C) 2020 Patrick Wagner + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + public_ptbxl Copyright (C) 2020 Patrick Wagner + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. \ No newline at end of file diff --git a/src/models/baselines/nonaka/README.md b/src/models/baselines/nonaka/README.md new file mode 100644 index 0000000..3f3b691 --- /dev/null +++ b/src/models/baselines/nonaka/README.md @@ -0,0 +1,8 @@ + +All code adapted from codebase: https://github.com/seitalab/dnn_ecg_comparison +for the paper +``` +Nonaka, Seita. +"In-depth Benchmarking of Deep Neural Network Architectures for ECG Diagnosis" +``` + diff --git a/src/models/baselines/nonaka/basic_conv1d.py b/src/models/baselines/nonaka/basic_conv1d.py new file mode 100644 index 0000000..ee4e3b3 --- /dev/null +++ b/src/models/baselines/nonaka/basic_conv1d.py @@ -0,0 +1,468 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from typing import Optional, Collection + +############################################################################################################################################## +# utility functions + + +def listify(p=None, q=None): + # https://github.com/fastai/fastai1/blob/master/fastai/core.py#L129 + "Make `p` listy and the same length as `q`." + if p is None: + p = [] + elif isinstance(p, str): + p = [p] + elif not isinstance(p, list): + p = [p] + # Rank 0 tensors in PyTorch are Iterable but don't have a length. + else: + try: + a = len(p) + except: + p = [p] + n = q if type(q) == int else len(p) if q is None else len(q) + if len(p) == 1: + p = p * n + assert len(p) == n, f"List len mismatch ({len(p)} vs {n})" + return list(p) + + +def bn_drop_lin( + n_in: int, + n_out: int, + bn: bool = True, + p: float = 0.0, + actn: Optional[nn.Module] = None, +): + # https://github.com/fastai/fastai_old/blob/master/fastai_do_not_use/layers.py + "`n_in`->bn->dropout->linear(`n_in`,`n_out`)->`actn`" + layers = [nn.BatchNorm1d(n_in)] if bn else [] + if p != 0: + layers.append(nn.Dropout(p)) + layers.append(nn.Linear(n_in, n_out)) + if actn is not None: + layers.append(actn) + return layers + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +def _conv1d( + in_planes, + out_planes, + kernel_size=3, + stride=1, + dilation=1, + act="relu", + bn=True, + drop_p=0, +): + lst = [] + if drop_p > 0: + lst.append(nn.Dropout(drop_p)) + lst.append( + nn.Conv1d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + dilation=dilation, + bias=not (bn), + ) + ) + if bn: + lst.append(nn.BatchNorm1d(out_planes)) + if act == "relu": + lst.append(nn.ReLU(True)) + if act == "elu": + lst.append(nn.ELU(True)) + if act == "prelu": + lst.append(nn.PReLU(True)) + return nn.Sequential(*lst) + + +def _fc(in_planes, out_planes, act="relu", bn=True): + lst = [nn.Linear(in_planes, out_planes, bias=not (bn))] + if bn: + lst.append(nn.BatchNorm1d(out_planes)) + if act == "relu": + lst.append(nn.ReLU(True)) + if act == "elu": + lst.append(nn.ELU(True)) + if act == "prelu": + lst.append(nn.PReLU(True)) + return nn.Sequential(*lst) + + +def cd_adaptiveconcatpool(relevant, irrelevant, module): + mpr, mpi = module.mp.attrib(relevant, irrelevant) + apr, api = module.ap.attrib(relevant, irrelevant) + return torch.cat([mpr, apr], 1), torch.cat([mpi, api], 1) + + +def attrib_adaptiveconcatpool(self, relevant, irrelevant): + return cd_adaptiveconcatpool(relevant, irrelevant, self) + + +class AdaptiveConcatPool1d(nn.Module): + "Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`." + + def __init__(self, sz: Optional[int] = None): + "Output will be 2*sz or 2 if sz is None" + super().__init__() + sz = sz or 1 + self.ap, self.mp = nn.AdaptiveAvgPool1d(sz), nn.AdaptiveMaxPool1d(sz) + + def forward(self, x): + return torch.cat([self.mp(x), self.ap(x)], 1) + + def attrib(self, relevant, irrelevant): + return attrib_adaptiveconcatpool(self, relevant, irrelevant) + + +class SqueezeExcite1d(nn.Module): + """squeeze excite block as used for example in LSTM FCN""" + + def __init__(self, channels, reduction=16): + super().__init__() + channels_reduced = channels // reduction + self.w1 = torch.nn.Parameter( + torch.randn(channels_reduced, channels).unsqueeze(0) + ) + self.w2 = torch.nn.Parameter( + torch.randn(channels, channels_reduced).unsqueeze(0) + ) + + def forward(self, x): + # input is bs,ch,seq + z = torch.mean(x, dim=2, keepdim=True) # bs,ch + intermed = F.relu( + torch.matmul(self.w1, z) + ) # (1,ch_red,ch * bs,ch,1) = (bs, ch_red, 1) + s = F.sigmoid( + torch.matmul(self.w2, intermed) + ) # (1,ch,ch_red * bs, ch_red, 1=bs, ch, 1 + return s * x # bs,ch,seq * bs, ch,1 = bs,ch,seq + + +def weight_init(m): + """call weight initialization for model n via n.appy(weight_init)""" + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if isinstance(m, SqueezeExcite1d): + stdv1 = math.sqrt(2.0 / m.w1.size[0]) + nn.init.normal_(m.w1, 0.0, stdv1) + stdv2 = math.sqrt(1.0 / m.w2.size[1]) + nn.init.normal_(m.w2, 0.0, stdv2) + + +def create_head1d( + nf: int, + nc: int, + lin_ftrs: Optional[Collection[int]] = None, + ps=0.5, + bn_final: bool = False, + bn: bool = True, + act="relu", + concat_pooling=True, +): + "Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes; added bn and act here" + lin_ftrs = ( + [2 * nf if concat_pooling else nf, nc] + if lin_ftrs is None + else [2 * nf if concat_pooling else nf] + lin_ftrs + [nc] + ) # was [nf, 512,nc] + ps = listify(ps) + if len(ps) == 1: + ps = [ps[0] / 2] * (len(lin_ftrs) - 2) + ps + actns = [nn.ReLU(inplace=True) if act == "relu" else nn.ELU(inplace=True)] * ( + len(lin_ftrs) - 2 + ) + [None] + layers = [AdaptiveConcatPool1d() if concat_pooling else nn.MaxPool1d(2), Flatten()] + for ni, no, p, actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns): + layers += bn_drop_lin(ni, no, bn, p, actn) + if bn_final: + layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01)) + return nn.Sequential(*layers) + + +############################################################################################################################################## +# basic convolutional architecture + + +class basic_conv1d(nn.Sequential): + """basic conv1d""" + + def __init__( + self, + filters=[128, 128, 128, 128], + kernel_size=3, + stride=2, + dilation=1, + pool=0, + pool_stride=1, + squeeze_excite_reduction=0, + num_classes=2, + input_channels=8, + act="relu", + bn=True, + headless=False, + split_first_layer=False, + drop_p=0.0, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, + ): + layers = [] + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * len(filters) + for i in range(len(filters)): + layers_tmp = [] + + layers_tmp.append( + _conv1d( + input_channels if i == 0 else filters[i - 1], + filters[i], + kernel_size=kernel_size[i], + stride=(1 if (split_first_layer is True and i == 0) else stride), + dilation=dilation, + act="none" + if ( + (headless is True and i == len(filters) - 1) + or (split_first_layer is True and i == 0) + ) + else act, + bn=False if (headless is True and i == len(filters) - 1) else bn, + drop_p=(0.0 if i == 0 else drop_p), + ) + ) + if split_first_layer is True and i == 0: + layers_tmp.append( + _conv1d( + filters[0], + filters[0], + kernel_size=1, + stride=1, + act=act, + bn=bn, + drop_p=0.0, + ) + ) + # layers_tmp.append(nn.Linear(filters[0],filters[0],bias=not(bn))) + # layers_tmp.append(_fc(filters[0],filters[0],act=act,bn=bn)) + if pool > 0 and i < len(filters) - 1: + layers_tmp.append( + nn.MaxPool1d(pool, stride=pool_stride, padding=(pool - 1) // 2) + ) + if squeeze_excite_reduction > 0: + layers_tmp.append(SqueezeExcite1d(filters[i], squeeze_excite_reduction)) + layers.append(nn.Sequential(*layers_tmp)) + + # head + # layers.append(nn.AdaptiveAvgPool1d(1)) + # layers.append(nn.Linear(filters[-1],num_classes)) + # head #inplace=True leads to a runtime error see ReLU+ dropout https://discuss.pytorch.org/t/relu-dropout-inplace/13467/5 + self.headless = headless + if headless is True: + head = nn.Sequential(nn.AdaptiveAvgPool1d(1), Flatten()) + else: + head = create_head1d( + filters[-1], + nc=num_classes, + lin_ftrs=lin_ftrs_head, + ps=ps_head, + bn_final=bn_final_head, + bn=bn_head, + act=act_head, + concat_pooling=concat_pooling, + ) + layers.append(head) + + super().__init__(*layers) + + def get_layer_groups(self): + return (self[2], self[-1]) + + def get_output_layer(self): + if self.headless is False: + return self[-1][-1] + else: + return None + + def set_output_layer(self, x): + if self.headless is False: + self[-1][-1] = x + + +############################################################################################ +# convenience functions for basic convolutional architectures + + +def fcn(filters=[128] * 5, num_classes=2, input_channels=8): + filters_in = filters + [num_classes] + return basic_conv1d( + filters=filters_in, + kernel_size=3, + stride=1, + pool=2, + pool_stride=2, + input_channels=input_channels, + act="relu", + bn=True, + headless=True, + ) + + +def fcn_wang( + num_classes=2, + input_channels=8, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, +): + return basic_conv1d( + filters=[128, 256, 128], + kernel_size=[8, 5, 3], + stride=1, + pool=0, + pool_stride=2, + num_classes=num_classes, + input_channels=input_channels, + act="relu", + bn=True, + lin_ftrs_head=lin_ftrs_head, + ps_head=ps_head, + bn_final_head=bn_final_head, + bn_head=bn_head, + act_head=act_head, + concat_pooling=concat_pooling, + ) + + +def schirrmeister( + num_classes=2, + input_channels=8, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, +): + return basic_conv1d( + filters=[25, 50, 100, 200], + kernel_size=10, + stride=3, + pool=3, + pool_stride=1, + num_classes=num_classes, + input_channels=input_channels, + act="relu", + bn=True, + headless=False, + split_first_layer=True, + drop_p=0.5, + lin_ftrs_head=lin_ftrs_head, + ps_head=ps_head, + bn_final_head=bn_final_head, + bn_head=bn_head, + act_head=act_head, + concat_pooling=concat_pooling, + ) + + +def sen( + filters=[128] * 5, + num_classes=2, + input_channels=8, + squeeze_excite_reduction=16, + drop_p=0.0, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, +): + return basic_conv1d( + filters=filters, + kernel_size=3, + stride=2, + pool=0, + pool_stride=0, + input_channels=input_channels, + act="relu", + bn=True, + num_classes=num_classes, + squeeze_excite_reduction=squeeze_excite_reduction, + drop_p=drop_p, + lin_ftrs_head=lin_ftrs_head, + ps_head=ps_head, + bn_final_head=bn_final_head, + bn_head=bn_head, + act_head=act_head, + concat_pooling=concat_pooling, + ) + + +def basic1d( + filters=[128] * 5, + kernel_size=3, + stride=2, + dilation=1, + pool=0, + pool_stride=1, + squeeze_excite_reduction=0, + num_classes=2, + input_channels=8, + act="relu", + bn=True, + headless=False, + drop_p=0.0, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, +): + return basic_conv1d( + filters=filters, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + pool=pool, + pool_stride=pool_stride, + squeeze_excite_reduction=squeeze_excite_reduction, + num_classes=num_classes, + input_channels=input_channels, + act=act, + bn=bn, + headless=headless, + drop_p=drop_p, + lin_ftrs_head=lin_ftrs_head, + ps_head=ps_head, + bn_final_head=bn_final_head, + bn_head=bn_head, + act_head=act_head, + concat_pooling=concat_pooling, + ) diff --git a/src/models/baselines/nonaka/inception.py b/src/models/baselines/nonaka/inception.py new file mode 100644 index 0000000..da69aea --- /dev/null +++ b/src/models/baselines/nonaka/inception.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from src.models.baselines.nonaka.basic_conv1d import AdaptiveConcatPool1d, create_head1d + +######################################################################################################## +# Inception time inspired by https://github.com/hfawaz/InceptionTime/blob/master/classifiers/inception.py and https://github.com/tcapelle/TimeSeries_fastai/blob/master/inception.py + +def conv(in_planes, out_planes, kernel_size=3, stride=1): + "convolution with padding" + return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=(kernel_size-1)//2, bias=False) + +def noop(x): return x + +class InceptionBlock1d(nn.Module): + def __init__(self, ni, nb_filters, kss, stride=1, act='linear', bottleneck_size=32): + super().__init__() + self.bottleneck = conv(ni, bottleneck_size, 1, stride) if (bottleneck_size>0) else noop + + self.convs = nn.ModuleList([conv(bottleneck_size if (bottleneck_size>0) else ni, nb_filters, ks) for ks in kss]) + self.conv_bottle = nn.Sequential(nn.MaxPool1d(3, stride, padding=1), conv(ni, nb_filters, 1)) + self.bn_relu = nn.Sequential(nn.BatchNorm1d((len(kss)+1)*nb_filters), nn.ReLU()) + + def forward(self, x): + #print("block in",x.size()) + bottled = self.bottleneck(x) + out = self.bn_relu(torch.cat([c(bottled) for c in self.convs]+[self.conv_bottle(x)], dim=1)) + return out + +class Shortcut1d(nn.Module): + def __init__(self, ni, nf): + super().__init__() + self.act_fn=nn.ReLU(True) + self.conv=conv(ni, nf, 1) + self.bn=nn.BatchNorm1d(nf) + + def forward(self, inp, out): + #print("sk",out.size(), inp.size(), self.conv(inp).size(), self.bn(self.conv(inp)).size) + #input() + return self.act_fn(out + self.bn(self.conv(inp))) + +class InceptionBackbone(nn.Module): + def __init__(self, input_channels, kss, depth, bottleneck_size, nb_filters, use_residual): + super().__init__() + + self.depth = depth + assert((depth % 3) == 0) + self.use_residual = use_residual + + n_ks = len(kss) + 1 + self.im = nn.ModuleList([InceptionBlock1d(input_channels if d==0 else n_ks*nb_filters,nb_filters=nb_filters,kss=kss, bottleneck_size=bottleneck_size) for d in range(depth)]) + self.sk = nn.ModuleList([Shortcut1d(input_channels if d==0 else n_ks*nb_filters, n_ks*nb_filters) for d in range(depth//3)]) + + def forward(self, x): + + input_res = x + for d in range(self.depth): + x = self.im[d](x) + if self.use_residual and d % 3 == 2: + x = (self.sk[d//3])(input_res, x) + input_res = x.clone() + return x + +class Inception1d(nn.Module): + '''inception time architecture''' + def __init__(self, num_classes=2, input_channels=8, kernel_size=40, depth=6, bottleneck_size=32, nb_filters=32, use_residual=True,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): + super().__init__() + assert(kernel_size>=40) + kernel_size = [k-1 if k%2==0 else k for k in [kernel_size,kernel_size//2,kernel_size//4]] #was 39,19,9 + + layers = [InceptionBackbone(input_channels=input_channels, kss=kernel_size, depth=depth, bottleneck_size=bottleneck_size, nb_filters=nb_filters, use_residual=use_residual)] + + n_ks = len(kernel_size) + 1 + #head + head = create_head1d(n_ks*nb_filters, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) + layers.append(head) + #layers.append(AdaptiveConcatPool1d()) + #layers.append(Flatten()) + #layers.append(nn.Linear(2*n_ks*nb_filters, num_classes)) + self.layers = nn.Sequential(*layers) + + def forward(self, x, *args, **kwargs): + y = self.layers(x.transpose(-1, -2)) + return y, None + + def get_layer_groups(self): + depth = self.layers[0].depth + if(depth>3): + return ((self.layers[0].im[3:],self.layers[0].sk[1:]),self.layers[-1]) + else: + return (self.layers[-1]) + + def get_output_layer(self): + return self.layers[-1][-1] + + def set_output_layer(self,x): + self.layers[-1][-1] = x + +def inception1d(**kwargs): + """Constructs an Inception model + """ + return Inception1d(**kwargs) diff --git a/src/models/baselines/nonaka/resnet.py b/src/models/baselines/nonaka/resnet.py new file mode 100644 index 0000000..f35d631 --- /dev/null +++ b/src/models/baselines/nonaka/resnet.py @@ -0,0 +1,269 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from src.models.baselines.nonaka.basic_conv1d import create_head1d, Flatten +############################################################################################### +# Standard resnet + +def conv(in_planes, out_planes, stride=1, kernel_size=3): + "convolution with padding" + return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=(kernel_size-1)//2, bias=False) + + +class BasicBlock1d(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, kernel_size=[3,3], downsample=None): + super().__init__() + + # if(isinstance(kernel_size,int)): kernel_size = [kernel_size,kernel_size//2+1] + if(isinstance(kernel_size,int)): kernel_size = [kernel_size, kernel_size] + + self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=kernel_size[0]) + self.bn1 = nn.BatchNorm1d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv(planes, planes,kernel_size=kernel_size[1]) + self.bn2 = nn.BatchNorm1d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x if self.downsample is None else self.downsample(x) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + + x += residual + x = self.relu(x) + + return x + + +class Bottleneck1d(nn.Module): + expansion = 4 + def __init__(self, inplanes, planes, stride=1, kernel_size=3, downsample=None): + super().__init__() + + self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm1d(planes) + self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=stride, + padding=(kernel_size-1)//2, bias=False) + self.bn2 = nn.BatchNorm1d(planes) + self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm1d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet1d(nn.Sequential): + '''1d adaptation of the torchvision resnet''' + def __init__(self, block, layers, kernel_size=3, num_classes=2, input_channels=3, inplanes=64, fix_feature_dim=True, kernel_size_stem = None, stride_stem=2, pooling_stem=True, stride=2, lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): + self.inplanes = inplanes + + layers_tmp = [] + + if(kernel_size_stem is None): + kernel_size_stem = kernel_size[0] if isinstance(kernel_size,list) else kernel_size + #stem + layers_tmp.append(nn.Conv1d(input_channels, inplanes, kernel_size=kernel_size_stem, stride=stride_stem, padding=(kernel_size_stem-1)//2,bias=False)) + layers_tmp.append(nn.BatchNorm1d(inplanes)) + layers_tmp.append(nn.ReLU(inplace=True)) + if(pooling_stem is True): + layers_tmp.append(nn.MaxPool1d(kernel_size=3, stride=2, padding=1)) + #backbone + for i,l in enumerate(layers): + if(i==0): + layers_tmp.append(self._make_layer(block, inplanes, layers[0],kernel_size=kernel_size)) + else: + layers_tmp.append(self._make_layer(block, inplanes if fix_feature_dim else (2**i)*inplanes, layers[i], stride=stride,kernel_size=kernel_size)) + + #head + #layers_tmp.append(nn.AdaptiveAvgPool1d(1)) + #layers_tmp.append(Flatten()) + #layers_tmp.append(nn.Linear((inplanes if fix_feature_dim else (2**len(layers)*inplanes)) * block.expansion, num_classes)) + + head = create_head1d((inplanes if fix_feature_dim else (2**len(layers)*inplanes)) * block.expansion, nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) + layers_tmp.append(head) + + super().__init__(*layers_tmp) + + def _make_layer(self, block, planes, blocks, stride=1,kernel_size=3): + downsample = None + + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv1d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm1d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, kernel_size, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def get_layer_groups(self): + return (self[6],self[-1]) + + def get_output_layer(self): + return self[-1][-1] + + def set_output_layer(self,x): + self[-1][-1]=x + def forward(self, x, *args, **kwargs): + y = super().forward(x.transpose(-1, -2)) + return y, None + +def resnet1d18(**kwargs): + """Constructs a ResNet-18 model. + """ + return ResNet1d(BasicBlock1d, [2, 2, 2, 2], **kwargs) + +def resnet1d34(**kwargs): + """Constructs a ResNet-34 model. + """ + return ResNet1d(BasicBlock1d, [3, 4, 6, 3], **kwargs) + +def resnet1d50(**kwargs): + """Constructs a ResNet-50 model. + """ + return ResNet1d(Bottleneck1d, [3, 4, 6, 3], **kwargs) + +def resnet1d101(**kwargs): + """Constructs a ResNet-101 model. + """ + return ResNet1d(Bottleneck1d, [3, 4, 23, 3], **kwargs) + +def resnet1d152(**kwargs): + """Constructs a ResNet-152 model. + """ + return ResNet1d(Bottleneck1d, [3, 8, 36, 3], **kwargs) + + +#original used kernel_size_stem = 8 +def resnet1d_wang(**kwargs): + + if(not("kernel_size" in kwargs.keys())): + kwargs["kernel_size"]=[5,3] + if(not("kernel_size_stem" in kwargs.keys())): + kwargs["kernel_size_stem"]=7 + if(not("stride_stem" in kwargs.keys())): + kwargs["stride_stem"]=1 + if(not("pooling_stem" in kwargs.keys())): + kwargs["pooling_stem"]=False + if(not("inplanes" in kwargs.keys())): + kwargs["inplanes"]=128 + + + return ResNet1d(BasicBlock1d, [1, 1, 1], **kwargs) + +def resnet1d(**kwargs): + """Constructs a custom ResNet model. + """ + return ResNet1d(BasicBlock1d, **kwargs) + + +############################################################################################### +# wide resnet adopted from fastai wrn + +def noop(x): return x + +def conv1d(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias=False) -> nn.Conv1d: + "Create `nn.Conv1d` layer: `ni` inputs, `nf` outputs, `ks` kernel size. `padding` defaults to `k//2`." + if padding is None: padding = ks//2 + return nn.Conv1d(ni, nf, kernel_size=ks, stride=stride, padding=padding, bias=bias) + +def _bn1d(ni, init_zero=False): + "Batchnorm layer with 0 initialization" + m = nn.BatchNorm1d(ni) + m.weight.data.fill_(0 if init_zero else 1) + m.bias.data.zero_() + return m + +def bn_relu_conv1d(ni, nf, ks, stride, init_zero=False): + bn_initzero = _bn1d(ni, init_zero=init_zero) + return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv1d(ni, nf, ks, stride)) + +class BasicBlock1dwrn(nn.Module): + def __init__(self, ni, nf, stride, drop_p=0.0, ks=3): + super().__init__() + if(isinstance(ks,int)): + ks = [ks,ks//2+1] + self.bn = nn.BatchNorm1d(ni) + self.conv1 = conv1d(ni, nf, ks[0], stride) + self.conv2 = bn_relu_conv1d(nf, nf, ks[0], 1) + self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None + self.shortcut = conv1d(ni, nf, ks[1], stride) if (ni != nf or stride>1) else noop #adapted to make it work for fix_feature_dim=True + + def forward(self, x): + x2 = F.relu(self.bn(x), inplace=True) + r = self.shortcut(x2) + x = self.conv1(x2) + if self.drop: x = self.drop(x) + x = self.conv2(x) * 0.2 + return x.add_(r) + +def _make_group(N, ni, nf, block, stride, drop_p,ks=3): + return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p,ks=ks) for i in range(N)] + +class WideResNet1d(nn.Sequential): + def __init__(self, input_channels:int, num_groups:int, N:int, num_classes:int, k:int=1, drop_p:float=0.0, start_nf:int=16,fix_feature_dim=True,kernel_size=5,lin_ftrs_head=None, ps_head=0.5, bn_final_head=False, bn_head=True, act_head="relu", concat_pooling=True): + super().__init__() + n_channels = [start_nf] + + for i in range(num_groups): n_channels.append(start_nf if fix_feature_dim else start_nf*(2**i)*k) + + layers = [conv1d(input_channels, n_channels[0], 3, 1)] # conv1 stem + for i in range(num_groups): + layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock1dwrn, (1 if i==0 else 2), drop_p,ks=kernel_size) + + #layers += [nn.BatchNorm1d(n_channels[-1]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1), + # Flatten(), nn.Linear(n_channels[-1], num_classes)] + head = create_head1d(n_channels[-1], nc=num_classes, lin_ftrs=lin_ftrs_head, ps=ps_head, bn_final=bn_final_head, bn=bn_head, act=act_head, concat_pooling=concat_pooling) + layers.append(head) + + super().__init__(*layers) + + def get_layer_groups(self): + return (self[6],self[-1]) + + def get_output_layer(self): + return self[-1][-1] + + def set_output_layer(self,x): + self[-1][-1] = x + + +def wrn1d_22(**kwargs): return WideResNet1d(num_groups=3, N=3, k=6, drop_p=0.,**kwargs) diff --git a/src/models/baselines/nonaka/xresnet.py b/src/models/baselines/nonaka/xresnet.py new file mode 100644 index 0000000..96d319d --- /dev/null +++ b/src/models/baselines/nonaka/xresnet.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from src.models.baselines.nonaka.basic_conv1d import create_head1d, Flatten + +from enum import Enum +import re + +# delegates +import inspect + + +def delegates(to=None, keep=False): + "Decorator: replace `**kwargs` in signature with params from `to`" + + def _f(f): + if to is None: + to_f, from_f = f.__base__.__init__, f.__init__ + else: + to_f, from_f = to, f + sig = inspect.signature(from_f) + sigd = dict(sig.parameters) + k = sigd.pop("kwargs") + s2 = { + k: v + for k, v in inspect.signature(to_f).parameters.items() + if v.default != inspect.Parameter.empty and k not in sigd + } + sigd.update(s2) + if keep: + sigd["kwargs"] = k + from_f.__signature__ = sig.replace(parameters=sigd.values()) + return f + + return _f + + +def store_attr(self, nms): + "Store params named in comma-separated `nms` from calling context into attrs in `self`" + mod = inspect.currentframe().f_back.f_locals + for n in re.split(", *", nms): + setattr(self, n, mod[n]) + + +NormType = Enum("NormType", "Batch BatchZero Weight Spectral Instance InstanceZero") + + +def _conv_func(ndim=2, transpose=False): + "Return the proper conv `ndim` function, potentially `transposed`." + assert 1 <= ndim <= 3 + return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d') + + +def init_default(m, func=nn.init.kaiming_normal_): + "Initialize `m` weights with `func` and set `bias` to 0." + if func and hasattr(m, "weight"): + func(m.weight) + with torch.no_grad(): + if getattr(m, "bias", None) is not None: + m.bias.fill_(0.0) + return m + + +def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs): + "Norm layer with `nf` features and `ndim` initialized depending on `norm_type`." + assert 1 <= ndim <= 3 + bn = getattr(nn, f"{prefix}{ndim}d")(nf, **kwargs) + if bn.affine: + bn.bias.data.fill_(1e-3) + bn.weight.data.fill_(0.0 if zero else 1.0) + return bn + + +def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs): + "BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`." + return _get_norm( + "BatchNorm", nf, ndim, zero=norm_type == NormType.BatchZero, **kwargs + ) + + +class ConvLayer(nn.Sequential): + "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers." + + def __init__( + self, + ni, + nf, + ks=3, + stride=1, + padding=None, + bias=None, + ndim=2, + norm_type=NormType.Batch, + bn_1st=True, + act_cls=nn.ReLU, + transpose=False, + init=nn.init.kaiming_normal_, + xtra=None, + **kwargs, + ): + if padding is None: + padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in (NormType.Batch, NormType.BatchZero) + inn = norm_type in (NormType.Instance, NormType.InstanceZero) + if bias is None: + bias = not (bn or inn) + conv_func = _conv_func(ndim, transpose=transpose) + conv = init_default( + conv_func( + ni, + nf, + kernel_size=ks, + bias=bias, + stride=stride, + padding=padding, + **kwargs, + ), + init, + ) + if norm_type == NormType.Weight: + conv = torch.nn.utils.weight_norm(conv) + elif norm_type == NormType.Spectral: + conv = torch.nn.utils.spectral_norm(conv) + layers = [conv] + act_bn = [] + if act_cls is not None: + act_bn.append(act_cls()) + if bn: + act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim)) + if inn: + act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim)) + if bn_1st: + act_bn.reverse() + layers += act_bn + if xtra: + layers.append(xtra) + super().__init__(*layers) + + +def AdaptiveAvgPool(sz=1, ndim=2): + "nn.AdaptiveAvgPool layer for `ndim`" + assert 1 <= ndim <= 3 + return getattr(nn, f"AdaptiveAvgPool{ndim}d")(sz) + + +def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): + "nn.MaxPool layer for `ndim`" + assert 1 <= ndim <= 3 + return getattr(nn, f"MaxPool{ndim}d")(ks, stride=stride, padding=padding) + + +def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False): + "nn.AvgPool layer for `ndim`" + assert 1 <= ndim <= 3 + return getattr(nn, f"AvgPool{ndim}d")( + ks, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + +class ResBlock(nn.Module): + "Resnet block from `ni` to `nh` with `stride`" + + @delegates(ConvLayer.__init__) + def __init__( + self, + expansion, + ni, + nf, + stride=1, + kernel_size=3, + groups=1, + reduction=None, + nh1=None, + nh2=None, + dw=False, + g2=1, + sa=False, + sym=False, + norm_type=NormType.Batch, + act_cls=nn.ReLU, + ndim=2, + pool=AvgPool, + pool_first=True, + **kwargs, + ): + super().__init__() + norm2 = ( + NormType.BatchZero + if norm_type == NormType.Batch + else NormType.InstanceZero + if norm_type == NormType.Instance + else norm_type + ) + if nh2 is None: + nh2 = nf + if nh1 is None: + nh1 = nh2 + nf, ni = nf * expansion, ni * expansion + k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs) + k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs) + layers = ( + [ + ConvLayer( + ni, + nh2, + kernel_size, + stride=stride, + groups=ni if dw else groups, + **k0, + ), + ConvLayer(nh2, nf, kernel_size, groups=g2, **k1), + ] + if expansion == 1 + else [ + ConvLayer(ni, nh1, 1, **k0), + ConvLayer( + nh1, + nh2, + kernel_size, + stride=stride, + groups=nh1 if dw else groups, + **k0, + ), + ConvLayer(nh2, nf, 1, groups=g2, **k1), + ] + ) + self.convs = nn.Sequential(*layers) + convpath = [self.convs] + if reduction: + convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls)) + if sa: + convpath.append(SimpleSelfAttention(nf, ks=1, sym=sym)) + self.convpath = nn.Sequential(*convpath) + idpath = [] + if ni != nf: + idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs)) + if stride != 1: + idpath.insert((1, 0)[pool_first], pool(2, ndim=ndim, ceil_mode=True)) + self.idpath = nn.Sequential(*idpath) + self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls() + + def forward(self, x): + return self.act(self.convpath(x) + self.idpath(x)) + + +######################### adapted from vison.models.xresnet +def init_cnn(m): + if getattr(m, "bias", None) is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + nn.init.kaiming_normal_(m.weight) + for l in m.children(): + init_cnn(l) + + +class XResNet1d(nn.Sequential): + @delegates(ResBlock) + def __init__( + self, + block, + expansion, + layers, + p=0.0, + input_channels=3, + num_classes=1000, + stem_szs=(32, 32, 64), + kernel_size=5, + kernel_size_stem=5, + widen=1.0, + sa=False, + act_cls=nn.ReLU, + lin_ftrs_head=None, + ps_head=0.5, + bn_final_head=False, + bn_head=True, + act_head="relu", + concat_pooling=True, + **kwargs, + ): + store_attr(self, "block,expansion,act_cls") + stem_szs = [input_channels, *stem_szs] + stem = [ + ConvLayer( + stem_szs[i], + stem_szs[i + 1], + ks=kernel_size_stem, + stride=2 if i == 0 else 1, + act_cls=act_cls, + ndim=1, + ) + for i in range(3) + ] + + # block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)] + block_szs = [ + int(o * widen) for o in [64, 64, 64, 64] + [32] * (len(layers) - 4) + ] + block_szs = [64 // expansion] + block_szs + blocks = [ + self._make_layer( + ni=block_szs[i], + nf=block_szs[i + 1], + blocks=l, + stride=1 if i == 0 else 2, + kernel_size=kernel_size, + sa=sa and i == len(layers) - 4, + ndim=1, + **kwargs, + ) + for i, l in enumerate(layers) + ] + + head = create_head1d( + block_szs[-1] * expansion, + nc=num_classes, + lin_ftrs=lin_ftrs_head, + ps=ps_head, + bn_final=bn_final_head, + bn=bn_head, + act=act_head, + concat_pooling=concat_pooling, + ) + + super().__init__( + *stem, + nn.MaxPool1d(kernel_size=3, stride=2, padding=1), + *blocks, + head, + ) + init_cnn(self) + + def _make_layer(self, ni, nf, blocks, stride, kernel_size, sa, **kwargs): + return nn.Sequential( + *[ + self.block( + self.expansion, + ni if i == 0 else nf, + nf, + stride=stride if i == 0 else 1, + kernel_size=kernel_size, + sa=sa and i == (blocks - 1), + act_cls=self.act_cls, + **kwargs, + ) + for i in range(blocks) + ] + ) + + def get_layer_groups(self): + return (self[3], self[-1]) + + def get_output_layer(self): + return self[-1][-1] + + def set_output_layer(self, x): + self[-1][-1] = x + + def forward(self, x, *args, **kwargs): + y = super().forward(x.transpose(-1, -2)) + return y, None + +# xresnets +def _xresnet1d(expansion, layers, **kwargs): + return XResNet1d(ResBlock, expansion, layers, **kwargs) + + +def xresnet1d18(**kwargs): + return _xresnet1d(1, [2, 2, 2, 2], **kwargs) + + +def xresnet1d34(**kwargs): + return _xresnet1d(1, [3, 4, 6, 3], **kwargs) + + +def xresnet1d50(**kwargs): + return _xresnet1d(4, [3, 4, 6, 3], **kwargs) + + +def xresnet1d101(**kwargs): + return _xresnet1d(4, [3, 4, 23, 3], **kwargs) + + +def xresnet1d152(**kwargs): + return _xresnet1d(4, [3, 8, 36, 3], **kwargs) + + +def xresnet1d18_deep(**kwargs): + return _xresnet1d(1, [2, 2, 2, 2, 1, 1], **kwargs) + + +def xresnet1d34_deep(**kwargs): + return _xresnet1d(1, [3, 4, 6, 3, 1, 1], **kwargs) + + +def xresnet1d50_deep(**kwargs): + return _xresnet1d(4, [3, 4, 6, 3, 1, 1], **kwargs) + + +def xresnet1d18_deeper(**kwargs): + return _xresnet1d(1, [2, 2, 1, 1, 1, 1, 1, 1], **kwargs) + + +def xresnet1d34_deeper(**kwargs): + return _xresnet1d(1, [3, 4, 6, 3, 1, 1, 1, 1], **kwargs) + + +def xresnet1d50_deeper(**kwargs): + return _xresnet1d(4, [3, 4, 6, 3, 1, 1, 1, 1], **kwargs) diff --git a/src/models/baselines/nrde.py b/src/models/baselines/nrde.py index c4b10a0..d73384a 100644 --- a/src/models/baselines/nrde.py +++ b/src/models/baselines/nrde.py @@ -1,4 +1,4 @@ -""" Neural Rough Differential Equations. """ +"""Neural Rough Differential Equations.""" import torch from torch import nn @@ -193,18 +193,3 @@ def __init__(self, input_dim, logsig_dim, num_layers=1, hidden_dim=15): def forward(self, h): return self.net(h).view(-1, self.input_dim, self.logsig_dim) - - -if __name__ == '__main__': - - B = 1 - D = 256 - LS = 512 - L = 1024 - nrde = NeuralRDE(1, LS, D, 1, hidden_hidden_dim=3 * D, num_layers=3) - x = torch.randn(B, 1) - logsig = torch.randn(B, L, LS) - import time - start_time = time.time() - print(nrde.forward((x, logsig))) - print(time.time() - start_time) diff --git a/src/models/baselines/odelstm.py b/src/models/baselines/odelstm.py index 3e99660..3901f74 100644 --- a/src/models/baselines/odelstm.py +++ b/src/models/baselines/odelstm.py @@ -5,23 +5,23 @@ import torch.nn as nn from torchdyn.models import NeuralDE import pytorch_lightning as pl -from pytorch_lightning.metrics.functional import accuracy +from torchmetrics.functional import accuracy class ODELSTMCell(nn.Module): - def __init__(self, d_input, d_model, solver_type="dopri5"): + def __init__(self, d_model, d_hidden, solver_type="dopri5"): super(ODELSTMCell, self).__init__() self.solver_type = solver_type self.fixed_step_solver = solver_type.startswith("fixed_") - self.lstm = nn.LSTMCell(d_input, d_model) + self.lstm = nn.LSTMCell(d_model, d_hidden) # 1 hidden layer NODE self.f_node = nn.Sequential( - nn.Linear(d_model, d_model), + nn.Linear(d_hidden, d_hidden), nn.Tanh(), - nn.Linear(d_model, d_model), + nn.Linear(d_hidden, d_hidden), ) - self.d_input = d_input self.d_model = d_model + self.d_hidden = d_hidden if not self.fixed_step_solver: self.node = NeuralDE(self.f_node, solver=solver_type) else: @@ -77,30 +77,30 @@ def rk4(self, y, delta_t): class ODELSTM(nn.Module): def __init__( self, - d_input, - d_output, d_model, + d_output=None, + d_hidden=None, return_sequences=True, solver_type="dopri5", - l_output=None, - l_max=None, ): super(ODELSTM, self).__init__() - self.d_input = d_input + d_output = d_output or d_model + d_hidden = d_hidden or d_model self.d_model = d_model + self.d_hidden = d_hidden self.d_output = d_output self.return_sequences = return_sequences - self.rnn_cell = ODELSTMCell(d_input, d_model, solver_type=solver_type) - self.fc = nn.Linear(self.d_model, self.d_output) + self.rnn_cell = ODELSTMCell(d_model, d_hidden, solver_type=solver_type) + self.fc = nn.Linear(self.d_hidden, self.d_output) - def forward(self, x, timespans=None, mask=None): + def forward(self, x, state=None, timespans=None, mask=None): device = x.device batch_size = x.size(0) seq_len = x.size(1) hidden_state = [ - torch.zeros((batch_size, self.d_model), device=device), - torch.zeros((batch_size, self.d_model), device=device), + torch.zeros((batch_size, self.d_hidden), device=device), + torch.zeros((batch_size, self.d_hidden), device=device), ] outputs = [] last_output = torch.zeros((batch_size, self.d_output), device=device) @@ -123,7 +123,7 @@ def forward(self, x, timespans=None, mask=None): outputs = torch.stack(outputs, dim=1) # return entire sequence else: outputs = last_output # only last item - return outputs + return outputs, hidden_state class IrregularSequenceLearner(pl.LightningModule): diff --git a/src/models/baselines/resnet.py b/src/models/baselines/resnet.py index a19a32c..855bd98 100644 --- a/src/models/baselines/resnet.py +++ b/src/models/baselines/resnet.py @@ -1,72 +1,35 @@ -""" 2D ResNet baseline, mostly used to test Pathfinder currently. """ +"""2D ResNet baselines from torchvision""" import torch.nn as nn -import math import torchvision.models as models +from einops import rearrange -class Resnet18CelebA(nn.Module): - +class TorchVisionResnet(nn.Module): def __init__( - self, - d_output, - **kwargs, + self, + # d_input, + variant="resnet18", # e.g. [ "resnet18" | "resnet34" | "resnet50" | "wide_resnet50_2" ] ): super().__init__() - if 'l_output' in kwargs and kwargs['l_output'] > 1: - d_output = kwargs['l_output'] - - self.resnet = resnet18(pretrained=False) - self.resnet.fc = nn.Linear(512, d_output) - - def forward(self, x, *args, **kwargs): - # BSC -> BCS - x = x.transpose(1, 2) - # BCS -> BCHW - x = x.view(x.shape[0], 3, 178, 218) - return self.resnet.forward(x) -class ResnetSquare(nn.Module): + self.resnet = getattr(models, variant)(pretrained=False) - def __init__( - self, - d_input, - variant='18', - ): - super().__init__() + # Remove pooling from stem: too much downsizing for CIFAR + self.resnet.maxpool = nn.Identity() - self.d_input = d_input - self.resnet = { - '18': models.resnet18, - '34': models.resnet34, - '50': models.resnet50, - 18: models.resnet18, - 34: models.resnet34, - 50: models.resnet50, - 'wrn': models.wide_resnet50_2, - }[variant](pretrained=False) + # Remove final head: handled by decoder + self.d_output = self.resnet.fc.in_features self.resnet.fc = nn.Identity() - self.d_output = { - '18': 512, - '34': 512, - '50': 2048, - 18: 512, - 34: 512, - 50: 2048, - 'wrn': 2048, - }[variant] + self.resnet.avgpool = nn.Identity() def forward(self, x, *args, **kwargs): - # BSC -> BCS - x = x.transpose(1, 2) - # BCS -> BCHW - n = int(x.size(-1)**.5) - x = x.view(x.shape[0], self.d_input, n, n) - if self.d_input == 1: + x = rearrange(x, 'b ... h -> b h ...') + if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) - elif self.d_input == 3: + elif x.size(1) == 3: pass - else: raise NotImplementedError - y = self.resnet.forward(x) - y = y.unsqueeze(-2) # (B 1 C) + else: + raise NotImplementedError + y = self.resnet(x) return y, None diff --git a/src/models/baselines/samplernn.py b/src/models/baselines/samplernn.py index 858048d..fc53ed2 100644 --- a/src/models/baselines/samplernn.py +++ b/src/models/baselines/samplernn.py @@ -4,9 +4,9 @@ import math import numpy as np +from src.models.baselines.lstm import TorchLSTM +from src.models.baselines.gru import TorchGRU from src.models.sequence.base import SequenceModule -from src.models.sequence.rnns.lstm import TorchLSTM -from src.models.sequence.rnns.gru import TorchGRU from src.models.sequence.ss.s4 import S4 from src.dataloaders.audio import mu_law_decode, linear_decode, q_zero @@ -24,10 +24,10 @@ def d_output(self): return self.d_model if self.output_linear else self.d_hidden def __init__( - self, + self, d_model, d_hidden, - n_layers, + n_layers, learn_h0=False, rnn_type='gru', skip_connections=False, @@ -43,7 +43,7 @@ def __init__( self.learn_h0 = learn_h0 self.skip_connections = skip_connections self.weight_norm = torch.nn.utils.weight_norm if weight_norm else lambda x: x - + self.output_linear = output_linear self.rnn_layers = torch.nn.ModuleList() self.lin_layers = torch.nn.ModuleList() @@ -72,7 +72,7 @@ def __init__( self.rnn_layers.append( RNN(d_model=d_hidden, d_hidden=d_hidden, n_layers=1, learn_h0=learn_h0), ) - + if skip_connections: self.lin_layers.append(self.weight_norm(torch.nn.Linear(d_hidden, d_hidden))) else: @@ -94,13 +94,13 @@ def __init__( for name, module in rnn.named_modules(): if isinstance(module, torch.nn.Linear): setattr(rnn, name, self.weight_norm(module)) - + # Use orthogonal initialization for W_hn if using GRU (weight_hh_l[0]) if rnn_type == 'gru': for rnn in self.rnn_layers: torch.nn.init.orthogonal_(rnn.weight_hh_l0[2 * d_hidden:].data) - + def default_state(self, *batch_shape, device=None): return [ rnn.default_state(*batch_shape, device=device) @@ -135,7 +135,7 @@ def forward(self, inputs, *args, state=None, **kwargs): out = self.output_layer(out) return out, next_states - + class StackedRNNBaseline(SequenceModule): """ @@ -144,11 +144,11 @@ class StackedRNNBaseline(SequenceModule): Marked as the "one_tier" model in the codebase. https://github.com/soroushmehr/sampleRNN_ICLR2017/blob/master/models/one_tier/one_tier.py - Discrete Input (Q_LEVELS) --> - Embedding (EMB_SIZE) --> + Discrete Input (Q_LEVELS) --> + Embedding (EMB_SIZE) --> ----------- (start) this module implements the RNN + Linear Layers backbone ----------- - StackedRNN (N_RNN \in [5], FRAME_SIZE, DIM, LEARNED_H0, WEIGHT_NORM, SKIP_CONNECTIONS) --> + StackedRNN (N_RNN \in [5], FRAME_SIZE, DIM, LEARNED_H0, WEIGHT_NORM, SKIP_CONNECTIONS) --> Linear (DIM, DIM) + ReLU --> Linear (DIM, DIM) + ReLU --> Linear (DIM, DIM) + ReLU --> @@ -201,10 +201,10 @@ def __init__( self.lin1 = torch.nn.utils.weight_norm(self.lin1) self.lin2 = torch.nn.utils.weight_norm(self.lin2) self.lin3 = torch.nn.utils.weight_norm(self.lin3) - + def default_state(self, *batch_shape, device=None): return self.rnn.default_state(*batch_shape, device=device) - + def forward(self, inputs, *args, state=None, **kwargs): outputs = inputs outputs, state = self.rnn(outputs, state) @@ -264,13 +264,13 @@ def d_output(self): return self.d_hidden def __init__( - self, + self, frame_sizes=(16, 4), - n_rnn=2, - d_hidden=1024, + n_rnn=2, + d_hidden=1024, bits=8, - learn_h0=True, - d_model=256, + learn_h0=True, + d_model=256, weight_norm=True, reproduce=True, quantization='linear', @@ -298,12 +298,12 @@ def __init__( ns_frame_samples = map(int, np.cumprod(frame_sizes)) # e.g. (16, 4) -> (16, 64) self.frame_level_rnns = torch.nn.ModuleList([ FrameLevelRNN( - frame_size=frame_size, - n_frame_samples=n_frame_samples, + frame_size=frame_size, + n_frame_samples=n_frame_samples, d_model=d_model, - n_rnn=n_rnn, - d_hidden=d_hidden, - learn_h0=learn_h0, + n_rnn=n_rnn, + d_hidden=d_hidden, + learn_h0=learn_h0, weight_norm=weight_norm, reproduce=reproduce, layer=layer, @@ -312,10 +312,10 @@ def __init__( ]) self.sample_level_mlp = SampleLevelMLP( - frame_size=frame_sizes[0], - d_hidden=d_hidden, + frame_size=frame_sizes[0], + d_hidden=d_hidden, bits=bits, - d_model=d_model, + d_model=d_model, weight_norm=weight_norm, reproduce=reproduce, ) @@ -332,10 +332,10 @@ def step(self, x, state=None, *args, **kwargs): state = self.default_state(batch_size, device=x.device) self._frame_level_outputs = [None for _ in self.frame_level_rnns] self._window = torch.zeros( - batch_size, + batch_size, self.lookback, - x.shape[1] if len(x.shape) == 2 else x.shape[2], - dtype=x.dtype, + x.shape[1] if len(x.shape) == 2 else x.shape[2], + dtype=x.dtype, device=x.device, ) + q_zero(bits=self.bits) self._step_idx = self.lookback @@ -348,15 +348,15 @@ def step(self, x, state=None, *args, **kwargs): # Update window (but on the first step) self._window[:, :-1] = self._window[:, 1:].clone() self._window[:, -1] = x - + new_states = [] - + for (i, rnn), state_ in zip(reversed(list(enumerate(self.frame_level_rnns))), reversed(state)): if self._step_idx % rnn.n_frame_samples != 0: # Don't need to process this rnn new_states.append(state_) continue - + # prev_samples shape: (B, CHUNK_SIZE, D) e.g. (16, 16384, 1) prev_samples = self._window[:, -rnn.n_frame_samples:] @@ -374,7 +374,7 @@ def step(self, x, state=None, *args, **kwargs): prev_samples = self.encoder(prev_samples) prev_samples = prev_samples.contiguous() prev_samples = prev_samples.view(batch_size, -1, rnn.n_frame_samples, self.d_model) - + # upper_tier_conditioning shape: None -> (B, M, D_HIDDEN) [first rnn] # (B, M_{i-1}, D_HIDDEN) -> (B, M_i, D_HIDDEN) [second rnn] if i == len(self.frame_level_rnns) - 1: @@ -444,7 +444,7 @@ def forward(self, inputs, *args, state=None, **kwargs): prev_samples = self.encoder(prev_samples) prev_samples = prev_samples.contiguous() prev_samples = prev_samples.view(batch_size, -1, rnn.n_frame_samples, self.d_model) - + # upper_tier_conditioning shape: None -> (B, M, D_HIDDEN) [first rnn] # (B, M_{i-1}, D_HIDDEN) -> (B, M_i, D_HIDDEN) [second rnn] upper_tier_conditioning, new_state = rnn(prev_samples, upper_tier_conditioning, state_) @@ -482,13 +482,13 @@ def concat_init(tensor, inits): class FrameLevelRNN(torch.nn.Module): def __init__( - self, - frame_size, - n_frame_samples, + self, + frame_size, + n_frame_samples, d_model, - n_rnn, + n_rnn, d_hidden, - learn_h0=True, + learn_h0=True, weight_norm=True, reproduce=False, layer='gru', @@ -573,7 +573,7 @@ def __init__( stride=frame_size, bias=True, ) - + if weight_norm and reproduce: self.input_expand = torch.nn.utils.weight_norm(self.input_expand) @@ -595,7 +595,7 @@ def forward(self, prev_samples, upper_tier_conditioning, state=None): """ if not self.reproduce: # Use strided convolutions to get frame embeddings - # This generalizes the SampleRNN operation to handle non-1D signals + # This generalizes the SampleRNN operation to handle non-1D signals # This reshapes from (B, M_i, FRAME, D_MODEL) -> (B, M_i, D_HIDDEN) prev_samples = prev_samples.view(prev_samples.shape[0], -1, self.d_model) input = self.input_expand(prev_samples.permute(0, 2, 1)).permute(0, 2, 1) @@ -625,11 +625,11 @@ def forward(self, prev_samples, upper_tier_conditioning, state=None): class SampleLevelMLP(torch.nn.Module): def __init__( - self, + self, frame_size, d_hidden, bits=8, - d_model=256, + d_model=256, weight_norm=True, embedding=True, reproduce=False, @@ -652,7 +652,7 @@ def __init__( kernel_size=frame_size, bias=False, ) - + if self.reproduce: self.hidden = torch.nn.Conv1d( in_channels=d_hidden, @@ -661,7 +661,7 @@ def __init__( ) else: self.hidden = torch.nn.Linear(d_hidden, d_hidden) - + if self.reproduce: self.output = torch.nn.Conv1d( in_channels=d_hidden, @@ -707,71 +707,8 @@ def forward(self, prev_samples, upper_tier_conditioning): # Take (B, L', D_MODEL), (B, L, D_HIDDEN) -> (B, D_HIDDEN, L) x = F.relu(self.input(prev_samples) + upper_tier_conditioning) # x: (B, D_HIDDEN, L) -> (B, L, D_HIDDEN) - x = x.permute(0, 2, 1) + x = x.permute(0, 2, 1) x = F.relu(self.hidden(x)) x = self.output(x) return x.contiguous() - - -def test_stacked_rnn(): - rnn = StackedRNN( - d_model=256, - d_hidden=32, - n_layers=4, - skip_connections=True, - dropout=0.0, - output_linear=False, - ) - x = torch.randn(8, 100, 256) - y, states = rnn(x) - assert y.shape == (8, 100, 32) - assert len(states) == 4 - -def test_rnn_baseline(): - rnn = StackedRNNBaseline( - d_model=256, - d_hidden=32, - n_layers=4, - learned_h0=True, - weight_norm=True, - skip_connections=True, - dropout=0.0, - ) - x = torch.randn(8, 100, 256) - y, states = rnn(x) - assert y.shape == (8, 100, 256) - assert len(states) == 4 - -def test_sample_rnn(): - rnn = SampleRNN( - frame_sizes=(16, 4), - n_rnn=1, - d_hidden=1024, - bits=8, - learn_h0=True, - d_model=256, - weight_norm=True, - reproduce=True, - quantization='linear', - ).cuda() - x = torch.randint(0, 255, (2, 1023, 1), dtype=torch.long).cuda() - y, states = rnn(x) - # assert y.shape == (8, 960, 256) - # assert len(states) == 4 - - with torch.no_grad(): - y_i, state = rnn.step(x[:, :rnn.lookback, :], state=None) - ys = [y_i] - for i in range(rnn.lookback, x.shape[1]): - x_i = x[:, i, :] - y_i, state = rnn.step(x_i, state) - ys.append(y_i) - y_ = torch.stack(ys).squeeze().transpose(0, 1) - breakpoint() - -if __name__ == "__main__": - # test_stacked_rnn() - # test_rnn_baseline() - test_sample_rnn() - pass diff --git a/src/models/baselines/unicornn.py b/src/models/baselines/unicornn.py index 073abf4..7a91b05 100644 --- a/src/models/baselines/unicornn.py +++ b/src/models/baselines/unicornn.py @@ -8,6 +8,9 @@ import torch.nn as nn from torch.autograd import Function from torch.nn import Parameter +from collections import namedtuple + +from src.models.sequence.base import SequenceModule, TransposedModule try: from cupy.cuda import function @@ -17,9 +20,6 @@ except ImportError: _unicornn_available = False -from collections import namedtuple - -import pdb UnICORNN_CODE = """ extern "C" { @@ -43,13 +43,13 @@ const float * __restrict__ weight_hh, const float * __restrict__ hy_initial, const float * __restrict__ hz_initial, float * __restrict__ hy_final, float * __restrict__ hz_final, - const int len, const int batch, const int d_model, const float * __restrict__ c, + const int len, const int batch, const int d_model, const float * __restrict__ c, double dt, double alpha, float * __restrict__ hy_all) { int ncols = batch*d_model; int col = blockIdx.x * blockDim.x + threadIdx.x; - if (col >= ncols) return; + if (col >= ncols) return; const float weight_hh_cur = *(weight_hh + (col%d_model)); const float c_cur = *(c + (col%d_model)); float hy = *(hy_initial + col); @@ -68,30 +68,30 @@ *(hz_final + col) = hz; } __global__ void unicornn_bwd(const float * __restrict__ x, - const float * __restrict__ weight_hh, const float * __restrict__ hy_final, + const float * __restrict__ weight_hh, const float * __restrict__ hy_final, const float * __restrict__ hz_final, - const float * __restrict__ grad_h, - const int len, const int batch, const int d_model, const float * __restrict__ c, + const float * __restrict__ grad_h, + const int len, const int batch, const int d_model, const float * __restrict__ c, double dt, double alpha, float * __restrict__ grad_x, float * __restrict__ grad_weight_hh, float * __restrict__ grad_c) - { + { int ncols = batch*d_model; int col = blockIdx.x * blockDim.x + threadIdx.x; - if (col >= ncols) return; + if (col >= ncols) return; const float weight_hh_cur = *(weight_hh + (col%d_model)); const float c_cur = *(c + (col%d_model)); float gweight_hh = 0; float gc = 0; - const float *xp = x+col + (len-1)*ncols; + const float *xp = x+col + (len-1)*ncols; float *gxp = grad_x + col + (len-1)*ncols; const float *ghp = grad_h + col + (len-1)*ncols; float delta_z = 0; - float delta_y = (*ghp); + float delta_y = (*ghp); float delta_dt = 0; float hy = *(hy_final + col); float hz = *(hz_final + col); for (int row = len-1; row >= 0; --row) - { + { delta_dt = delta_y*dt*sigmoid_grad(c_cur)*hz; // reconstruct hidden states based on the final hidden state using adjoint symplectic Euler: hy=hy-dt*sigmoid(c_cur)*hz; @@ -291,7 +291,8 @@ def forward(self, x): return y -class UnICORNN(nn.Module): +@TransposedModule +class UnICORNN(SequenceModule): def __init__( self, # d_input, @@ -301,7 +302,7 @@ def __init__( dt, alpha, n_layers, - drop=0.1, + dropout=0.1, **kwargs ): if not _unicornn_available: @@ -312,7 +313,7 @@ def __init__( super(UnICORNN, self).__init__() self.d_model = d_model self.d_output = d_model - self.drop = drop + self.dropout = dropout self.nlayers = n_layers # self.l_output = l_output self.DIs = nn.ModuleList() @@ -351,17 +352,10 @@ def forward(self, input, *args, **kwargs): ) rnnoutputs["outlayer%d" % x] = self.RNNs[x](rnnoutputs["dilayer%d" % x]) rnnoutputs["outlayer%d" % x] = dropout_overtime( - rnnoutputs["outlayer%d" % x], self.drop, self.training + rnnoutputs["outlayer%d" % x], self.dropout, self.training ) - # temp = rnnoutputs["outlayer%d" % (len(self.RNNs) - 1)][-1] - # output = self.classifier(temp) output = rnnoutputs["outlayer%d" % (len(self.RNNs) - 1)] output = output.transpose(0, 1) - # if self.l_output == 0: - # output = output[:, -1] - # else: - # output = output[:, -self.l_output :] - return output diff --git a/src/models/baselines/wavegan.py b/src/models/baselines/wavegan.py index 4244129..325a860 100644 --- a/src/models/baselines/wavegan.py +++ b/src/models/baselines/wavegan.py @@ -1,15 +1,16 @@ """ Ported implementation of WaveGAN Discriminator https://github.com/chrisdonahue/wavegan Several modifications have been made to integrate this better with this codebase, and to add extra options. + +DEPRECATED as of July 22 (V3 release); this type of generic ConvNet is subsumed by the standard model backbone and Conv1d layer (see config convnet1d.yaml) """ import math import torch import torch.nn as nn -from torch.nn import Parameter import torch.nn.functional as F import torch.utils.data -# from params import * +from einops import rearrange, reduce, repeat from src.models.sequence import SequenceModule from src.models.nn.components import Normalization @@ -20,7 +21,7 @@ def __init__(self, d, layer, norm='none', dropout=0.0): self.d = d self.layer = layer self.norm = Normalization(d, transposed=True, _name_=norm) - self.drop = nn.Dropout2d(dropout) + self.drop = nn.Dropout(dropout) def forward(self, x): y = self.layer(x) @@ -36,7 +37,6 @@ def __init__( input_channels, output_channels, kernel_size, - causal=True, stride=4, # padding=12, # alpha=0.2, @@ -47,7 +47,6 @@ def __init__( super().__init__() layers = [] # Residual convolution layers - # padding = (kernel_size-1, 0) if causal else (kernel_size-1)//2 padding = (kernel_size-1)//2 for _ in range(n_layers-1): layers.append(ResidualBlock( @@ -68,7 +67,7 @@ def __init__( # else nn.Identity() ) # self.alpha = alpha - layers.append(nn.Dropout2d(dropout)) + layers.append(nn.Dropout(dropout)) self.layers = nn.Sequential(*layers) def forward(self, x): @@ -79,14 +78,15 @@ class WaveGANDiscriminator(SequenceModule): def __init__( self, d_model=1, - d_output=10, + d_output=35, l_output=0, # Unused, absorbs argument from sequence model_size=64, n_layers=1, + n_blocks=5, kernel_size=25, # alpha=0.2, norm='none', - causal=True, # Currently doesn't work + pool=False, verbose=False, l_max=16384, # use_batch_norm=False, @@ -94,8 +94,6 @@ def __init__( ): super().__init__() assert kernel_size % 2 == 1, f"Only odd kernel sizes supported" - # assert l_max in [16384, 32768, 65536] # used to predict longer utterances - # assert l_max == 16384 # only support up to 16k sequences for now self.d_model = d_model # c self.d_output = d_output @@ -103,11 +101,9 @@ def __init__( self.l_output = l_output self.l_max = 2 ** math.ceil(math.log2(l_max)) - print(self.l_max) self.model_size = model_size # d - # self.use_batch_norm = use_batch_norm - # self.alpha = alpha + self.pool = pool self.verbose = verbose conv_layers = [ @@ -118,12 +114,14 @@ def __init__( stride=4, # padding=12, # use_batch_norm=use_batch_norm, - causal=causal, norm=norm, n_layers=n_layers, # alpha=alpha, dropout=dropout, - ), + ) + ] + for _ in range(1, n_blocks): + conv_layers.append( Conv1DBlock( model_size, 2 * model_size, @@ -131,90 +129,21 @@ def __init__( stride=4, # padding=12, # use_batch_norm=use_batch_norm, - causal=causal, - norm=norm, - n_layers=n_layers, - # alpha=alpha, - dropout=dropout, - ), - Conv1DBlock( - 2 * model_size, - 4 * model_size, - kernel_size, - stride=4, - # padding=12, - # use_batch_norm=use_batch_norm, - causal=causal, - norm=norm, - n_layers=n_layers, - # alpha=alpha, - dropout=dropout, - ), - Conv1DBlock( - 4 * model_size, - 8 * model_size, - kernel_size, - stride=4, - # padding=12, - # use_batch_norm=use_batch_norm, - causal=causal, norm=norm, n_layers=n_layers, # alpha=alpha, dropout=dropout, - ), - Conv1DBlock( - 8 * model_size, - 16 * model_size, - kernel_size, - stride=4, - # padding=12, - # use_batch_norm=use_batch_norm, - causal=causal, - norm=norm, - n_layers=n_layers, - # alpha=alpha, - dropout=dropout, - ), - ] - self.causal = causal - # self.fc_d_input = 256 * model_size - if self.causal: - self.fc_d_input = 16*model_size - else: - self.fc_d_input = self.l_max // 64 * model_size - - # Logic for very long sequences from WaveGAN code - # if l_max == 32768: - # conv_layers.append( - # Conv1D( - # 16 * model_size, - # 32 * model_size, - # kernel_size, - # stride=2, - # padding=12, - # use_batch_norm=use_batch_norm, - # alpha=alpha, - # ) - # ) - # self.fc_d_input = 480 * model_size - # elif l_max == 65536: - # conv_layers.append( - # Conv1D( - # 16 * model_size, - # 32 * model_size, - # kernel_size, - # stride=4, - # padding=12, - # use_batch_norm=use_batch_norm, - # alpha=alpha, - # ) - # ) - # self.fc_d_input = 512 * model_size - + ) + ) + model_size *= 2 self.conv_layers = nn.ModuleList(conv_layers) - self.fc1 = nn.Linear(self.fc_d_input, self.d_output) + if pool: + self.fc = nn.Linear(model_size, self.d_output) + else: + # self.fc_d_input = self.l_max // 64 * model_size + self.fc_d_input = self.l_max // 4**(n_blocks) * model_size # total length * channels after all conv layers + self.fc1 = nn.Linear(self.fc_d_input, self.d_output) for m in self.modules(): if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): @@ -231,40 +160,11 @@ def forward(self, x, *args, **kwargs): x = conv(x) if self.verbose: print(x.shape) - if self.causal: - x = self.fc1(x.transpose(-1, -2)) # (B, L, output) - - if self.l_output == 0: - return x[:, -1, :], None - else: - return x[:, -self.l_output:, :], None + assert self.l_output == 0 + if self.pool: + x = reduce(x, 'b c l -> b c', 'mean') + x = self.fc(x) else: - assert self.l_output == 0 x = x.reshape(-1, self.fc_d_input) - if self.verbose: - print(x.shape) - return self.fc1(x), None - - -if __name__ == "__main__": - # from torch.autograd import Variable - - channels = 3 - classes = 10 - for l_max in [1024, 4096, 16000]: - - D = WaveGANDiscriminator( - d_model=channels, - d_output=10, - verbose=True, - # use_batch_norm=True, - norm='batch', - causal=False, - n_layers=2, - dropout=0.1, - l_max=l_max, - ) - out2 = D(torch.randn(10, l_max, channels)) - print(out2.shape) - assert out2.shape == (10, classes) - print("==========================") + x = self.fc1(x) + return x, None diff --git a/src/models/baselines/wavenet.py b/src/models/baselines/wavenet.py index 2665bea..a721e5e 100644 --- a/src/models/baselines/wavenet.py +++ b/src/models/baselines/wavenet.py @@ -17,7 +17,6 @@ def mu_law_expansion(data, mu): s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu return s -# def dilate(x, dilation, init_dilation=1, pad_start=True): def dilate(x, dilation, init_dilation=1): """ :param x: Tensor of size (N, C, L), where N is the input dilation, C is the number of channels, and L is the input length @@ -35,7 +34,6 @@ def dilate(x, dilation, init_dilation=1): new_l = int(np.ceil(l / dilation_factor) * dilation_factor) if new_l != l: l = new_l - # x = constant_pad_1d(x, new_l, dimension=2, pad_start=pad_start) x = constant_pad_1d(x, new_l) l_old = int(round(l / dilation_factor)) @@ -94,7 +92,7 @@ def reset(self, device): def constant_pad_1d( input, target_size, -): +): cp1d = torch.nn.ConstantPad1d((target_size - input.size(-1), 0), 0) return cp1d(input) @@ -124,7 +122,7 @@ def d_output(self): def default_state(self, *batch_shape, device=None): return None - + def __init__( self, layers=10, @@ -222,6 +220,17 @@ def __init__( self.receptive_field = receptive_field + # print("Receptive field: {}".format(self.receptive_field)) + + ### TODO + # This piece of code used to go in the generation script to set up the WaveNet in autoregressive mode + # Instead of being in the generation script, it should go as part of this __init__ or default_state() + # if isinstance(model.model, WaveNetModel) and not benchmark: + # l_prefix += model.model.receptive_field + # T += model.model.receptive_field + # if x.shape[1] == 1: + # x = x.repeat(1, l_prefix + 1) + ######### def wavenet(self, input, dilation_func): @@ -280,15 +289,15 @@ def queue_dilate(self, input, dilation, init_dilation, i): queue.enqueue(input) x = queue.dequeue(num_deq=self.kernel_size, dilation=dilation) - + return x - def forward(self, input, state=None): + def forward(self, input, state=None, **kwargs): # BLD -> BDL input = input.transpose(1, 2).contiguous() x = self.wavenet( - input, + input, dilation_func=self.wavenet_dilate, ) @@ -302,7 +311,7 @@ def step(self, x, state=None): x = x.unsqueeze(1).unsqueeze(1) elif len(x.shape) == 2: x = x.unsqueeze(1) - + if state is None: # Reset dilated queues for queue in self.dilated_queues: @@ -313,37 +322,3 @@ def step(self, x, state=None): x = x.transpose(1, 2).contiguous() return x, self.dilated_queues - -def test_wavenet(): - wavenet = WaveNetModel( - layers=10, - blocks=4, - dilation_channels=32, - residual_channels=32, - skip_channels=256, - end_channels=256, - classes=256, - # output_length=16000, - kernel_size=2, - ).cuda() - - print(wavenet) - print(wavenet.parameter_count()) - print(wavenet.receptive_field) - # BLD - x = torch.randn(7, 4093 + 16, 256).cuda() - y, _ = wavenet(x) - print(y.shape) - - with torch.no_grad(): - state = None - ys = [] - for i in range(x.shape[1]): - y_i, state = wavenet.step(x[:, i, :], state) - ys.append(y_i) - y_ = torch.stack(ys).squeeze().transpose(0, 1) - breakpoint() - # assert y.shape == (8, 16000, 256) - -if __name__ == "__main__": - test_wavenet() diff --git a/src/models/functional/cauchy.py b/src/models/functional/cauchy.py index eebf63e..bedbccd 100644 --- a/src/models/functional/cauchy.py +++ b/src/models/functional/cauchy.py @@ -1,4 +1,4 @@ -""" pykeops implementations of the core Cauchy kernel used in the S3 algorithm. +"""pykeops implementations of the Cauchy matrix multiplication used in the S4 algorithm. The interface of the Cauchy multiplication is: v: (N) @@ -7,13 +7,6 @@ Return: y (L) y_k = \sum_i v_i / (z_i - w_k) """ -if __name__ == '__main__': - import sys - import pathlib - p = pathlib.Path().absolute() - print("Adding path: ", p) - sys.path.append(str(p)) - import math import torch @@ -36,7 +29,7 @@ def _broadcast_dims(*tensors): def _c2r(x): return torch.view_as_real(x) def _r2c(x): return torch.view_as_complex(x) -def cauchy_slow(v, z, w, conj=True): +def cauchy_naive(v, z, w, conj=True): """ v: (..., N) z: (..., L) @@ -49,19 +42,6 @@ def cauchy_slow(v, z, w, conj=True): cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) return torch.sum(cauchy_matrix, dim=-2) -def cauchy_lazy(v, z, w, conj=True): - if conj: - v = _conj(v) - w = _conj(w) - v, z, w = _broadcast_dims(v, z, w) - v_l = LazyTensor(rearrange(v, '... N -> ... N 1 1')) - w_l = LazyTensor(rearrange(w, '... N -> ... N 1 1')) - z_l = LazyTensor(rearrange(z, '... L -> ... 1 L 1')) - sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension - div = v_l / sub - s = div.sum(dim=len(v_l.shape)-2) - return s.squeeze(-1) - def cauchy(v, z, w, conj=False): expr = 'ComplexDivide(v, z-w)' cauchy_mult = Genred( @@ -138,99 +118,3 @@ def cauchy_conj(v, z, w, num=2, denom=2): r = 2*cauchy_mult(v, z, w, backend='GPU') return _r2c(r) - -def cauchy_conj_components(v, z, w): - """ Assumes z is pure imaginary (as in S4 with bilinear) """ - - expr_num = 'Imag2Complex(zi*vr) - Real2Complex(vr*wr + vi*wi)' - expr_denom = 'Real2Complex(Square(wr)+Square(wi)-Square(zi)) - Imag2Complex(IntCst(2)*zi*wr)' - cauchy_mult = Genred( - f'ComplexDivide({expr_num}, {expr_denom})', - [ - 'vr = Vj(1)', - 'vi = Vj(1)', - 'wr = Vj(1)', - 'wi = Vj(1)', - 'zi = Vi(1)', - ], - reduction_op='Sum', - axis=1, - ) - - v, z, w = _broadcast_dims(v, z, w) - v = v.unsqueeze(-1) - z = z.unsqueeze(-1) - w = w.unsqueeze(-1) - - v_r, v_i = v.real.contiguous(), v.imag.contiguous() - w_r, w_i = w.real.contiguous(), w.imag.contiguous() - z_i = z.imag.contiguous() - - r = 2*cauchy_mult(v_r, v_i, w_r, w_i, z_i, backend='GPU') - return _r2c(r) - -def cauchy_conj_components_lazy(v, z, w, type=1): - v, z, w = _broadcast_dims(v, z, w) - - v_r, v_i = v.real.contiguous(), v.imag.contiguous() - w_r, w_i = w.real.contiguous(), w.imag.contiguous() - z_i = z.imag.contiguous() - - v_r = LazyTensor(rearrange(v_r, '... N -> ... 1 N 1')) - v_i = LazyTensor(rearrange(v_i, '... N -> ... 1 N 1')) - w_r = LazyTensor(rearrange(w_r, '... N -> ... 1 N 1')) - w_i = LazyTensor(rearrange(w_i, '... N -> ... 1 N 1')) - z_i = LazyTensor(rearrange(z_i, '... L -> ... L 1 1')) - - if type == 1: - num = -v_r*w_r-v_i*w_i + 1j* z_i*v_r - denom = w_r**2+w_i**2-z_i**2 - 2j*w_r*z_i - else: - # z = torch.complex(-w_r, z_i) # Not supported - z = -w_r + 1j* z_i - num = v_r * z - v_i*w_i - denom = z*z + w_i**2 # z**2 is bugged for complex - - r = num / denom - r = 2*r.sum(dim=len(z_i.shape)-1) - return r.squeeze(-1) - -def cauchy_conj2(v, z, w): - expr = 'ComplexDivide(v, z-w) + ComplexDivide(Conj(v), z-Conj(w))' - # expr = 'ComplexDivide(v, z-w)' - cauchy_mult = Genred( - expr, - [ - 'v = Vj(2)', - 'z = Vi(2)', - 'w = Vj(2)', - ], - reduction_op='Sum', - axis=1, - ) - - v, z, w = _broadcast_dims(v, z, w) - if complex: - v = _c2r(v) - z = _c2r(z) - w = _c2r(w) - - r = cauchy_mult(v, z, w, backend='GPU') - return _r2c(r) - - -def trigger_compilation(): - """ Small function to trigger the compilation of a pykeops kernel - - Used in scenarios where we must manually control compilation, e.g. the multi-gpu case (https://github.com/getkeops/keops/issues/168) """ - B = 2 - N = 4 - L = 16 - - w = torch.randn(B, N//2, dtype=torch.cfloat, device='cuda') - v = torch.randn(B, N//2, dtype=torch.cfloat, device='cuda') - z = torch.randn(B, L, dtype=torch.cfloat, device='cuda') - w.requires_grad = True - v.requires_grad = True - - cauchy_conj(v, z, w) diff --git a/src/models/functional/complex.py b/src/models/functional/complex.py index 3160145..a41bff9 100644 --- a/src/models/functional/complex.py +++ b/src/models/functional/complex.py @@ -123,41 +123,8 @@ def backward(ctx, grad): grad_X = ComplexMul.apply(grad, conjugate(Y)).sum_to_size(*X.shape) if ctx.needs_input_grad[1]: grad_Y = ComplexMul.apply(grad, conjugate(X)).sum_to_size(*Y.shape) - # grad_X, grad_Y = ComplexMul.apply(grad, conjugate(Y)), ComplexMul.apply(grad, conjugate(X)) - # # Need to sum over dimensions that were broadcasted - # grad_X = grad_X.sum_to_size(*X.shape) - # grad_Y = grad_Y.sum_to_size(*Y.shape) - # dims_to_sum_X = [-i for i in range(1, X.dim() + 1) if X.shape[-i] != grad.shape[-i]] - # dims_to_sum_Y = [-i for i in range(1, Y.dim() + 1) if Y.shape[-i] != grad.shape[-i]] - # if dims_to_sum_X: # If empty list is passed to sum, it sums all the dimensions - # grad_X = grad_X.sum(dim=dims_to_sum_X, keepdim=True) - # if dims_to_sum_Y: # If empty list is passed to sum, it sums all the dimensions - # grad_Y = grad_Y.sum(dim=dims_to_sum_Y, keepdim=True) - # if grad.dim() > X.dim(): - # grad_X = grad_X.sum(tuple(range(grad.dim() - X.dim()))) - # if grad.dim() > Y.dim(): - # grad_Y = grad_Y.sum(tuple(range(grad.dim() - Y.dim()))) return grad_X, grad_Y complex_mul = ComplexMul.apply if use_cupy else complex_mul_torch if use_pt_native: complex_mul = complex_mul_native - -# @profile -# def complex_mul(X, Y): -# assert X.shape[-1] == 2 and Y.shape[-1] == 2, 'Last dimension must be 2' -# prod = X.unsqueeze(-1) * Y.unsqueeze(-2) -# real = prod[..., 0, 0] - prod[..., 1, 1] -# imag = prod[..., 0, 1] + prod[..., 1, 0] -# return torch.stack( (real, imag), dim=-1) - -# TODO maybe optimizations to be had by wrapping this into a function - - # real = X.select(-1, 0) * Y.select(-1, 0) - X.select(-1, 1) * Y.select(-1, 1) - # imag = X.select(-1, 0) * Y.select(-1, 1) + X.select(-1, 1) * Y.select(-1, 0) - # return torch.stack( (real, imag), dim=-1) - - # return torch.stack( - # (X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1], - # X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0]), - # dim=-1) diff --git a/src/models/functional/krylov.py b/src/models/functional/krylov.py index 3ebe003..3d986e0 100644 --- a/src/models/functional/krylov.py +++ b/src/models/functional/krylov.py @@ -1,4 +1,4 @@ -""" Compute a Krylov function efficiently. (S3 renames the Krylov function to a "state space kernel") +""" Compute a Krylov function efficiently. (S4 renames the Krylov function to a "state space kernel") A : (N, N) b : (N,) @@ -6,7 +6,6 @@ Return: [c^T A^i b for i in [L]] """ - import torch import torch.nn.functional as F from einops import rearrange, repeat @@ -92,6 +91,7 @@ def krylov(L, A, b, c=None, return_power=False): else: return x +@torch.no_grad() def power(L, A, v=None): """ Compute A^L and the scan sum_i A^i v_i @@ -108,7 +108,10 @@ def power(L, A, v=None): L //= 2 if L == 0: break l *= 2 - powers.append(powers[-1] @ powers[-1]) + if v is None: + powers = [powers[-1] @ powers[-1]] + else: + powers.append(powers[-1] @ powers[-1]) if v is None: return I diff --git a/src/models/functional/toeplitz.py b/src/models/functional/toeplitz.py index 3fe5d84..af00739 100644 --- a/src/models/functional/toeplitz.py +++ b/src/models/functional/toeplitz.py @@ -7,12 +7,9 @@ """ import torch -# import torch.nn as nn +import torch.nn as nn import torch.nn.functional as F -# from model.complex import complex_mul -# from pytorch_memlab import profile - def construct_toeplitz(v, f=0.0): """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] @@ -158,118 +155,3 @@ def causal_convolution(u, v, fast=True, pad=False): return triangular_toeplitz_multiply_padded(u, v) if pad and fast: return triangular_toeplitz_multiply_padded_fast(u, v) - -def _fft(x, N): return torch.fft.rfft(F.pad(x, (0, 2*N-x.shape[-1])), n=2*N, dim=-1) -def _ifft(x, N): return torch.fft.irfft(x, n=2*N, dim=-1)[..., :N] - -def causal_convolution_inverse(u): - """ Invert the causal convolution/polynomial/triangular Toeplitz matrix represented by u. - - This is easiest in the polynomial view: - https://www.csa.iisc.ac.in/~chandan/courses/CNT/notes/lec5.pdf - The idea is that - h = g^{-1} (mod x^m) => 2h - gh^2 = g^{-1} (mod x^{2m}) - - # TODO this can be numerically unstable if input is "poorly conditioned", - # for example if u[0] is magnitudes different from the rest of u - """ - N = u.shape[-1] - v = u[..., :1].reciprocal() - while v.shape[-1] < N: - M = v.shape[-1] - v_f = _fft(v, 2*M) - u_f = _fft(u[..., :2*M], 2*M) - _v = -_ifft(u_f * v_f**2, 2*M) - _v[..., :M] = _v[..., :M] + 2*v - v = _v - # TODO contiguous? - v = v[..., :N] - return v - -""" Below are experimental functions for improving the stability of LSSL/S3 algorithm. Currently not used anywhere. """ - -def causal_convolution_inverse_wrong(u, v): - """ Solve u * x = v. Initial attempt by inverting the multiplication algorithm, which I think doesn't work. """ - n = u.shape[-1] - u_expand = F.pad(u, (0, n)) - v_expand = F.pad(v, (0, n)) - u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) - v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) - uv_f = v_f / u_f - x = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] - return x - -def construct_toeplitz_log(v): - n = v.shape[-1] - a = torch.arange(n, device=v.device) - b = -a - indices = a[:, None] + b[None] - K = v[..., indices] - K[..., indices < 0] = -100.0 - return K - -def _logsumexp(x, dim=-1): - """ logsumexp for complex """ - m = torch.max(torch.real(x), dim=dim, keepdim=True)[0] - x = x - m - x = torch.log(torch.sum(torch.exp(x), dim=dim)) - x = x + m.squeeze(dim) - return x - -def causal_convolution_inverse_log(u, N=-1): - """ Invert the causal convolution/polynomial/triangular Toeplitz matrix represented by u. - - This is easiest in the polynomial view: - https://www.csa.iisc.ac.in/~chandan/courses/CNT/notes/lec5.pdf - The idea is that - h = g^{-1} (mod x^m) => 2h - gh^2 = g^{-1} (mod x^{2m}) - - # TODO this can be numerically unstable if input is "poorly conditioned", - # for example if u[0] is magnitudes different from the rest of u - """ - if N < 0: - N = u.shape[-1] - v = - u[..., :1] - while v.shape[-1] < N: - M = v.shape[-1] - _v = F.pad(v, (0, M), value=-100.0) - _v_ = construct_toeplitz_log(_v) - u_ = u[..., :2*M] if u.shape[-1] >= 2*M else F.pad(u, (0, 2*M-u.shape[-1]), value=-100.0) - _u = _logsumexp(_v_ + u_, dim=-1) - _u = _logsumexp(_v_ + _u, dim=-1) - _u = _u + torch.log(-torch.ones_like(_u)) - _v = _v + torch.log(2.0 * torch.ones_like(_u)) - v = _logsumexp(torch.stack([_v, _u], dim=-1), dim=-1) - # TODO contiguous? - v = v[..., :N] - - check = _logsumexp(construct_toeplitz_log(v) + F.pad(u, (0, N-u.shape[-1]), value=-100.0)) - print("check", check, torch.exp(check)) - return v - - - -if __name__ == '__main__': - a = torch.tensor([1., 2, 3, 4], requires_grad=True) - b = torch.tensor([5., 6, 7, 8], requires_grad=True) - a.retain_grad() - b.retain_grad() - x = triangular_toeplitz_multiply_padded(F.pad(a, (0, 4)), F.pad(b, (0, 4)))[:4] - print(x) # [5 16 34 60] - x = x.sum() - x.backward() - print(x, a.grad, b.grad) # [26 18 11 5] [10 6 3 1] - -if __name__ == '__main__': - N = 4 - a = torch.randn(N) - construct_toeplitz(a) - print(a) - b = causal_convolution_inverse(a) - print("inverse", b) - print("check", causal_convolution(a, b)) - i = torch.zeros(N) - i[0] = 1.0 - b = causal_convolution_inverse_wrong(a, i) - print(b) - print(causal_convolution(a, b)) diff --git a/src/models/functional/unroll.py b/src/models/functional/unroll.py index 41ff3b7..a28a478 100644 --- a/src/models/functional/unroll.py +++ b/src/models/functional/unroll.py @@ -211,9 +211,7 @@ def variable_unroll_sequential(A, u, s=None, variable=True): outputs = [] for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)): # s = F.linear(s, A_) + u_ - # print("shapes", A_.shape, s.shape, has_batch) s = batch_mult(A_.unsqueeze(0), s.unsqueeze(0), has_batch)[0] - # breakpoint() s = s + u_ outputs.append(s) @@ -314,11 +312,9 @@ def variable_unroll_toeplitz_sequential(A, u, s=None, variable=True, pad=False): if pad: n = A.shape[-1] - # print("shapes", A.shape, u.shape) A = F.pad(A, (0, n)) u = F.pad(u, (0, n)) s = F.pad(s, (0, n)) - # print("shapes", A.shape, u.shape) ret = variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply_padded, variable=True) ret = ret[..., :n] return ret @@ -344,7 +340,7 @@ def variable_unroll_general(A, u, s, op, compose_op=None, sequential_op=None, va compose_op = op uneven = u.shape[0] % 2 == 1 - has_batch = len(u.shape) >= len(A.shape) + # has_batch = len(u.shape) >= len(A.shape) u_0 = u[0::2, ...] u_1 = u[1::2, ...] @@ -412,11 +408,9 @@ def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=F if pad: n = A.shape[-1] - # print("shapes", A.shape, u.shape) A = F.pad(A, (0, n)) u = F.pad(u, (0, n)) s = F.pad(s, (0, n)) - # print("shapes", A.shape, u.shape) op = triangular_toeplitz_multiply_padded ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit) ret = ret[..., :n] @@ -425,185 +419,3 @@ def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=F op = triangular_toeplitz_multiply ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit) return ret - - - -### Testing - -def test_correctness(): - print("Testing Correctness\n====================") - - # Test sequential unroll - L = 3 - A = torch.Tensor([[1, 1], [1, 0]]) - u = torch.ones((L, 2)) - x = unroll(A, u) - assert torch.isclose(x, torch.Tensor([[1., 1.], [3., 2.], [6., 4.]])).all() - - # Test utilities - assert torch.isclose(shift_up(x), torch.Tensor([[0., 0.], [1., 1.], [3., 2.]])).all() - assert torch.isclose(interleave(x, x), torch.Tensor([[1., 1.], [1., 1.], [3., 2.], [3., 2.], [6., 4.], [6., 4.]])).all() - - # Test parallel unroll - x = parallel_unroll_recursive(A, u) - assert torch.isclose(x, torch.Tensor([[1., 1.], [3., 2.], [6., 4.]])).all() - - # Powers - L = 12 - A = torch.Tensor([[1, 0, 0], [2, 1, 0], [3, 3, 1]]) - u = torch.ones((L, 3)) - x = parallel_unroll_recursive(A, u) - print("recursive", x) - x = parallel_unroll_recursive_br(A, u) - print("recursive_br", x) - x = parallel_unroll_iterative(A, u) - print("iterative_br", x) - - - A = A.repeat((L, 1, 1)) - s = torch.zeros(3) - print("A shape", A.shape) - x = variable_unroll_sequential(A, u, s) - print("variable_unroll", x) - x = variable_unroll(A, u, s) - print("parallel_variable_unroll", x) - - -def generate_data(L, N, B=None, cuda=True): - A = torch.eye(N) + torch.normal(0, 1, size=(N, N)) / (N**.5) / L - u = torch.normal(0, 1, size=(L, B, N)) - - - # device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') - device = torch.device('cuda:0') if cuda else torch.device('cpu') - A = A.to(device) - u = u.to(device) - return A, u - -def test_stability(): - print("Testing Stability\n====================") - L = 256 - N = L // 2 - B = 100 - A, u = generate_data(L, N, B) - - x = unroll(A, u) - x1 = parallel_unroll_recursive(A, u) - x2 = parallel_unroll_recursive_br(A, u) - x3 = parallel_unroll_iterative(A, u) - print("norm error", torch.norm(x-x1)) - print("norm error", torch.norm(x-x2)) - print("norm error", torch.norm(x-x3)) - # print(x-x1) - # print(x-x2) - # print(x-x3) - print("max error", torch.max(torch.abs(x-x1))) - print("max error", torch.max(torch.abs(x-x2))) - print("max error", torch.max(torch.abs(x-x3))) - - A = A.repeat((L, 1, 1)) - x = variable_unroll_sequential(A, u) - x_ = variable_unroll(A, u) - # x_ = variable_unroll_matrix_sequential(A, u) - x_ = variable_unroll_matrix(A, u) - print(x-x_) - abserr = torch.abs(x-x_) - relerr = abserr/(torch.abs(x)+1e-8) - print("norm abs error", torch.norm(abserr)) - print("max abs error", torch.max(abserr)) - print("norm rel error", torch.norm(relerr)) - print("max rel error", torch.max(relerr)) - -def test_toeplitz(): - from model.toeplitz import construct_toeplitz - def summarize(name, x, x_, showdiff=False): - print(name, "stats") - if showdiff: - print(x-x_) - abserr = torch.abs(x-x_) - relerr = abserr/(torch.abs(x)+1e-8) - print(" norm abs error", torch.norm(abserr)) - print(" max abs error", torch.max(abserr)) - print(" norm rel error", torch.norm(relerr)) - print(" max rel error", torch.max(relerr)) - - print("Testing Toeplitz\n====================") - L = 512 - N = L // 2 - B = 100 - A, u = generate_data(L, N, B) - - A = A[..., 0] - A = construct_toeplitz(A) - - # print("SHAPES", A.shape, u.shape) - - # Static A - x = unroll(A, u) - x_ = variable_unroll(A, u, variable=False) - summarize("nonvariable matrix original", x, x_, showdiff=False) - x_ = variable_unroll_matrix(A, u, variable=False) - summarize("nonvariable matrix general", x, x_, showdiff=False) - x_ = variable_unroll_toeplitz(A[..., 0], u, variable=False) - summarize("nonvariable toeplitz", x, x_, showdiff=False) - - # Sequential - A = A.repeat((L, 1, 1)) - for _ in range(1): - x_ = variable_unroll_sequential(A, u) - summarize("variable unroll sequential", x, x_, showdiff=False) - x_ = variable_unroll_matrix_sequential(A, u) - summarize("variable matrix sequential", x, x_, showdiff=False) - x_ = variable_unroll_toeplitz_sequential(A[..., 0], u, pad=True) - summarize("variable toeplitz sequential", x, x_, showdiff=False) - - # Parallel - for _ in range(1): - x_ = variable_unroll(A, u) - summarize("variable matrix original", x, x_, showdiff=False) - x_ = variable_unroll_matrix(A, u) - summarize("variable matrix general", x, x_, showdiff=False) - x_ = variable_unroll_toeplitz(A[..., 0], u, pad=True, recurse_limit=8) - summarize("variable toeplitz", x, x_, showdiff=False) - -def test_speed(variable=False, it=1): - print("Testing Speed\n====================") - N = 256 - L = 1024 - B = 100 - A, u = generate_data(L, N, B) - As = A.repeat((L, 1, 1)) - - u.requires_grad=True - As.requires_grad=True - for _ in range(it): - x = unroll(A, u) - x = torch.sum(x) - x.backward() - - x = parallel_unroll_recursive(A, u) - x = torch.sum(x) - x.backward() - - # parallel_unroll_recursive_br(A, u) - # parallel_unroll_iterative(A, u) - - for _ in range(it): - if variable: - x = variable_unroll_sequential(As, u, variable=True, recurse_limit=16) - x = torch.sum(x) - x.backward() - x = variable_unroll(As, u, variable=True, recurse_limit=16) - x = torch.sum(x) - x.backward() - else: - variable_unroll_sequential(A, u, variable=False, recurse_limit=16) - variable_unroll(A, u, variable=False, recurse_limit=16) - -# TODO refactor using benchmark util - -if __name__ == '__main__': - # test_correctness() - test_stability() - # test_toeplitz() - # test_speed(variable=True, it=100) diff --git a/src/models/functional/vandermonde.py b/src/models/functional/vandermonde.py new file mode 100644 index 0000000..d36ce37 --- /dev/null +++ b/src/models/functional/vandermonde.py @@ -0,0 +1,134 @@ +"""pykeops implementations of the Vandermonde matrix multiplication kernel used in the S4D kernel.""" +import math +import torch + +from einops import rearrange, repeat +from opt_einsum import contract + +import os + +try: + import pykeops + from pykeops.torch import LazyTensor, Genred +except: + pass + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] + return tensors + +def _c2r(x): return torch.view_as_real(x) +def _r2c(x): return torch.view_as_complex(x) + +def vandermonde_naive(v, x, L, conj=True): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + if conj: + x = _conj(x) + v = _conj(v) + vandermonde_matrix = x.unsqueeze(-1) ** torch.arange(L).to(x) # (... N L) + vandermonde_prod = torch.sum(v.unsqueeze(-1) * vandermonde_matrix, dim=-2) # (... L) + return vandermonde_prod + +def log_vandermonde_naive(v, x, L, conj=True): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L) + if conj: + return 2*vandermonde_prod.real + else: + return vandermonde_prod + +def log_vandermonde_lazy(v, x, L, conj=True): + if conj: + v = _conj(v) + x = _conj(x) + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v_l = LazyTensor(rearrange(v, '... N -> ... N 1 1')) + x_l = LazyTensor(rearrange(x, '... N -> ... N 1 1')) + l_l = LazyTensor(rearrange(l, '... L -> ... 1 L 1')) + # exp + vand = (x_l * l_l).exp() + s = (v_l*vand).sum(dim=len(v_l.shape)-2) + return s.squeeze(-1) + +def log_vandermonde(v, x, L, conj=True): + expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'v = Vj(2)', + 'x = Vj(2)', + 'l = Vi(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(v, x, l, backend='GPU') + if conj: + return 2*_r2c(r).real + else: + return _r2c(r) + +def log_vandermonde_transpose_naive(u, v, x, L): + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L) + return vandermonde_prod + +def log_vandermonde_transpose(u, v, x, L): + """ + u: ... H L + v: ... H N + x: ... H N + Returns: ... H N + + V = Vandermonde(a, L) : (H N L) + contract_L(V * u * v) + """ + expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'u = Vj(2)', + 'v = Vi(2)', + 'x = Vi(2)', + 'l = Vj(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + u, v, x, l = _broadcast_dims(u, v, x, l) + u = _c2r(u) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(u, v, x, l, backend='GPU') + return _r2c(r) + +def _log_vandermonde_matmul(x, L): + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + return vandermonde_matrix + +def log_vandermonde_matmul(v, K): + prod = contract('...n, ...nl -> ...l', v, K) + return 2*prod.real diff --git a/src/models/hippo/hippo.py b/src/models/hippo/hippo.py index a391cd3..679a0c0 100644 --- a/src/models/hippo/hippo.py +++ b/src/models/hippo/hippo.py @@ -14,8 +14,7 @@ def embed_c2r(A): np.pad(A, ((0, 0), (1, 0), (0, 0), (1,0))) return rearrange(A, 'm x n y -> (m x) (n y)') -# TODO take in 'torch' option to return torch instead of numpy, which converts the shape of B from (N, 1) to (N) -# TODO remove tlagt +# TODO take in 'torch' option to return torch instead of numpy, and converts the shape of B from (N, 1) to (N) def transition(measure, N, **measure_args): """ A, B transition matrices for different measures @@ -30,11 +29,6 @@ def transition(measure, N, **measure_args): b = measure_args.get('beta', 1.0) A = np.eye(N) / 2 - np.tril(np.ones((N, N))) B = b * np.ones((N, 1)) - elif measure == 'tlagt': - # beta = 1 corresponds to no tilt - b = measure_args.get('beta', 1.0) - A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N))) - B = b * np.ones((N, 1)) # Generalized Laguerre # alpha 0, beta small is most stable (limits to the 'lagt' measure) # alpha 0, beta 1 has transition matrix A = [lower triangular 1] @@ -57,7 +51,6 @@ def transition(measure, N, **measure_args): A = -A # Halve again for timescale correctness - # A, B = A/2, B/2 A *= 0.5 B *= 0.5 # LMU: equivalent to LegT up to normalization @@ -88,20 +81,16 @@ def transition(measure, N, **measure_args): B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) A += .5 * B*B[None, :, 0] B = B / 2.0 - elif measure == 'fourier_old': - freqs = np.arange(N//2) - d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] - A = 2*np.pi*(np.diag(d, 1) - np.diag(d, -1)) - A = A - embed_c2r(np.ones((N//2, N//2))) - B = embed_c2r(np.ones((N//2, 1)))[..., :1] - elif measure == 'fourier_diag': + elif measure in ['fourier_diag', 'foud']: freqs = np.arange(N//2) d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1)) - # A = A - 0.5*embed_c2r(np.ones((N//2, N//2))) A = A - .5 * np.eye(N) - B = embed_c2r(np.ones((N//2, 1)))[..., :1] - elif measure == 'fourier': + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + B = B[:, None] + elif measure in ['fourier', 'fout']: freqs = np.arange(N//2) d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) @@ -134,7 +123,6 @@ def transition(measure, N, **measure_args): # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case A = A - B[:, None] * B[None, :] * 2 B = B[:, None] * 2 - elif measure == 'random': A = np.random.randn(N, N) / N B = np.random.randn(N, 1) @@ -164,15 +152,7 @@ def rank_correction(measure, N, rank=1, dtype=torch.float): elif measure == 'lagt': assert rank >= 1 P = .5**.5 * torch.ones(1, N, dtype=dtype) - elif measure == 'fourier_old': - P = torch.ones(N, dtype=dtype) # (N) - P0 = P.clone() - P0[0::2] = 0. - P1 = P.clone() - P1[1::2] = 0. - P = torch.stack([P0, P1], dim=0) # (2 N) - P = torch.zeros(1, N, dtype=dtype) - elif measure == 'fourier': + elif measure in ['fourier', 'fout']: P = torch.zeros(N) P[0::2] = 2**.5 P[0] = 1 @@ -188,7 +168,7 @@ def rank_correction(measure, N, rank=1, dtype=torch.float): P[0::2] = 2**.5 P[0] = 1 P = 2**.5 * P.unsqueeze(0) - elif measure in ['fourier_diag', 'legsd']: + elif measure in ['fourier_diag', 'foud', 'legsd']: P = torch.zeros(1, N, dtype=dtype) else: raise NotImplementedError @@ -202,8 +182,6 @@ def initial_C(measure, N, dtype=torch.float): if measure == 'legt': C = (torch.arange(N, dtype=dtype)*2+1)**.5 * (-1)**torch.arange(N) - elif measure == 'fourier_old': - C = torch.ones(N, dtype=dtype) # (N) elif measure == 'fourier': C = torch.zeros(N) C[0::2] = 2**.5 @@ -214,12 +192,13 @@ def initial_C(measure, N, dtype=torch.float): return C -def nplr(measure, N, rank=1, dtype=torch.float): +def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): """ Return w, p, q, V, B such that (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V i.e. A = V[w - p q^*]V^*, B = V B """ - assert dtype == torch.float or torch.cfloat + assert dtype == torch.float or torch.double + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble A, B = transition(measure, N) A = torch.as_tensor(A, dtype=dtype) # (N, N) @@ -227,11 +206,6 @@ def nplr(measure, N, rank=1, dtype=torch.float): P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) - w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) - # V w V^{-1} = A - - # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) - # We require AP to be nearly skew-symmetric _A = AP + AP.transpose(-1, -2) @@ -239,9 +213,21 @@ def nplr(measure, N, rank=1, dtype=torch.float): print("WARNING: HiPPO matrix not skew symmetric", err) + # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately + # Imaginary part can use eigh instead of eig + w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) + + # Diagonalize in double precision + if diagonalize_precision: AP = AP.to(torch.double) + # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + w_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N) + if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype) + w = w_re + 1j * w_im + # Check: V w V^{-1} = A + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + # Only keep half of each conjugate pair - # w = w[..., 0::2].contiguous() - # V = V[..., 0::2].contiguous() _, idx = torch.sort(w.imag) w_sorted = w[idx] V_sorted = V[:, idx] @@ -257,115 +243,16 @@ def nplr(measure, N, rank=1, dtype=torch.float): V[1, -1] = 2**-0.5 * 1j _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) - # assert torch.allclose(2*_AP.real, AP, atol=1e-5) if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5): print("Warning: Diagonalization of A matrix not numerically precise - error", err) # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) - - # # Override eigenvectors for 0 eigenvalues, to make them conjugate pairs - # breakpoint() - # rotate = torch.tensor([[1, 1], [1j, -1j]]) / 2**.5 - # # rotate = torch.tensor([[1, -1j], [1, 1j]]) / 2**.5 - # V_rot = (V.view(N, N//2, 2) @ rotate).view(N, N) # rotate every pair of eigenvectors - # V = torch.where(w.repeat(N, 1) == 0, V_rot, V) - - V_inv = V.conj().transpose(-1, -2) - C = initial_C(measure, N, dtype=dtype) + # C = initial_C(measure, N, dtype=dtype) B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B - C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C + # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P - - return w, P, B, C, V - -def random_dplr(N, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, scaling='inverse', random_real=False, random_imag=False, normalize=True): - assert dtype == torch.float or torch.double - # batch_shape = (H, N//2) if H is not None else (N//2,) - dtype = torch.cfloat if dtype == torch.float else torch.cdouble - # w = -torch.exp(torch.randn(N//2)) + 1j*torch.randn(N//2) - # w = -torch.exp(torch.randn(N//2)) + 1j*2*torch.tensor(np.pi)*N*torch.rand(N//2) # try larger eigenvalue spread - - pi = torch.tensor(np.pi) - if random_real: - real_part = torch.rand(H, N//2) - else: - real_part = .5 * torch.ones(H, N//2) - if random_imag: - imag_part = N//2 * torch.rand(H, N//2) - else: - imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) - - real_part = real_scale * real_part - if scaling == 'random': - imag_part = torch.randn(H, N//2) - elif scaling == 'linear': - imag_part = pi * imag_part - elif scaling == 'inverse': # Based on asymptotics of the default HiPPO matrix - # intercept = torch.log(N//2)/torch.log(2) * 2./3. - # log_imag_part = intercept + 2. * torch.atanh((1+imag_part*2)/N*2-1) - # imag_part = torch.exp(log_imag_part) - # intercept = torch.log(N//2) - .5 - # imag_part = torch.exp(2. * torch.atanh((1+imag_part*2)/N*2-1)) - imag_part = 1/pi * N * (N/(1+2*imag_part)-1) - elif scaling == 'inverse2': # Based on asymptotics of the default HiPPO matrix - # intercept = torch.log(N//2)/torch.log(2) * 2./3. - # log_imag_part = intercept + 2. * torch.atanh((1+imag_part*2)/N*2-1) - # imag_part = torch.exp(log_imag_part) - # intercept = torch.log(N//2) - .5 - # imag_part = torch.exp(2. * torch.atanh((1+imag_part*2)/N*2-1)) - imag_part = 1/pi * N * (N/(1+imag_part)-1) - elif scaling == 'quadratic': - imag_part = 1/pi * (1+2*imag_part)**2 - else: raise NotImplementedError - imag_part = imag_scale * imag_part - w = -real_part + 1j * imag_part - - - - - # w = -torch.rand(N//2) + 1j*2*torch.tensor(np.pi)*N*torch.rand(N//2) # try larger eigenvalue spread - # w = -1 + torch.arange(N//2) * 1j * 2 * torch.tensor(np.pi) - P = torch.randn(rank, H, N//2, dtype=dtype) - # p = torch.zeros(rank, N//2, dtype=dtype) - B = torch.randn(H, N//2, dtype=dtype) - # B = torch.ones(N//2, dtype=dtype) - C = torch.randn(H, N//2, dtype=dtype) - V = torch.eye(N, dtype=dtype)[..., :N//2] # Only used in testing - - if normalize: # TODO can normalize the full matrix with rank correction too - norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function - zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector - B = B / zeta**.5 - - return w, P, B, C, V - -def test_nplr(): - N = 4 - measure = 'fourier_decay' - w, P, B, C, V = nplr(measure, N, rank=1) - w = torch.cat([w, w.conj()], dim=-1) - V = torch.cat([V, V.conj()], dim=-1) - B = torch.cat([B, B.conj()], dim=-1) - P = torch.cat([P, P.conj()], dim=-1) - Q = P - # q = torch.cat([q, q.conj()], dim=-1) - A = torch.diag_embed(w) - contract('... r p, ... r q -> ... p q', P, Q.conj()) - - A = contract('ij, jk, kl -> ... il', V, A, V.conj().transpose(-1,-2)) # Ap^{-1} = V @ w^{-1} @ V^T - B = contract('ij, ... j -> ... i', V, B) - print(A.real) - print(B.real) - -if __name__ == '__main__': - from benchmark import utils - - torch.set_printoptions(precision=3) - - device = 'cuda' # 'cpu' - device = torch.device(device) - - # benchmark_krylov(measure='legs', rank=1) - test_nplr() + # return w, P, B, C, V + return w, P, B, V diff --git a/src/models/hippo/transition.py b/src/models/hippo/transition.py index 7bbbb58..a302e94 100644 --- a/src/models/hippo/transition.py +++ b/src/models/hippo/transition.py @@ -1,6 +1,6 @@ """ Utilities to calculate the transitions of the HiPPO ODE x' = Ax + Bu and discrete-time recurrence approximation. -Note that this logic was heavily used in the LSSL, but is no longed needed for S3. +Note that these modules were heavily used in LSSL, but is no longed needed for S4. """ import torch @@ -578,6 +578,27 @@ def __init__(self, N, diag_scale=2, diag_add=True, **kwargs): # print(self.A) +class JacTriDInverseAdaptiveTransition(TriDInverseAdaptiveTransition): + def __init__(self, N, halve=False, double_B=True, **kwargs): + # print(diag_scale, kwargs) + p = torch.sqrt(2*torch.arange(N)+2) + dl = _diag(N, -1.) + du = _diag(N, 0.) + d = torch.ones(N) + if halve: + c = - .5 * torch.ones(N) + else: + c = 0.0 * torch.ones(N) + + if double_B: + B = 2 * torch.ones(N) + else: + B = torch.ones(N) + + super().__init__(N, dl, d, du, p, p, c, B, **kwargs) + # print(self.A) + + class ChebITriDInverseAdaptiveTransition(TriDInverseAdaptiveTransition): def __init__(self, N, **kwargs): # p = torch.sqrt(1+2*torch.arange(N)) diff --git a/src/models/hippo/visualizations.py b/src/models/hippo/visualizations.py new file mode 100644 index 0000000..b2208e8 --- /dev/null +++ b/src/models/hippo/visualizations.py @@ -0,0 +1,501 @@ +""" Standalone implementation of HiPPO operators. + +Contains experiments for the function reconstruction experiment in original HiPPO paper, as well as new animations from "How to Train Your HiPPO" + +This file ports the notebook notebooks/hippo_function_approximation.ipynb, which is recommended if Jupyter is supported +""" + +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import numpy as np +from scipy import signal +from scipy import linalg as la +from scipy import special as ss +from einops import rearrange, repeat, reduce + +import src.models.functional.unroll as unroll # Not necessary, can comment out and set fast=False in HiPPO modules + +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + +import seaborn as sns +sns.set(rc={ + "figure.dpi":300, + 'savefig.dpi':300, + 'animation.html':'jshtml', + 'animation.embed_limit':100, # Max animation size in Mb +}) +# sns.set_context('notebook') +sns.set_style('ticks') # or 'whitegrid' + +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +# HiPPO matrices +def transition(measure, N, **measure_args): + # Laguerre (translated) + if measure == 'lagt': + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Legendre (translated) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == 'fourier': + freqs = np.arange(N//2) + d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] + A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2 + B[0] = 2**.5 + A = A - B[:, None] * B[None, :] + # A = A - np.eye(N) + B *= 2**.5 + B = B[:, None] + + return A, B + +def measure(method, c=0.0): + if method == 'legt': + fn = lambda x: np.heaviside(x, 0.0) * np.heaviside(1.0-x, 0.0) + elif method == 'legs': + fn = lambda x: np.heaviside(x, 1.0) * np.exp(-x) + elif method == 'lagt': + fn = lambda x: np.heaviside(x, 1.0) * np.exp(-x) + elif method in ['fourier']: + fn = lambda x: np.heaviside(x, 1.0) * np.heaviside(1.0-x, 1.0) + else: raise NotImplementedError + fn_tilted = lambda x: np.exp(c*x) * fn(x) + return fn_tilted + +def basis(method, N, vals, c=0.0, truncate_measure=True): + """ + vals: list of times (forward in time) + returns: shape (T, N) where T is length of vals + """ + if method == 'legt': + eval_matrix = ss.eval_legendre(np.arange(N)[:, None], 2*vals-1).T + eval_matrix *= (2*np.arange(N)+1)**.5 * (-1)**np.arange(N) + elif method == 'legs': + _vals = np.exp(-vals) + eval_matrix = ss.eval_legendre(np.arange(N)[:, None], 1-2*_vals).T # (L, N) + eval_matrix *= (2*np.arange(N)+1)**.5 * (-1)**np.arange(N) + elif method == 'lagt': + vals = vals[::-1] + eval_matrix = ss.eval_genlaguerre(np.arange(N)[:, None], 0, vals) + eval_matrix = eval_matrix * np.exp(-vals / 2) + eval_matrix = eval_matrix.T + elif method == 'fourier': + cos = 2**.5 * np.cos(2*np.pi*np.arange(N//2)[:, None]*(vals)) # (N/2, T/dt) + sin = 2**.5 * np.sin(2*np.pi*np.arange(N//2)[:, None]*(vals)) # (N/2, T/dt) + cos[0] /= 2**.5 + eval_matrix = np.stack([cos.T, sin.T], axis=-1).reshape(-1, N) # (T/dt, N) +# print("eval_matrix shape", eval_matrix.shape) + + if truncate_measure: + eval_matrix[measure(method)(vals) == 0.0] = 0.0 + + p = torch.tensor(eval_matrix) + p *= np.exp(-c*vals)[:, None] # [::-1, None] + return p + + +class HiPPOScale(nn.Module): + """ Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """ + def __init__(self, N, method='legs', max_length=1024, discretization='bilinear'): + """ + max_length: maximum sequence length + """ + super().__init__() + self.N = N + A, B = transition(method, N) + B = B.squeeze(-1) + A_stacked = np.empty((max_length, N, N), dtype=A.dtype) + B_stacked = np.empty((max_length, N), dtype=B.dtype) + for t in range(1, max_length + 1): + At = A / t + Bt = B / t + if discretization == 'forward': + A_stacked[t - 1] = np.eye(N) + At + B_stacked[t - 1] = Bt + elif discretization == 'backward': + A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True) + B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True) + elif discretization == 'bilinear': + A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True) + B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True) + else: # ZOH + A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t))) + B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True) + self.register_buffer('A_stacked', torch.Tensor(A_stacked)) # (max_length, N, N) + self.register_buffer('B_stacked', torch.Tensor(B_stacked)) # (max_length, N) + + vals = np.linspace(0.0, 1.0, max_length) + self.eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T ) + + def forward(self, inputs, fast=True): + """ + inputs : (length, ...) + output : (length, ..., N) where N is the order of the HiPPO projection + """ + + L = inputs.shape[0] + + inputs = inputs.unsqueeze(-1) + u = torch.transpose(inputs, 0, -2) + u = u * self.B_stacked[:L] + u = torch.transpose(u, 0, -2) # (length, ..., N) + + if fast: + result = unroll.variable_unroll_matrix(self.A_stacked[:L], u) + return result + + c = torch.zeros(u.shape[1:]).to(inputs) + cs = [] + for t, f in enumerate(inputs): + c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f + cs.append(c) + return torch.stack(cs, dim=0) + + def reconstruct(self, c): + a = self.eval_matrix.to(c) @ c.unsqueeze(-1) + return a + +class HiPPO(nn.Module): + """ Linear time invariant x' = Ax + Bu """ + def __init__(self, N, method='legt', dt=1.0, T=1.0, discretization='bilinear', scale=False, c=0.0): + """ + N: the order of the HiPPO projection + dt: discretization step size - should be roughly inverse to the length of the sequence + """ + super().__init__() + self.method = method + self.N = N + self.dt = dt + self.T = T + self.c = c + + A, B = transition(method, N) + A = A + np.eye(N)*c + self.A = A + self.B = B.squeeze(-1) + self.measure_fn = measure(method) + + C = np.ones((1, N)) + D = np.zeros((1,)) + dA, dB, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization) + + dB = dB.squeeze(-1) + + self.register_buffer('dA', torch.Tensor(dA)) # (N, N) + self.register_buffer('dB', torch.Tensor(dB)) # (N,) + + self.vals = np.arange(0.0, T, dt) + self.eval_matrix = basis(self.method, self.N, self.vals, c=self.c) # (T/dt, N) + self.measure = measure(self.method)(self.vals) + + + def forward(self, inputs, fast=True): + """ + inputs : (length, ...) + output : (length, ..., N) where N is the order of the HiPPO projection + """ + + inputs = inputs.unsqueeze(-1) + u = inputs * self.dB # (length, ..., N) + + if fast: + dA = repeat(self.dA, 'm n -> l m n', l=u.size(0)) + return unroll.variable_unroll_matrix(dA, u) + + c = torch.zeros(u.shape[1:]).to(inputs) + cs = [] + for f in inputs: + c = F.linear(c, self.dA) + self.dB * f + cs.append(c) + return torch.stack(cs, dim=0) + + + + def reconstruct(self, c, evals=None): # TODO take in a times array for reconstruction + """ + c: (..., N,) HiPPO coefficients (same as x(t) in S4 notation) + output: (..., L,) + """ + if evals is not None: + eval_matrix = basis(self.method, self.N, evals) + else: + eval_matrix = self.eval_matrix + + m = self.measure[self.measure != 0.0] + + c = c.unsqueeze(-1) + y = eval_matrix.to(c) @ c + return y.squeeze(-1).flip(-1) + + +### Synthetic data generation + +def whitesignal(period, dt, freq, rms=0.5, batch_shape=()): + """ + Produces output signal of length period / dt, band-limited to frequency freq + Output shape (*batch_shape, period/dt) + Adapted from the nengo library + """ + + if freq is not None and freq < 1. / period: + raise ValueError(f"Make ``{freq=} >= 1. / {period=}`` to produce a non-zero signal",) + + nyquist_cutoff = 0.5 / dt + if freq > nyquist_cutoff: + raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})") + + n_coefficients = int(np.ceil(period / dt / 2.)) + shape = batch_shape + (n_coefficients + 1,) + sigma = rms * np.sqrt(0.5) + coefficients = 1j * np.random.normal(0., sigma, size=shape) + coefficients[..., -1] = 0. + coefficients += np.random.normal(0., sigma, size=shape) + coefficients[..., 0] = 0. + + set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq + coefficients *= (1-set_to_zero) + power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients) + if power_correction > 0.: coefficients /= power_correction + coefficients *= np.sqrt(2 * n_coefficients) + signal = np.fft.irfft(coefficients, axis=-1) + signal = signal - signal[..., :1] # Start from 0 + return signal + + +def plot(T, dt, N, freq): + np.random.seed(0) + vals = np.arange(0.0, T, dt) + + u = whitesignal(T, dt, freq=freq) + u = torch.tensor(u, dtype=torch.float) + u = u.to(device) + + plt.figure(figsize=(16, 8)) + offset = 0.0 + plt.plot(vals, u.cpu()+offset, 'k', linewidth=1.0) + + # Linear Time Invariant (LTI) methods x' = Ax + Bu + lti_methods = [ + 'legs', + 'legt', + 'fourier', + ] + + for method in lti_methods: + hippo = HiPPO(method=method, N=N, dt=dt, T=T).to(device) + u_hippo = hippo.reconstruct(hippo(u))[-1].cpu() + plt.plot(vals[-len(u_hippo):], u_hippo, label=method) + + # Original HiPPO-LegS, which uses time-varying SSM x' = 1/t [ Ax + Bu] + # we call this "linear scale invariant" + lsi_methods = ['legs'] + for method in lsi_methods: + hippo = HiPPOScale(N=N, method=method, max_length=int(T/dt)).to(device) + u_hippo = hippo.reconstruct(hippo(u))[-1].cpu() + plt.plot(vals[-len(u_hippo):], u_hippo, label=method+' (scaled)') + + + # plt.xlabel('Time (normalized)', labelpad=-10) + plt.legend() + plt.savefig(f'function_approximation.pdf', bbox_inches='tight') + plt.show() + plt.close() + + +# Animation code from HTTYH + +def plt_lines(x, y, color, size, label=None): + return plt.plot(x, y, color, linewidth=size, label=label)[0] + +def update_lines(ln, x, y): + ln.set_data(x, y) + +def animate_hippo( + method, + T=5, dt=5e-4, N=64, freq=20.0, + interval=100, + plot_hippo=False, hippo_offset=0.0, label_hippo=False, + plot_measure=False, measure_offset=-3.0, label_measure=False, + plot_coeff=None, coeff_offset=3.0, + plot_s4=False, s4_offset=6.0, + plot_hippo_type='line', plot_measure_type='line', plot_coeff_type='line', + size=1.0, + plot_legend=True, plot_xticks=True, plot_box=True, + plot_vline=False, + animate_u=False, + seed=2, +): + np.random.seed(seed) + + vals = np.arange(0, int(T/dt)+1) + L = int(T/dt)+1 + + u = torch.FloatTensor(whitesignal(T, dt, freq=freq)) + u = F.pad(u, (1, 0)) + u = u + torch.FloatTensor(np.sin(1.5*np.pi/T*np.arange(0, T+dt, dt))) # add 3/4 of a sin cycle + u = u.to(device) + + hippo = HiPPO(method=method, N=N, dt=dt, T=T).to(device) + coef_hippo = hippo(u).cpu().numpy() + h_hippo = hippo.reconstruct(hippo(u)).cpu().numpy() + u = u.cpu().numpy() + + fig, ax = plt.subplots(figsize=(12, 4)) + + if animate_u: + ln_u = plt_lines([], [], 'k', size, label='Input $u(t)$') + else: + plt_lines(vals, u, 'k', size, label='Input $u(t)$') + + if plot_hippo: + label_args = {'label': 'HiPPO reconstruction'} if label_hippo else {} + ln = plt_lines([], [], size=size, color='red', **label_args) + + if plot_measure: + label_args = {'label': 'HiPPO Measure'} if label_measure else {} + ln_measure = plt_lines(vals, np.zeros(len(vals))+measure_offset, size=size, color='green', **label_args) + + if plot_coeff is None: plot_coeff = [] + if isinstance(plot_coeff, int): plot_coeff = [plot_coeff] + if len(plot_coeff) > 0: + ln_coeffs = [ + plt_lines([], [], size=size, color='blue') + for _ in plot_coeff + ] + plt_lines([], [], size=size, color='blue', label='State $x(t)$') # For the legend + + + ### Y AXIS LIMITS + if plot_measure: + min_y = measure_offset + else: + min_y = np.min(u) + + if len(plot_coeff) > 0: + max_u = np.max(u) + coeff_offset + else: + max_u = np.max(u) + + + C = np.random.random(N) + s4 = np.sum(coef_hippo * C, axis=-1) + max_s4 = 0.0 + if plot_s4: + ln_s4 = plt_lines([], [], size=size, color='red', label='Output $y(t)$') + max_s4 = np.max(s4)+s4_offset + + if plot_vline: + ln_vline = ax.axvline(0, ls='-', color='k', lw=1) + + if plot_legend: + plt.legend(loc='upper left', fontsize='x-small') + + + def init(): + left_endpoint = vals[0] + ax.set_xlim(left_endpoint, vals[-1]+1) + ax.set_ylim(min_y, max(max_u, max_s4)) + ax.set_yticks([]) + if not plot_xticks: ax.set_xticks([]) + if not plot_box: plt.box(False) + return [] # ln, + + def update(frame): + if animate_u: + xdata = np.arange(frame) + ydata = u[:frame] + update_lines(ln_u, xdata, ydata) + + m = np.zeros(len(vals)) + m[:frame] = hippo.measure_fn(np.arange(frame)*dt)[::-1] + xdata = vals + if plot_measure: + update_lines(ln_measure, xdata, m+measure_offset) + + if plot_hippo: + ydata = h_hippo[frame] + hippo_offset + m2 = hippo.measure_fn(np.arange(len(ydata))*dt)[::-1] + # Remove reconstruction where measure is 0 + ydata[m2 == 0.0] = np.nan + xdata = np.arange(frame-len(ydata), frame) + update_lines(ln, xdata, ydata) + + if len(plot_coeff) > 0: + for coeff, ln_coeff in zip(plot_coeff, ln_coeffs): + update_lines(ln_coeff, np.arange(frame), coef_hippo[:frame, coeff] + coeff_offset) + if plot_s4: # Only scale case; scale case should copy plot_hippo logic + update_lines(ln_s4, np.arange(0, frame), s4[:frame] + s4_offset) + + if plot_vline: + ln_vline.set_xdata([frame, frame]) + + return [] + + ani = FuncAnimation(fig, update, + frames=np.arange(0, int(T*1000/interval)+1)*int(interval/1000/dt), + interval=interval, + init_func=init, blit=True) + + return ani + + +if __name__ == '__main__': + plot(T=3, dt=1e-3, N=64, freq=3.0) + + # Visualize HiPPO online reconstruction + + ani = animate_hippo( + 'legs', # Try 'legt' or 'fourier' + T=5, dt=5e-4, N=64, interval=100, + # T=1, dt=1e-3, N=64, interval=200, # Faster rendering for testing + size=1.0, + + animate_u=True, + plot_hippo=True, hippo_offset=0.0, label_hippo=True, + plot_s4=False, s4_offset=6.0, + plot_measure=True, measure_offset=-3.0, label_measure=True, + plot_coeff=[], coeff_offset=3.0, + plot_legend=True, plot_xticks=True, plot_box=True, + plot_vline=True, + ) + ani.save('hippo_legs.gif') + + # Visualize S4 + + ani = animate_hippo( + 'legs', # Try 'legt' or 'fourier' + T=5, dt=5e-4, N=64, interval=100, + size=1.0, + + animate_u=True, + plot_hippo=False, hippo_offset=0.0, label_hippo=True, + plot_s4=True, s4_offset=6.0, + plot_measure=False, measure_offset=-3.0, label_measure=True, + plot_coeff=[0,1,2,3], coeff_offset=3.0, + plot_legend=True, plot_xticks=True, plot_box=True, + plot_vline=True, + ) + ani.save('s4_legs.gif') diff --git a/src/models/nn/__init__.py b/src/models/nn/__init__.py index 67f10de..aee8113 100644 --- a/src/models/nn/__init__.py +++ b/src/models/nn/__init__.py @@ -1 +1 @@ -from .components import LinearActivation, Activation, Normalization +from .components import LinearActivation, Activation, Normalization, DropoutNd diff --git a/src/models/nn/adaptive_softmax.py b/src/models/nn/adaptive_softmax.py index b900fe0..4ac9e2f 100644 --- a/src/models/nn/adaptive_softmax.py +++ b/src/models/nn/adaptive_softmax.py @@ -55,7 +55,7 @@ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, self.n_clusters = len(self.cutoffs) - 1 self.head_size = self.shortlist_size + self.n_clusters - # [21-09-15 AG]: bake the first False into the definition, just as [0] is built into the cutoffs + # bake the first False into the definition, just as [0] is built into the cutoffs if tie_projs is None: tie_projs = [] elif isinstance(tie_projs, bool): tie_projs = [tie_projs] * len(cutoffs) else: tie_projs = list(tie_projs) @@ -89,7 +89,6 @@ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, nn.Parameter(torch.zeros(d_proj, d_embed)) ) else: - # self.out_projs = [None] * len(self.cutoffs) self.out_projs.append(None) self.out_layers_biases.append( @@ -222,6 +221,7 @@ def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) + # First term accounts for cluster probabilities logprob_i = head_logprob_i[:, -i] \ + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) @@ -230,10 +230,72 @@ def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args else: nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) - offset += logprob_i.size(0) + offset += logprob_i.size(0) # TODO This should be a bug in the original implementation; it should go into the continue case above as well return nll.mean() # TODO maybe cases for length or padding_mask + def compute_logits(self, hidden): + """Compute full vector of logits + + Adapted from https://github.com/kimiyoung/transformer-xl/issues/88 + """ + hidden = hidden.reshape(-1, hidden.size(-1)) + + if self.n_clusters == 0: + logits = self._compute_logit(hidden, self.out_layers_weights[0], + self.out_layers_biases[0], self.get_out_proj(0)) + return logits + else: + # construct weights and biases + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers_weights[0][l_idx:r_idx] + bias_i = self.out_layers_biases[0][l_idx:r_idx] + else: + weight_i = self.out_layers_weights[i] + bias_i = self.out_layers_biases[i] + + if i == 0: + weight_i = torch.cat( + [weight_i, self.cluster_weight], dim=0) + bias_i = torch.cat( + [bias_i, self.cluster_bias], dim=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) + + head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) + head_logprob = F.log_softmax(head_logit, dim=1) + + out_full_logps = [head_logprob[:, :self.cutoffs[0]]] + offset = 0 + cutoff_values = [0] + self.cutoffs + + for i in range(1, len(cutoff_values) - 1): + l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] + head_logprob_i = head_logprob # .index_select(0, indices_i) + + if i == 0: + logprob_i = head_logprob_i + else: + weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) + + hidden_i = hidden # .index_select(0, indices_i) + + tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) + tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) + logprob_i = head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i + + offset += logprob_i.size(0) + out_full_logps.append(logprob_i) + out_full_logps = torch.cat(out_full_logps, dim = 1) + # print(torch.sum(out_full_ps), out_full_ps.shape) + return out_full_logps + class AdaptiveEmbedding(nn.Module): """ Copy of transformers.AdaptiveEmbedding that works with fp16 by replacing the index_put_ operation @@ -277,7 +339,7 @@ def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_val=1, ini # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) _init_proj(self.emb_projs[-1], d_proj, init_scale) - def forward(self, inp, *args, **kwargs): + def forward(self, inp): if self.div_val == 1: embed = self.emb_layers[0](inp) embed = self.drop(embed) @@ -285,9 +347,9 @@ def forward(self, inp, *args, **kwargs): embed = F.linear(embed, self.emb_projs[0]) else: param = next(self.parameters()) - inp_flat = inp.view(-1) + inp_flat = inp.reshape(-1) - # Changes + # Changes from original impl # emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) embeddings = [] indices = torch.zeros_like(inp_flat) # empty should work as long as cutoffs[-1] > max token @@ -340,7 +402,3 @@ def _init_weight(weight, d : int, init_scale : Optional[float], default=None): _init_embed = functools.partial(_init_weight, default=0.02) _init_proj = functools.partial(_init_weight, default=0.01) - -### Just for this codebase, we need to squeeze the last dimension because inputs are always given as (B, L, D) instead of (B, L) -import src.models.nn.utils as U -# AdaptiveEmbedding = U.Squeeze(AdaptiveEmbedding) diff --git a/src/models/nn/components.py b/src/models/nn/components.py index 21d37a1..df91e17 100644 --- a/src/models/nn/components.py +++ b/src/models/nn/components.py @@ -2,13 +2,98 @@ from functools import partial import math +from typing import ForwardRef import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from opt_einsum import contract from src.models.nn.exprnn.orthogonal import modrelu + +def stochastic_depth(input: torch.tensor, p: float, mode: str, training: bool = True): + """ + Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" + `_ used for randomly dropping residual + branches of residual architectures. + + Args: + input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): probability of the input to be zeroed. + mode (str): ``"batch"`` or ``"row"``. + ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes + randomly selected rows from the batch. + training: apply stochastic depth if is ``True``. Default: ``True`` + + Returns: + Tensor[N, ...]: The randomly zeroed tensor. + """ + if p < 0.0 or p > 1.0: + raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) + if mode not in ["batch", "row"]: + raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) + if not training or p == 0.0: + return input + + survival_rate = 1.0 - p + if mode == "row": + size = [input.shape[0]] + [1] * (input.ndim - 1) + else: + size = [1] * input.ndim + noise = torch.empty(size, dtype=input.dtype, device=input.device) + noise = noise.bernoulli_(survival_rate).div_(survival_rate) + return input * noise + +class StochasticDepth(nn.Module): + """ + See :func:`stochastic_depth`. + """ + def __init__(self, p: float, mode: str) -> None: + # TODO(karan): need to upgrade to torchvision==0.11.0 to use StochasticDepth directly + # from torchvision.ops import StochasticDepth + super().__init__() + self.p = p + self.mode = mode + + def forward(self, input): + return stochastic_depth(input, self.p, self.mode, self.training) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'p=' + str(self.p) + tmpstr += ', mode=' + str(self.mode) + tmpstr += ')' + return tmpstr + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """ X: (batch, dim, lengths...) """ + if self.training: + if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + X = X * mask * (1.0/(1-self.p)) + if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') + return X + return X + + def Activation(activation=None, size=None, dim=-1): if activation in [ None, 'id', 'identity', 'linear' ]: return nn.Identity() @@ -26,6 +111,10 @@ def Activation(activation=None, size=None, dim=-1): return nn.Sigmoid() elif activation == 'modrelu': return Modrelu(size) + elif activation == 'sqrelu': + return SquaredReLU() + elif activation == 'ln': + return TransposedLN(dim) else: raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) @@ -71,6 +160,7 @@ def LinearActivation( """ Returns a linear nn.Module with control over axes order, initialization, and activation """ # Construct core module + # linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear linear_cls = TransposedLinear if transposed else nn.Linear if activation == 'glu': d_output *= 2 linear = linear_cls(d_input, d_output, bias=bias, **kwargs) @@ -88,13 +178,19 @@ def LinearActivation( linear = nn.utils.weight_norm(linear) if activate and activation is not None: - activation = Activation(activation, d_output, dim=-2 if transposed else -1) + activation = Activation(activation, d_output, dim=1 if transposed else -1) linear = nn.Sequential(linear, activation) return linear +class SquaredReLU(nn.Module): + def forward(self, x): + return F.relu(x)**2 + class TransposedLinear(nn.Module): - """ Linear module on the second-to-last dimension """ + """ Linear module on the second-to-last dimension + Assumes shape (B, D, L), where L can be 1 or more axis + """ def __init__(self, d_input, d_output, bias=True): super().__init__() @@ -104,18 +200,22 @@ def __init__(self, d_input, d_output, bias=True): # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent if bias: - self.bias = nn.Parameter(torch.empty(d_output, 1)) + self.bias = nn.Parameter(torch.empty(d_output)) bound = 1 / math.sqrt(d_input) nn.init.uniform_(self.bias, -bound, bound) + setattr(self.bias, "_optim", {"weight_decay": 0.0}) else: self.bias = 0.0 def forward(self, x): - return contract('... u l, v u -> ... v l', x, self.weight) + self.bias + num_axis = len(x.shape[2:]) # num_axis in L, for broadcasting bias + y = contract('b u ..., v u -> b v ...', x, self.weight) + self.bias.view(-1, *[1]*num_axis) + return y class TransposedLN(nn.Module): - """ LayerNorm module over second-to-last dimension + """ LayerNorm module over second dimension + Assumes shape (B, D, L), where L can be 1 or more axis This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup """ @@ -125,15 +225,20 @@ def __init__(self, d, scalar=True): if self.scalar: self.m = nn.Parameter(torch.zeros(1)) self.s = nn.Parameter(torch.ones(1)) + setattr(self.m, "_optim", {"weight_decay": 0.0}) + setattr(self.s, "_optim", {"weight_decay": 0.0}) else: self.ln = nn.LayerNorm(d) def forward(self, x): if self.scalar: - s, m = torch.std_mean(x, dim=-2, unbiased=False, keepdim=True) + # calc. stats over D dim / channels + s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True) y = (self.s/s) * (x-m+self.m) else: - y = self.ln(x.transpose(-1,-2)).transpose(-1,-2) + # move channel to last axis, apply layer_norm, then move channel back to second axis + _x = self.ln(rearrange(x, 'b d ... -> b ... d')) + y = rearrange(_x, 'b ... d -> b d ...') return y class Normalization(nn.Module): @@ -146,6 +251,7 @@ def __init__( ): super().__init__() self.transposed = transposed + self._name_ = _name_ if _name_ == 'layer': self.channel = True # Normalize over channel dimension @@ -163,18 +269,111 @@ def __init__( norm_args = {'affine': True, 'track_running_stats': True} norm_args.update(kwargs) self.norm = nn.BatchNorm1d(d, **norm_args) + elif _name_ == 'group': + self.channel = False + self.norm = nn.GroupNorm(1, d, *kwargs) elif _name_ == 'none': self.channel = True self.norm = nn.Identity() else: raise NotImplementedError def forward(self, x): + # Handle higher dimension logic + shape = x.shape + if self.transposed: + x = rearrange(x, 'b d ... -> b d (...)') + else: + x = rearrange(x, 'b ... d -> b (...)d ') + # The cases of LayerNorm / no normalization are automatically handled in all cases # Instance/Batch Norm work automatically with transposed axes if self.channel or self.transposed: - return self.norm(x) + x = self.norm(x) else: x = x.transpose(-1, -2) x = self.norm(x) x = x.transpose(-1, -2) - return x + + x = x.view(shape) + return x + + def step(self, x, **kwargs): + assert self._name_ in ["layer", "none"] + if self.transposed: x = x.unsqueeze(-1) + x = self.forward(x) + if self.transposed: x = x.squeeze(-1) + return x + +class TSNormalization(nn.Module): + + def __init__(self, method, horizon): + super().__init__() + + self.method = method + self.horizon = horizon + + + def forward(self, x): + # x must be BLD + if self.method == 'mean': + self.scale = x.abs()[:, :-self.horizon].mean(dim=1)[:, None, :] + return x / self.scale + elif self.method == 'last': + self.scale = x.abs()[:, -self.horizon-1][:, None, :] + return x / self.scale + return x + +class TSInverseNormalization(nn.Module): + + def __init__(self, method, normalizer): + super().__init__() + + self.method = method + self.normalizer = normalizer + + def forward(self, x): + if self.method == 'mean' or self.method == 'last': + return x * self.normalizer.scale + return x + +class ReversibleInstanceNorm1dInput(nn.Module): + def __init__(self, d, transposed=False): + super().__init__() + # BLD if transpoed is False, otherwise BDL + self.transposed = transposed + self.norm = nn.InstanceNorm1d(d, affine=True, track_running_stats=False) + + def forward(self, x): + # Means, stds + if not self.transposed: + x = x.transpose(-1, -2) + + self.s, self.m = torch.std_mean(x, dim=-1, unbiased=False, keepdim=True) + self.s += 1e-4 + + x = (x - self.m) / self.s + # x = self.norm.weight.unsqueeze(-1) * x + self.norm.bias.unsqueeze(-1) + + if not self.transposed: + return x.transpose(-1, -2) + return x + +class ReversibleInstanceNorm1dOutput(nn.Module): + + def __init__(self, norm_input): + super().__init__() + self.transposed = norm_input.transposed + self.weight = norm_input.norm.weight + self.bias = norm_input.norm.bias + self.norm_input = norm_input + + def forward(self, x): + if not self.transposed: + x = x.transpose(-1, -2) + + # x = (x - self.bias.unsqueeze(-1))/self.weight.unsqueeze(-1) + x = x * self.norm_input.s + self.norm_input.m + + if not self.transposed: + return x.transpose(-1, -2) + return x diff --git a/src/models/nn/dxt.py b/src/models/nn/dxt.py index 39de73b..a9813bc 100644 --- a/src/models/nn/dxt.py +++ b/src/models/nn/dxt.py @@ -1,4 +1,7 @@ -""" Implementations of several types of Discrete Sin/Cosine Transforms with various reductions to FFT. """ +"""Implementations of several types of Discrete Sin/Cosine Transforms with various reductions to FFT. + +Currently not used by S4 +""" import torch import torch.nn as nn @@ -20,7 +23,7 @@ def __init__(self, N, norm='backward'): self.register_buffer('P', P) # TODO take care of normalization - Q = np.exp(-1j * np.pi / (2 * self.N) * np.arange(self.N)) + Q = np.exp(-1j * np.pi / (2 * self.N) * np.arange(self.N)) Q = torch.tensor(Q, dtype=torch.cfloat) self.register_buffer('Q', Q) # half shift diff --git a/src/models/nn/gate.py b/src/models/nn/gate.py index 43e4392..d0a531f 100644 --- a/src/models/nn/gate.py +++ b/src/models/nn/gate.py @@ -120,11 +120,9 @@ def forward_diff(self, x): def backward_diff(self, x): return x / (1+x) - # return 1 / (1+1/x) def trapezoid(self, x): return x / (1 + x/2) - # return 1 / (.5 + 1/x) def zoh(self, x): return 1 - torch.exp(-x) diff --git a/src/models/nn/utils.py b/src/models/nn/utils.py index 13eae78..6fd5dda 100644 --- a/src/models/nn/utils.py +++ b/src/models/nn/utils.py @@ -1,53 +1,101 @@ -""" Utility wrappers around modules to let them handle Tuples and extra arguments """ +""" Utility wrappers around modules to let them handle Args and extra arguments """ -# import torch +import inspect +from functools import wraps +import torch from torch import nn - -def TupleModule(module): - """ Wrap a nn.Module class with two features: - - discard extra arguments in the forward pass - - return a tuple +def wrap_kwargs(f): """ - # TODO maybe possible with functools.wraps - class WrappedModule(module): - def forward(self, x, *args, **kwargs): - y = super().forward(x) - return y if isinstance(y, tuple) else (y,) - # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically - WrappedModule.__name__ = module.__name__ - return WrappedModule - -def Squeeze(module, dim=-1): - """ Wrap a nn.Module to squeeze a dimension. - Use for e.g. Embeddings, because our sequence API assumes a feature dimension while nn.Embedding does not + Given a callable f that can consume some named arguments, + wrap it with a kwargs that passes back any unused args + + EXAMPLES + -------- + + Basic usage: + def foo(x, y=None): + return x + + wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) + + -------- + + The wrapped function can return its own argument dictionary, + which gets merged with the new kwargs. + def foo(x, y=None): + return x, {} + wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) + + def foo(x, y=None): + return x, {"y": y, "z": None} + wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2}) + + -------- + + The wrapped function can have its own kwargs parameter: + def foo(x, y=None, **kw_args): + return x, {} + wrap_kwargs(foo)(0, y=1, z=2) == (0, {}) + + -------- + + Partial functions and modules work automatically: + class Module: + def forward(self, x, y=0): + return x, {"y": y+1} + + m = Module() + + wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2}) + """ - class WrappedModule(module): - def forward(self, x, *args, **kwargs): - assert x.size(dim) == 1 - x = x.squeeze(dim) - y = super().forward(x) - return y - # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically - WrappedModule.__name__ = module.__name__ - return WrappedModule - -# TODO maybe call these TupleIdentity etc. instead? -Identity = TupleModule(nn.Identity) -Embedding = TupleModule(nn.Embedding) -# Embedding = TupleModule(Squeeze(nn.Embedding)) -Linear = TupleModule(nn.Linear) - -def TupleSequential(*modules): - """ Similar to TupleModule: - - Discard extra arguments in forward pass - - Return a Tuple + sig = inspect.signature(f) + # Check if f already has kwargs + has_kwargs = any([ + param.kind == inspect.Parameter.VAR_KEYWORD + for param in sig.parameters.values() + ]) + if has_kwargs: + @wraps(f) + def f_kwargs(*args, **kwargs): + y = f(*args, **kwargs) + if isinstance(y, tuple) and isinstance(y[-1], dict): + return y + else: + return y, {} + else: + param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD) + sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs]) + @wraps(f) + def f_kwargs(*args, **kwargs): + bound = sig_kwargs.bind(*args, **kwargs) + if "kwargs" in bound.arguments: + kwargs = bound.arguments.pop("kwargs") + else: + kwargs = {} + y = f(**bound.arguments) + if isinstance(y, tuple) and isinstance(y[-1], dict): + return *y[:-1], {**y[-1], **kwargs} + else: + return y, kwargs + return f_kwargs + +def discard_kwargs(f): + if f is None: return None + f_kwargs = wrap_kwargs(f) + @wraps(f) + def f_(*args, **kwargs): + return f_kwargs(*args, **kwargs)[0] + return f_ + +def PassthroughSequential(*modules): + """Special Sequential module that chains kwargs. Semantics are the same as nn.Sequential, with extra convenience features: - Discard None modules - Flatten inner Sequential modules - - Discard extra Identity modules - - If only one Module, extract it to top level + - In case with 0 or 1 Module, rename the class for ease of inspection """ def flatten(module): if isinstance(module, nn.Sequential): @@ -56,34 +104,21 @@ def flatten(module): return [module] modules = flatten(nn.Sequential(*modules)) - modules = [module for module in modules if module if not None and not isinstance(module, nn.Identity)] + modules = [module for module in modules if module if not None] class Sequential(nn.Sequential): - def forward(self, x, *args, **kwargs): - # layer_args = [] - x = x, + def forward(self, x, **kwargs): for layer in self: - x = layer(*(x + args), **kwargs) # Always a tuple - # args = tuple(layer_args) + args - return x # Returns a tuple + x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs) + return x, kwargs + + def step(self, x, **kwargs): + for layer in self: + x, kwargs = wrap_kwargs(layer.step)(x, **kwargs) + return x, kwargs if len(modules) == 0: - return Identity() + Sequential.__name__ = "Identity" elif len(modules) == 1: - return modules[0] - else: - return Sequential(*modules) - -def Transpose(module_cls): - class TransposedModule(module_cls): - def __init__(self, *args, transposed=False, **kwargs): - super().__init__(*args, **kwargs) - self.transposed = transposed - - def forward(self, x, *args, **kwargs): - if self.transposed: x = x.transpose(-1, -2) - y, *z = super().forward(x, *args, **kwargs) - if self.transposed: y = y.transpose(-1, -2) - return y, *z - TransposedModule.__name__ = module_cls.__name__ - return TransposedModule + Sequential.__name__ = type(modules[0]).__name__ + return Sequential(*modules) diff --git a/src/models/s4/README.md b/src/models/s4/README.md new file mode 100644 index 0000000..808693c --- /dev/null +++ b/src/models/s4/README.md @@ -0,0 +1,101 @@ +This folder contains several standalone implementations of S4 variants. +The file [s4.py](./s4.py) contains the full implementation of S4 with all available options, which subsumes several variants of S4. +Other standalone implementations are documented below. + +## Full S4(D) Model + + +`s4.py` is a standalone implementation of the full S4(D) model with all options, which are documented inside the class. + +The corresponding [config](/configs/model/layer/s4.yaml) also lists all available options. + +### S4 + +S4 is characterized by the arguments `mode=nplr` (the Normal Plus Low-Rank kernel described in the original S4 paper) and `measure=legs` (the HiPPO-LegS matrix), which are both set by default. +Alternative measures are supported, such as `measure=fout` which is the S4-FouT model described in [HTTYH](https://arxiv.org/abs/2206.12037). + + +### S4D + +S4D is activated by the argument `mode=diag` which uses the diagonal kernel. +Pass in `measure=diag-lin` or `measure=diag-inv` for S4D-Lin or S4D-Inv. +Other options described in the S4D paper include +- `disc={'bilinear','zoh'}`: Bilinear vs. ZOH discretization +- `lr.B={0.0,None}`: frozen vs. trainable $B$ parameter (requires custom optimizer to register the hook) +- `real_type={'exp','relu','none'}`: parameterization of real part of $A$ + +### Usage and Features + +#### Convolution Mode +The `forward` pass of the module maps a sequence of shape `(B, H, L) -> (B, H, L)` (batch size, hidden dimension, sequence length). The forward pass first constructs a convolution kernel using the algorithm described in the S4(D) papers, then convolves using the FFT. + +#### Recurrent Mode +The `step` method of the module maps `(B, H) -> (B, H)`. This represents a single step or "unroll" of the model like an RNN. + +#### Sample Rate Change +The `rate` argument in the forward pass multiplies the internal step size $\Delta$. +For example, a model trained on audio signals at 16000Hz using the default `rate=1.0` can be used to process audio signals at 8000Hz *without retraining* by passing in `rate=2.0`. + +#### State Forwarding +The forward pass of the model accepts an optional initial state of shape `(B, H, N)`. +The model will then compute "forward" the state through the sequence, returning the final state as well as the output. + +Note that this is equivalent to using `step` repeatedly, but is much faster by combining both recurrent and convolutional mode. + +**It is recommended to use S4D for this feature. The S4 implementation is currently not optimized.** + +### Other Variants + +#### DSS + +[DSS](https://arxiv.org/abs/2203.14343) is the first diagonal SSM variant. It has two main characteristics: +1. *Computation* - uses a "softmax" which combines ZOH discretization + normalization over sequence length +2. *Initialization* - uses an [approximation](https://arxiv.org/abs/2206.11893) to the HiPPO matrix (also called HiPPO-LegS) + +This model is equivalent to setting the options +``` +S4(mode='diag', disc='dss', measure='diag-legs') +``` +Performance should be similar to S4D, but it may consume more memory. + +#### GSS + +[GSS](https://arxiv.org/abs/2206.13947) is another variant specialized for language modeling on TPUs. +It has two main characteristics: +1. *Gating* - Incorporates an additional multiplicative feedforward branch. Additionally, it bottlenecks the dimension of the input to the SSM. These changes are largely motivated by efficiently on TPUs, which is better suited for large feedforward matmuls rather than the FFT convolutions used by the SSM. +2. *Simplified kernel* - Matrix $A$ is randomly initialized, matrix $B=1$ and step size $\Delta=1.0$ are frozen. + +These modifications can all be flexibly toggled. The full GSS layer is roughly equivalent to the following options. +``` +S4( + gate=4, # Multiplicative gating layer that also expands dimension by factor of 4 + bottleneck=4, # Reduce dimension of SSM by factor of 4 + measure='diag-rand', # Randomly initialize A + dt_min=1.0, dt_max=1.0, # Initialize dt to 1.0 + lr={'dt': 0.0, 'B': 0.0}, # Freeze B and dt +) +``` + + +## Minimal S4D + +`s4d.py` contains a minimal implementation of the S4D layer. This file is primarily for pedagogical purposes to illustrate the simplicity of the core SSM principles behind S4. + + +This S4D layer is equivalent to using the full S4 layer with specific settings, and stripping out all extra features: + +``` +S4(mode='diag', measure='diag-lin', bidirectional=False, disc='zoh', real_type='exp') +``` + +The `example.py` script incorporates this into a simple deep neural network backbone to achieve 88% on sequential CIFAR with a model of 200K parameters. + +## Simple S4 + +TODO: Merge branch and document + + +## LSSL +[lssl.py](./lssl.py) is an implementation of the [predecessor](https://arxiv.org/abs/2110.13985) of S4. diff --git a/src/models/s4/lssl.py b/src/models/s4/lssl.py new file mode 100644 index 0000000..f52bb84 --- /dev/null +++ b/src/models/s4/lssl.py @@ -0,0 +1,302 @@ +""" Standalone file implementing the Linear State Space Layer (LSSL). + +The main Module is StateSpace, which implements a sequence-to-sequence described by the state space model. +Part of this model involves discretizing a continous-time system into a discrete transition, +implemented by the modules LegTTransition and LegTTransitionDense. +These perform equivalent computations; the latter is slower and simpler using Pytorch primitives, and is the default. +If speed is an issue for large models, try switching to the former module. +This requires compiling a Pytorch C extension by going into the folder `extensions/legt/` and running `python setup.py install`. +""" + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +try: + from extensions.legt.legt import legt_gbt_forward, legt_gbt_backward, legt_gbt_forward_t, legt_gbt_backward_t +except: + pass + +def triangular_toeplitz_multiply(u, v): + n = u.shape[-1] + u_expand = F.pad(u, (0, n)) + v_expand = F.pad(v, (0, n)) + u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) + v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] + return output + +def krylov(L, A, b): + """ Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. """ + + x = b.unsqueeze(-1) # (..., N, 1) + A_ = A + + done = L == 1 + while not done: + # Save memory on last iteration + l = x.shape[-1] + if L - l <= l: + done = True + _x = x[..., :L-l] + else: _x = x + + _x = A_ @ _x + x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes + if not done: A_ = A_ @ A_ + + assert x.shape[-1] == L + + x = x.contiguous() + return x + + +def hippo(N): + """ Return the HiPPO-LegT state matrices """ + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + return A, B + +class AdaptiveTransition(nn.Module): + """ General class which supports discretizing a state space equation x' = Ax + Bu + + Different subclasses can compute the forward and inverse mults in different ways + This particular method is specialized to the HiPPO-LegT transition for simplicity + """ + + def __init__(self, N): + """ + N: State space order, size of HiPPO matrix + """ + + super().__init__() + self.N = N + + A, B = hippo(N) + A = torch.as_tensor(A, dtype=torch.float) + B = torch.as_tensor(B, dtype=torch.float)[:, 0] + self.register_buffer('A', A) + self.register_buffer('B', B) + + # Register some common buffers + # (helps make sure every subclass has access to them on the right device) + I = torch.eye(N) + self.register_buffer('I', I) + + + def forward_mult(self, u, delta): + """ Computes (I + delta A) u + + A: (n, n) + u: (..., n) + delta: (...) or scalar + + output: (..., n) + """ + raise NotImplementedError + + def inverse_mult(self, u, delta): # TODO swap u, delta everywhere + """ Computes (I - d A)^-1 u """ + raise NotImplementedError + + def forward_diff(self, d, u, v): + """ Computes the 'forward diff' or Euler update rule: (I - d A)^-1 u + d B v + d: (...) + u: (..., N) + v: (...) + """ + v = d * v + v = v.unsqueeze(-1) * self.B + x = self.forward_mult(u, d) + x = x + v + return x + + def backward_diff(self, d, u, v): + """ Computes the 'forward diff' or Euler update rule: (I - d A)^-1 u + d (I - d A)^-1 B v + d: (...) + u: (..., N) + v: (...) + """ + v = d * v + v = v.unsqueeze(-1) * self.B + x = u + v + x = self.inverse_mult(x, d) + return x + + def bilinear(self, dt, u, v, alpha=.5): + """ Computes the bilinear (aka trapezoid or Tustin's) update rule. + + (I - d/2 A)^-1 (I + d/2 A) u + d B (I - d/2 A)^-1 B v + + dt: (...) + u: (..., N) + v: (...) + """ + x = self.forward_mult(u, (1-alpha)*dt) + v = dt * v + v = v.unsqueeze(-1) * self.B + x = x + v + x = self.inverse_mult(x, (alpha)*dt) + return x + + def gbt_A(self, dt, alpha=.5): + """ Compute the transition matrices associated with bilinear transform + + dt: (...) + returns: (..., N, N) + """ + # solve (N, ...) parallel problems of size N + dims = len(dt.shape) + I = self.I.view([self.N] + [1]*dims + [self.N]) + A = self.bilinear(dt, I, dt.new_zeros(*dt.shape), alpha=alpha) # (N, ..., N) + A = rearrange(A, 'n ... m -> ... m n', n=self.N, m=self.N) + return A + + def gbt_B(self, dt, alpha=.5): + B = self.bilinear(dt, dt.new_zeros(*dt.shape, self.N), dt.new_ones(1), alpha=alpha) # (..., N) + return B + + +class LegTTransitionDense(AdaptiveTransition): + """ Slower and memory inefficient version via manual matrix mult/inv """ + + def forward_mult(self, u, delta, transpose=False): + if isinstance(delta, torch.Tensor): + delta = delta.unsqueeze(-1) + A_ = self.A.transpose(-1, -2) if transpose else self.A + x = (A_ @ u.unsqueeze(-1)).squeeze(-1) + x = u + delta * x + + return x + + def inverse_mult(self, u, delta, transpose=False): + """ Computes (I - d A)^-1 u """ + + if isinstance(delta, torch.Tensor): + delta = delta.unsqueeze(-1).unsqueeze(-1) + _A = self.I - delta * self.A + if transpose: _A = _A.transpose(-1, -2) + + # x = torch.linalg.solve(_A, u.unsqueeze(-1)).squeeze(-1) # this can run out of memory + xs = [] + for _A_, u_ in zip(*torch.broadcast_tensors(_A, u.unsqueeze(-1))): + x_ = torch.linalg.solve(_A_, u_[...,:1]).squeeze(-1) + xs.append(x_) + x = torch.stack(xs, dim=0) + + return x + + +class LegTTransition(AdaptiveTransition): + """ Fast version using cuSPARSE tridiagonal solver """ + + def forward_mult(self, u, delta, transpose=False): + if transpose: return legt_gbt_forward_t(delta, u, transpose=True) + else: return legt_gbt_forward(delta, u) + + def inverse_mult(self, u, delta, transpose=False): + if transpose: return legt_gbt_backward_t(-delta, u, transpose=True) + else: return legt_gbt_backward(-delta, u) + + +class StateSpace(nn.Module): + """ Computes a state space layer. + + Simulates the state space ODE + x' = Ax + Bu + y = Cx + Du + + - A single state space computation maps a 1D function u to a 1D function y + - For an input of H features, each feature is independently run through the state space + with a different timescale / sampling rate / discretization step size. + """ + def __init__( + self, + d, # hidden dimension, also denoted H below + order=-1, # order of the state space, i.e. dimension N of the state x + dt_min=1e-3, # discretization step size - should be roughly inverse to the length of the sequence + dt_max=1e-1, + channels=1, # denoted by M below + dropout=0.0, + ): + super().__init__() + self.H = d + self.N = order if order > 0 else d + + # Construct transition + # self.transition = LegTTransition(self.N) # NOTE use this line for speed + self.transition = LegTTransitionDense(self.N) + + self.M = channels + + self.C = nn.Parameter(torch.randn(self.H, self.M, self.N)) + self.D = nn.Parameter(torch.randn(self.H, self.M)) + + # Initialize timescales + log_dt = torch.rand(self.H) * (math.log(dt_max)-math.log(dt_min)) + math.log(dt_min) + self.register_buffer('dt', torch.exp(log_dt)) + + # Cached Krylov (convolution filter) + self.k = None + + self.activation_fn = nn.GELU() + self.dropout = nn.Dropout(dropout) + + self.output_linear = nn.Linear(self.M * self.H, self.H) + + + def forward(self, u): # absorbs return_output and transformer src mask + """ + u: (L, B, H) or (length, batch, hidden) + Returns: (L, B, H) + """ + + # We need to compute the convolution filter if first pass or length changes + if self.k is None or u.shape[0] > self.k.shape[-1]: + A = self.transition.gbt_A(self.dt) # (..., N, N) + B = self.transition.gbt_B(self.dt) # (..., N) + self.k = krylov(u.shape[0], A, B) # (H, N, L) + + # Convolution + y = self.linear_system_from_krylov(u, self.k[..., :u.shape[0]]) # (L, B, H, M) + + # Dropout + y = self.dropout(self.activation_fn(y)) + + # Linear + y = rearrange(y, 'l b h m -> l b (h m)') # (L, B, H*M) + y = self.output_linear(y) # (L, B, H) + return y + + def linear_system_from_krylov(self, u, k): + """ + Computes the state-space system y = Cx + Du from Krylov matrix K(A, B) + + u: (L, B, ...) ... = H + C: (..., M, N) ... = H + D: (..., M) + k: (..., N, L) Krylov matrix representing b, Ab, A^2b... + + y: (L, B, ..., M) + """ + + + k = self.C @ k # (..., M, L) + + k = rearrange(k, '... m l -> m ... l') + k = k.to(u) # if training in half precision, need to go back to float32 for the fft + k = k.unsqueeze(1) # (M, 1, ..., L) + + v = u.unsqueeze(-1).transpose(0, -1) # (1, B, ..., L) + y = triangular_toeplitz_multiply(k, v) # (M, B, ..., L) + y = y.transpose(0, -1) # (L, B, ..., M) + y = y + u.unsqueeze(-1) * self.D # (L, B, ..., M) + return y diff --git a/src/models/s4/s4.py b/src/models/s4/s4.py new file mode 100644 index 0000000..9fc1035 --- /dev/null +++ b/src/models/s4/s4.py @@ -0,0 +1,1568 @@ +""" Standalone version of Structured (Sequence) State Space (S4) model. """ + +import logging +from functools import partial +import math +import numpy as np +from scipy import special as ss +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning.utilities import rank_zero_only +from einops import rearrange, repeat +import opt_einsum as oe + +contract = oe.contract +contract_expression = oe.contract_expression + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger +log = get_logger(__name__) + +""" Cauchy and Vandermonde kernels """ + +try: # Try CUDA extension + from extensions.cauchy.cauchy import cauchy_mult + has_cauchy_extension = True +except: + log.warn( + "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%" + ) + has_cauchy_extension = False + +try: # Try pykeops + import pykeops + from pykeops.torch import Genred + has_pykeops = True + log.info("Pykeops installation found.") + + def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] + return tensors + + def cauchy_conj(v, z, w): + """ Pykeops version """ + expr_num = 'z * ComplexReal(v) - Real2Complex(Sum(v * w))' + expr_denom = 'ComplexMult(z-w, z-Conj(w))' + + cauchy_mult = Genred( + f'ComplexDivide({expr_num}, {expr_denom})', + [ + 'v = Vj(2)', + 'z = Vi(2)', + 'w = Vj(2)', + ], + reduction_op='Sum', + axis=1, + ) + + v, z, w = _broadcast_dims(v, z, w) + v = _c2r(v) + z = _c2r(z) + w = _c2r(w) + + r = 2*cauchy_mult(v, z, w, backend='GPU') + return _r2c(r) + + def log_vandermonde(v, x, L): + expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'v = Vj(2)', + 'x = Vj(2)', + 'l = Vi(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(v, x, l, backend='GPU') + return 2*_r2c(r).real + + def log_vandermonde_transpose(u, v, x, L): + """ + u: ... H L + v: ... H N + x: ... H N + Returns: ... H N + + V = Vandermonde(a, L) : (H N L) + contract_L(V * u * v) + """ + expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'u = Vj(2)', + 'v = Vi(2)', + 'x = Vi(2)', + 'l = Vj(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + u, v, x, l = _broadcast_dims(u, v, x, l) + u = _c2r(u) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(u, v, x, l, backend='GPU') + return _r2c(r) + +except ImportError: + has_pykeops = False + if not has_cauchy_extension: + log.warning( + "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency." + ) + def cauchy_naive(v, z, w): + """ + v, w: (..., N) + z: (..., L) + returns: (..., L) + """ + cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) + + # Vandermonde functions + log.error( + "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." + ) + def log_vandermonde(v, x, L): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L) + return 2*vandermonde_prod.real + + def log_vandermonde_transpose(u, v, x, L): + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L) + return vandermonde_prod + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +_c2r = torch.view_as_real +_r2c = torch.view_as_complex +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + + +""" Simple nn.Module components """ + +def Activation(activation=None, dim=-1): + if activation in [ None, 'id', 'identity', 'linear' ]: + return nn.Identity() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'relu': + return nn.ReLU() + elif activation == 'gelu': + return nn.GELU() + elif activation in ['swish', 'silu']: + return nn.SiLU() + elif activation == 'glu': + return nn.GLU(dim=dim) + elif activation == 'sigmoid': + return nn.Sigmoid() + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + +def LinearActivation( + d_input, d_output, bias=True, + transposed=False, + activation=None, + activate=False, # Apply activation as part of this module + **kwargs, + ): + """ Returns a linear nn.Module with control over axes order, initialization, and activation """ + + # Construct core module + linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + if activation == 'glu': d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """ X: (batch, dim, lengths...) """ + if self.training: + if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + X = X * mask * (1.0/(1-self.p)) + if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') + return X + return X + +""" Misc functional utilities """ + +def power(L, A, v=None): + """ Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: I = powers[-1] @ I + L //= 2 + if L == 0: break + l *= 2 + powers.append(powers[-1] @ powers[-1]) + + if v is None: return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, '... (z l) -> ... z l', z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + + +""" HiPPO utilities """ + +def transition(measure, N): + """ A, B transition matrices for different measures """ + # Legendre (translated) + if measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + + # Halve again for timescale correctness + A *= 0.5 + B *= 0.5 + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == 'legsd': + # Essentially equivalent to S4D-LegS + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + A += .5 * B*B[None, :, 0] + B = B / 2.0 + elif measure in ['fourier_diag', 'foud']: + # Essentially equivalent to S4D-Lin + freqs = np.arange(N//2) + d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] + A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + A = A - .5 * np.eye(N) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + B = B[:, None] + elif measure in ['fourier', 'fout']: + freqs = np.arange(N//2) + d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] + B = B[:, None] + else: + raise NotImplementedError + + return A, B + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """ Return low-rank matrix L such that A + L is normal """ + + if measure == 'legs': + assert rank >= 1 + P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == 'legt': + assert rank >= 2 + P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + P *= 2**(-0.5) # Halve the rank correct just like the original matrix was halved + elif measure in ['fourier', 'fout']: + P = torch.zeros(N) + P[0::2] = 2**.5 + P[0] = 1 + P = P.unsqueeze(0) + elif measure in ['fourier_diag', 'foud', 'legsd']: + P = torch.zeros(1, N, dtype=dtype) + else: raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N) + return P + +def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): + """ Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or torch.double + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) + AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) + + # We require AP to be nearly skew-symmetric + _A = AP + AP.transpose(-1, -2) + if (err := torch.sum((_A - _A[0,0]*torch.eye(N))**2) / N) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): + print("WARNING: HiPPO matrix not skew symmetric", err) + + + # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately + # Imaginary part can use eigh instead of eig + w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) + + # Diagonalize in double precision + if diagonalize_precision: AP = AP.to(torch.double) + w_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N) + if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype) + w = w_re + 1j * w_im + # Check: V w V^{-1} = A + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + + # Only keep half of each conjugate pair + _, idx = torch.sort(w.imag) + w_sorted = w[idx] + V_sorted = V[:, idx] + + # There is an edge case when eigenvalues can be 0, which requires some machinery to handle + # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) + V = V_sorted[:, :N//2] + w = w_sorted[:N//2] + assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" + if w[-1].abs() < 1e-4: + V[:, -1] = 0. + V[0, -1] = 2**-0.5 + V[1, -1] = 2**-0.5 * 1j + + _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) + if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5): + print("Warning: Diagonalization of A matrix not numerically precise - error", err) + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + V_inv = V.conj().transpose(-1, -2) + + B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B + P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P + + return w, P, B, V + +def dplr(scaling, N, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False): + assert dtype == torch.float or torch.double + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + + pi = torch.tensor(math.pi) + if random_real: + real_part = torch.rand(H, N//2) + else: + real_part = .5 * torch.ones(H, N//2) + if random_imag: + imag_part = N//2 * torch.rand(H, N//2) + else: + imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) + + real_part = real_scale * real_part + if scaling == 'random': + imag_part = torch.randn(H, N//2) + elif scaling == 'real': + imag_part = 0 * imag_part + real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H) + elif scaling in ['linear', 'lin']: + imag_part = pi * imag_part + elif scaling in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix + imag_part = 1/pi * N * (N/(1+2*imag_part)-1) + elif scaling in ['inverse2', 'inv2']: + imag_part = 1/pi * N * (N/(1+imag_part)-1) + elif scaling in ['quadratic', 'quad']: + imag_part = 1/pi * (1+2*imag_part)**2 + elif scaling in ['legs', 'hippo']: + w, _, _, _ = nplr('legsd', N) + imag_part = w.imag + + else: raise NotImplementedError + imag_part = imag_scale * imag_part + w = -real_part + 1j * imag_part + + # Initialize B + if random_B: + B = torch.randn(H, N//2, dtype=dtype) + else: + B = torch.ones(H, N//2, dtype=dtype) + + if normalize: + norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector + B = B / zeta**.5 + + P = torch.randn(rank, H, N//2, dtype=dtype) + if diagonal: P = P * 0.0 + V = torch.eye(N, dtype=dtype)[: :N//2] # Only used in testing + V = repeat(V, 'n m -> h n m', h=H) + + return w, P, B, V + +def ssm(measure, N, R, H, **ssm_args): + """Dispatcher to create single SSM initialization + + N: state size + R: rank (for DPLR parameterization) + H: number of independent SSM copies + """ + + if measure == "dplr": + w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) + elif measure.startswith("diag"): + args = measure.split("-") + assert args[0] == "diag" and len(args) > 1 + scaling = args[1] + w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) + else: + w, P, B, V = nplr(measure, N, R, **ssm_args) + w = repeat(w, 'n -> s n', s=H) + P = repeat(P, 'r n -> r s n', s=H) + B = repeat(B, 'n -> s n', s=H) + V = repeat(V, 'n m -> s n m', s=H) + return w, P, B, V + +combinations = { + 'hippo': ['legs', 'fourier'], + 'diag': ['diag-inv', 'diag-lin'], + 'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'], +} + +def combination(measures, N, R, S, **ssm_args): + if isinstance(measures, str): + measures = combinations[measures] if measures in combinations else [measures] + + assert S % len(measures) == 0, f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" + w, P, B, V = zip( + *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] + ) + w = torch.cat(w, dim=0) # (S N) + P = torch.cat(P, dim=1) # (R S N) + B = torch.cat(B, dim=0) # (S N) + V = torch.cat(V, dim=0) # (S N N) + return w, P, B, V + + +class OptimModule(nn.Module): + """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ + + def register(self, name, tensor, lr=None): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {"weight_decay": 0.0} + if lr is not None: optim["lr"] = lr + setattr(getattr(self, name), "_optim", optim) + +class SSKernelNPLR(OptimModule): + """ Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) + """ + + @torch.no_grad() + def _setup_C(self, L): + """ Construct C~ from C + + Two modes are supported: go directly to length L if self.L is 1, or length is doubled + """ + + if self.L.item() == 0: + if self.verbose: log.info(f"S4: Initializing kernel to length {L}") + double_length = False + elif L > self.L.item(): # 2*int(self.L) == L: + if self.verbose: log.info(f"S4: Doubling length from L = {self.L.item()} to {2*self.L.item()}") + double_length = True + L = self.L.item() # Convenience for the math below + else: return + + C = _r2c(self.C) + dA, _ = self._setup_state() + dA_L = power(L, dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., :self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + self.L = 2*self.L if double_length else self.L+L # Preserve type/device + + def _omega(self, L, dtype, device, cache=True): + """ Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform + This should be called everytime the internal length self.L changes """ + + # Use cached if available + if cache and hasattr(self, 'omega') and self.omega.size(-1) == L//2+1: + return self.omega, self.z + + omega = torch.tensor( + np.exp(-2j * np.pi / (L)), dtype=dtype, device=device + ) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + + # Cache if necessary + if cache: + self.omega = omega + self.z = z + return omega, z + + def __init__( + self, + w, P, B, C, log_dt, + L=None, # starting/maximum length of kernel + lr=None, + verbose=False, + keops=False, + real_type='exp', # ['none' | 'exp' | 'relu' | sigmoid'] + real_tolerance=1e-3, + bandlimit=None, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + A is represented by diag(w) - PP^* + w: (S, N) diagonal part + P: (R, S, N) low-rank part + + B: (S, N) + C: (C, H, N) + dt: (H) timescale per feature + lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) + + Dimensions: + N (or d_state): state size + H (or d_model): total SSM copies + S (or n_ssm): number of trainable copies of (A, B, dt); must divide H + R (or rank): rank of low-rank part + C (or channels): system is 1-dim to C-dim + + The forward pass of this Module returns a tensor of shape (C, H, L) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + self.verbose = verbose + self.keops = keops + self.bandlimit = bandlimit + self.real_type = real_type + self.real_tolerance = real_tolerance + + # Rank of low-rank correction + self.rank = P.shape[-3] + assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Check different SSM inits + assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm + assert self.H % w.size(0) == 0 + self.n_ssm = w.size(0) + self.broadcast = self.H // w.size(0) # Each trainable SSM needs to be duplicated this many times + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) + B = B.unsqueeze(0) # (1, 1, N) + + # Register parameters + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + self.register("log_dt", log_dt, lr_dict.get('dt', lr)) + self.register("B", _c2r(B), lr_dict.get('B', lr)) + self.register("P", _c2r(P), lr_dict.get('A', lr)) + self.register("inv_w_real", self._w_init(w.real), lr_dict.get('A', lr)) + self.register("w_imag", w.imag, lr_dict.get('A', lr)) + + self.l_max = L + self.register_buffer('L', torch.tensor(0)) # Internal length + + def _w_init(self, w_real): + w_real = torch.clamp(w_real, max=-self.real_tolerance) + if self.real_type == 'none': + return -w_real + elif self.real_type == 'exp': + return torch.log(-w_real) # Some of the HiPPO methods have real part 0 + elif self.real_type == 'relu': + return -w_real + elif self.real_type == 'sigmoid': + return torch.logit(-w_real) + elif self.real_type == 'softplus': + return torch.log(torch.exp(-w_real)-1) + else: raise NotImplementedError + + def _w(self): + # Get the internal w (diagonal) parameter + if self.real_type == 'none': + w_real = -self.inv_w_real + elif self.real_type == 'exp': + w_real = -torch.exp(self.inv_w_real) + elif self.real_type == 'relu': + w_real = -F.relu(self.inv_w_real) + elif self.real_type == 'sigmoid': + w_real = -F.sigmoid(self.inv_w_real) + elif self.real_type == 'softplus': + w_real = -F.softplus(self.inv_w_real) + else: raise NotImplementedError + w = w_real + 1j * self.w_imag + return w + + def forward(self, state=None, rate=1.0, L=None): + """ + state: (B, H, N) initial state + rate: sampling rate factor + L: target length + + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state + """ + + # Initialize C~ if necessary (done in forward pass so it's on the correct device) + if self.L.item() == 0 and self.l_max is not None and self.l_max > 0: + self._setup_C(self.l_max) + + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate + if L is None: + L = round(self.L.item() / rate) + + # Increase the internal length if needed + continuous_L = round(rate*L) + while continuous_L > self.L.item(): + self._setup_C(continuous_L) + discrete_L = round(self.L.item()/rate) + + dt = torch.exp(self.log_dt) * rate + B = _r2c(self.B) + C = _r2c(self.C) + P = _r2c(self.P) + Q = P.conj() + w = self._w() # (n_ssm, N) + + # Address bandlimiting + if self.bandlimit is not None: + freqs = w.imag.abs() / (2*math.pi) # (H, N) + freqs = dt[:, None] / rate * freqs # (H, N) + mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + C = C * mask + + # Get FFT nodes of right length + omega, z = self._omega(discrete_L, dtype=w.dtype, device=w.device, cache=(rate==1.0)) + + # Broadcast parameters to same hidden features H + B = repeat(B, '1 t n -> 1 (v t) n', v=self.broadcast) + P = repeat(P, 'r t n -> r (v t) n', v=self.broadcast) + Q = repeat(Q, 'r t n -> r (v t) n', v=self.broadcast) + w = repeat(w, 't n -> (v t) n', v=self.broadcast) + + # Augment B + if state is not None: + # Have to "unbilinear" the state to put it into the same "type" as B + # Compute 1/dt * (I + dt/2 A) @ state + + # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way + s = _conj(state) if state.size(-1) == self.N else state # (B H N) + sA = ( + s * _conj(w) # (B H N) + - contract('bhm, rhm, rhn -> bhn', s, _conj(Q), _conj(P)) + ) + s = s / dt.unsqueeze(-1) + sA / 2 + s = s[..., :self.N] + + B = torch.cat([s, B], dim=-3) # (B+1, H, N) + + # Incorporate dt into A + w = w * dt.unsqueeze(-1) # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) + C = torch.cat([C, Q], dim=-3) # (C+R, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) + + # Calculate resolvent at omega + if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: + r = cauchy_mult(v, z, w, symmetric=True) + elif has_pykeops: + r = cauchy_conj(v, z, w) + else: + r = cauchy_naive(v, z, w) + r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) + + # Low-rank Woodbury correction + if self.rank == 1: + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (1 + r[-1:, -1:, :, :]) + elif self.rank == 2: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[:1, 1:, :, :] * r11[1:, :1, :, :] + s = ( + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + ) + s = s / det + k_f = r00 - s + else: + r00 = r[:-self.rank, :-self.rank, :, :] + r01 = r[:-self.rank, -self.rank:, :, :] + r10 = r[-self.rank:, :-self.rank, :, :] + r11 = r[-self.rank:, -self.rank:, :, :] + r11 = rearrange(r11, "a b h n -> h n a b") + r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) + r11 = rearrange(r11, "h n a b -> a b h n") + k_f = r00 - torch.einsum("i j h n, j k h n, k l h n -> i l h n", r01, r11, r10) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) + + # # Truncate to target length + k = k[..., :L] + + if state is not None: + k_state = k[:-1, :, :, :] # (B, C, H, L) + else: + k_state = None + k_B = k[-1, :, :, :] # (C H L) + + return k_B, k_state + + @torch.no_grad() + def _setup_linear(self): + """ Create parameters that allow fast linear stepping of state """ + w = self._w() + B = _r2c(self.B) # (H N) + P = _r2c(self.P) + Q = P.conj() + + # Repeat w shape properly + B = repeat(B, '1 t n -> 1 (v t) n', v=self.broadcast) + P = repeat(P, 'r t n -> r (v t) n', v=self.broadcast) + Q = repeat(Q, 'r t n -> r (v t) n', v=self.broadcast) + w = repeat(w, 't n -> (v t) n', v=self.broadcast) + + # Prepare Linear stepping + dt = torch.exp(self.log_dt) + D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) + R = (torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H R R) + Q_D = rearrange(Q*D, 'r h n -> h r n') + try: + R = torch.linalg.solve(R, Q_D) # (H R N) + except: + R = torch.tensor(np.linalg.solve(R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu())).to(Q_D) + R = rearrange(R, 'h r n -> r h n') + + self.step_params = { + "D": D, # (H N) + "R": R, # (R H N) + "P": P, # (R H N) + "Q": Q, # (R H N) + "B": B, # (1 H N) + "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster + + u: (H) input + state: (H, N/2) state with conjugate pairs + Optionally, the state can have last dimension N + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if state.size(-1) == self.N: # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', _conj(p), _conj(x), _conj(y))[..., :self.N] # inner outer product + else: + assert state.size(-1) == 2*self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', p, x, y) # inner outer product + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (R H N) + P = step_params["P"] # (R H N) + Q = step_params["Q"] # (R H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """ Construct dA and dB for discretized state equation """ + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c(self.C) # Just returns a view that we use for finding dtype/device + + state = torch.eye(2*self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + dB = rearrange(dB, '1 h n -> h n') # (H N) + return dA, dB + + def _step_state(self, u, state): + """ Must be called after self.default_state() is used to construct an initial state! """ + next_state = self.state_contraction(self.dA, state) + self.input_contraction(self.dB, u) + return next_state + + def _setup_step(self, mode='dense'): + """ Set up dA, dB, dC discretized parameters for stepping """ + self.dA, self.dB = self._setup_state() + + # Calculate original C + C = _conj(_r2c(self.C)) # (H C N) + if self.L.item() == 0: + dC = C + else: + # self.C represents C_tilde + dA_L = power(self.L.item(), self.dA) + I = torch.eye(self.dA.size(-1)).to(dA_L) + + dC = torch.linalg.solve( + I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) + self.dC = dC + + # Do special preprocessing for different step modes + + self._step_mode = mode + if mode == 'linear': + # Linear case: special step function for the state, we need to handle output + # use conjugate symmetry by default, which affects the output projection + self.dC = 2*self.dC[:, :, :self.N] + elif mode == 'diagonal': + # Eigendecomposition of the A matrix + L, V = torch.linalg.eig(self.dA) + V_inv = torch.linalg.inv(V) + # Check that the eigendedecomposition is correct + if self.verbose: + print("Diagonalization error:", torch.dist(V @ torch.diag_embed(L) @ V_inv, self.dA)) + + # Change the parameterization to diagonalize + self.dA = L + self.dB = contract('h n m, h m -> h n', V_inv, self.dB) + self.dC = contract('h n m, c h n -> c h m', V, self.dC) + + elif mode == 'dense': + pass + else: raise NotImplementedError("NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}") + + def default_state(self, *batch_shape): + C = _r2c(self.C) + N = C.size(-1) + H = C.size(-2) + + # Cache the tensor contractions we will later do, for efficiency + # These are put in this function because they depend on the batch size + step_mode = getattr(self, "_step_mode", "dense") # Used in default_state, which is called without _setup_step() in forward_state() + if step_mode != 'linear': + N *= 2 + + if step_mode == 'diagonal': + self.state_contraction = contract_expression( + "h n, ... h n -> ... h n", + (H, N), + batch_shape + (H, N), + ) + else: + # Dense (quadratic) case: expand all terms + self.state_contraction = contract_expression( + "h m n, ... h n -> ... h m", + (H, N, N), + batch_shape + (H, N), + ) + + self.input_contraction = contract_expression( + "h n, ... h -> ... h n", + (H, N), # self.dB.shape + batch_shape + (H,), + ) + + self.output_contraction = contract_expression( + "c h n, ... h n -> ... c h", + (C.shape[0], H, N), # self.dC.shape + batch_shape + (H, N), + ) + + state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + """ Must have called self._setup_step() and created state with self.default_state() before calling this """ + + if self._step_mode == 'linear': + new_state = self._step_state_linear(u, state) + else: + new_state = self._step_state(u, state) + y = self.output_contraction(self.dC, new_state) + return y.real, new_state + +class SSKernelDiag(OptimModule): + """Version using (complex) diagonal state matrix (S4D)""" + + def __init__( + self, + A, B, C, log_dt, + L=None, + disc='bilinear', + real_type='exp', + lr=None, + bandlimit=None, + ): + + super().__init__() + self.L = L + self.disc = disc + self.bandlimit = bandlimit + self.real_type = real_type + + # Rank of low-rank correction + assert A.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = A.size(-1) + assert A.size(-2) == B.size(-2) # Number of independent SSMs trained + assert self.H % A.size(-2) == 0 + self.n_ssm = A.size(-2) + self.repeat = self.H // A.size(0) + + self.channels = C.shape[0] + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + + # Register parameters + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + + self.register("log_dt", log_dt, lr_dict.get('dt', lr)) + self.register("A", _c2r(A), lr_dict.get('A', lr)) + self.register("B", _c2r(B), lr_dict.get('B', lr)) + self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr)) + self.register("A_imag", A.imag, lr_dict.get('A', lr)) + + def _A_init(self, A_real): + A_real = torch.clamp(A_real, max=-1e-4) + if self.real_type == 'none': + return -A_real + elif self.real_type == 'exp': + return torch.log(-A_real) # Some of the HiPPO methods have real part 0 + elif self.real_type == 'relu': + return -A_real + elif self.real_type == 'sigmoid': + return torch.logit(-A_real) + elif self.real_type == 'softplus': + return torch.log(torch.exp(-A_real)-1) + else: raise NotImplementedError + + def _A(self): + # Get the internal A (diagonal) parameter + if self.real_type == 'none': + A_real = -self.inv_A_real + elif self.real_type == 'exp': + A_real = -torch.exp(self.inv_A_real) + elif self.real_type == 'relu': + # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it + A_real = -F.relu(self.inv_A_real)-1e-4 + elif self.real_type == 'sigmoid': + A_real = -F.sigmoid(self.inv_A_real) + elif self.real_type == 'softplus': + A_real = -F.softplus(self.inv_A_real) + else: raise NotImplementedError + A = A_real + 1j * self.A_imag + return A + + def forward(self, L, state=None, rate=1.0, u=None): + """ + state: (B, H, N) initial state + rate: sampling rate factor + L: target length + + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state + """ + + dt = torch.exp(self.log_dt) * rate # (H) + C = _r2c(self.C) # (C H N) + A = self._A() # (H N) + + B = _r2c(self.B) + B = repeat(B, 't n -> 1 (v t) n', v=self.repeat) + + if self.bandlimit is not None: + freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N) + mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + C = C * mask + + # Incorporate dt into A + A = repeat(A, 't n -> (v t) n', v=self.repeat) + dtA = A * dt.unsqueeze(-1) # (H N) + + + # Augment B with state + if state is not None: + s = state / dt.unsqueeze(-1) + if self.disc == 'bilinear': + s = s * (1. + dtA/2) + elif self.disc == 'zoh': + s = s * dtA * dtA.exp() / (dtA.exp() - 1.) + B = torch.cat([s, B], dim=-3) # (1+B H N) + + C = (B[:, None, :, :] * C).view(-1, self.H, self.N) + if self.disc == 'zoh': + # Power up + C = C * (torch.exp(dtA)-1.) / A + K = log_vandermonde(C, dtA, L) # (H L) + elif self.disc == 'bilinear': + C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A + dA = (1. + dtA/2) / (1. - dtA/2) + K = log_vandermonde(C, dA.log(), L) + elif self.disc == 'dss': + # Implementation from DSS meant for case when real eigenvalues can be positive + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + A_gt_0 = A.real > 0 # [N] + if A_gt_0.any(): + with torch.no_grad(): + P_max = dtA * (A_gt_0 * (L-1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] + + dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] + + # Inline reciprocal function for DSS logic + x = den * A + x_conj = _resolve_conj(x) + r = x_conj / (x*x_conj + 1e-7) + + C = C * num * r # [C H N] + K = contract('chn,hnl->chl', C, S).float() + else: assert False, f"{self.disc} not supported" + + K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + if state is not None: + K_state = K[:-1, :, :, :] # (B C H L) + else: + K_state = None + K = K[-1, :, :, :] # (C H L) + return K, K_state + + def _setup_step(self): + # These methods are organized like this to be compatible with the NPLR kernel interface + dt = torch.exp(self.log_dt) # (H) + B = _r2c(self.B) # (H N) + C = _r2c(self.C) # (C H N) + self.dC = C + A = self._A() # (H N) + + # Incorporate dt into A + dtA = A * dt.unsqueeze(-1) # (H N) + if self.disc == 'zoh': + self.dA = torch.exp(dtA) # (H N) + self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) + elif self.disc == 'bilinear': + self.dA = (1. + dtA/2) / (1. - dtA/2) + self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A + + + def default_state(self, *batch_shape): + C = _r2c(self.C) + state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + next_state = contract("h n, b h n -> b h n", self.dA, state) \ + + contract("h n, b h -> b h n", self.dB, u) + y = contract("c h n, b h n -> b c h", self.dC, next_state) + return 2*y.real, next_state + + def forward_state(self, u, state): + self._setup_step() + AL = self.dA ** u.size(-1) + u = u.flip(-1).to(self.dA).contiguous() # (B H L) + v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) + next_state = AL * state + v + return next_state + + +class SSKernel(nn.Module): + """Wrapper around SSKernel parameterizations. + + The SSKernel is expected to support the interface + forward() + default_state() + _setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=None, + measure="legs", + rank=1, + channels=1, + dt_min=0.001, + dt_max=0.1, + deterministic=False, + lr=None, + mode="nplr", + n_ssm=None, + verbose=False, + measure_args={}, + **kernel_args, + ): + """State Space Kernel which computes the convolution kernel $\\bar{K}$ + + H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. + N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. + L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. + measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) + rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" + channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead + dt_min, dt_max: min and max values for the step size dt (\Delta) + mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing + n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H + lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. + """ + super().__init__() + self.N = N + self.H = H + dtype, cdtype = torch.float, torch.cfloat + self.channels = channels + self.n_ssm = n_ssm if n_ssm is not None else H + self.mode = mode + self.verbose = verbose + self.kernel_args = kernel_args + + # Generate dt + if deterministic: + log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) + else: + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + # Compute the preprocessed representation + w, P, B, V = combination(measure, self.N, rank, self.n_ssm, **measure_args) + + # Broadcast C to have H channels + if deterministic: + C = torch.zeros(channels, self.H, self.N, dtype=cdtype) + C[:, :, :1] = 1. + C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C + else: + C = torch.randn(channels, self.H, self.N//2, dtype=cdtype) + + # Broadcast other parameters to have n_ssm copies + assert self.n_ssm % B.size(-2) == 0 \ + and self.n_ssm % P.size(-2) == 0 \ + and self.n_ssm % w.size(-2) == 0 + # Broadcast tensors to n_ssm copies + # These will be the parameters, so make sure tensors are materialized and contiguous + B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() + P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous() + w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous() + C = C.contiguous() + + if mode == "nplr": + self.kernel = SSKernelNPLR( + w, P, B, C, + log_dt, L=L, + lr=lr, + verbose=verbose, + **kernel_args, + ) + elif mode == "diag": + C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm) + self.kernel = SSKernelDiag( + w, B, C, log_dt, L=L, + lr=lr, + **kernel_args, + ) + else: raise NotImplementedError(f"{mode=} is not valid") + + def forward(self, state=None, L=None, rate=None): + return self.kernel(state=state, L=L, rate=rate) + + @torch.no_grad() + def forward_state(self, u, state): + """ Forward the state through a sequence, i.e. computes the state after passing chunk through SSM + + state: (B, H, N) + u: (B, H, L) + + Returns: (B, H, N) + """ + + if hasattr(self.kernel, "forward_state"): + return self.kernel.forward_state(u, state) + + dA, dB = self.kernel._setup_state() # Construct dA, dB matrices + # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) + + conj = state.size(-1) != dA.size(-1) + if conj: state = _conj(state) + + v = contract('h n, b h l -> b h n l', dB, u.flip(-1)) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) + AL, v = power(u.size(-1), dA, v) + next_state = contract("h m n, b h n -> b h m", AL, state) + next_state = next_state + v + + if conj: next_state = next_state[..., : next_state.size(-1) // 2] + return next_state + + def _setup_step(self, **kwargs): + # This method is intended to be private so that setting up an S4 module with + # ``` + # if hasattr(module, 'setup_step'): module.setup_step() + # ``` + # will not trigger this method multiple times + self.kernel._setup_step(**kwargs) + + def step(self, u, state, **kwargs): + y, state = self.kernel.step(u, state, **kwargs) + return y, state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) + +class S4(nn.Module): + def __init__( + self, + d_model, + d_state=64, + l_max=None, + channels=1, + bidirectional=False, + # Arguments for position-wise feedforward components + activation='gelu', + postact='glu', + hyper_act=None, + dropout=0.0, tie_dropout=False, + bottleneck=None, + gate=None, + transposed=True, + verbose=False, + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel + channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models + bidirectional: if True, convolution kernel will be two-sided + + Position-wise feedforward components: + -------------------- + activation: activation in between SS and FF + postact: activation after FF + hyper_act: use a "hypernetwork" multiplication (experimental) + dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d + + Other arguments: + -------------------- + transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] + gate: add gated activation (GSS) + bottleneck: reduce SSM dimension (GSS) + + See the class SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + + super().__init__() + if verbose: + log.info(f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})") + + self.d_model = d_model + self.H = d_model + self.N = d_state + self.L = l_max + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + + self.gate = gate + self.bottleneck = bottleneck + + if bottleneck is not None: + self.H = self.H // bottleneck + self.input_linear = LinearActivation( + self.d_model, + self.H, + transposed=self.transposed, + activation=activation, + activate=True, + ) + + if gate is not None: + self.input_gate = LinearActivation( + self.d_model, + self.d_model * gate, + transposed=self.transposed, + activation=activation, + activate=True, + ) + self.output_gate = LinearActivation( + self.d_model * gate, + self.d_model, + transposed=self.transposed, + activation=None, + activate=False, + ) + + # optional multiplicative modulation GLU-style + # https://arxiv.org/abs/2002.05202 + self.hyper = hyper_act is not None + if self.hyper: + channels *= 2 + self.hyper_activation = Activation(hyper_act) + + self.D = nn.Parameter(torch.randn(channels, self.H)) + + if self.bidirectional: + channels *= 2 + + + # SSM Kernel + self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=channels, verbose=verbose, **kernel_args) + + # Pointwise + self.activation = Activation(activation) + dropout_fn = DropoutNd if tie_dropout else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + # position-wise output transform to mix features + self.output_linear = LinearActivation( + self.H*self.channels, + self.d_model*(1 if self.gate is None else self.gate), + transposed=self.transposed, + activation=postact, + activate=True, + ) + + def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + + # Mask out padding tokens + if isinstance(lengths, int): + if lengths != L: + lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) + else: + lengths = None + if lengths is not None: + assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)] + mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.) + u = u * mask + + if self.gate is not None: + v = self.input_gate(u) + if self.bottleneck is not None: + u = self.input_linear(u) + + # Compute SS Kernel + L_kernel = L if self.L is None else min(L, round(self.L / rate)) + k, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) \ + + k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L) + u_f = torch.fft.rfft(u, n=L_kernel+L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) + y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L) + + + # Compute D term in state space equation - essentially a skip connection + y = y + contract('bhl,ch->bchl', u, self.D) + + # Compute state update + if state is not None: + assert not self.bidirectional, "Bidirectional not supported with state forwarding" + y = y + k_state # + next_state = self.kernel.forward_state(u, state) + else: + next_state = None + + # Optional hyper-network multiplication + if self.hyper: + y, yh = rearrange(y, 'b (s c) h l -> s b c h l', s=2) + y = self.hyper_activation(yh) * y + + # Reshape to flatten channels + y = rearrange(y, '... c h l -> ... (c h) l') + + y = self.dropout(self.activation(y)) + + if not self.transposed: y = y.transpose(-1, -2) + + y = self.output_linear(y) + + if self.gate is not None: + y = self.output_gate(y * v) + + return y, next_state + + def setup_step(self, **kwargs): + self.kernel._setup_step(**kwargs) + + def step(self, u, state): + """ Step one time step as a recurrent model. Intended to be used during validation. + + u: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + assert not self.training + + y, next_state = self.kernel.step(u, state) # (B C H) + y = y + u.unsqueeze(-2) * self.D + y = rearrange(y, 'b c h -> b (c h)') + y = self.activation(y) + if self.transposed: + y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) + else: + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + # kernel is not a SequenceModule so it doesn't need to adhere to same interface + # the kernel will know the device of its own parameters + return self.kernel.default_state(*batch_shape) + + @property + def d_output(self): + return self.d_model diff --git a/src/models/s4/s4d.py b/src/models/s4/s4d.py new file mode 100644 index 0000000..603ce7e --- /dev/null +++ b/src/models/s4/s4d.py @@ -0,0 +1,114 @@ +""" Standalone version of Structured (Sequence) State Space (S4) model. """ + + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +import opt_einsum as oe + +from src.models.nn import DropoutNd + +_c2r = torch.view_as_real +_r2c = torch.view_as_complex + +class S4DKernel(nn.Module): + """Wrapper around SSKernelDiag that generates the diagonal SSM parameters + """ + + def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None): + super().__init__() + # Generate dt + H = d_model + log_dt = torch.rand(H) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + C = torch.randn(H, N // 2, dtype=torch.cfloat) + self.C = nn.Parameter(_c2r(C)) + self.register("log_dt", log_dt, lr) + + log_A_real = torch.log(0.5 * torch.ones(H, N//2)) + A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H) + self.register("log_A_real", log_A_real, lr) + self.register("A_imag", A_imag, lr) + + def forward(self, L): + """ + returns: (..., c, L) where c is number of channels (default 1) + """ + + # Materialize parameters + dt = torch.exp(self.log_dt) # (H) + C = _r2c(self.C) # (H N) + A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N) + + # Vandermonde multiplication + dtA = A * dt.unsqueeze(-1) # (H N) + K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L) + C = C * (torch.exp(dtA)-1.) / A + K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real + + return K + + def register(self, name, tensor, lr=None): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {"weight_decay": 0.0} + if lr is not None: optim["lr"] = lr + setattr(getattr(self, name), "_optim", optim) + + +class S4D(nn.Module): + + def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args): + super().__init__() + + self.h = d_model + self.n = d_state + self.d_output = self.h + self.transposed = transposed + + self.D = nn.Parameter(torch.randn(self.h)) + + # SSM Kernel + self.kernel = S4DKernel(self.h, N=self.n, **kernel_args) + + # Pointwise + self.activation = nn.GELU() + # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11 + dropout_fn = DropoutNd + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + self.output_linear = nn.Sequential( + nn.Conv1d(self.h, 2*self.h, kernel_size=1), + nn.GLU(dim=-2), + ) + + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask + """ Input and output shape (B, H, L) """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + + # Compute SSM Kernel + k = self.kernel(L=L) # (H L) + + # Convolution + k_f = torch.fft.rfft(k, n=2*L) # (H L) + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L) + + # Compute D term in state space equation - essentially a skip connection + y = y + u * self.D.unsqueeze(-1) + + y = self.dropout(self.activation(y)) + y = self.output_linear(y) + if not self.transposed: y = y.transpose(-1, -2) + return y, None # Return a dummy state to satisfy this repo's interface, but this can be modified diff --git a/src/models/sequence/__init__.py b/src/models/sequence/__init__.py index 72e2a39..ef89dad 100644 --- a/src/models/sequence/__init__.py +++ b/src/models/sequence/__init__.py @@ -1,5 +1,4 @@ -from .base import SequenceModule +from .base import SequenceModule, TransposedModule from .model import SequenceModel from .unet import SequenceUNet from .ff import FF -# from .pool import Downpool, Uppool diff --git a/src/models/sequence/attention/linear.py b/src/models/sequence/attention/linear.py index b7b2528..27072d1 100644 --- a/src/models/sequence/attention/linear.py +++ b/src/models/sequence/attention/linear.py @@ -10,7 +10,7 @@ from fast_transformers.feature_maps import elu_feature_map from fast_transformers.masking import TriangularCausalMask -from models.sequence.base import SequenceModule +from models.sequence.base import SequenceModule, TransposedModule import src.models.nn.utils as U try: @@ -125,6 +125,7 @@ def forward(self, query, key, value, attn_mask=None, key_padding_mask=None, need out = rearrange(out, 'b h s d -> b s (h d)') return out, attn +@TransposedModule class Performer(SequenceModule): """ [21-09-29] TODO the MHA class should take options for attention like full, performer, etc. Currently this is essentially duplicated from MultiheadAttention class """ def __init__(self, d_model, n_heads, *args, causal=True, **kwargs): @@ -145,5 +146,3 @@ def forward(self, src, attn_mask=None, key_padding_mask=None, state=None, **kwar def step(self, x, state): raise NotImplementedError - -Performer = U.Transpose(Performer) diff --git a/src/models/sequence/attention/performer.py b/src/models/sequence/attention/performer.py index 6c6a7e9..8bb9da4 100644 --- a/src/models/sequence/attention/performer.py +++ b/src/models/sequence/attention/performer.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/idiap/fast-transformers/blob/master/fast_transformers/feature_maps/fourier_features.py +# Adapted from https://github.com/HazyResearch/zoo +# in turn adapted from https://github.com/idiap/fast-transformers/blob/master/fast_transformers/feature_maps/fourier_features.py import math import torch @@ -6,10 +7,6 @@ from fast_transformers.feature_maps.base import FeatureMap -# from src.models.modules.attention.performer_utils import ( -# gaussian_orthogonal_random_matrix, -# softmax_kernel -# ) def orthogonal_matrix_chunk(cols, device=None): unstructured_block = torch.randn((cols, cols), device=device) q, r = torch.linalg.qr(unstructured_block) diff --git a/src/models/sequence/base.py b/src/models/sequence/base.py index 2fe37b9..4f8a4ff 100644 --- a/src/models/sequence/base.py +++ b/src/models/sequence/base.py @@ -1,34 +1,70 @@ from torch import nn +import functools class SequenceModule(nn.Module): - """ Abstract sequence model class. All layers that the backbones use must adhere to this + """Abstract sequence model class. All models must adhere to this interface - A sequence model is a layer that transforms an input of shape - (n_batch, l_sequence, d_input) to (n_batch, l_sequence, d_output) + A SequenceModule is generally a model that transforms an input of shape + (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output) - Additionally, it returns a "state" which can be any additional information - For example, RNN and SSM layers may return their hidden state, - while some types of transformer layers (e.g. Transformer-XL) may want to pass through state as well + REQUIRED methods and attributes + forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation + __init__ should also satisfy the following interface; see SequenceIdentity for an example + def __init__(self, d_model, transposed=False, **kwargs) - - default_state receives a batch_shape with device and returns an initial state - - step simulates a single step of the sequence (e.g. one unroll for an RNN). It receives a state and single input (n_batch, d_input) and returns a state and output (n_batch, d_output) - - forward is a sequence-to-sequence transformation that receives an optional state + OPTIONAL methods + default_state, step: allows stepping the model recurrently with a hidden state + state_to_tensor, d_state: allows decoding from hidden state """ - # def __init__(self, transposed=False, *args, **kwargs): - # """ model should support regular (B, L, H) and transposed (B, H, L) axes ordering """ - # self.transposed = transposed + @property + def d_model(self): + """Model dimension (generally same as input dimension). + + This attribute is required for all SequenceModule instantiations. + It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model. + """ + if getattr(self, "_d_model", None) is None: + raise NotImplementedError("SequenceModule instantiation must set d_model") + return self._d_model + + @d_model.setter + def d_model(self, d): + self._d_model = d @property def d_output(self): + """Output dimension of model. + + This attribute is required for all SequenceModule instantiations. + It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. + """ + if getattr(self, "_d_output", None) is None: + raise NotImplementedError("SequenceModule instantiation must specify d_output for decoder") return self._d_output + @d_output.setter def d_output(self, d): self._d_output = d + def forward(self, x, state=None, **kwargs): + """Forward pass of sequence model, a sequence-to-sequence transformation with an optional state. + + Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output) + + Additionally, it returns a "state" which can be any additional information + For example, RNN and SSM layers may return their hidden state, + while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well + """ + return x, None + @property def state_to_tensor(self): - """ Returns a function mapping a state to a single tensor, in case one wants to use the hidden state instead of the output for final prediction """ + """Returns a function mapping a state to a single tensor. + + This method should be implemented if one wants to use the hidden state instead of the output sequence for final prediction. + Currently only used with the StateDecoder. + """ return lambda _: None @property @@ -36,45 +72,60 @@ def d_state(self): """ Returns dimension of output of self.state_to_tensor """ return None - @property - def transposed(self): - return self._transposed - @transposed.setter - def transposed(self, x): - self._transposed = x + def default_state(self, *batch_shape, device=None): + """Create initial state for a batch of inputs.""" - def default_state(self, *batch_shape, device=None): # TODO device shouldn't be needed; models should store their own initial state at initialization return None - def step(self, x, state=None, *args, **kwargs): - return x, state + def step(self, x, state=None, **kwargs): + """Step the model recurrently for one step of the input sequence. - def forward(self, x, state=None, *args, **kwargs): - return x, state + For example, this should correspond to unrolling an RNN for one step. + If the forward pass has signature (B, L, H1) -> (B, L, H2), + this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state. + """ + raise NotImplementedError -def Transpose(module): - """ Wrap a SequenceModule class to transpose the forward pass """ - # TODO maybe possible with functools.wraps - class WrappedModule(module): +def TransposedModule(module): + """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" + # https://stackoverflow.com/a/65470430/1980685 + @functools.wraps(module, updated=()) + class TransposedModule(module): def __init__(self, *args, transposed=False, **kwargs): super().__init__(*args, **kwargs) self.transposed = transposed - def forward(self, x, *args, **kwargs): + def forward(self, x, state=None, **kwargs): if self.transposed: x = x.transpose(-1, -2) - x, state = super().forward(x) + x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM + next_state = None if state is None else next_state if self.transposed: x = x.transpose(-1,-2) - return x, state + return x, next_state # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically - WrappedModule.__name__ = module.__name__ - return WrappedModule + # TransposedModule.__name__ = module.__name__ # functools wraps is better solution + return TransposedModule +@TransposedModule class SequenceIdentity(SequenceModule): - def __init__(self, d_model, dropout=0.0): + """Simple SequenceModule for testing purposes""" + + def __init__(self, d_model, dropout=0.0, **kwargs): + """Default interface for SequenceModule + + d_model: input dimension (sometimes denoted H for hidden dimension) + transposed: if True, inputs have axis ordering (B, H, L) instead of (B, H, L) + """ super().__init__() + self.d_model = d_model self.d_output = d_model - def forward(self, x, state=None, *args, **kwargs): + + def forward(self, x, state=None): + return x, state + + def default_state(self, *batch_shape, device=None): + return None + + def step(self, x, state=None, **kwargs): return x, state -SequenceIdentity = Transpose(SequenceIdentity) diff --git a/src/models/sequence/block.py b/src/models/sequence/block.py index 6e25871..723b575 100644 --- a/src/models/sequence/block.py +++ b/src/models/sequence/block.py @@ -9,8 +9,9 @@ from torch import nn +from functools import partial import src.utils as utils -from src.models.nn.components import Normalization +from src.models.nn.components import Normalization, StochasticDepth, DropoutNd from src.models.sequence import SequenceModule from src.models.sequence.pool import registry as pool_registry from src.models.nn.residual import registry as residual_registry @@ -24,10 +25,13 @@ def __init__( i_layer=None, # Only needs to be passed into certain residuals like Decay prenorm=True, dropout=0.0, + tie_dropout=False, + transposed=False, layer=None, # Config for black box module residual=None, # Config for residual function norm=None, # Config for normalization layer pool=None, + drop_path=0., ): super().__init__() @@ -35,6 +39,7 @@ def __init__( self.d_input = d_input self.layer = utils.instantiate(registry.layer, layer, d_input) self.prenorm = prenorm + self.transposed = transposed # Residual # d_residual is the output dimension after residual @@ -59,13 +64,12 @@ def __init__( self.pool = utils.instantiate(pool_registry, pool, self.d_residual, transposed=self.transposed) # Dropout - drop_cls = nn.Dropout2d if self.transposed else nn.Dropout - self.drop = drop_cls(dropout) if dropout > 0.0 else nn.Identity() + dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout + self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() + # Stochastic depth + self.drop_path = StochasticDepth(drop_path, mode='row') if drop_path > 0.0 else nn.Identity() - @property - def transposed(self): - return getattr(self.layer, 'transposed', False) @property def d_output(self): @@ -82,49 +86,44 @@ def state_to_tensor(self): def default_state(self, *args, **kwargs): return self.layer.default_state(*args, **kwargs) - def forward(self, x, *args, state=None, **kwargs): + def forward(self, x, state=None, **kwargs): y = x # Pre-norm if self.norm is not None and self.prenorm: y = self.norm(y) - # Black box module - y, state = self.layer(y, *args, state=state, **kwargs) + # Black box layer + y, state = self.layer(y, state=state, **kwargs) # Residual - if self.residual is not None: x = self.residual(x, self.drop(y), self.transposed) + if self.residual is not None: y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) # Post-norm - if self.norm is not None and not self.prenorm: x = self.norm(x) + if self.norm is not None and not self.prenorm: y = self.norm(y) # Pool - # x = pool.downpool(x, self.pool, self.expand, self.transposed) - if self.pool is not None: x = self.pool(x) + if self.pool is not None: y = self.pool(y) - return x, state + return y, state - def step(self, x, state, *args, **kwargs): # TODO needs fix for transpose logic + def step(self, x, state, **kwargs): y = x # Pre-norm if self.norm is not None and self.prenorm: - if self.transposed: y = y.unsqueeze(-1) - y = self.norm(y) # TODO transpose seems wrong - if self.transposed: y = y.squeeze(-1) + y = self.norm.step(y) - # Black box module - y, state = self.layer.step(y, state, *args, **kwargs) + # Black box layer + y, state = self.layer.step(y, state, **kwargs) # Residual - if self.residual is not None: x = self.residual(x, y, transposed=False) # TODO this would not work with concat + if self.residual is not None: y = self.residual(x, y, transposed=False) # NOTE this would not work with concat residual function (catformer) # Post-norm if self.norm is not None and not self.prenorm: - if self.transposed: y = y.unsqueeze(-1) - x = self.norm(x)#.step(x) - if self.transposed: y = y.squeeze(-1) + y = self.norm.step(y) # Pool - if self.pool is not None: x = self.pool(x) + if self.pool is not None: y = self.pool(y) - return x, state + return y, state diff --git a/src/models/sequence/convs/conv1d.py b/src/models/sequence/convs/conv1d.py new file mode 100644 index 0000000..c6316f4 --- /dev/null +++ b/src/models/sequence/convs/conv1d.py @@ -0,0 +1,31 @@ +""" Wrapper around nn.Conv1d to adhere to SequenceModule interface. """ + +import torch +from torch import nn + +from src.models.sequence.base import SequenceModule +from src.models.nn import Activation + +class Conv1d(SequenceModule): + """ Simple wrapper for nn.Conv1d """ + def __init__(self, d_model, *args, d_output=None, activation='gelu', dropout=0.0, transposed=True, **kwargs): + # Accepted kwargs passed into Conv1d interface + # torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) + super().__init__() + + self.d_model = d_model + if d_output is None: d_output = d_model + self.d_output = d_output + self.transposed = transposed + self.conv1d = nn.Conv1d(d_model, d_output, *args, **kwargs) + self.activation = Activation(activation) + + def forward(self, x, resolution=None, state=None, *args, **kwargs): + if not self.transposed: x = x.transpose(-1, -2) + y = self.conv1d(x) + if not self.transposed: y = y.transpose(-1, -2) + y = self.activation(y) + return y, None + + def step(self, x, state): + raise NotImplementedError diff --git a/src/models/sequence/convs/conv2d.py b/src/models/sequence/convs/conv2d.py new file mode 100644 index 0000000..199fbb1 --- /dev/null +++ b/src/models/sequence/convs/conv2d.py @@ -0,0 +1,46 @@ +""" Wrapper around nn.Conv2d to adhere to SequenceModule interface. """ + +import torch +from torch import nn + +from src.models.sequence.base import SequenceModule +from src.models.nn import Activation, DropoutNd + +class Conv2d(SequenceModule): + """ Simple wrapper for nn.Conv1d """ + def __init__(self, d_model, d_output=None, activation='gelu', depthwise=False, dropout=0.0, tie_dropout=False, transposed=True, **kwargs): + # kwargs passed into Conv2d interface: + # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) + super().__init__() + + valid_kwargs = ["in_channels","out_channels","kernel_size","stride", + "padding","padding_mode","dilation","groups","bias"] + kwargs = {k:v for k,v in kwargs.items() if k in valid_kwargs} + + self.d_model = d_model + if d_output is None: d_output = d_model + self.d_output = d_output + self.transposed = transposed + self.depthwise = depthwise + + if self.depthwise: + self.conv2d = nn.Conv2d(d_model, d_model, padding='same', groups=d_model, **kwargs) + self.linear = nn.Conv2d(d_model, d_output, 1, 1) + else: + self.conv2d = nn.Conv2d(d_model, d_output, padding='same', **kwargs) + self.linear = nn.Identity() + dropout_fn = DropoutNd if tie_dropout else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + self.activation = Activation(activation) + + def forward(self, x, resolution=None, state=None, *args, **kwargs): + if not self.transposed: x = x.transpose(-1, -2) + y = self.conv2d(x) + y = self.activation(y) # NOTE doesn't work with glu + y = self.dropout(y) + y = self.linear(y) + if not self.transposed: y = y.transpose(-1, -2) + return y, None + + def step(self, x, state): + raise NotImplementedError diff --git a/src/models/sequence/ff.py b/src/models/sequence/ff.py index 2330ff7..b96d885 100644 --- a/src/models/sequence/ff.py +++ b/src/models/sequence/ff.py @@ -1,11 +1,12 @@ """ Implementation of FFN block in the style of Transformers """ +from functools import partial from torch import nn from src.models.sequence.base import SequenceModule -from src.models.nn import LinearActivation +from src.models.nn import LinearActivation, DropoutNd class FF(SequenceModule): - def __init__(self, d_input, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0): + def __init__(self, d_input, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0, tie_dropout=False): super().__init__() self.d_output = d_input if d_output is None else d_output self.transposed = transposed @@ -18,7 +19,8 @@ def __init__(self, d_input, expand=2, d_output=None, transposed=False, activatio initializer=initializer, activate=True, ) - dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout + dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout + # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() linear2 = LinearActivation( @@ -38,7 +40,7 @@ def __init__(self, d_input, expand=2, d_output=None, transposed=False, activatio def forward(self, x, *args, **kwargs): return self.ff(x), None - def step(self, x, state): + def step(self, x, state, **kwargs): # x: [batch, d_input] if self.transposed: # expects: [batch, d_input, seq_len] diff --git a/src/models/sequence/mha.py b/src/models/sequence/mha.py index 1f540bd..1f32e14 100644 --- a/src/models/sequence/mha.py +++ b/src/models/sequence/mha.py @@ -1,10 +1,14 @@ """ Wrapper around nn.MultiheadAttention to adhere to SequenceModule interface. """ import torch +import torch.nn.functional as F from torch import nn -from models.sequence.base import SequenceModule +import hydra +from models.sequence.base import SequenceModule, TransposedModule import src.models.nn.utils as U +from einops import rearrange +@TransposedModule class MultiheadAttention(SequenceModule): """ Simple wrapper for MultiheadAttention """ def __init__(self, d_model, n_heads, *args, causal=True, **kwargs): @@ -22,7 +26,7 @@ def forward(self, src, attn_mask=None, key_padding_mask=None, state=None, **kwar diagonal=1) # attn_mask, key_padding_mask = state # Note that this returns None for the second argument - y, z = self.mha(src, src, src, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, **kwargs) + y, _ = self.mha(src, src, src, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) return y, None def step(self, x, state): @@ -30,4 +34,89 @@ def step(self, x, state): # x: (B, D) y, z = self.mha(src, src, src, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, **kwargs) -MultiheadAttention = U.Transpose(MultiheadAttention) + +class VitAttention(SequenceModule): + """Copied from implementation for ViT: only used for ViT model + + This attention class makes several simplifying assumptions (commonly satisfied in vision + applications): + 1. q = k = v + 2. No masks: no attention mask, no key padding mask + 3. Embed dimension = Input dimension, i.e. projection matrices are square. + """ + + @property + def d_output(self): + return self.dim + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + # proj_drop=0., + packed_linear=True, + linear_cfg=None, + **kwargs, + ): + """packed_linear: whether to pack all 3 q_proj, k_proj, v_proj into 2 matrix. + This option is to be compatible with T2T-ViT pretrained weights, where there's only one + projection weight matrix. + """ + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + if linear_cfg is not None: + packed_linear = False + self.packed_linear = packed_linear + if packed_linear: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + if linear_cfg is None: + linear_cfg = {'_target_': 'torch.nn.Linear'} + self.q_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, + _recursive_=False) + self.k_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, + _recursive_=False) + self.v_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, + _recursive_=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + # Removing this dropout because we do this in SequenceResidualBlock + # self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, state=None): + B, N, C = x.shape + if self.packed_linear: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + else: + q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) + q, k, v = [rearrange(x, 'b n (h d) -> b h n d', h=self.num_heads) for x in (q, k, v)] + + # attn = (q @ k.transpose(-2, -1) * self.scale) + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = q.size() + _, _, k_seq_len, _ = k.size() + q = rearrange(q, 'b h t d -> (b h) t d') + k = rearrange(k, 'b h s d -> (b h) d s') + # Preallocate attn_weights for `baddbmm` + attn = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=q.dtype, device=q.device) + attn = rearrange(torch.baddbmm(attn, q, k, beta=0, alpha=self.scale), + '(b h) t s -> b h t s', h = self.num_heads) + + attn = F.softmax(attn, dim=-1, dtype=v.dtype) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + # x = self.proj_drop(x) + return x, None diff --git a/src/models/sequence/model.py b/src/models/sequence/model.py index 9e7c02d..2af7201 100644 --- a/src/models/sequence/model.py +++ b/src/models/sequence/model.py @@ -1,53 +1,46 @@ -""" Core deep sequence model backbone, in the style of ResNets / Transformers. +""" Isotropic deep sequence model backbone, in the style of ResNets / Transformers. The SequenceModel class implements a generic (batch, length, d_input) -> (batch, length, d_output) transformation """ -import functools +from functools import partial import torch import torch.nn as nn from einops import rearrange -from omegaconf import DictConfig from src.utils.config import to_list, to_dict -# from src.models.sequence.rnns import rnn # [21-09-13] I get a baffling error where hydra claims circular import if I _remove_ this line. This import doesn't even appear to be used at all in this file from src.models.sequence.block import SequenceResidualBlock from src.models.sequence.base import SequenceModule -from src.models.nn.components import Normalization -from src.models.nn.initialization import weights_init -from src.tasks import encoders, decoders +from src.models.nn.components import Normalization, DropoutNd class SequenceModel(SequenceModule): def __init__( self, - d_model, # Resize input (useful for deep models with residuals) - n_layers=1, # Number of layers - transposed=False, - dropout=0.0, # Residual dropout parameter - prenorm=True, - layer=None, # layer config, must be specified - residual=None, # Residual config - norm=None, # Normalization config (e.g. layer vs batch) - pool=None, - init=None, - verbose=False, - track_norms=True, - dropinp=0.0, + d_model, # Resize input (useful for deep models with residuals) + n_layers=1, # Number of layers + transposed=False, # Transpose inputs so each layer receives (batch, dim, length) + dropout=0.0, # Dropout parameter applied on every residual and every layer + tie_dropout=False, # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d + prenorm=True, # Pre-norm vs. post-norm + n_repeat=1, # Each layer is repeated n times per stage before applying pooling + layer=None, # Layer config, must be specified + residual=None, # Residual config + norm=None, # Normalization config (e.g. layer vs batch) + pool=None, # Config for pooling layer per stage + track_norms=True, # Log norms of each layer output + dropinp=0.0, # Input dropout ): super().__init__() # Save arguments needed for forward pass self.d_model = d_model self.transposed = transposed - self.verbose = verbose self.track_norms = track_norms - self._forward = False - if dropinp > 0.0: - self.drop = nn.Dropout2d(dropinp) if self.transposed else nn.Dropout(dropinp) - else: - self.drop = nn.Identity() + # Input dropout (not really used) + dropout_fn = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout + self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity() layer = to_list(layer, recursive=False) @@ -60,19 +53,20 @@ def __init__( _layer['transposed'] = transposed # Duplicate layers - layers = layer * n_layers + layers = layer * n_layers * n_repeat # Instantiate layers _layers = [] d = d_model for l, layer in enumerate(layers): - block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, layer=layer, residual=residual, norm=norm, pool=pool) + # Pool at the end of every n_repeat blocks + pool_cfg = pool if (l+1) % n_repeat == 0 else None + block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, tie_dropout=tie_dropout, transposed=transposed, layer=layer, residual=residual, norm=norm, pool=pool_cfg) _layers.append(block) d = block.d_output self.d_output = d self.layers = nn.ModuleList(_layers) - if prenorm: if norm is None: self.norm = None @@ -83,18 +77,9 @@ def __init__( else: self.norm = nn.Identity() - # Initializer hook - if init is not None: - self.apply(functools.partial(weights_init, init_cfg=init)) - def forward(self, inputs, *args, state=None, **kwargs): """ Inputs assumed to be (batch, sequence, dim) """ - # Debug - if self.verbose and not self._forward: - print("Model: unused kwargs", kwargs) - self._forward = True - - if self.transposed: inputs = rearrange(inputs, 'b l d -> b d l') + if self.transposed: inputs = rearrange(inputs, 'b ... d -> b d ...') inputs = self.drop(inputs) # Track norms @@ -105,12 +90,12 @@ def forward(self, inputs, *args, state=None, **kwargs): prev_states = [None] * len(self.layers) if state is None else state next_states = [] for layer, prev_state in zip(self.layers, prev_states): - outputs, state = layer(outputs, *args, state=prev_state, **kwargs) # TODO handle state + outputs, state = layer(outputs, *args, state=prev_state, **kwargs) next_states.append(state) if self.track_norms: output_norms.append(torch.mean(outputs.detach() ** 2)) - outputs = self.norm(outputs) + if self.norm is not None: outputs = self.norm(outputs) - if self.transposed: outputs = rearrange(outputs, 'b d l -> b l d') + if self.transposed: outputs = rearrange(outputs, 'b d ... -> b ... d') if self.track_norms: metrics = to_dict(output_norms, recursive=False) @@ -136,26 +121,14 @@ def fn(state): def default_state(self, *batch_shape, device=None): return [layer.default_state(*batch_shape, device=device) for layer in self.layers] - def step(self, x, state): - """ - Step one time step as a recurrent model. Intended to be used during validation. - - u: (B H) - state: (B H N) - Returns: output (B H), state (B H N) - """ - - # if self.transposed: x = rearrange(x, 'b l d -> b d l') - + def step(self, x, state, **kwargs): # Apply layers prev_states = [None] * len(self.layers) if state is None else state next_states = [] for layer, prev_state in zip(self.layers, prev_states): - x, state = layer.step(x, state=prev_state) + x, state = layer.step(x, state=prev_state, **kwargs) next_states.append(state) - - x = self.norm(x) - # if self.transposed: x = rearrange(x, 'b d l -> b l d') + x = self.norm(x) return x, next_states diff --git a/src/models/sequence/pool.py b/src/models/sequence/pool.py index 518d8de..8bdd4b1 100644 --- a/src/models/sequence/pool.py +++ b/src/models/sequence/pool.py @@ -1,4 +1,4 @@ -""" Implements downsampling and upsampling on sequences """ +"""Implements downsampling and upsampling on sequences.""" import torch from torch import nn @@ -14,24 +14,20 @@ expand: Repeat on the feature dimension """ -def downsample(x, stride=1, expand=1, average=False, transposed=False): +def downsample(x, stride=1, expand=1, transposed=False): if x is None: return None if stride > 1: - # TODO higher dimension stuff + assert x.ndim == 3, "Downsampling with higher-dimensional inputs is currently not supported. It is recommended to use average or spectral pooling instead." if transposed: - # einops appears slower than F - # if average: x = reduce(x, '... (l s) -> ... l', 'mean', s=stride) - if average: x = F.avg_pool1d(x, stride, stride) - else: x = x[..., 0::stride] + x = x[..., 0::stride] else: - if average: x = reduce(x, '... (l s) h -> ... l h', 'mean', s=stride) - else: x = x[..., 0::stride, :] + x = x[..., 0::stride, :] if expand > 1: if transposed: - x = repeat(x, '... d l -> ... (d e) l', e=expand) + x = repeat(x, 'b d ... -> b (d e) ...', e=expand) else: - x = repeat(x, '... d -> ... (d e)', e=expand) + x = repeat(x, 'b ... d -> b ... (d e)', e=expand) return x @@ -55,7 +51,6 @@ def __init__(self, d_input, stride=1, expand=1, transposed=True): self.d_input = d_input self.stride = stride self.expand = expand - # self.average = average self.transposed = transposed def forward(self, x): @@ -76,11 +71,67 @@ def __init__(self, d_input, stride=1, expand=1, transposed=True): self.d_input = d_input self.stride = stride self.expand = expand - # self.average = average self.transposed = transposed def forward(self, x): - return downsample(x, self.stride, self.expand, True, self.transposed) + if not self.transposed: + x = rearrange(x, 'b ... d -> b d ...') + # einops appears slower than F + if x.ndim == 3: + x = F.avg_pool1d(x, self.stride, self.stride) + elif x.ndim == 4: + x = F.avg_pool2d(x, self.stride, self.stride) + else: + # Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2" + reduce_str = "b d " + " ".join([f"(l{i} {self.stride})" for i in range(x.ndim-2)]) \ + + " -> b d " + " ".join([f"l{i}" for i in range(x.ndim-2)]) + x = reduce(x, reduce_str, 'mean') + + if self.expand > 1: + x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) + if not self.transposed: + x = rearrange(x, 'b d ... -> b ... d') + return x + + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state + + @property + def d_output(self): + return self.d_input * self.expand + +class DownSpectralPool(SequenceModule): + def __init__(self, d_input, stride=1, expand=1, transposed=True): + super().__init__() + self.d_input = d_input + self.stride = stride + self.expand = expand + self.transposed = transposed + + def forward(self, x): + """ + x: (B, L..., D) + """ + if not self.transposed: + x = rearrange(x, 'b ... d -> b d ...') + shape = x.shape[2:] + x_f = torch.fft.ifftn(x, s=shape) + + for axis, l in enumerate(shape): + assert l % self.stride == 0, 'input length must be divisible by stride' + new_l = l // self.stride + idx = torch.cat([torch.arange(0, new_l-new_l//2), l+torch.arange(-new_l//2, 0)]).to(x_f.device) + x_f = torch.index_select(x_f, 2+axis, idx) + x = torch.fft.ifftn(x_f, s=[l//self.stride for l in shape]) + x = x.real + + if self.expand > 1: + x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) + if not self.transposed: + x = rearrange(x, 'b d ... -> b ... d') + return x def step(self, x, state, **kwargs): if self.stride > 1 or self.expand > 1: @@ -124,10 +175,6 @@ def __init__(self, d_input, stride=1, expand=1, transposed=True): d_input * stride, d_input * expand, transposed=transposed, - # initializer=initializer, - # weight_norm = weight_norm, - # activation=activation, - # activate=True if activation is not None else False, ) def forward(self, x): @@ -168,13 +215,16 @@ def forward(self, x): if self.transposed: x = self.pool(x) +# TODO DownPool/UpPool are currently used by unet/sashimi backbones +# DownLinearPool is used by the registry (for isotropic backbone) +# DownPool is essentially the same as DownLinearPool. These should be consolidated class DownPool(SequenceModule): def __init__(self, d_input, d_output=None, expand=None, stride=1, transposed=True, weight_norm=True, initializer=None, activation=None): super().__init__() assert (d_output is None) + (expand is None) == 1 if d_output is None: d_output = d_input * expand - self._d_output = d_output + self.d_output = d_output self.stride = stride self.transposed = transposed @@ -212,14 +262,11 @@ def step(self, x, state, **kwargs): # TODO needs fix in transpose ca, **kwargsse else: return None, state - def default_state(self, *args, **kwargs): + def default_state(self, *batch_shape, device=None): return [] - @property - def d_output(self): return self._d_output - -class UpPool(SequenceModule): # TODO subclass SequenceModule +class UpPool(SequenceModule): def __init__(self, d_input, d_output, stride, transposed=True, weight_norm=True, initializer=None, activation=None): super().__init__() @@ -279,21 +326,5 @@ def d_output(self): return self._d_output 'sample': DownSample, 'pool': DownAvgPool, 'linear': DownLinearPool, - # 'pool': DownPool, + 'spectral': DownSpectralPool, } - -if __name__ == '__main__': - from benchmark import utils - - a = torch.ones(50, 256, 1024) - a, = utils.convert_data(a) - stride = 4 - - y0 = downsample(a, stride=stride, average=True, transposed=True) - y1 = F.avg_pool1d(a, stride, stride) - - print(y0.shape, y1.shape) - print(y0 - y1) - - utils.benchmark(downsample, a, stride, 1, True, True, repeat=100, desc='einops') - utils.benchmark(F.avg_pool1d, a, stride, stride, repeat=100, desc='torch') diff --git a/src/models/sequence/rnns/qrnn.py b/src/models/sequence/rnns/qrnn.py index 8d24af9..5b55561 100644 --- a/src/models/sequence/rnns/qrnn.py +++ b/src/models/sequence/rnns/qrnn.py @@ -1,6 +1,6 @@ """ Implements variant of HiPPO-RNN that doesn't feed the hidden and memory states into each other time-wise, instead using simpler linear recurrences in time and letting them interact depthwise. -[21-10-22] AG: This was old experimental code. It should still work (perhaps with some minimal modifications), but there is no reason to use this now. This was the initial step toward "deep linear parallelizable" versions of the HiPPO RNN which culminated in LSSL and S3. +[21-10-22] AG: This was old experimental code. It should still work (perhaps with some minimal modifications), but there is not much reason to use this now. This was the initial step toward "deep linear parallelizable" versions of the HiPPO RNN which culminated in LSSL and S3. """ import torch @@ -28,7 +28,6 @@ def __init__(self, order, measure, dt, discretization='bilinear'): A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization) - # self.register_buffer('A', torch.Tensor(A-np.eye(self.order))) self.register_buffer('A', torch.Tensor(A)) self.register_buffer('B', torch.Tensor(B)) diff --git a/src/models/sequence/rnns/rnn.py b/src/models/sequence/rnns/rnn.py index 617ec2d..50a507c 100644 --- a/src/models/sequence/rnns/rnn.py +++ b/src/models/sequence/rnns/rnn.py @@ -1,8 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F -# from utils.config import instantiate_name import src.utils as utils from src.models.sequence.rnns.cells import CellBase from src.models.sequence import SequenceModule @@ -72,7 +70,6 @@ def d_output(self): @property def state_to_tensor(self): """ Convert state into a single tensor output """ - # return self.cell.state_to_tensor(state) return self.cell.state_to_tensor diff --git a/src/models/sequence/rnns/sru.py b/src/models/sequence/rnns/sru.py index a756f01..dbb2f97 100644 --- a/src/models/sequence/rnns/sru.py +++ b/src/models/sequence/rnns/sru.py @@ -10,7 +10,8 @@ from src.models.sequence.rnns.cells import CellBase from src.models.nn import LinearActivation -from src.models.sequence.base import SequenceModule +import src.models.nn.utils as U +from src.models.sequence.base import SequenceModule, TransposedModule class SRUCell(CellBase): """ Implementation of the pure SRU cell that works with the models.rnn.RNN class """ @@ -33,14 +34,14 @@ def default_architecture(self): def __init__( self, d_input, d_model, - skip='H', # Highway, Residual, None + residual='H', # Highway, Residual, None offset=True, # whether to use previous or current cell to compute highway gate **kwargs ): self.offset = offset - self.skip = skip - assert self.skip in ['H', 'R', 'N'] + self.residual = residual + assert self.residual in ['H', 'R', 'N'] super().__init__(d_input, d_model, **kwargs) @@ -51,15 +52,15 @@ def reset_parameters(self): self.W_fc = nn.Parameter(torch.randn(self.d_model)) # highway - if self.skip == 'H': + if self.residual == 'H': self.W_rx = LinearActivation(self.d_input, self.d_model, bias=True, initializer=self.initializers['rx'], activation='sigmoid') self.W_rc = nn.Parameter(torch.randn(self.d_model)) # resize input if self.d_input != self.d_model: - self.skip_transform = nn.Linear(self.d_input, self.d_model) + self.residual_transform = nn.Linear(self.d_input, self.d_model) else: - self.skip_transform = nn.Identity() + self.residual_transform = nn.Identity() def forward(self, x, c): @@ -67,14 +68,14 @@ def forward(self, x, c): g = torch.sigmoid(self.W_fx(x) + self.W_fc * c) c_ = (1.-g) * c + g * self.W(x) - if self.skip == 'H': + if self.residual == 'H': if self.offset: r = torch.sigmoid(self.W_rx(x) + self.W_rc * c) else: r = torch.sigmoid(self.W_rx(x) + self.W_rc * c_) - h = (1-r) * self.skip_transform(x) + r * c_ - elif self.skip == 'R': - h = c_ + self.skip_transform(x) + h = (1-r) * self.residual_transform(x) + r * c_ + elif self.residual == 'R': + h = c_ + self.residual_transform(x) else: h = c_ @@ -92,7 +93,7 @@ def __init__(self, d_model, feedback=True): if self.feedback: self.W_fc = nn.Parameter(torch.randn(self.d_model)) - def forward(self, f, u): + def forward(self, f, u, state=None): """ f, u: (batch, length, dim) """ @@ -101,41 +102,46 @@ def forward(self, f, u): if not self.feedback: f = torch.sigmoid(f) - c = f.new_zeros(f.shape[..., 1:, :], requires_grad=False) + if state is None: + c = f.new_zeros((f.shape[0], f.shape[2]), requires_grad=False) + else: + assert state.shape == (f.shape[0], f.shape[2]) + c = state cs = [] for f_, u_ in zip(torch.unbind(f, dim=-2), torch.unbind(u, dim=-2)): if self.feedback: f_ = torch.sigmoid(f_ + self.W_fc * c) c = (1.-f_) * c + f_ * u_ cs.append(c) - return torch.stack(cs, dim=0) + return torch.stack(cs, dim=1), c +@TransposedModule class SRURNN(SequenceModule): """ Full RNN layer implementing the SRU (not just a Cell) """ - def __init__(self, d_input, d_model, feedback=True, return_output=True, dropout=0.0): + def __init__(self, d_input, d_model=None, feedback=True, return_output=True, dropout=0.0): super().__init__() + if d_model is None: d_model = d_input self.d_input = d_input self.d_model = d_model self.return_output = return_output self.W_fused = LinearActivation(d_input, 2*self.d_model, bias=True) - self.C = SRURNNGate(d_model, feedback) + self.C = SRURNNGate(d_model, feedback=feedback) if dropout > 0.0: raise NotImplementedError("Dropout currently not supported for SRU") - def forward(self, x, return_output=True): + def forward(self, x, state=None): ufr = self.W_fused(x) ufr = rearrange(ufr, 'b l (c d) -> b l c d', c=2) u, fx = torch.unbind(ufr, dim=2) # (B, L, H) - c = self.C(fx, u) # (B, L, H) - state = c[..., -1, :] + y, c = self.C(fx, u, state=state) # (B, L, H) if self.return_output: - return c, state + return y, c else: - return None, state + return None, c @property def d_state(self): @@ -148,4 +154,4 @@ def d_output(self): @property def state_to_tensor(self): return lambda state: state - # TODO haven't checked the default_state, step functions + # TODO default_state, step functions diff --git a/src/models/sequence/sashimi.py b/src/models/sequence/sashimi.py index 62c641e..d837951 100644 --- a/src/models/sequence/sashimi.py +++ b/src/models/sequence/sashimi.py @@ -11,11 +11,11 @@ class Sashimi(SequenceModule): def __init__( self, - d_model, - n_layers, - pool=[], - expand=1, - ff=2, + d_model, + n_layers, + pool=[], + expand=1, + ff=2, prenorm=False, dropout=0.0, dropres=0.0, @@ -47,7 +47,7 @@ def __init__( layer_cfg = layer.copy() layer_cfg['dropout'] = dropout layer_cfg['transposed'] = self.transposed - # layer_cfg['initializer'] = initializer + layer_cfg['initializer'] = initializer layer_cfg['l_max'] = L ff_cfg = { @@ -65,6 +65,7 @@ def _residual(d, i, layer): i, prenorm=prenorm, dropout=dropres, + transposed=self.transposed, layer=layer, residual=residual if residual is not None else 'R', norm=norm, @@ -102,7 +103,7 @@ def _residual(d, i, layer): if ff > 0: block.append(_residual(H, i+1, ff_cfg)) u_layers.append(nn.ModuleList(block)) - + self.u_layers = nn.ModuleList(u_layers) assert H == d_model @@ -126,7 +127,7 @@ def _residual(d, i, layer): def d_output(self): return self.d_model - def forward(self, x, state=None): + def forward(self, x, state=None, **kwargs): """ input: (batch, length, d_input) output: (batch, length, d_output) diff --git a/src/models/sequence/ss/dplr.py b/src/models/sequence/ss/dplr.py new file mode 100644 index 0000000..2192417 --- /dev/null +++ b/src/models/sequence/ss/dplr.py @@ -0,0 +1,104 @@ +"""Initializations of structured state space models""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +import src.models.hippo.hippo as hippo + +def dplr(scaling='linear', N=64, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False): + assert dtype == torch.float or torch.double + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + + pi = torch.tensor(math.pi) + if random_real: + real_part = torch.rand(H, N//2) + else: + real_part = .5 * torch.ones(H, N//2) + if random_imag: + imag_part = N//2 * torch.rand(H, N//2) + else: + imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) + + real_part = real_scale * real_part + if scaling == 'random': + imag_part = torch.randn(H, N//2) + elif scaling == 'real': + imag_part = 0 * imag_part + real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H) + elif scaling in ['linear', 'lin']: + imag_part = pi * imag_part + elif scaling in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix + imag_part = 1/pi * N * (N/(1+2*imag_part)-1) + elif scaling in ['inverse2', 'inv2']: + imag_part = 1/pi * N * (N/(1+imag_part)-1) + elif scaling in ['quadratic', 'quad']: + imag_part = 1/pi * (1+2*imag_part)**2 + elif scaling in ['legs', 'hippo']: + w, _, _, _ = hippo.nplr('legsd', N) + imag_part = w.imag + + else: raise NotImplementedError + imag_part = imag_scale * imag_part + w = -real_part + 1j * imag_part + + # Initialize B + if random_B: + B = torch.randn(H, N//2, dtype=dtype) + else: + B = torch.ones(H, N//2, dtype=dtype) + + if normalize: + norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector + B = B / zeta**.5 + + P = torch.randn(rank, H, N//2, dtype=dtype) + if diagonal: P = P * 0.0 + V = torch.eye(N, dtype=dtype)[:, :N//2] # Only used in testing + V = repeat(V, 'n m -> h n m', h=H) + + return w, P, B, V + +def ssm(measure, N, R, H, **ssm_args): + """Dispatcher to create single SSM initialization + + N: state size + R: rank (for DPLR parameterization) + H: number of independent SSM copies + """ + + if measure == "dplr": + w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) + elif measure.startswith("diag"): + args = measure.split("-") + assert args[0] == "diag" and len(args) > 1 + scaling = args[1] + w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) + else: + w, P, B, V = hippo.nplr(measure, N, R, **ssm_args) + w = repeat(w, 'n -> s n', s=H) + P = repeat(P, 'r n -> r s n', s=H) + B = repeat(B, 'n -> s n', s=H) + V = repeat(V, 'n m -> s n m', s=H) + return w, P, B, V + +combinations = { + 'hippo': ['legs', 'fourier'], + 'diag': ['diag-inv', 'diag-lin'], + 'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'], +} + +def combination(measures, N, R, S, **ssm_args): + if isinstance(measures, str): + measures = combinations[measures] if measures in combinations else [measures] + + assert S % len(measures) == 0, f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" + w, P, B, V = zip( + *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] + ) + w = torch.cat(w, dim=0) # (S N) + P = torch.cat(P, dim=1) # (R S N) + B = torch.cat(B, dim=0) # (S N) + V = torch.cat(V, dim=0) # (S N N) + return w, P, B, V diff --git a/src/models/sequence/ss/kernel.py b/src/models/sequence/ss/kernel.py index 35cb845..f4fc822 100644 --- a/src/models/sequence/ss/kernel.py +++ b/src/models/sequence/ss/kernel.py @@ -1,59 +1,58 @@ -""" Core S4 convolution kernel implementing the 'normal plus low-rank' algorithm. +"""SSM convolution kernels. -The main module is SSKernelNPLR, which stores parameters A, B, C, dt, and calling it creates the SSM convolution kernel bar{K}. +SSKernelNPLR is the S4 kernel, implementing the 'normal plus low-rank' algorithm from the original S4 paper. This stores parameters A, B, C, dt, and calling it creates the SSM convolution kernel bar{K}. A much simpler version SSKernelSlow is included for illustration purposes: it has the same output, but uses the naive algorithm which is much slower. This module is meant for testing and exposition, to understand what the State Space Kernel actually does. -HiPPOSSKernel specializes the SSKernels to specific instantiations of HiPPO matrices. -""" - -if __name__ == "__main__": - import sys - import pathlib +SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. - p = pathlib.Path().absolute() - print("Adding path: ", p) - sys.path.append(str(p)) +SSKernel wraps these with common options and handles the initialization. +""" import math import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -import scipy.fft from einops import rearrange, repeat from opt_einsum import contract, contract_expression -from omegaconf import DictConfig import src.models.hippo.hippo as hippo +import src.models.sequence.ss.dplr as dplr from src.models.functional.krylov import krylov, power - import src.utils.train log = src.utils.train.get_logger(__name__) -try: +try: # Try CUDA extension from extensions.cauchy.cauchy import cauchy_mult has_cauchy_extension = True - log.info("CUDA extension for cauchy multiplication found.") + log.info("CUDA extension for Cauchy multiplication found.") except: log.warn( - "CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%" + "CUDA extension for Cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%" ) has_cauchy_extension = False try: import pykeops from src.models.functional.cauchy import cauchy_conj + from src.models.functional.vandermonde import log_vandermonde, log_vandermonde_transpose + has_pykeops = True log.info("Pykeops installation found.") except ImportError: has_pykeops = False - from src.models.functional.cauchy import cauchy_slow + from src.models.functional.cauchy import cauchy_naive + from src.models.functional.vandermonde import log_vandermonde_naive as log_vandermonde + from src.models.functional.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose if not has_cauchy_extension: log.error( - "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency." + "Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for memory efficiency." ) + log.error( + "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." + ) @@ -69,63 +68,25 @@ else: _resolve_conj = lambda x: x.conj() -def bilinear(dt, A, B=None): - """ - dt: (...) timescales - A: (... N N) - B: (... N) - """ - N = A.shape[-1] - I = torch.eye(N).to(A) - A_backwards = I - dt[:, None, None] / 2 * A - A_forwards = I + dt[:, None, None] / 2 * A - - if B is None: - dB = None - else: - dB = dt[..., None] * torch.linalg.solve( - A_backwards, B.unsqueeze(-1) - ).squeeze(-1) # (... N) - - dA = torch.linalg.solve(A_backwards, A_forwards) # (... N N) - return dA, dB - class OptimModule(nn.Module): """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ - def register(self, name, tensor, trainable=False, lr=None, wd=None): - """Utility method: register a tensor as a buffer or trainable parameter""" + def register(self, name, tensor, lr=None): + """Register a tensor with a configurable learning rate and 0 weight decay""" - if trainable: - self.register_parameter(name, nn.Parameter(tensor)) - else: + if lr == 0.0: self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) - optim = {} - if trainable and lr is not None: - optim["lr"] = lr - if trainable and wd is not None: - optim["weight_decay"] = wd - if len(optim) > 0: + optim = {"weight_decay": 0.0} + if lr is not None: optim["lr"] = lr setattr(getattr(self, name), "_optim", optim) - class SSKernelNPLR(OptimModule): - """ Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) - - The class name stands for 'State-Space SSKernel for Normal Plus Low-Rank'. - The parameters of this function are as follows. - - A: (... N N) the state matrix - B: (... N) input matrix - C: (... N) output matrix - dt: (...) timescales / discretization step size - p, q: (... P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix - - The forward pass of this Module returns: - (... L) that represents represents FFT SSKernel_L(A^dt, B^dt, C) - + """ + Stores a representation of and computes the SSKernel function K_L(dt, A, B, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) """ @torch.no_grad() @@ -181,26 +142,32 @@ def __init__( self, w, P, B, C, log_dt, L=None, # starting/maximum length of kernel - trainable=None, lr=None, - lr_dt=None, verbose=False, keops=False, - fast_gate=False, - quadrature=None, + real_type='exp', # ['none' | 'exp' | 'relu' | sigmoid'] + real_tolerance=1e-3, + bandlimit=None, ): """ L: Maximum length; this module computes an SSM kernel of length L - w: (n_ssm, N) - p: (r, n_ssm, N) low-rank correction to A - A represented by diag(w) - pq^* + A is represented by diag(w) - PP^* + w: (S, N) diagonal part + P: (R, S, N) low-rank part - B: (n_ssm, N) + B: (S, N) + C: (C, H, N) dt: (H) timescale per feature - C: (C, H, N) system is 1-D to c-D (channels) + lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) - trainable: toggle which of the parameters is trainable - lr: add hook to set lr of hippo parameters specially (everything besides C) + Dimensions: + N (or d_state): state size + H (or d_model): total SSM copies + S (or n_ssm): number of trainable copies of (A, B, dt); must divide H + R (or rank): rank of low-rank part + C (or channels): system is 1-dim to C-dim + + The forward pass of this Module returns a tensor of shape (C, H, L) Note: tensor shape N here denotes half the true state size, because of conjugate symmetry """ @@ -208,7 +175,9 @@ def __init__( super().__init__() self.verbose = verbose self.keops = keops - self.fast_gate = fast_gate + self.bandlimit = bandlimit + self.real_type = real_type + self.real_tolerance = real_tolerance # Rank of low-rank correction self.rank = P.shape[-3] @@ -217,48 +186,67 @@ def __init__( self.N = w.size(-1) # Check different SSM inits - assert w.size(-2) == P.size(-2) == B.size(-2) # Number of copies + assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm assert self.H % w.size(0) == 0 self.n_ssm = w.size(0) - self.copies = self.H // w.size(0) - if lr_dt is None: lr_dt = lr + self.broadcast = self.H // w.size(0) # Each trainable SSM needs to be duplicated this many times # Broadcast everything to correct shapes - C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) B = B.unsqueeze(0) # (1, 1, N) # Register parameters self.C = nn.Parameter(_c2r(_resolve_conj(C))) - train = False - if trainable is None: trainable = {} - if trainable == False: trainable = {} - if trainable == True: trainable, train = {}, True - self.register("log_dt", log_dt, trainable.get('dt', train), lr_dt, 0.0) - self.register("B", _c2r(B), trainable.get('B', train), lr, 0.0) - self.register("P", _c2r(P), trainable.get('P', train), lr, 0.0) - log_w_real = torch.log(-w.real + 1e-3) # Some of the HiPPO methods have real part 0 - w_imag = w.imag - self.register("log_w_real", log_w_real, trainable.get('A', 0), lr, 0.0) - self.register("w_imag", w_imag, trainable.get('A', train), lr, 0.0) + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + self.register("log_dt", log_dt, lr_dict.get('dt', lr)) + self.register("B", _c2r(B), lr_dict.get('B', lr)) + self.register("P", _c2r(P), lr_dict.get('A', lr)) + self.register("inv_w_real", self._w_init(w.real), lr_dict.get('A', lr)) + self.register("w_imag", w.imag, lr_dict.get('A', lr)) self.l_max = L self.register_buffer('L', torch.tensor(0)) # Internal length - self.quadrature = quadrature + def _w_init(self, w_real): + w_real = torch.clamp(w_real, max=-self.real_tolerance) + if self.real_type == 'none': + return -w_real + elif self.real_type == 'exp': + return torch.log(-w_real) # Some of the HiPPO methods have real part 0 + elif self.real_type == 'relu': + return -w_real + elif self.real_type == 'sigmoid': + return torch.logit(-w_real) + elif self.real_type == 'softplus': + return torch.log(torch.exp(-w_real)-1) + else: raise NotImplementedError def _w(self): # Get the internal w (diagonal) parameter - w_real = -torch.exp(self.log_w_real) - w_imag = self.w_imag - w = w_real + 1j * w_imag + if self.real_type == 'none': + w_real = -self.inv_w_real + elif self.real_type == 'exp': + w_real = -torch.exp(self.inv_w_real) + elif self.real_type == 'relu': + w_real = -F.relu(self.inv_w_real) + elif self.real_type == 'sigmoid': + w_real = -F.sigmoid(self.inv_w_real) + elif self.real_type == 'softplus': + w_real = -F.softplus(self.inv_w_real) + else: raise NotImplementedError + w = w_real + 1j * self.w_imag return w def forward(self, state=None, rate=1.0, L=None): """ - state: (..., s, N) extra tensor that augments B + state: (B, H, N) initial state rate: sampling rate factor + L: target length - returns: (..., c+s, L) + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state """ # Initialize C~ if necessary (done in forward pass so it's on the correct device) @@ -266,11 +254,7 @@ def forward(self, state=None, rate=1.0, L=None): self._setup_C(self.l_max) # Handle sampling rate logic - # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) sampling rate rate - # If either are not passed in, assume we're not asked to change the scale of our kernel - assert not (rate is None and L is None) - if rate is None: - rate = self.L.item() / L + # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate if L is None: L = round(self.L.item() / rate) @@ -280,22 +264,28 @@ def forward(self, state=None, rate=1.0, L=None): self._setup_C(continuous_L) discrete_L = round(self.L.item()/rate) - if self.fast_gate: dt = torch.exp(torch.sinh(self.log_dt)) * rate - else: dt = torch.exp(self.log_dt) * rate + dt = torch.exp(self.log_dt) * rate B = _r2c(self.B) C = _r2c(self.C) P = _r2c(self.P) Q = P.conj() - w = self._w() + w = self._w() # (S, N) where S=n_ssm + + # Address bandlimiting + if self.bandlimit is not None: + freqs = w.imag.abs() / (2*math.pi) # (H, N) + freqs = dt[:, None] / rate * freqs # (H, N) + mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + C = C * mask # Get FFT nodes of right length omega, z = self._omega(discrete_L, dtype=w.dtype, device=w.device, cache=(rate==1.0)) # Broadcast parameters to same hidden features H - B = repeat(B, '1 t n -> 1 (v t) n', v=self.copies) - P = repeat(P, 'r t n -> r (v t) n', v=self.copies) - Q = repeat(Q, 'r t n -> r (v t) n', v=self.copies) - w = repeat(w, 't n -> (v t) n', v=self.copies) + B = repeat(B, '1 t n -> 1 (v t) n', v=self.broadcast) + P = repeat(P, 'r t n -> r (v t) n', v=self.broadcast) + Q = repeat(Q, 'r t n -> r (v t) n', v=self.broadcast) + w = repeat(w, 't n -> (v t) n', v=self.broadcast) # Augment B if state is not None: @@ -311,19 +301,17 @@ def forward(self, state=None, rate=1.0, L=None): s = s / dt.unsqueeze(-1) + sA / 2 s = s[..., :self.N] - B = torch.cat([s, B], dim=-3) # (s+1, H, N) + B = torch.cat([s, B], dim=-3) # (B+1, H, N) # Incorporate dt into A w = w * dt.unsqueeze(-1) # (H N) # Stack B and p, C and q for convenient batching - B = torch.cat([B, P], dim=-3) # (s+1+r, H, N) - C = torch.cat([C, Q], dim=-3) # (c+r, H, N) + B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) + C = torch.cat([C, Q], dim=-3) # (C+R, H, N) # Incorporate B and C batch dimensions - v = B.unsqueeze(-3) * C.unsqueeze(-4) # (s+1+r, c+r, H, N) - # w = w[None, None, ...] # (1, 1, H, N) - # z = z[None, None, None, ...] # (1, 1, 1, L) + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) # Calculate resolvent at omega if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: @@ -331,8 +319,8 @@ def forward(self, state=None, rate=1.0, L=None): elif has_pykeops: r = cauchy_conj(v, z, w) else: - r = cauchy_slow(v, z, w) - r = r * dt[None, None, :, None] # (S+1+R, C+R, H, L) + r = cauchy_naive(v, z, w) + r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) # Low-rank Woodbury correction if self.rank == 1: @@ -365,44 +353,35 @@ def forward(self, state=None, rate=1.0, L=None): k_f = k_f * 2 / (1 + omega) # Move from frequency to coefficients - k = torch.fft.irfft(k_f, n=discrete_L) # (S+1, C, H, L) + k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) # # Truncate to target length k = k[..., :L] if state is not None: - k_state = k[:-1, :, :, :] # (S, C, H, L) + k_state = k[:-1, :, :, :] # (B, C, H, L) else: k_state = None k_B = k[-1, :, :, :] # (C H L) - if self.quadrature == 'trapezoid': - w = torch.ones(*k_B.shape).to(k_B) * dt[None, :, None] - w[..., 0] /= 2. - w[..., -1] /= 2 - k_B = k_B * w - elif self.quadrature == 'simpson': - w = torch.ones(*k_B.shape).to(k_B) * dt[None, :, None] / 3. - w[..., 1:-1:2] *= 4 - w[..., 2:-1:2] *= 2 - k_B = k_B * w - return k_B, k_state @torch.no_grad() def double_length(self): - # if self.verbose: log.info(f"S4: Doubling length from L = {self.L} to {2*self.L}") self._setup_C(2*self.L) @torch.no_grad() def _check(self): """Check if A, B, C parameters and vanilla SSKernel construction can be recovered""" - self.setup_step() + # assert self.L > 0, "Set up module first" + + K = self.forward(L=self.l_max)[0] - K = krylov(self.L, self.dA, self.dB, self.dC) + self._setup_step() + K_ = krylov(self.l_max, self.dA, self.dB, self.dC) - diff = K - self.forward(L=self.L)[0] + diff = K - K_ print("checking DPLR Kernel construction", torch.sum(diff ** 2)) @torch.no_grad() @@ -414,27 +393,27 @@ def _setup_linear(self): Q = P.conj() # Repeat w shape properly - B = repeat(B, '1 t n -> 1 (v t) n', v=self.copies) - P = repeat(P, 'r t n -> r (v t) n', v=self.copies) - Q = repeat(Q, 'r t n -> r (v t) n', v=self.copies) - w = repeat(w, 't n -> (v t) n', v=self.copies) + B = repeat(B, '1 t n -> 1 (v t) n', v=self.broadcast) + P = repeat(P, 'r t n -> r (v t) n', v=self.broadcast) + Q = repeat(Q, 'r t n -> r (v t) n', v=self.broadcast) + w = repeat(w, 't n -> (v t) n', v=self.broadcast) # Prepare Linear stepping dt = torch.exp(self.log_dt) D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) - R = (torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H r r) + R = (torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H R R) Q_D = rearrange(Q*D, 'r h n -> h r n') try: - R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N) - except torch._C._LinAlgError: - R = torch.tensor(np.linalg.solve(R.to(Q_D).cpu(), Q_D.cpu())).to(Q_D) + R = torch.linalg.solve(R, Q_D) # (H R N) + except: + R = torch.tensor(np.linalg.solve(R.to(Q_D).contiguous().detach().cpu(), Q_D.contiguous().detach().cpu())).to(Q_D) R = rearrange(R, 'h r n -> r h n') self.step_params = { "D": D, # (H N) - "R": R, # (r H N) - "P": P, # (r H N) - "Q": Q, # (r H N) + "R": R, # (R H N) + "P": P, # (R H N) + "Q": Q, # (R H N) "B": B, # (1 H N) "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) } @@ -468,9 +447,9 @@ def _step_state_linear(self, u=None, state=None): contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', p, x, y) # inner outer product D = step_params["D"] # (H N) E = step_params["E"] # (H N) - R = step_params["R"] # (r H N) - P = step_params["P"] # (r H N) - Q = step_params["Q"] # (r H N) + R = step_params["R"] # (R H N) + P = step_params["P"] # (R H N) + Q = step_params["Q"] # (R H N) B = step_params["B"] # (1 H N) new_state = E * state - contract_fn(P, Q, state) # (B H N) @@ -489,7 +468,6 @@ def _setup_state(self): state = torch.eye(2*self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N) dA = self._step_state_linear(state=state) dA = rearrange(dA, "n h m -> h m n") - # self.dA = dA # (H N N) u = C.new_ones(self.H) dB = self._step_state_linear(u=u) @@ -502,20 +480,23 @@ def _step_state(self, u, state): next_state = self.state_contraction(self.dA, state) + self.input_contraction(self.dB, u) return next_state - - def setup_step(self, mode='dense'): + def _setup_step(self, mode='dense'): """ Set up dA, dB, dC discretized parameters for stepping """ self.dA, self.dB = self._setup_state() # Calculate original C - dA_L = power(self.L, self.dA) - I = torch.eye(self.dA.size(-1)).to(dA_L) C = _conj(_r2c(self.C)) # (H C N) - - dC = torch.linalg.solve( - I - dA_L.transpose(-1, -2), - C.unsqueeze(-1), - ).squeeze(-1) + if self.L.item() == 0: + dC = C + else: + # self.C represents C_tilde + dA_L = power(self.L.item(), self.dA) + I = torch.eye(self.dA.size(-1)).to(dA_L) + + dC = torch.linalg.solve( + I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) self.dC = dC # Do special preprocessing for different step modes @@ -542,7 +523,6 @@ def setup_step(self, mode='dense'): pass else: raise NotImplementedError("NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}") - def default_state(self, *batch_shape): C = _r2c(self.C) N = C.size(-1) @@ -550,10 +530,11 @@ def default_state(self, *batch_shape): # Cache the tensor contractions we will later do, for efficiency # These are put in this function because they depend on the batch size - if self._step_mode !='linear': + step_mode = getattr(self, "_step_mode", "dense") # Used in default_state, which is called without _setup_step() in forward_state() + if step_mode != 'linear': N *= 2 - if self._step_mode == 'diagonal': + if step_mode == 'diagonal': self.state_contraction = contract_expression( "h n, ... h n -> ... h n", (H, N), @@ -583,14 +564,14 @@ def default_state(self, *batch_shape): return state def step(self, u, state): - """ Must have called self.setup_step() and created state with self.default_state() before calling this """ + """ Must have called self._setup_step() and created state with self.default_state() before calling this """ if self._step_mode == 'linear': new_state = self._step_state_linear(u, state) else: new_state = self._step_state(u, state) y = self.output_contraction(self.dC, new_state) - return y, new_state + return y.real, new_state class SSKernelSlow(OptimModule): @@ -606,7 +587,30 @@ class SSKernelSlow(OptimModule): Result is expected to be equal to SSKernelNPLR(L, w, P, B, C, log_dt, P)() if A = w - PP^* """ - def __init__(self, A, B, C, log_dt, L=None, trainable=None, lr=None): + @staticmethod + def bilinear(dt, A, B=None): + """ + dt: (...) timescales + A: (... N N) + B: (... N) + """ + N = A.shape[-1] + I = torch.eye(N).to(A) + A_backwards = I - dt[:, None, None] / 2 * A + A_forwards = I + dt[:, None, None] / 2 * A + + if B is None: + dB = None + else: + dB = dt[..., None] * torch.linalg.solve( + A_backwards, B.unsqueeze(-1) + ).squeeze(-1) # (... N) + + dA = torch.linalg.solve(A_backwards, A_forwards) # (... N N) + return dA, dB + + + def __init__(self, A, B, C, log_dt, L=None, lr=None): super().__init__() self.L = L self.N = A.size(-1) @@ -615,20 +619,18 @@ def __init__(self, A, B, C, log_dt, L=None, trainable=None, lr=None): C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) # Register parameters - train = False - if trainable is None: trainable = {} - if trainable == False: trainable = {} - if trainable == True: trainable, train = {}, True - self.register("log_dt", log_dt, trainable.get('dt', train), lr) - self.register("A", A, trainable.get('A', train), lr) - self.register("B", B, trainable.get('B', train), lr) + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + self.register("log_dt", log_dt, lr_dict.get('dt', lr)) + self.register("A", _c2r(A), lr_dict.get('A', lr)) + self.register("B", _c2r(B), lr_dict.get('B', lr)) # NOTE leaving in complex form for convenience, which means it currently won't work with DDP and might have incorrect param count # This class shouldn't be used for anything other than testing and simple ablations, so this is fine # self.register("C", C.conj().resolve_conj(), True, None, wd=None) self.C = nn.Parameter(_resolve_conj(C)) # Cache if nothing is trained - self.trainable = trainable.get('dt', train) or trainable.get('A', train) or trainable.get('B', train) + self.trainable = lr_dict.get('dt', lr) > 0.0 or lr_dict.get('A', lr) > 0.0 or lr_dict.get('B', lr) > 0.0 self.K = None # Compute in forward pass since that ensures correct device def forward(self, state=None, rate=1.0, L=None): @@ -638,18 +640,20 @@ def forward(self, state=None, rate=1.0, L=None): assert rate == 1.0 and L is not None if self.trainable: - dA, dB = bilinear(torch.exp(self.log_dt), self.A, self.B) + dA, dB = SSKernelSlow.bilinear(torch.exp(self.log_dt), self.A, self.B) k = krylov(L, dA, dB, self.C) # (H L) else: if self.K is None: - dA, dB = bilinear(torch.exp(self.log_dt), self.A, self.B) + dA, dB = SSKernelSlow.bilinear(torch.exp(self.log_dt), self.A, self.B) self.K = krylov(L, dA, dB) # (H N L) k = contract('hnl,chn->chl', self.K[..., :L], self.C) + k = k.float() if state is not None: state = state.to(self.dA) state = contract("... n m, ... m -> ... n", self.dA, state) k_state = krylov(L, self.dA, state.unsqueeze(-3), self.C) + k_state = k_state.float() else: k_state = None return k, k_state @@ -660,195 +664,186 @@ def default_state(self, *batch_shape): return state def _setup_state(self): - self.dA, self.dB = bilinear(torch.exp(self.log_dt), self.A, self.B) + dA, dB = SSKernelSlow.bilinear(torch.exp(self.log_dt), self.A, self.B) + return dA, dB - def setup_step(self): - self._setup_state() + def _setup_step(self): + self.dA, self.dB = self._setup_state() self.dC = self.C def step(self, u, state): next_state = contract("h m n, b h n -> b h m", self.dA, state) \ + contract("h n, b h -> b h n", self.dB, u) y = contract("c h n, b h n -> b c h", self.dC, next_state) - return y, next_state + return y.real, next_state class SSKernelDiag(OptimModule): - """ Version using (complex) diagonal state matrix. Main difference is this uses the ZOH instead of Bilinear transform. Note that it is slower and less memory efficient than the NPLR kernel because of lack of kernel support. - - """ + """Version using (complex) diagonal state matrix (S4D)""" def __init__( self, - w, C, log_dt, + A, B, C, log_dt, L=None, disc='bilinear', - trainable=None, + real_type='exp', lr=None, - quadrature=None, + bandlimit=None, ): super().__init__() self.L = L self.disc = disc + self.bandlimit = bandlimit + self.real_type = real_type # Rank of low-rank correction - assert w.size(-1) == C.size(-1) + assert A.size(-1) == C.size(-1) self.H = log_dt.size(-1) - self.N = w.size(-1) - assert self.H % w.size(0) == 0 - self.copies = self.H // w.size(0) + self.N = A.size(-1) + assert A.size(-2) == B.size(-2) # Number of independent SSMs trained + assert self.H % A.size(-2) == 0 + self.n_ssm = A.size(-2) + self.repeat = self.H // A.size(0) - # Broadcast everything to correct shapes - C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) + self.channels = C.shape[0] + self.C = nn.Parameter(_c2r(_resolve_conj(C))) # Register parameters - # C is a regular parameter, not part of state - self.C = nn.Parameter(_c2r(_resolve_conj(C))) - train = False - if trainable is None: trainable = {} - if trainable == False: trainable = {} - if trainable == True: trainable, train = {}, True - self.register("log_dt", log_dt, trainable.get('dt', train), lr, 0.0) - - if self.disc in ['bilinear', 'zoh', 'foh']: - log_w_real = torch.log(-w.real + 1e-3) # Some of the HiPPO methods have real part 0 - w_imag = w.imag - self.register("log_w_real", log_w_real, trainable.get('A', 0), lr, 0.0) - self.register("w_imag", w_imag, trainable.get('A', train), lr, 0.0) - elif self.disc == 'dss': - self.register("w", _c2r(w), trainable.get('A', train), lr, 0.0) + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + + self.register("log_dt", log_dt, lr_dict.get('dt', lr)) + self.register("A", _c2r(A), lr_dict.get('A', lr)) + self.register("B", _c2r(B), lr_dict.get('B', lr)) + self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr)) + self.register("A_imag", A.imag, lr_dict.get('A', lr)) + + def _A_init(self, A_real): + A_real = torch.clamp(A_real, max=-1e-4) + if self.real_type == 'none': + return -A_real + elif self.real_type == 'exp': + return torch.log(-A_real) # Some of the HiPPO methods have real part 0 + elif self.real_type == 'relu': + return -A_real + elif self.real_type == 'sigmoid': + return torch.logit(-A_real) + elif self.real_type == 'softplus': + return torch.log(torch.exp(-A_real)-1) else: raise NotImplementedError - self.quadrature = quadrature - - - def _w(self): - # Get the internal w (diagonal) parameter - if self.disc != 'dss': - w_real = -torch.exp(self.log_w_real) - w_imag = self.w_imag - w = w_real + 1j * w_imag - else: - w = _r2c(self.w) # (..., N) - w = repeat(w, 't n -> (v t) n', v=self.copies) # (H N) - return w + def _A(self): + # Get the internal A (diagonal) parameter + if self.real_type == 'none': + A_real = -self.inv_A_real + elif self.real_type == 'exp': + A_real = -torch.exp(self.inv_A_real) + elif self.real_type == 'relu': + # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it + A_real = -F.relu(self.inv_A_real)-1e-4 + elif self.real_type == 'sigmoid': + A_real = -F.sigmoid(self.inv_A_real) + elif self.real_type == 'softplus': + A_real = -F.softplus(self.inv_A_real) + else: raise NotImplementedError + A = A_real + 1j * self.A_imag + return A - def forward(self, state=None, rate=1.0, L=None): + def forward(self, L, state=None, rate=1.0, u=None): """ - state: (..., s, N) extra tensor that augments B + state: (B, H, N) initial state rate: sampling rate factor + L: target length - returns: (..., c+s, L) + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state """ - # Handle sampling rate logic - # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) sampling rate rate - # If either are not passed in, assume we're not asked to change the scale of our kernel - assert not (rate is None and L is None) - if rate is None: - rate = self.L / L - if L is None: - L = round(self.L / rate) dt = torch.exp(self.log_dt) * rate # (H) C = _r2c(self.C) # (C H N) - w = self._w() # (H N) + A = self._A() # (H N) + + B = _r2c(self.B) + B = repeat(B, 't n -> 1 (v t) n', v=self.repeat) + if self.bandlimit is not None: + freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N) + mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + C = C * mask # Incorporate dt into A - dtA = w * dt.unsqueeze(-1) # (H N) + A = repeat(A, 't n -> (v t) n', v=self.repeat) + dtA = A * dt.unsqueeze(-1) # (H N) + + # Augment B with state + if state is not None: + s = state / dt.unsqueeze(-1) + if self.disc == 'bilinear': + s = s * (1. + dtA/2) + elif self.disc == 'zoh': + s = s * dtA * dtA.exp() / (dtA.exp() - 1.) + B = torch.cat([s, B], dim=-3) # (1+B H N) + + C = (B[:, None, :, :] * C).view(-1, self.H, self.N) if self.disc == 'zoh': # Power up - K = dtA.unsqueeze(-1) * torch.arange(L, device=w.device) # (H N L) - C = C * (torch.exp(dtA)-1.) / w - K = contract('chn, hnl -> chl', C, torch.exp(K)) - K = 2*K.real - elif self.disc == 'foh': - # Power up - K = dtA.unsqueeze(-1) * torch.arange(L, device=w.device) # (H N L) - K_exp = torch.exp(K) - C = C / (dt.unsqueeze(-1) * w ** 2) - exp_dA = torch.exp(dtA) - C_0 = - (exp_dA - 1. - dtA * exp_dA) * C # kernel for conv with u_k - C_1 = (exp_dA - 1. - dtA) * C # kernel for conv with u_{k+1} - K_0 = contract('chn, hnl -> chl', C_0, K_exp) - K_1 = contract('chn, hnl -> chl', C_1, K_exp) - K_0 = 2*K_0.real - K_1 = 2*K_1.real - K = K_1 - K[..., -1] = 0. - K[..., 1:] += K_0[..., :-1] + C = C * (torch.exp(dtA)-1.) / A + K = log_vandermonde(C, dtA, L) # (H L) elif self.disc == 'bilinear': + C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A dA = (1. + dtA/2) / (1. - dtA/2) - K = dA.unsqueeze(-1) ** torch.arange(L, device=w.device) # (H N L) - C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / w - K = contract('chn, hnl -> chl', C, K) - K = 2*K.real - else: + K = log_vandermonde(C, dA.log(), L) + elif self.disc == 'dss': # Implementation from DSS meant for case when real eigenvalues can be positive - P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] - w_gt_0 = w.real > 0 # [N] - if w_gt_0.any(): + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + A_gt_0 = A.real > 0 # [N] + if A_gt_0.any(): with torch.no_grad(): - P_max = dtA * (w_gt_0 * (L-1)) # [H N] - P = P - P_max.unsqueeze(-1) # [H N L] - S = P.exp() # [H N L] + P_max = dtA * (A_gt_0 * (L-1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] - dtA_neg = dtA * (1 - 2*w_gt_0) # [H N] - # S.sum(-1) == den / num - num = dtA_neg.exp() - 1 # [H N] - den = (dtA_neg * L).exp() - 1 # [H N] + dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] # Inline reciprocal function for DSS logic - x = den * w + x = den * A x_conj = _resolve_conj(x) r = x_conj / (x*x_conj + 1e-7) C = C * num * r # [C H N] K = contract('chn,hnl->chl', C, S).float() + else: assert False, f"{self.disc} not supported" + K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + if state is not None: + K_state = K[:-1, :, :, :] # (B C H L) + else: + K_state = None + K = K[-1, :, :, :] # (C H L) + return K, K_state - if self.quadrature == 'trapezoid': - w = torch.ones(*K.shape).to(K) * dt[None, :, None] - w[..., 0] /= 2. - w[..., -1] /= 2 - K = K * w - elif self.quadrature == 'simpson': - w = torch.ones(*K.shape).to(K) * dt[None, :, None] / 3. - w[..., 1:-1:2] *= 4 - w[..., 2:-1:2] *= 2 - K = K * w - - return K, None - - def setup_step(self): + def _setup_step(self): + # These methods are organized like this to be compatible with the NPLR kernel interface dt = torch.exp(self.log_dt) # (H) + B = _r2c(self.B) # (H N) C = _r2c(self.C) # (C H N) - w = self._w() # (H N) + self.dC = C + A = self._A() # (H N) # Incorporate dt into A - dtA = w * dt.unsqueeze(-1) # (H N) + dtA = A * dt.unsqueeze(-1) # (H N) if self.disc == 'zoh': self.dA = torch.exp(dtA) # (H N) - self.dC = C * (torch.exp(dtA)-1.) / w # (C H N) + self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) elif self.disc == 'bilinear': self.dA = (1. + dtA/2) / (1. - dtA/2) - self.dC = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / w - self.dB = self.dC.new_ones(self.H, self.N) # (H N) - - # if self.disc == 'zoh': - # # Power up - # K = dtA.unsqueeze(-1) * torch.arange(L, device=w.device) # (H N L) - # C = C * (torch.exp(dtA)-1.) / w - # K = contract('chn, hnl -> chl', C, torch.exp(K)) - # K = 2*K.real - # elif self.disc == 'bilinear': - # dA = (1. + dtA/2) / (1. - dtA/2) - # K = dA.unsqueeze(-1) ** torch.arange(L, device=w.device) # (H N L) - # C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / w - # K = contract('chn, hnl -> chl', C, K) - # K = 2*K.real + self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A + def default_state(self, *batch_shape): C = _r2c(self.C) @@ -861,22 +856,21 @@ def step(self, u, state): y = contract("c h n, b h n -> b c h", self.dC, next_state) return 2*y.real, next_state + def forward_state(self, u, state): + self._setup_step() + AL = self.dA ** u.size(-1) + u = u.flip(-1).to(self.dA).contiguous() # (B H L) + v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) + next_state = AL * state + v + return next_state - # @staticmethod - # def reciprocal(x, epsilon=1e-7, clamp=False): - # """ returns 1 / x, with bounded norm """ - # x_conj = x.conj() - # norm_sq = (x*x_conj).real.clamp(epsilon) if clamp else (x*x_conj + epsilon) - # return x_conj / norm_sq - - -class HippoSSKernel(nn.Module): - """Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments. +class SSKernel(nn.Module): + """Wrapper around SSKernel parameterizations. The SSKernel is expected to support the interface forward() default_state() - setup_step() + _setup_step() step() """ @@ -887,42 +881,47 @@ def __init__( L=None, measure="legs", rank=1, - channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" + channels=1, dt_min=0.001, dt_max=0.1, deterministic=False, - trainable=None, # Dictionary of options to train various HiPPO parameters - lr=None, # Hook to set LR of hippo parameters differently - lr_dt=None, - mode="nplr", # 'slow' for complex naive version, 'real' for real naive version - n_ssm=1, # Copies of the ODE parameters A and B. Must divide H - precision=1, # 1 (single) or 2 (double) for the kernel - resample=False, # If given inputs of different lengths, adjust the sampling rate. Note that L should always be provided in this case, as it assumes that L is the true underlying length of the continuous signal + lr=None, + mode="nplr", + n_ssm=None, verbose=False, - fast_gate=False, - diag_tilt=0.0, - rank_weight=1.0, measure_args={}, **kernel_args, ): + """State Space Kernel which computes the convolution kernel $\\bar{K}$ + + H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. + N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. + L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. + measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) + rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" + channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead + dt_min, dt_max: min and max values for the step size dt (\Delta) + mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing + n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H + lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. + """ super().__init__() self.N = N self.H = H - assert not (resample and L is None), "Cannot have sampling rate adjustment and no base length passed in" - self.precision = precision - dtype = torch.double if self.precision == 2 else torch.float - cdtype = torch.cfloat if dtype == torch.float else torch.cdouble - self.rate = None if resample else 1.0 + dtype, cdtype = torch.float, torch.cfloat self.channels = channels - self.n_ssm = n_ssm + self.n_ssm = n_ssm if n_ssm is not None else H + self.mode = mode + self.verbose = verbose + self.kernel_args = kernel_args # Generate dt - log_dt = torch.rand(self.H, dtype=dtype) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - - if fast_gate: - log_dt = torch.asinh(log_dt) + if deterministic: + log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) + else: + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) # Compute the preprocessed representation if mode == "real": # For testing and ablation purposes @@ -932,39 +931,26 @@ def __init__( B = torch.as_tensor(B, dtype=dtype)[:, 0] # Generate C - C = torch.randn(channels, self.H, self.N, dtype=dtype) + if deterministic: + C = torch.zeros(channels, self.H, self.N, dtype=dtype) + C[..., :1] = 1.0 + else: + C = torch.randn(channels, self.H, self.N, dtype=dtype) self.kernel = SSKernelSlow( A, B, C, log_dt, L=L, - trainable=trainable, lr=lr, + lr=lr, ) else: - # Generate low rank correction p for the measure - if measure == "random": - w, P, B, C, _ = hippo.random_dplr(self.N, rank=rank, H=n_ssm, dtype=dtype, **measure_args) - elif measure == 'hippo': - w0, P0, B0, C0, _ = hippo.nplr('legs', self.N, rank, dtype=dtype) - w1, P1, B1, C1, _ = hippo.nplr('fourier', self.N, rank, dtype=dtype) - w = torch.stack([w0, w1], dim=0) - P = torch.stack([P0, P1], dim=1) - B = torch.stack([B0, B1], dim=0) - C = torch.stack([C0, C1], dim=0) - else: - w, P, B, C, _ = hippo.nplr(measure, self.N, rank, dtype=dtype) - w = w.unsqueeze(0) # (s N), s is num SSM copies - P = P.unsqueeze(1) # (r s N) - B = B.unsqueeze(0) # (s N) - C = C.unsqueeze(0) # (H N) - - # Handle extra options - w = w - diag_tilt - P = P * rank_weight + w, P, B, V = dplr.combination(measure, self.N, rank, self.n_ssm, **measure_args) # Broadcast C to have H channels if deterministic: - C = repeat(C, 't n -> c (v t) n', c=channels, v=self.H // C.size(0)) + C = torch.zeros(channels, self.H, self.N, dtype=cdtype) + C[:, :, :1] = 1. + C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C else: - C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + C = torch.randn(channels, self.H, self.N//2, dtype=cdtype) # Broadcast other parameters to have n_ssm copies assert self.n_ssm % B.size(-2) == 0 \ @@ -975,314 +961,73 @@ def __init__( B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous() w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous() + C = C.contiguous() if mode == "nplr": self.kernel = SSKernelNPLR( w, P, B, C, log_dt, L=L, - trainable=trainable, - lr=lr, lr_dt=lr_dt, + lr=lr, verbose=verbose, - fast_gate=fast_gate, **kernel_args, ) elif mode == "diag": C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm) self.kernel = SSKernelDiag( - w, C, log_dt, L=L, - trainable=trainable, + w, B, C, log_dt, L=L, lr=lr, **kernel_args, ) - elif mode == "slow": # Testing only + elif mode == "slow": # Mainly for testing A = torch.diag_embed(_conj(w)) \ - contract("... r p, ... r q -> ... p q", _conj(P), _conj(P).conj()) self.kernel = SSKernelSlow( A, _conj(B), _conj(C), log_dt, L=L, - trainable=trainable, lr=lr, + lr=lr, ) else: raise NotImplementedError(f"{mode=} is not valid") - self.B = B - self.C = C - self.w = w - self.log_dt = log_dt - def forward(self, state=None, L=None): - k, k_state = self.kernel(state=state, rate=self.rate, L=L) - k_state = None if k_state is None else k_state.float() - return k.float(), k_state + def forward(self, state=None, L=None, rate=None): + return self.kernel(state=state, L=L, rate=rate) @torch.no_grad() def forward_state(self, u, state): """ Forward the state through a sequence, i.e. computes the state after passing chunk through SSM - state: (..., H, N) - u: (..., H, L) + state: (B, H, N) + u: (B, H, L) - Returns: (..., H, N) + Returns: (B, H, N) """ - self.kernel._setup_state() - dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) + if hasattr(self.kernel, "forward_state"): + return self.kernel.forward_state(u, state) + + dA, dB = self.kernel._setup_state() # Construct dA, dB matrices + # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) conj = state.size(-1) != dA.size(-1) if conj: state = _conj(state) - v = contract('h n, ... h l -> ... h n l', dB, u.flip(-1)) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) + v = contract('h n, b h l -> b h n l', dB, u.flip(-1)) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) AL, v = power(u.size(-1), dA, v) - next_state = contract("... m n, ... n -> ... m", AL, state) + next_state = contract("h m n, b h n -> b h m", AL, state) next_state = next_state + v if conj: next_state = next_state[..., : next_state.size(-1) // 2] return next_state - def setup_step(self): - self.kernel.setup_step() + def _setup_step(self, **kwargs): + # This method is intended to be private so that setting up an S4 module with + # ``` + # if hasattr(module, 'setup_step'): module.setup_step() + # ``` + # will not trigger this method multiple times + self.kernel._setup_step(**kwargs) def step(self, u, state, **kwargs): - u, state = self.kernel.step(u, state, **kwargs) - return u.float(), state + y, state = self.kernel.step(u, state, **kwargs) + return y, state def default_state(self, *args, **kwargs): return self.kernel.default_state(*args, **kwargs) - - - -""" Tests below """ - - -def generate_kernel(H, N, L, measure="legs", rank=1): - A, B = hippo.transition(measure, N) - A = torch.as_tensor(A, dtype=torch.float) - B = torch.as_tensor(B, dtype=torch.float)[:, 0] - # _C = torch.ones(1, H, N) - _C = torch.randn(1, H, N) - log_dt = torch.log((1 + 10 * torch.arange(H) / H) * 1 / L) - - # kernel slow real - kernel_real = SSKernelSlow(A, B, _C, log_dt, L=L) - kernel_real.to(device) - kernel_real.setup_step() - - # kernel slow complex - w, p, B, C, V = hippo.nplr(measure, N, rank=rank) - C = contract( - "ij, ... j -> ... i", V.conj().transpose(-1, -2), _C.to(V) - ) # V^* B - A = torch.diag_embed(_conj(w)) - contract( - "... r p, ... r q -> ... p q", _conj(p), _conj(p).conj() - ) - kernel_slow = SSKernelSlow(A, _conj(B), _conj(C), log_dt, L=L) - kernel_slow.to(device) - kernel_slow.setup_step() - - print("kernel real vs kernel complex", kernel_real()[0] - kernel_slow()[0]) - kernel = SSKernelNPLR(w.unsqueeze(0), p.unsqueeze(1), B.unsqueeze(0), C.unsqueeze(0).unsqueeze(0), log_dt, L=L, verbose=True) - kernel.to(device) # TODO need to add this line for custom CUDA kernel - kernel.setup_step() - kernel._check() - - print("kernel slow vs kernel fast", kernel_slow(L=L)[0] - kernel(L=L)[0]) - - # print(f"dA \nslow:\n{kernel_slow.dA}\nfast:\n{kernel.dA}") - # print("dC real slow fast:", kernel_real.dC, kernel_slow.dC, kernel.dC) - - return kernel_real.to(device), kernel_slow.to(device), kernel.to(device) - - -def benchmark_kernel(): - N = 64 - L = 4096 - H = 256 - - kernel_real, kernel_slow, kernel = generate_kernel(H, N, L) - - utils.compare_outputs(kernel_slow(), kernel(), full=False, relative=True) - - utils.benchmark_forward(100, kernel_slow, desc="kernel fft manual") - utils.benchmark_forward(100, kernel, desc="kernel fft rank") - utils.benchmark_backward(100, kernel_slow, desc="kernel fft manual") - utils.benchmark_backward(100, kernel, desc="kernel fft rank") - - utils.benchmark_memory(kernel_slow, desc="kernel fft manual") - utils.benchmark_memory(kernel, desc="kernel fft rank") - - -def test_step(diagonal=False, **kwargs): - B = 1 - L = 8 - N = 4 - H = 3 - - kernel_real, kernel_slow, kernel = generate_kernel(H, N, L, **kwargs) - - print("=====TESTING SLOW STEP=====") - kernel_slow.setup_step() - state = kernel_slow.default_state(B) - u = torch.ones(B, H, L).to(device) - ys = [] - for u_ in torch.unbind(u, dim=-1): - y_, state = kernel_slow.step(u_, state=state) - ys.append(y_) - print("state", state, state.shape) - y = torch.stack(ys, dim=-1) - print("y", y, y.shape) - - print("=======TESTING STEP=======") - kernel.setup_step(mode='dense') - state = kernel.default_state(B)# torch.zeros(B, H, N).to(device).to(torch.cfloat) - u = torch.ones(B, H, L).to(device) - ys = [] - for u_ in torch.unbind(u, dim=-1): - y_, state = kernel.step(u_, state=state) - ys.append(y_) - print("state", state, state.shape) - y = torch.stack(ys, dim=-1) - print("y", y, y.shape) - - print("=====TESTING LINEAR STEP=====") - kernel.setup_step(mode='linear') - state = kernel.default_state(B) - u = torch.ones(B, H, L).to(device) - ys = [] - for u_ in torch.unbind(u, dim=-1): - y_, state = kernel.step(u_, state=state) - ys.append(y_) - print("state", state, state.shape) - y = torch.stack(ys, dim=-1) - print("y", y, y.shape) - - if diagonal: - print("=====TESTING DIAGONAL STEP=====") - kernel.setup_step(mode='diagonal') - state = kernel.default_state(B) - u = torch.ones(B, H, L).to(device) - ys = [] - for u_ in torch.unbind(u, dim=-1): - y_, state = kernel.step(u_, state=state) - ys.append(y_) - print("state", state, state.shape) - y = torch.stack(ys, dim=-1) - print("y", y, y.shape) - - -@torch.inference_mode() -def benchmark_step(): - B = 1024 - L = 16 - N = 64 - H = 1024 - - _, _, kernel = generate_kernel(H, N, L) - kernel.setup_step() - - print("Benchmarking Step") - state = torch.zeros(B, H, N).to(device) - u = torch.ones(B, H).to(device) - utils.benchmark_forward(16, kernel.step, u, state, linear=False, desc="dense step") - - print("Benchmarking Linear Step") - state = torch.zeros(B, H, N).to(device) # .to(torch.cfloat) - u = torch.ones(B, H).to(device) - utils.benchmark_forward(16, kernel.step, u, state, linear=True, desc="linear step") - - state = torch.zeros(B, H, N // 2).to(device) # .to(torch.cfloat) - u = torch.ones(B, H).to(device) - utils.benchmark_forward( - 16, kernel.step, u, state, linear=True, desc="linear step conj" - ) - - -def test_double(): - # torch.set_printoptions(sci_mode=False, linewidth=160) - L = 8 - N = 4 - H = 3 - - _, kernel_slow, kernel = generate_kernel(H, N, L, "legs", 1) - - print("Testing Length Doubling") - print("=======================") - print("Original:") - k = kernel.forward()[0] - # print(k, k.shape) - kernel._check() - - print("Doubled:") - kernel.double_length() - k_ = kernel.forward()[0] - # print(k, k_.shape) - print("Doubling error:", torch.sum((k_[..., :k.size(-1)] - k)**2)) - - -def test_state(): - B = 1 - N = 4 - L = 4 - H = 3 - kernel_real, kernel_slow, kernel = generate_kernel(H, N, L) - - state = torch.ones(B, H, N // 2, device=device, dtype=torch.cfloat) - - k, k_state = kernel_slow.forward(state=state) - print("k slow", k) - print("k_state slow", k_state) - - k, k_state = kernel.forward(state=state) - print("k", k) - print("k_state", k_state) - - -def test_diag(): - bs = 1 - L = 8 - N = 4 - H = 3 - - w, P, B, C, _ = hippo.nplr('legs', N, 1) - w = w.unsqueeze(0) # (s N), s is num SSM copies - P = P.unsqueeze(1) # (r s N) - B = B.unsqueeze(0) # (s N) - # C = repeat(C, 'n -> 1 h n', h=H) - C = torch.randn(1, H, N//2, dtype=torch.cfloat) - # C = C.unsqueeze(0).unsqueeze(0) # (1 H N) - log_dt = torch.log((1 + 10 * torch.arange(H) / H) * 1 / L) - - kernel = SSKernelDiag(w, C, log_dt, L=L) - kernel.to(device) - kernel.setup_step() - - K, _ = kernel(L=8) - print(torch.cumsum(K, axis=-1)) - - - print("=====TESTING DIAG STEP=====") - kernel.setup_step() - state = kernel.default_state(bs) - u = torch.ones(bs, H, L).to(device) - ys = [] - for u_ in torch.unbind(u, dim=-1): - y_, state = kernel.step(u_, state=state) - ys.append(y_) - print("state", state, state.shape) - y = torch.stack(ys, dim=-1) - print("y", y, y.shape) - - -if __name__ == "__main__": - from benchmark import utils - - device = "cuda" # 'cpu' - device = torch.device(device) - - torch.set_printoptions(sci_mode=False, linewidth=160) - - has_cauchy_extension = False # turn off CUDA kernel for ease of testing, don't have to move between devices - - # generate_kernel(3, 4, 8, measure='legt', rank=2) - # benchmark_kernel() - # test_double() - # test_step(diagonal=True, measure='legt', rank=2) # TODO this needs to be fixed - # benchmark_step() - # test_state() - test_diag() diff --git a/src/models/sequence/ss/linear_system_recurrence.py b/src/models/sequence/ss/linear_system_recurrence.py index 2236690..50d4e84 100644 --- a/src/models/sequence/ss/linear_system_recurrence.py +++ b/src/models/sequence/ss/linear_system_recurrence.py @@ -387,104 +387,3 @@ def backward(ctx, dy): linearsystemstepsize = LinearSystemStepsizeFunction.apply - - -def _abs_err(x, y): - x_ = x.detach().cpu().numpy() - y_ = y.detach().cpu().numpy() - return (y_ - x_) / x_ - - -def test_linear_system(L, batch, dim, N, M, stepsize=False): - from models.hippo import transition # for testing - - # Define A, B, C, D - A = torch.eye(N) - B = torch.ones(N) - C = torch.ones(dim, M, N, requires_grad=True).to(device) - D = torch.ones(dim, M, requires_grad=True).to(device) - C.retain_grad() - D.retain_grad() - - # Create u and dt - u = torch.arange(L, dtype=torch.float, requires_grad=True).to(device) - u = u.unsqueeze(-1).unsqueeze(-1).repeat((1, batch, dim)) # (L, B, D) - u.retain_grad() - - dt = torch.ones(L, batch, dim) * 0.001 # for LegT - # dt = torch.ones_like(u, requires_grad=True).to(device) * 0.001 # for LegT - # dt = torch.ones_like(u, requires_grad=True).to(device) * 0.1 # for LagT - # dt.retain_grad() - - # Construct model - transition = transition.ManualAdaptiveTransition(N, A, B).to(device) - # transition = transition.ConstantBilinearTransition(N, A, B, dt[0]).to(device) - # transition = transition.LegTAdaptiveTransition(N).to(device) - # transition = transition.LagTCumsumAdaptiveTransition(N).to(device) - dt = dt.to(device) - if stepsize: - hippo = LinearSystemStepsize(N, transition, C, D) # .to(device) - dt.requires_grad_(True) - dt.retain_grad() - else: - hippo = LinearSystem(N, transition, C, D) # .to(device) - - # Autograd - if stepsize: - y, x = hippo.forward(dt, u) - else: - y, x = hippo.forward(dt, u) - x.retain_grad() - y.retain_grad() - z = y.sum() - z.backward(retain_graph=True) - # print(f"{y=}") - - # Manual adjoint - if stepsize: - du, ddt, dC, dD = hippo.adjoint(y.grad, x, dt, u) - print("du", u.grad, "\nerror", _abs_err(u.grad, du)) - print("ddt", dt.grad, "\nerror", _abs_err(dt.grad, ddt)) - print("dC", C.grad, "\nerror", _abs_err(C.grad, dC)) - print("dD", D.grad, "\nerror", _abs_err(D.grad, dD)) - - print("Function vs Module abs error") - u.grad.zero_() - dt.grad.zero_() - C.grad.zero_() - D.grad.zero_() - y_ = linearsystemstepsize(None, dt, u, C, D, transition) - print(f"y", y_ - y) - y_.sum().backward() - print("du", u.grad - du) - print("ddt", dt.grad - ddt) - print("dC", C.grad - dC) - print("dD", D.grad - dD) - else: - du = hippo.adjoint_input(y.grad, dt) - dC, dD = hippo.adjoint_projection(y.grad, dt, u) - print("du", u.grad, "\nerror", _abs_err(u.grad, du)) - print("dC", C.grad, "\nerror", _abs_err(C.grad, dC)) - print("dD", D.grad, "\nerror", _abs_err(D.grad, dD)) - - print("Function vs Module abs error") - u.grad.zero_() - C.grad.zero_() - D.grad.zero_() - y_ = linearsystem(None, dt, u, C, D, transition) - print(f"y", y_ - y) - y_.sum().backward() - print("du", u.grad - du) - print("dC", C.grad - dC) - print("dD", D.grad - dD) - - -if __name__ == "__main__": - L = 8 - B = 1 - D = 2 - N = 8 - M = 1 - - test_linear_system(L, B, D, N, M, False) - # test_linear_system(L, B, D, N, M, True) diff --git a/src/models/sequence/ss/lssl.py b/src/models/sequence/ss/lssl.py index 30b6384..6659fd7 100644 --- a/src/models/sequence/ss/lssl.py +++ b/src/models/sequence/ss/lssl.py @@ -1,4 +1,4 @@ -""" Implementation of LSSL module. Succeeded by S3. """ +"""Implementation of LSSL module. Succeeded by S4.""" import math import torch @@ -11,7 +11,7 @@ from src.models.functional.krylov import krylov from src.models.hippo import transition, hippo from src.models.functional.toeplitz import causal_convolution -from src.models.sequence.base import SequenceModule +from src.models.sequence.base import SequenceModule, TransposedModule import src.models.nn.utils as U def linear_system_from_krylov(u, C, D, k): @@ -303,7 +303,6 @@ def forward(self, u, *args, state=None, **kwargs): if self.ff: y = self.output_linear(y) # (L, B, H) - # y = self.drop(y) # moved to residual y = y.transpose(0, 1) # Back to (B, L, H) as expected return y, next_state @@ -338,21 +337,4 @@ def d_output(self): def state_to_tensor(self): return lambda state: state -LSSL = U.Transpose(Platypus) - -if __name__ == '__main__': - device = torch.device('cuda') - - N = 8 - B = 1 - d = 5 - L = 10 - u = torch.randn(L, B, d).to(device) - measure = 'identity' - # measure = 'legt' - dt_min = 1e-3 - dt_max = 1e0 - - platypus = Platypus(d, N, measure=measure, init='constant').to(device) - y, _ = platypus(u) - print(y, y.shape) +LSSL = TransposedModule(Platypus) diff --git a/src/models/sequence/ss/lssl_recurrent.py b/src/models/sequence/ss/lssl_recurrent.py index d015487..d974a5d 100644 --- a/src/models/sequence/ss/lssl_recurrent.py +++ b/src/models/sequence/ss/lssl_recurrent.py @@ -103,20 +103,3 @@ def d_output(self): @property def state_to_tensor(self): return lambda state: state - -if __name__ == '__main__': - device = torch.device('cuda') - - N = 8 - B = 1 - d = 5 - L = 10 - u = torch.randn(L, B, d).to(device) - measure = 'identity' - # measure = 'legt' - dt_min = 1e-3 - dt_max = 1e0 - - hippo = RecurrentLSSL(d, N, measure=measure, dt_min=dt_min, dt_max=dt_max, init='constant').to(device) - y, _ = hippo(u) - print(y, y.shape) diff --git a/src/models/sequence/ss/s4.py b/src/models/sequence/ss/s4.py index 62ac9c5..30ef38e 100644 --- a/src/models/sequence/ss/s4.py +++ b/src/models/sequence/ss/s4.py @@ -1,23 +1,11 @@ -if __name__ == "__main__": - import sys - import pathlib - - p = pathlib.Path().absolute() - print("Adding path: ", p) - sys.path.append(str(p)) - -import math import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.utils as U +from functools import partial from einops import rearrange, repeat -from omegaconf import DictConfig import opt_einsum as oe -import numpy as np -import itertools -import time -import math + optimized = True if optimized: @@ -25,45 +13,57 @@ else: contract = torch.einsum -from src.models.sequence.ss.kernel import HippoSSKernel, _conj -from src.models.nn import LinearActivation, Activation, Normalization - +from src.models.sequence.ss.kernel import SSKernel +from src.models.nn import LinearActivation, Activation, DropoutNd class S4(nn.Module): - requires_length = True - def __init__( - self, - d_model, - d_state=64, - l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer than sequence. However, this can be marginally slower if the true length is not a power of 2 - channels=1, # maps 1-dim to C-dim - bidirectional=False, - # Arguments for FF - activation="gelu", # activation in between SS and FF - ln=False, # Extra normalization - postact=None, # activation after FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - hyper_act=None, # Use a "hypernetwork" multiplication - dropout=0.0, - transposed=True, # axis ordering (B, L, D) or (B, D, L) - verbose=False, - shift=False, - linear=False, - liquid=0, - allcombs=True, - # SSM Kernel arguments - **kernel_args, - ): + self, + d_model, + d_state=64, + l_max=None, + channels=1, + bidirectional=False, + # Arguments for position-wise feedforward components + activation='gelu', + postact='glu', + initializer=None, + weight_norm=False, + hyper_act=None, + dropout=0.0, tie_dropout=False, + bottleneck=None, + gate=None, + transposed=True, + verbose=False, + shift=False, + linear=False, + # SSM Kernel arguments + **kernel_args, + ): """ d_state: the dimension of the state, also denoted by N - l_max: the maximum sequence length, also denoted by L - if this is not known at model creation, set l_max=1 - channels: can be interpreted as a number of "heads" - bidirectional: bidirectional - dropout: standard dropout argument - transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension] + l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel + channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models + bidirectional: if True, convolution kernel will be two-sided + + Position-wise feedforward components: + -------------------- + activation: activation in between SS and FF + postact: activation after FF + initializer: initializer on FF + weight_norm: weight normalization on FF + hyper_act: use a "hypernetwork" multiplication (experimental) + dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d + + Other arguments: + -------------------- + transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] + gate: add gated activation (GSS) + bottleneck: reduce SSM dimension (GSS) + shift: experimental option, shouldn't affect results + linear: Remove pointwise components so that the entire module is a linear SSM + + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" Other options are all experimental and should not need to be configured """ @@ -71,30 +71,53 @@ def __init__( super().__init__() if verbose: import src.utils.train - log = src.utils.train.get_logger(__name__) log.info(f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})") - log = src.utils.train.get_logger(__name__) - if liquid == 1: - raise ValueError("Illegal argument (liquid=1). Valid options are 0 (vanilla S4) and n>1 (liquid S4 with up to n-term)") - if liquid >= 1: - log.info(f"Constructing liquid-S4 with degree={liquid}, allcombs={allcombs}") - else: - log.info( - f"Using plain S4 (to enable liquid-S4 run with model.layer.liquid=1 argument)" - ) - - self.h = d_model - self.n = d_state + self.d_model = d_model + self.H = d_model + self.N = d_state + self.L = l_max self.bidirectional = bidirectional - self.ln = ln self.channels = channels self.transposed = transposed self.shift = shift self.linear = linear - self.liquid = liquid - self.allcombs = allcombs + + self.gate = gate + self.bottleneck = bottleneck + + if bottleneck is not None: + self.H = self.H // bottleneck + self.input_linear = LinearActivation( + self.d_model, + self.H, + transposed=self.transposed, + initializer=initializer, + activation=activation, + activate=True, + weight_norm=weight_norm, + ) + + if gate is not None: + self.input_gate = LinearActivation( + self.d_model, + self.d_model * gate, + transposed=self.transposed, + initializer=initializer, + activation=activation, + activate=True, + weight_norm=weight_norm, + ) + self.output_gate = LinearActivation( + self.d_model * gate, + self.d_model, + transposed=self.transposed, + initializer=initializer, + activation=None, + activate=False, + weight_norm=weight_norm, + ) # optional multiplicative modulation GLU-style # https://arxiv.org/abs/2002.05202 @@ -103,171 +126,123 @@ def __init__( channels *= 2 self.hyper_activation = Activation(hyper_act) - self.D = nn.Parameter(torch.randn(channels, self.h)) + self.D = nn.Parameter(torch.randn(channels, self.H)) if self.bidirectional: channels *= 2 + # SSM Kernel - self.kernel = HippoSSKernel( - self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args - ) + self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=channels, verbose=verbose, **kernel_args) # Pointwise if not self.linear: self.activation = Activation(activation) - dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout + # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11 + dropout_fn = DropoutNd if tie_dropout else nn.Dropout self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() - if self.ln: - self.norm = Normalization(self.h * self.channels, transposed=transposed) - else: - self.norm = nn.Identity() - # position-wise output transform to mix features if not self.linear: self.output_linear = LinearActivation( - self.h * self.channels, - self.h, + self.H*self.channels, + self.d_model*(1 if self.gate is None else self.gate), transposed=self.transposed, initializer=initializer, activation=postact, activate=True, weight_norm=weight_norm, ) - self._allcombs_index_cache = None - def forward( - self, u, state=None, **kwargs - ): # absorbs return_output and transformer src mask + + + def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask """ u: (B H L) if self.transposed else (B L H) state: (H N) never needed unless you know what you're doing Returns: same shape as u """ - if not self.transposed: - u = u.transpose(-1, -2) + if not self.transposed: u = u.transpose(-1, -2) L = u.size(-1) + # Mask out padding tokens + # TODO handle option for mask - instead of lengths, which assumes suffix padding + if isinstance(lengths, int): + if lengths != L: + lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) + else: + lengths = None + if lengths is not None: + assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)] + mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.) + u = u * mask + + if self.gate is not None: + v = self.input_gate(u) + if self.bottleneck is not None: + u = self.input_linear(u) + # Compute SS Kernel - k, k_state = self.kernel(L=L, state=state) # (C H L) (B C H L) + L_kernel = L if self.L is None else min(L, round(self.L / rate)) + k, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L) # Convolution if self.bidirectional: - k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) - k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) \ + if self.shift: # Try flip and pad to correct for potential off-by-one - k_f = torch.fft.rfft(F.pad(k.flip(-1), (L, 0)), n=2 * L) # (C H L) - u_f = torch.fft.rfft(F.pad(u.flip(-1), (L, 0)), n=2 * L) # (B H L) - y_f = contract( - "bhl,chl->bchl", u_f, k_f - ) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) - y = torch.fft.irfft(y_f, n=2 * L)[..., L:].flip(-1) # (B C H L) + k_f = torch.fft.rfft(F.pad(k.flip(-1), (L, 0)), n=2*L) # (C H L) + u_f = torch.fft.rfft(F.pad(u.flip(-1), (L, 0)), n=2*L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) + y = torch.fft.irfft(y_f, n=L_kernel+L)[..., L:].flip(-1) # (B C H L) else: - k_f = torch.fft.rfft(k, n=2 * L) # (C H L) - u_f = torch.fft.rfft(u, n=2 * L) # (B H L) - y_f = contract( - "bhl,chl->bchl", u_f, k_f - ) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) - y = torch.fft.irfft(y_f, n=2 * L)[..., :L] # (B C H L) + k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L) + u_f = torch.fft.rfft(u, n=L_kernel+L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) + y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L) - # Compute D term in state space equation - essentially a skip connection - y = y + contract( - "bhl,ch->bchl", u, self.D - ) # u.unsqueeze(-3) * self.D.unsqueeze(-1) - - seq_len = int(y.size(-1)) - if self._allcombs_index_cache is None: - self._allcombs_index_cache = [] - for p in range(2,self.liquid+1): - selected_count = 1 - for n in range(2,seq_len): - count = math.comb(n,p) - if count >= seq_len: - selected_count = n - break - indices = range(seq_len-selected_count,seq_len) - indices = list(itertools.combinations(indices, p)) - # print(f"p={p}, seq_len={seq_len}, selected_count={selected_count}",) - # print(f"{len(indices)=}") - if len(indices) != seq_len: - # select exactly amount to match sequence length dimension - indices = indices[-seq_len:] - indices = torch.LongTensor(indices) - self._allcombs_index_cache.append((p,indices)) - - dt = torch.exp(self.kernel.log_dt.to(u.device)) - B = _conj(self.kernel.B).to(u.device) - dC = _conj(self.kernel.C).to(u.device) - w = _conj(self.kernel.w).to(u.device) - dB = torch.diag_embed(1.0 / (1.0 - 0.5 * dt[:, None] * w)) # (256,64,64) - - dB = dt[:, None] * contract("dab,db->da", dB, B) - us = u - for i in range(self.liquid-1): - # print(f"[Liquid={self.liquid}] Generating degree {i+1} input polynomial") - if self.allcombs: - p,indices = self._allcombs_index_cache[i] - us = u[..., indices[:, 0]] - for j in range(1,p): - us = us*u[..., indices[:, j]] - if us.size(-1) != u.size(-1): - us = F.pad(us, (0, u.size(-1) - us.size(-1))) - else: - us_shift = torch.nn.functional.pad(us[..., :-1], (1, 0), "constant", 0) - us = us * us_shift - dB1 = dB.unsqueeze(2) - dB2 = dB.unsqueeze(1) - dB = (dB1 * dB2).sum(2) - dCB = contract("abc,bc->ab", dC, dB).unsqueeze(2) - if self.bidirectional: - fwd, bwd = dCB.unbind(0) - fwd, bwd = fwd.unsqueeze(0), bwd.unsqueeze(0) - y = ( - y - + (us * fwd).unsqueeze(1).float() - + (us.flip(2) * bwd).unsqueeze(1).float() - ) - else: - y = y + (us * dCB).unsqueeze(1).float() + + # Compute D term in state space equation - essentially a skip connection + y = y + contract('bhl,ch->bchl', u, self.D) # Compute state update if state is not None: - assert ( - not self.bidirectional - ), "Bidirectional not supported with state forwarding" - y = y + k_state + assert not self.bidirectional, "Bidirectional not supported with state forwarding" + y = y + k_state # next_state = self.kernel.forward_state(u, state) else: next_state = None # Optional hyper-network multiplication if self.hyper: - y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) + y, yh = rearrange(y, 'b (s c) h l -> s b c h l', s=2) y = self.hyper_activation(yh) * y # Reshape to flatten channels - y = rearrange(y, "... c h l -> ... (c h) l") + y = rearrange(y, '... c h l -> ... (c h) l') if not self.linear: y = self.dropout(self.activation(y)) - if not self.transposed: - y = y.transpose(-1, -2) + if not self.transposed: y = y.transpose(-1, -2) if not self.linear: - y = self.norm(y) y = self.output_linear(y) + if self.gate is not None: + y = self.output_gate(y * v) + return y, next_state - def setup_step(self): - self.kernel.setup_step() + def setup_step(self, **kwargs): + self.kernel._setup_step(**kwargs) def step(self, u, state): - """Step one time step as a recurrent model. Intended to be used during validation. + """ Step one time step as a recurrent model. Intended to be used during validation. u: (B H) state: (B H N) @@ -275,9 +250,9 @@ def step(self, u, state): """ assert not self.training - y, next_state = self.kernel.step(u, state) # (B C H) + y, next_state = self.kernel.step(u, state) # (B C H) y = y + u.unsqueeze(-2) * self.D - y = rearrange(y, "... c h -> ... (c h)") + y = rearrange(y, 'b c h -> b (c h)') y = self.activation(y) if self.transposed: y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) @@ -286,83 +261,18 @@ def step(self, u, state): return y, next_state def default_state(self, *batch_shape, device=None): + # kernel is not a SequenceModule so it doesn't need to adhere to same interface + # the kernel will know the device of its own parameters return self.kernel.default_state(*batch_shape) @property def d_state(self): - return self.h * self.n + return self.H * self.N @property def d_output(self): - return self.h + return self.d_model @property def state_to_tensor(self): - return lambda state: rearrange("... h n -> ... (h n)", state) - - -def test_state(random_init=False, **kwargs): - # B = 1 - # H = 64 - # N = 64 - # L = 1024 - B = 2 - H = 3 - N = 4 - L = 8 - s4 = S4(H, d_state=N, l_max=L, **kwargs) - s4.to(device) - s4.eval() - for module in s4.modules(): - if hasattr(module, "setup_step"): - module.setup_step() - - u = torch.ones(B, H, L).to(device) - initial_state = s4.default_state(B) - if random_init: - if initial_state.size(-1) == N: - initial_state = initial_state[..., : N // 2] - initial_state = torch.randn_like(initial_state) - initial_state = torch.cat([initial_state, initial_state.conj()], dim=-1) - - state = initial_state.clone() - y, final_state = s4(u, state=state) - print("output:\n", y, y.shape) - print("final state:\n", final_state, final_state.shape) - - # Use Stepping - state = initial_state.clone() - ys = [] - for u_ in torch.unbind(u, dim=-1): - y_, state = s4.step(u_, state=state) - ys.append(y_) - ys = torch.stack(ys, dim=-1) - print("step outputs:\n", ys) - print("step final state:\n", state) - - # Use Chunking - - chunks = 4 - state = initial_state.clone() - ys = [] - for u_ in u.chunk(chunks, dim=-1): - y_, state = s4(u_, state=state) - ys.append(y_) - ys = torch.cat(ys, dim=-1) - print("chunk outputs:\n", ys) - print("chunk final state:\n", state) - print("chunk output error:") - utils.compare_outputs(y, ys) - print("chunk final state error:") - utils.compare_outputs(final_state, state) - - -if __name__ == "__main__": - from benchmark import utils - - torch.manual_seed(42) - - device = "cuda" # 'cpu' - device = torch.device(device) - - test_state(random_init=True, mode="nplr", measure="legt", rank=2) \ No newline at end of file + return lambda state: rearrange('... h n -> ... (h n)', state) diff --git a/src/models/sequence/unet.py b/src/models/sequence/unet.py index 5fd69f2..a0e92ee 100644 --- a/src/models/sequence/unet.py +++ b/src/models/sequence/unet.py @@ -17,7 +17,6 @@ from src.models.sequence.block import SequenceResidualBlock - class SequenceUNet(SequenceModule): """ layer is a Namespace that specifies '_name_', referring to a constructor, and a list of arguments to that layer constructor. This layer must subscribe to the interface (i) takes a hidden dimension H and sequence length L (ii) forward pass transforms input sequence of shape (B, H, L) to output (B, H, L) @@ -35,6 +34,7 @@ def __init__( initializer=None, l_max=-1, transposed=True, + act_pool=None, ): super().__init__() self.d_model = d_model @@ -67,6 +67,7 @@ def _residual(d, i, layer): i, # temporary placeholder for i_layer prenorm=prenorm, dropout=dropres, + transposed=self.transposed, layer=layer, residual=residual if residual is not None else 'R', norm=norm, @@ -81,7 +82,7 @@ def _residual(d, i, layer): if ff > 0: d_layers.append(_residual(H, i+1, ff_cfg)) # Add sequence downsampling and feature expanding - d_layers.append(DownPool(H, H*expand, pool=p, transposed=self.transposed)) # TODO take expansion argument instead + d_layers.append(DownPool(H, H*expand, stride=p, transposed=self.transposed, activation=act_pool)) L //= p layer_cfg['l_max'] = L H *= expand @@ -100,7 +101,7 @@ def _residual(d, i, layer): H //= expand L *= p layer_cfg['l_max'] = L - u_layers.append(UpPool(H*expand, H, pool=p, transposed=self.transposed)) # TODO + u_layers.append(UpPool(H*expand, H, stride=p, transposed=self.transposed, activation=act_pool)) for i in range(n_layers): u_layers.append(_residual(H, i+1, layer_cfg)) @@ -111,14 +112,11 @@ def _residual(d, i, layer): self.norm = nn.LayerNorm(H) - # @property - # def transposed(self): - # return len(self.d_layers) > 0 and self.d_layers[0].transposed @property def d_output(self): return self.d_model - def forward(self, x, state=None): + def forward(self, x, state=None, **kwargs): """ input: (batch, length, d_input) output: (batch, length, d_output) @@ -199,52 +197,3 @@ def cache_all(self): for layer in modules: if hasattr(layer, 'cache_all'): layer.cache_all() -def prepare_generation(model): - model.eval() - if hasattr(model, 'cache_all'): model.cache_all() - -@torch.inference_mode() -def generate_recurrent(model, batch_size=None, x=None): - from src.tasks.mixture import mixture_sample -# TODO incorporate normalization function for dataset -# TODO handle or document non-mixture case - """ generate remaining L-L' samples given x: (B, L', C) a context for the model """ - - if x is None: - assert batch_size is not None - x = torch.zeros(batch_size, model.d_model, device=device) - state = model.default_state(batch_size, device=device) - else: raise NotImplementedError("Conditional generation not implemented yet") - - xs = [] - for i in range(model.L): - print("pixel", i) - x, state = model.step(x, state) - x = mixture_sample(x) - # TODO postprocess: clamp, divide into buckets, renormalize - x = x.unsqueeze(-1) - xs.append(x) - sample = torch.stack(xs, dim=1) - print("recurrent sample shape", sample.shape) - -@torch.no_grad() -def generate_global(model, batch_size=None, x=None, length=None): - from tasks.mixture import mixture_sample - """ generate remaining L-L' samples given x: (B, L', C) a context for the model """ - - if x is None: - assert batch_size is not None - x = torch.zeros(batch_size, model.L, model.d_input, device=device) - else: raise NotImplementedError("Conditional generation not implemented yet") - - if length is None: length = model.L - for i in range(length): - print("pixel", i) - y = model(x) - y = torch.cat([y, y.new_zeros(batch_size, 1, model.d_output)], dim=1) # TODO handle sequence shape properly - z = mixture_sample(y[:, i, :]) - # TODO postprocess: clamp, divide into buckets, renormalize - z = z.unsqueeze(-1) - x[:, i, :] = z - print("global sample shape", x.shape) - diff --git a/src/tasks/decoders.py b/src/tasks/decoders.py index 37e8046..175d603 100644 --- a/src/tasks/decoders.py +++ b/src/tasks/decoders.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, reduce import src.models.nn.utils as U import src.utils as utils @@ -10,13 +10,15 @@ log = src.utils.train.get_logger(__name__) + class Decoder(nn.Module): - """ This class doesn't do much but just signals the interface that Decoders are expected to adhere to + """This class doesn't do much but just signals the interface that Decoders are expected to adhere to TODO: is there a way to enforce the signature of the forward method? """ - def forward(x, y, state, *args, **kwargs): + + def forward(self, x, **kwargs): """ - x: input tensor + x: (batch, length, dim) input tensor state: additional state from the model backbone *args, **kwargs: additional info from the dataset @@ -24,10 +26,19 @@ def forward(x, y, state, *args, **kwargs): y: output tensor *args: other arguments to pass into the loss function """ - return x, + return x + + def step(self, x): + """ + x: (batch, dim) + """ + return self.forward(x.unsqueeze(1)).squeeze(1) + class SequenceDecoder(Decoder): - def __init__(self, d_model, d_output=None, l_output=None, use_lengths=False, mode='last'): + def __init__( + self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last" + ): super().__init__() self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) @@ -50,58 +61,71 @@ def __init__(self, d_model, d_output=None, l_output=None, use_lengths=False, mod if mode == 'ragged': assert not use_lengths - def forward(self, x, state, l_batch=None, *args, **kwargs): + def forward(self, x, state=None, lengths=None, l_output=None): """ x: (n_batch, l_seq, d_model) Returns: (n_batch, l_output, d_output) """ - if self.l_output is None: - if isinstance(l_batch, int): # Override by pass in - l_output = l_batch + if l_output is not None: + assert isinstance(l_output, int) # Override by pass in else: - # Grab entire output + # Grab entire output l_output = x.size(-2) squeeze = False else: l_output = self.l_output squeeze = self.squeeze - if self.mode == 'last': + if self.mode == "last": restrict = lambda x: x[..., -l_output:, :] - elif self.mode == 'first': + elif self.mode == "first": restrict = lambda x: x[..., :l_output, :] - elif self.mode == 'pool': - restrict = lambda x: (torch.cumsum(x, dim=-2) / torch.arange(1, 1+x.size(-2), device=x.device, dtype=x.dtype).unsqueeze(-1))[..., -l_output:, :] + elif self.mode == "pool": + restrict = lambda x: ( + torch.cumsum(x, dim=-2) + / torch.arange( + 1, 1 + x.size(-2), device=x.device, dtype=x.dtype + ).unsqueeze(-1) + )[..., -l_output:, :] + def restrict(x): L = x.size(-2) s = x.sum(dim=-2, keepdim=True) if l_output > 1: - c = torch.cumsum(x[..., -(l_output-1):, :].flip(-2), dim=-2) + c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2) c = F.pad(c, (0, 0, 1, 0)) - s = s - c # (B, l_output, D) + s = s - c # (B, l_output, D) s = s.flip(-2) - denom = torch.arange(L-l_output+1, L+1, dtype=x.dtype, device=x.device) + denom = torch.arange( + L - l_output + 1, L + 1, dtype=x.dtype, device=x.device + ) s = s / denom return s - elif self.mode == 'sum': + + elif self.mode == "sum": restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :] - # TODO use same restrict function as pool case\ + # TODO use same restrict function as pool case elif self.mode == 'ragged': - assert l_batch is not None, "l_batch must be provided for ragged mode" + assert lengths is not None, "lengths must be provided for ragged mode" # remove any additional padding (beyond max length of any sequence in the batch) - restrict = lambda x: x[..., : max(l_batch), :] - else: raise NotImplementedError("Mode must be ['last' | 'first' | 'pool' | 'sum']") + restrict = lambda x: x[..., : max(lengths), :] + else: + raise NotImplementedError( + "Mode must be ['last' | 'first' | 'pool' | 'sum']" + ) # Restrict to actual length of sequence if self.use_lengths: - assert l_batch is not None - x = torch.stack([ - restrict(out[..., :length, :]) - for out, length - in zip(torch.unbind(x, dim=0), l_batch) - ], dim=0) + assert lengths is not None + x = torch.stack( + [ + restrict(out[..., :length, :]) + for out, length in zip(torch.unbind(x, dim=0), lengths) + ], + dim=0, + ) else: x = restrict(x) @@ -110,76 +134,126 @@ def restrict(x): x = x.squeeze(-2) x = self.output_transform(x) - return x, + + return x + + def step(self, x, state=None): + # Ignore all length logic + return self.output_transform(x) + +class NDDecoder(Decoder): + """Decoder for single target (e.g. classification or regression)""" + def __init__( + self, d_model, d_output=None, mode="pool" + ): + super().__init__() + + assert mode in ["pool", "full"] + self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) + + self.mode = mode + + def forward(self, x, state=None): + """ + x: (n_batch, l_seq, d_model) + Returns: (n_batch, l_output, d_output) + """ + + if self.mode == 'pool': + x = reduce(x, 'b ... h -> b h', 'mean') + x = self.output_transform(x) + return x class StateDecoder(Decoder): - """ Use the output state to decode (useful for stateful models such as RNNs or perhaps Transformer-XL if it gets implemented """ + """Use the output state to decode (useful for stateful models such as RNNs or perhaps Transformer-XL if it gets implemented""" + def __init__(self, d_model, state_to_tensor, d_output): super().__init__() self.output_transform = nn.Linear(d_model, d_output) self.state_transform = state_to_tensor - def forward(self, x, state, *args, **kwargs): - return self.output_transform(self.state_transform(state)), + def forward(self, x, state=None): + return self.output_transform(self.state_transform(state)) + class RetrievalHead(nn.Module): - def __init__(self, d_input, d_model, n_classes, nli=True, activation='relu'): + def __init__(self, d_input, d_model, n_classes, nli=True, activation="relu"): super().__init__() self.nli = nli - if activation == 'relu': + if activation == "relu": activation_fn = nn.ReLU() - elif activation == 'gelu': + elif activation == "gelu": activation_fn = nn.GELU() - else: raise NotImplementedError + else: + raise NotImplementedError - if self.nli: # Architecture from https://github.com/mlpen/Nystromformer/blob/6539b895fa5f798ea0509d19f336d4be787b5708/reorganized_code/LRA/model_wrapper.py#L74 + if ( + self.nli + ): # Architecture from https://github.com/mlpen/Nystromformer/blob/6539b895fa5f798ea0509d19f336d4be787b5708/reorganized_code/LRA/model_wrapper.py#L74 self.classifier = nn.Sequential( - nn.Linear(4*d_input, d_model), + nn.Linear(4 * d_input, d_model), activation_fn, nn.Linear(d_model, n_classes), ) - else: # Head from https://github.com/google-research/long-range-arena/blob/ad0ff01a5b3492ade621553a1caae383b347e0c1/lra_benchmarks/models/layers/common_layers.py#L232 + else: # Head from https://github.com/google-research/long-range-arena/blob/ad0ff01a5b3492ade621553a1caae383b347e0c1/lra_benchmarks/models/layers/common_layers.py#L232 self.classifier = nn.Sequential( - nn.Linear(2*d_input, d_model), + nn.Linear(2 * d_input, d_model), activation_fn, nn.Linear(d_model, d_model // 2), activation_fn, nn.Linear(d_model // 2, n_classes), ) - def forward(self, x): # , state, *args, **kwargs): + def forward(self, x): """ x: (2*batch, dim) """ - outs = rearrange(x, '(z b) d -> z b d', z=2) - outs0, outs1 = outs[0], outs[1] # (n_batch, d_input) + outs = rearrange(x, "(z b) d -> z b d", z=2) + outs0, outs1 = outs[0], outs[1] # (n_batch, d_input) if self.nli: - features = torch.cat([outs0, outs1, outs0-outs1, outs0*outs1], dim=-1) # (batch, dim) + features = torch.cat( + [outs0, outs1, outs0 - outs1, outs0 * outs1], dim=-1 + ) # (batch, dim) else: - features = torch.cat([outs0, outs1], dim=-1) # (batch, dim) + features = torch.cat([outs0, outs1], dim=-1) # (batch, dim) logits = self.classifier(features) return logits + class RetrievalDecoder(Decoder): - """ Combines the standard FeatureDecoder to extract a feature before passing through the RetrievalHead """ - def __init__(self, d_input, n_classes, d_model=None, nli=True, activation='relu', *args, **kwargs): + """Combines the standard FeatureDecoder to extract a feature before passing through the RetrievalHead""" + + def __init__( + self, + d_input, + n_classes, + d_model=None, + nli=True, + activation="relu", + *args, + **kwargs + ): super().__init__() - if d_model is None: d_model = d_input - # self.feature = FeatureDecoder(d_input, None, *args, **kwargs) - self.feature = SequenceDecoder(d_input, d_output=None, l_output=0, *args, **kwargs) - self.retrieval = RetrievalHead(d_input, d_model, n_classes, nli=nli, activation=activation) - - def forward(self, x, state, *args, **kwargs): - x, = self.feature(x, state, *args, **kwargs) + if d_model is None: + d_model = d_input + self.feature = SequenceDecoder( + d_input, d_output=None, l_output=0, *args, **kwargs + ) + self.retrieval = RetrievalHead( + d_input, d_model, n_classes, nli=nli, activation=activation + ) + + def forward(self, x, state=None, **kwargs): + x = self.feature(x, state=state, **kwargs) x = self.retrieval(x) - return x, + return x class PackedDecoder(Decoder): - def forward(self, x, state, *args, **kwargs): - x, _ = nn.utils.rnn.pad_packed_sequence(x - , batch_first=True) - return x, + def forward(self, x, state=None): + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + return x + # For every type of encoder/decoder, specify: # - constructor class @@ -187,44 +261,59 @@ def forward(self, x, state, *args, **kwargs): # - list of attributes to grab from model registry = { - 'id': U.Identity, # U.TupleModule(nn.Identity), - 'sequence': SequenceDecoder, - 'retrieval': RetrievalDecoder, - 'state': StateDecoder, - 'pack': PackedDecoder, + "stop": Decoder, + "id": nn.Identity, + "linear": nn.Linear, + "sequence": SequenceDecoder, + "nd": NDDecoder, + "retrieval": RetrievalDecoder, + "state": StateDecoder, + "pack": PackedDecoder, } model_attrs = { - 'sequence': ['d_output'], - 'feature': ['d_output'], - 'retrieval': ['d_output'], - 'state': ['d_state', 'state_to_tensor'], + "linear": ["d_output"], + "sequence": ["d_output"], + "nd": ["d_output"], + "retrieval": ["d_output"], + "state": ["d_state", "state_to_tensor"], + "forecast": ["d_output"], } dataset_attrs = { - 'sequence': ['d_output', 'l_output'], - 'feature': ['d_output'], - 'retrieval': ['d_output'], - # TODO change d_output to n_classes? - 'state': ['d_output'], + "linear": ["d_output"], + "sequence": ["d_output", "l_output"], + "nd": ["d_output"], + "retrieval": ["d_output"], + "state": ["d_output"], + "forecast": ["d_output", "l_output"], } + def _instantiate(decoder, model=None, dataset=None): - """ Instantiate a single decoder """ - if decoder is None: return U.Identity() + """Instantiate a single decoder""" + if decoder is None: + return None - if isinstance(decoder, str): name = decoder - else: name = decoder['_name_'] + if isinstance(decoder, str): + name = decoder + else: + name = decoder["_name_"] # Extract arguments from attribute names - dataset_args = utils.config.extract_attrs_from_obj(dataset, *dataset_attrs.get(name, [])) + dataset_args = utils.config.extract_attrs_from_obj( + dataset, *dataset_attrs.get(name, []) + ) model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) # Instantiate decoder obj = utils.instantiate(registry, decoder, *model_args, *dataset_args) return obj + def instantiate(decoder, model=None, dataset=None): - """ Instantiate a full decoder config, e.g. handle list of configs + """Instantiate a full decoder config, e.g. handle list of configs Note that arguments are added in reverse order compared to encoder (model first, then dataset) """ decoder = utils.to_list(decoder) - return U.TupleSequential(*[_instantiate(d, model=model, dataset=dataset) for d in decoder]) + return U.PassthroughSequential( + *[_instantiate(d, model=model, dataset=dataset) for d in decoder] + ) diff --git a/src/tasks/encoders.py b/src/tasks/encoders.py index 84a45c3..8ff0ba1 100644 --- a/src/tasks/encoders.py +++ b/src/tasks/encoders.py @@ -1,23 +1,27 @@ +import datetime import math from typing import ForwardRef import torch from torch import nn import torch.nn.functional as F +from einops import rearrange import src.models.nn.utils as U import src.utils as utils import src.utils.config - +from src.models.sequence.block import SequenceResidualBlock +from src.models.nn.components import Normalization class Encoder(nn.Module): - """This class doesn't do much but just signals the interface that Encoder are expected to adhere to - TODO: is there a way to enforce the signature of the forward method? + """Encoder abstraction + Accepts a tensor and optional kwargs. Outside of the main tensor, all other arguments should be kwargs. + Returns a tensor and optional kwargs. + Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting kwargs are accumulated and passed into the model backbone. - NOTE: all encoders return a *tuple* where the first argument is a tensor and the rest are additional parameters to be passed into the model """ - def forward(self, x, *args): + def forward(self, x, **kwargs): """ x: input tensor *args: additional info from the dataset (e.g. sequence lengths) @@ -26,7 +30,8 @@ def forward(self, x, *args): y: output tensor *args: other arguments to pass into the model backbone """ - return (x,) + return x, {} + # Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py @@ -47,7 +52,7 @@ class PositionalEncoder(Encoder): >>> pos_encoder = PositionalEncoder(d_model) """ - def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None, causal=True): + def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None): super().__init__() self.dropout = nn.Dropout(p=dropout) if pe_init is not None: @@ -62,12 +67,11 @@ def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None, causal=Tru ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - # pe = pe.unsqueeze(1) # Comment this out to handle (B, L, D) instead of (L, B, D) self.register_buffer("pe", pe) self.attn_mask = None - def forward(self, x, seq_len=None, *args, **kwargs): + def forward(self, x): r"""Inputs of forward function Args: x: the sequence fed to the positional encoder model (required). @@ -79,47 +83,23 @@ def forward(self, x, seq_len=None, *args, **kwargs): padding_mask: """ - # TODO currently not used, but maybe will be someday - # e.g. attn_mask is defined directly in each attention layer - - # if self.attn_mask is None or self.attn_mask.shape[-1] != x.size(-2): - # # self.attn_mask = TriangularCausalMask(len(src), device=src.device) - # self.attn_mask = torch.triu(torch.ones(x.size(-2), x.size(-2), - # dtype=torch.bool, device=x.device), - # diagonal=1) - - # padding_mask = None - # if seq_len is not None and seq_len < x.size(-2): - # padding_mask = LengthMask( - # torch.full( - # (x.size(-2),), - # seq_len, - # device=x.device, - # dtype=torch.long, - # ), - # max_len=x.size(-2), - # ) - # else: - # padding_mask = None - x = x + self.pe[: x.size(-2)] - # return self.dropout(x), self.attn_mask, padding_mask - return (self.dropout(x),) + return self.dropout(x) class ClassEmbedding(Encoder): - # Should also be able to define this by subclassing TupleModule(Embedding) + # Should also be able to define this by subclassing Embedding def __init__(self, n_classes, d_model): super().__init__() self.embedding = nn.Embedding(n_classes, d_model) - def forward(self, x, y, *args, **kwargs): + def forward(self, x, y): x = x + self.embedding(y).unsqueeze(-2) # (B, L, D) - return (x,) + return x class Conv1DEncoder(Encoder): - def __init__(self, d_input, d_model, kernel_size, stride, padding=0): + def __init__(self, d_input, d_model, kernel_size=25, stride=1, padding='same'): super().__init__() self.conv = nn.Conv1d( in_channels=d_input, @@ -129,10 +109,94 @@ def __init__(self, d_input, d_model, kernel_size, stride, padding=0): padding=padding, ) - def forward(self, x, *args): + def forward(self, x): # BLD -> BLD x = self.conv(x.transpose(1, 2)).transpose(1, 2) - return (x,) + return x + +class LayerEncoder(Encoder): + """Use an arbitary SequenceModule layer""" + + def __init__(self, d_model, prenorm=False, norm='layer', layer=None): + super().__init__() + + # Simple stack of blocks + layer["transposed"] = False + self.layer = SequenceResidualBlock( + d_input=d_model, + prenorm=prenorm, + layer=layer, + residual='R', + norm=norm, + pool=None, + ) + + def forward(self, x): + x, _ = self.layer(x) # Discard state + return x + + +class TimestampEmbeddingEncoder(Encoder): + """ + General time encoder for Pandas Timestamp objects (encoded as torch tensors). + See MonashDataset for an example of how to return time features as 'z's. + """ + + cardinalities = { + 'day': (1, 31), + 'hour': (0, 23), + 'minute': (0, 59), + 'second': (0, 59), + 'month': (1, 12), + 'year': (1950, 2010), # (1800, 3000) used to be (1970, datetime.datetime.now().year + 1) but was not enough for all datasets in monash + 'dayofweek': (0, 6), + 'dayofyear': (1, 366), + 'quarter': (1, 4), + 'week': (1, 53), + 'is_month_start': (0, 1), + 'is_month_end': (0, 1), + 'is_quarter_start': (0, 1), + 'is_quarter_end': (0, 1), + 'is_year_start': (0, 1), + 'is_year_end': (0, 1), + 'is_leap_year': (0, 1), + } + + def __init__(self, d_model, table=False, features=None): + super().__init__() + self.table = table + self.ranges = {k: max_val - min_val + 2 for k, (min_val, max_val) in self.cardinalities.items()} # padding for null included + + if features is None: + pass + else: + self.cardinalities = {k: v for k, v in self.cardinalities.items() if k in features} + + if table: + self.embedding = nn.ModuleDict({ + attr: nn.Embedding(maxval - minval + 2, d_model, padding_idx=0) + for attr, (minval, maxval) in self.cardinalities.items() + }) + else: + self.embedding = nn.ModuleDict({ + attr: nn.Linear(1, d_model) + for attr in self.cardinalities + }) + + + + def forward(self, x, timestamps=None): + for attr in timestamps: + mask = timestamps[attr] == -1 + timestamps[attr] = timestamps[attr] - self.cardinalities[attr][0] + timestamps[attr][mask] = 0 + if self.table: + x = x + self.embedding[attr](timestamps[attr].to(torch.long)) + else: + x = x + self.embedding[attr]((2 * timestamps[attr] / self.ranges[attr] - 1).unsqueeze(-1)) + + #x = x + self.embedding(timestamps[attr].to(torch.float)).unsqueeze(1) + return x class TimeEncoder(Encoder): @@ -148,7 +212,8 @@ def __init__(self, n_tokens_time, d_model, timeenc=0): self.encoders = nn.Linear(len(n_tokens_time), d_model) self.mask_embed = nn.Embedding(2, d_model) - def forward(self, x, mark, mask, *args, **kwargs): + def forward(self, x, mark=None, mask=None): + assert mark is not None and mask is not None, "Extra arguments should be returned by collate function" if self.timeenc == 0: assert mark.size(-1) == len(self.encoders) embeddings = [ @@ -158,19 +223,16 @@ def forward(self, x, mark, mask, *args, **kwargs): else: time_encode = self.encoders(mark) mask_encode = self.mask_embed(mask.squeeze(-1)) - return (x + time_encode + mask_encode,) # (B, L, d_model) + return x + time_encode + mask_encode # (B, L, d_model) class PackedEncoder(Encoder): def forward(self, x, len_batch=None): assert len_batch is not None x = nn.utils.rnn.pack_padded_sequence( - x, - len_batch.cpu(), - enforce_sorted=False, - batch_first=True, + x, len_batch.cpu(), enforce_sorted=False, batch_first=True, ) - return (x,) + return x class OneHotEncoder(Encoder): @@ -179,8 +241,46 @@ def __init__(self, n_tokens, d_model): assert n_tokens <= d_model self.d_model = d_model - def forward(self, x, *args, **kwargs): - return (F.one_hot(x.squeeze(-1), self.d_model).float(),) + def forward(self, x): + return F.one_hot(x.squeeze(-1), self.d_model).float() + + +class Conv2DPatchEncoder(Encoder): + + """ + For encoding images into a sequence of patches. + """ + + def __init__(self, d_input, d_model, filter_sizes, flat=False): + + """ + d_input: dim of encoder input (data dimension) + d_model: dim of encoder output (model dimension) + filter_sizes: tuple with fh, fw + flat: if image is flattened from dataloader (like in cifar), + then we need to reshape back to 2D before conv + """ + + fh, fw = filter_sizes + self.flat = flat + + super().__init__() + assert len(filter_sizes) == 2 + + self.encoder = nn.Conv2d(d_input, d_model, kernel_size=(fh, fw), stride=(fh, fw)) + + def forward(self, x): + + """ + x shape expected = [b, h, w, c] + returns tuple with x, with new shape = [b, seq_len, c_out] + + """ + + x = rearrange(x, 'b h w c -> b c h w') + x = self.encoder(x) + x = rearrange(x, 'b c h w -> b (h w) c') + return x # For every type of encoder/decoder, specify: @@ -189,15 +289,19 @@ def forward(self, x, *args, **kwargs): # - list of attributes to grab from model registry = { - "id": U.Identity, - "embedding": U.Embedding, - "linear": U.Linear, + "stop": Encoder, + "id": nn.Identity, + "embedding": nn.Embedding, + "linear": nn.Linear, "position": PositionalEncoder, "class": ClassEmbedding, "pack": PackedEncoder, "time": TimeEncoder, "onehot": OneHotEncoder, "conv1d": Conv1DEncoder, + "patch2d": Conv2DPatchEncoder, + "timestamp_embedding": TimestampEmbeddingEncoder, + "layer": LayerEncoder, } dataset_attrs = { "embedding": ["n_tokens"], @@ -206,6 +310,7 @@ def forward(self, x, *args, **kwargs): "time": ["n_tokens_time"], "onehot": ["n_tokens"], "conv1d": ["d_input"], + "patch2d": ["d_input"], } model_attrs = { "embedding": ["d_model"], @@ -215,13 +320,16 @@ def forward(self, x, *args, **kwargs): "time": ["d_model"], "onehot": ["d_model"], "conv1d": ["d_model"], + "patch2d": ["d_model"], + "timestamp_embedding": ["d_model"], + "layer": ["d_model"], } def _instantiate(encoder, dataset=None, model=None): """Instantiate a single encoder""" if encoder is None: - return U.Identity() + return None if isinstance(encoder, str): name = encoder else: @@ -240,6 +348,6 @@ def _instantiate(encoder, dataset=None, model=None): def instantiate(encoder, dataset=None, model=None): encoder = utils.to_list(encoder) - return U.TupleSequential( + return U.PassthroughSequential( *[_instantiate(e, dataset=dataset, model=model) for e in encoder] ) diff --git a/src/tasks/metrics.py b/src/tasks/metrics.py index b0e4bf5..bd07e6d 100644 --- a/src/tasks/metrics.py +++ b/src/tasks/metrics.py @@ -1,10 +1,42 @@ import math import torch import torch.nn.functional as F -from src.tasks.mixture import mixture_loss, mixture_loss_kd from sklearn.metrics import f1_score, roc_auc_score from functools import partial +def _student_t_map(mu, sigma, nu): + sigma = F.softplus(sigma) + nu = 2.0 + F.softplus(nu) + return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1) + +def student_t_loss(outs, y): + mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2] + mu, sigma, nu = _student_t_map(mu, sigma, nu) + y = y.squeeze(axis=-1) + + nup1_half = (nu + 1.0) / 2.0 + part1 = 1.0 / nu * torch.square((y - mu) / sigma) + Z = ( + torch.lgamma(nup1_half) + - torch.lgamma(nu / 2.0) + - 0.5 * torch.log(math.pi * nu) + - torch.log(sigma) + ) + + ll = Z - nup1_half * torch.log1p(part1) + return -ll.mean() + +def gaussian_ll_loss(outs, y): + mu, sigma = outs[..., 0], outs[..., 1] + y = y.squeeze(axis=-1) + sigma = F.softplus(sigma) + ll = -1.0 * ( + torch.log(sigma) + + 0.5 * math.log(2 * math.pi) + + 0.5 * torch.square((y - mu) / sigma) + ) + return -ll.mean() + def binary_cross_entropy(logits, y): # BCE loss requires squeezing last dimension of logits so it has the same shape as y # requires y to be float, since it's overloaded to represent a probability @@ -21,6 +53,12 @@ def cross_entropy(logits, y): return F.cross_entropy(logits, y) +def soft_cross_entropy(logits, y, **kwargs): + logits = logits.view(-1, logits.shape[-1]) + # target is now 2d (no target flattening) + return F.cross_entropy(logits, y, **kwargs) + + def accuracy(logits, y): logits = logits.view(-1, logits.shape[-1]) if y.numel() > logits.shape[0]: @@ -38,6 +76,7 @@ def accuracy_at_k(logits, y, k=1): y = y.view(-1) return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean() + def f1_binary(logits, y): logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) @@ -87,6 +126,7 @@ def mse(outs, y, len_batch=None): return F.mse_loss(outs, y) else: # Computes the loss of the first `lens` items in the batches + # TODO document the use case of this mask = torch.zeros_like(outs, dtype=torch.bool) for i, l in enumerate(len_batch): mask[i, :l, :] = 1 @@ -94,6 +134,9 @@ def mse(outs, y, len_batch=None): y_masked = torch.masked_select(y, mask) return F.mse_loss(outs_masked, y_masked) +def forecast_rmse(outs, y, len_batch=None): + # TODO: generalize, currently for Monash dataset + return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean() def mae(outs, y, len_batch=None): # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 @@ -119,10 +162,6 @@ def loss(x, y, loss_fn): return loss_fn(x, y) -def rmse(x, y, loss_fn): - return loss_fn(x, y) ** 0.5 # NOTE this isn't exactly correct - - def bpb(x, y, loss_fn): """ bits per byte (image density estimation, speech generation, char LM) """ return loss_fn(x, y) / math.log(2) @@ -142,19 +181,36 @@ def ppl(x, y, loss_fn): 'accuracy@5': partial(accuracy_at_k, k=5), 'accuracy@10': partial(accuracy_at_k, k=10), "eval_loss": loss, - "mixture": mixture_loss, - "mixture_kd": mixture_loss_kd, "mse": mse, "mae": mae, + "forecast_rmse": forecast_rmse, "f1_binary": f1_binary, "f1_macro": f1_macro, "f1_micro": f1_micro, "roc_auc_macro": roc_auc_macro, "roc_auc_micro": roc_auc_micro, + "soft_cross_entropy": soft_cross_entropy, # only for pytorch 1.10+ + "student_t": student_t_loss, + "gaussian_ll": gaussian_ll_loss, } + +try: + from segmentation_models_pytorch.utils.functional import iou + from segmentation_models_pytorch.losses.focal import focal_loss_with_logits + + def iou_with_logits(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + return iou(pr.sigmoid(), gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) + + output_metric_fns["iou"] = partial(iou, threshold=0.5) + output_metric_fns["iou_with_logits"] = partial(iou_with_logits, threshold=0.5) + output_metric_fns["focal_loss"] = focal_loss_with_logits +except ImportError: + pass + loss_metric_fns = { "loss": loss, "bpb": bpb, "ppl": ppl, } metric_fns = {**output_metric_fns, **loss_metric_fns} # TODO py3.9 + diff --git a/src/tasks/tasks.py b/src/tasks/tasks.py index 25873cf..17dbe55 100644 --- a/src/tasks/tasks.py +++ b/src/tasks/tasks.py @@ -1,18 +1,20 @@ from typing import Optional, List, Tuple import math import functools +import collections import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from omegaconf import ListConfig +from src.models.nn.components import ReversibleInstanceNorm1dInput, ReversibleInstanceNorm1dOutput, \ + TSNormalization, TSInverseNormalization from src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax -from src.models.nn.initialization import weights_init_embedding import src.tasks.metrics as M import src.models.nn.utils as U import torchmetrics as tm -from src.utils.config import to_list, instantiate_partial, instantiate +from src.utils.config import to_list, instantiate class BaseTask: @@ -25,7 +27,7 @@ class BaseTask: encoder = None decoder = None - def __init__(self, dataset=None, model=None, loss=None, metrics=None, torchmetrics=None): + def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None): """ This class is allowed to grab attributes directly off a constructed dataset and model object """ self.dataset = dataset self.model = model @@ -36,8 +38,16 @@ def __init__(self, dataset=None, model=None, loss=None, metrics=None, torchmetri self.torchmetric_names = to_list(torchmetrics) self._tracked_torchmetrics = {} + # The decoder might pass through arguments that the loss needs (e.g. sequence lengths) + # but might also pass through extraneous arguments (e.g. sampling rate) + # Wrap loss and metrics so that they accept kwargs and + # Create loss function self.loss = instantiate(M.output_metric_fns, loss, partial=True) + self.loss = U.discard_kwargs(self.loss) + if loss_val is not None: + self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True) + self.loss_val = U.discard_kwargs(self.loss_val) def _init_torchmetrics(self, prefix): """ @@ -45,7 +55,7 @@ def _init_torchmetrics(self, prefix): """ self._tracked_torchmetrics[prefix] = {} for name in self.torchmetric_names: - if name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1']: + if name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1', 'F1Score']: self._tracked_torchmetrics[prefix][name] = getattr(tm, name)(average='macro', num_classes=self.dataset.d_output, compute_on_step=False).to('cuda') elif '@' in name: k = int(name.split('@')[1]) @@ -89,20 +99,25 @@ def torchmetrics(self, x, y, prefix): self._init_torchmetrics(prefix) for name in self.torchmetric_names: + if name.startswith('Accuracy'): + if len(x.shape) > 2: + # Multi-dimensional, multi-class + self._tracked_torchmetrics[prefix][name].update(x.transpose(1, 2), y.squeeze()) + continue self._tracked_torchmetrics[prefix][name].update(x, y) - def metrics(self, x, y): + def metrics(self, x, y, **kwargs): """ Metrics are just functions output metrics are a function of output and target loss metrics are a function of loss (e.g. perplexity) """ output_metrics = { - name: M.output_metric_fns[name](x, y) + name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs) for name in self.metric_names if name in M.output_metric_fns } loss_metrics = { - name: M.loss_metric_fns[name](x, y, self.loss) + name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs) for name in self.metric_names if name in M.loss_metric_fns } return {**output_metrics, **loss_metrics} @@ -116,40 +131,139 @@ def forward(self, x): return x * self.c class LMTask(BaseTask): - def __init__(self, tied=False, rescale=True, init=None, **kwargs): + def __init__(self, tied=False, rescale=True, **kwargs): super().__init__(loss='cross_entropy', **kwargs) n_tokens = self.dataset.n_tokens d_model = self.model.d_model d_output = self.model.d_output if rescale: - scale = U.TupleModule(Scalar)(math.sqrt(d_model)) + scale = Scalar(math.sqrt(d_model)) else: - scale = U.Identity() + scale = None - embedding = U.Embedding(n_tokens, d_model) + embedding = nn.Embedding(n_tokens, d_model) nn.init.normal_(embedding.weight, mean=0, std=d_model**-.5) - encoder = nn.Sequential( + encoder = U.PassthroughSequential( embedding, scale, ) self.encoder = encoder - decoder = U.TupleModule(nn.Linear)(d_output, n_tokens) + decoder = nn.Linear(d_output, n_tokens) self.decoder = decoder if tied: assert d_model == d_output self.decoder.weight = self.encoder[0].weight - if init is not None: - self.encoder.apply(functools.partial(weights_init_embedding, init_cfg=init)) - class ForecastingTask(BaseTask): + class DummyModule(nn.Module): + + def forward(self, *args): + return args + + def __init__(self, norm='mean', **kwargs): + super().__init__(**kwargs) + + if norm == 'revnorm': + self.encoder = ReversibleInstanceNorm1dInput(self.dataset.d_input, transposed=False) + self.decoder = ReversibleInstanceNorm1dOutput(self.encoder) + elif norm == 'mean': + self.encoder = TSNormalization(method='mean', horizon=self.dataset.dataset_train.forecast_horizon) + self.decoder = TSInverseNormalization(method='mean', normalizer=self.encoder) + elif norm == 'last': + self.encoder = TSNormalization(method='last', horizon=self.dataset.dataset_train.forecast_horizon) + self.decoder = TSInverseNormalization(method='last', normalizer=self.encoder) + else: + self.encoder = None + self.decoder = None + + try: + if hasattr(self.dataset.dataset_train, 'mean'): + self.mean = torch.tensor(self.dataset.dataset_train.mean) + self.std = torch.tensor(self.dataset.dataset_train.std) + elif hasattr(self.dataset.dataset_train, 'standardization'): + self.mean = torch.tensor(self.dataset.dataset_train.standardization['means']) + self.std = torch.tensor(self.dataset.dataset_train.standardization['stds']) + else: + self.mean = None + self.std = None + except AttributeError: + raise AttributeError('Dataset does not have mean/std attributes') + self.mean = torch.tensor(self.dataset.dataset_train.standardization['means']) + self.std = torch.tensor(self.dataset.dataset_train.standardization['stds']) + + if hasattr(self.dataset.dataset_train, 'log_transform'): + self.log_transform = self.dataset.dataset_train.log_transform + else: + self.log_transform = False + print("Log Transform", self.log_transform) + + def metrics(self, x, y, state=None, timestamps=None, ids=None): # Explicit about which arguments the decoder might pass through, but can future-proof with **kwargs + if self.mean is not None: + means = self.mean[ids].to(x.device) + stds = self.std[ids].to(x.device) + x_ = x * stds[:, None, None] + means[:, None, None] + y_ = y * stds[:, None, None] + means[:, None, None] + else: + x_ = x + y_ = y + + if self.log_transform: + x_ = torch.exp(x_) + y_ = torch.exp(y_) + + return super().metrics(x_, y_) + +class VideoTask(BaseTask): def __init__(self, **kwargs): super().__init__(**kwargs) - + # self._y_to_logits = {} + self._vid_to_logits = {} + self._vid_to_label = {} + + # TODO needed to extract the first element of y, which includes the video idea; there should be a cleaner pattern to this + import copy + loss_fn = copy.deepcopy(self.loss) + self.loss = lambda x, y: loss_fn(x, y[0]) + if hasattr(self, 'loss_val'): + loss_val_fn = copy.deepcopy(self.loss_val) + self.loss_val = lambda x, y: loss_val_fn(x, y[0]) + + def metrics(self, logits, y, **kwargs): + labels, vids = y + return super().metrics(logits, labels, **kwargs) + + def torchmetrics(self, logits, y, prefix): + """ + logits: (batch, n_classes) + y = tuple of labels and video ids + labels: (batch) + vids: (batch) + """ + for _logits, _label, _vid in zip(logits, y[0], y[1]): + _vid = _vid.item() + # Check that labels are consistent per video id + assert self._vid_to_label[prefix].get(_vid, _label) == _label + self._vid_to_label[prefix][_vid] = _label + + self._vid_to_logits[prefix][_vid].append(_logits) + + def _reset_torchmetrics(self, prefix): + self._vid_to_logits[prefix] = collections.defaultdict(list) + self._vid_to_label[prefix] = {} + + def get_torchmetrics(self, prefix): + vid_to_average_logits = {vid: torch.mean(torch.stack(logits, dim=0), dim=0) for vid, logits in self._vid_to_logits[prefix].items()} + # y is (label, vid) pair + all_labels = torch.stack(list(self._vid_to_label[prefix].values()), dim=0) # (n_videos) + all_logits = torch.stack(list(vid_to_average_logits.values()), dim=0) # (n_videos, n_classes) + m = M.accuracy(all_logits, all_labels) + return {'aggregate_accuracy': m} + + class AdaptiveLMTask(BaseTask): def __init__( self, @@ -168,7 +282,7 @@ def __init__( d_model = self.model.d_model d_output = self.model.d_output - encoder = U.TupleModule(AdaptiveEmbedding)( + encoder = AdaptiveEmbedding( n_tokens, d_model, d_model, @@ -196,12 +310,39 @@ def __init__( dropout=dropsoft, ) - self.encoder = U.TupleSequential(encoder, self.encoder) + self.encoder = encoder self.loss = loss +class ImageNetTask(BaseTask): + """ + Imagenet training uses mixup augmentations, which require a separate loss for train and val, + which we overide the base task here. + """ + + def __init__(self, **kwargs): + import hydra + + super().__init__( + dataset=kwargs.get("dataset", None), + model=kwargs.get("model", None), + loss=kwargs.get("loss", None), # we still create the base loss here, but will overide below + metrics=kwargs.get("metrics", None), + torchmetrics=kwargs.get("torchmetrics", None) + ) + + # if using mixup, overide loss (train) and loss_val, otherwise + # we have just one loss from the base task above + if "loss_val" in kwargs and "loss_train" in kwargs: + self.loss = hydra.utils.instantiate(kwargs.get("loss_train")) + self.loss_val = hydra.utils.instantiate(kwargs.get('loss_val')) + + registry = { 'base': BaseTask, 'lm': LMTask, 'adaptivelm': AdaptiveLMTask, + 'imagenet': ImageNetTask, + 'forecasting': ForecastingTask, + 'video': VideoTask, } diff --git a/src/utils/__init__.py b/src/utils/__init__.py index de2e624..960c2b9 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,2 +1 @@ from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate -from .config import dictconfig_filter_keys diff --git a/src/utils/config.py b/src/utils/config.py index 8461cc9..09e7f1a 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -119,94 +119,3 @@ def omegaconf_filter_keys(d, fn=None): return d -""" OLD CODE BELOW """ - -""" Special case instantiators - subsumed by instantiate """ - - -def instantiate_name(registry, config, *args, **kwargs): - if config is None: - return None - if isinstance(config, str): - obj = hydra.utils.instantiate({"_target_": config}, *args, **kwargs) - return obj - name = config.pop("_name_") - config["_target_"] = registry[name] - obj = hydra.utils.instantiate(config, *args, **kwargs) - config["_name_"] = name - return obj - - -def instantiate_cls(registry, config, *args, **kwargs): - if config is None: - return None - if isinstance(config, str): - obj = registry[config](*args, **kwargs) - return obj - name = config.pop("_name_") - cls = registry[name] - obj = cls(*args, **config, **kwargs) - config["_name_"] = name - return obj - - -# TODO is there a way ot combining all these cases? -def instantiate_partial(registry, config, *args, **kwargs): - if config is None: - return None - if isinstance(config, str): - obj = functools.partial(registry[config], *args, **kwargs) - return obj - name = config.pop("_name_") - fn = registry[name] - obj = functools.partial(fn, *args, **config, **kwargs) - config["_name_"] = name - return obj - - -""" Legacy infra utilities - currently not used """ - - -def dictconfig_filter_keys(d: DictConfig, fn: Optional[Callable] = None) -> DictConfig: - """Only keep keys where fn(key) is True. Support nested DictConfig. - # TODO can make this inplace? - """ - if fn is None: - fn = lambda _: True - return DictConfig( - { - k: dictconfig_filter_keys(v, fn) if isinstance(v, DictConfig) else v - for k, v in d.items() - if fn(k) - } - ) - - -# from munch import Munch -def remove_postfix(text, postfix): - if text.endswith(postfix): - return text[: -len(postfix)] - return text - - -# pytorch-lightning returns pytorch 0-dim tensor instead of python scalar -def to_scalar(x): - return x.item() if isinstance(x, torch.Tensor) else x - - -def dictconfig_to_munch(d): - """Convert object of type OmegaConf to Munch so Wandb can log properly - Support nested dictionary. - """ - return Munch( - { - k: dictconfig_to_munch(v) if isinstance(v, DictConfig) else v - for k, v in d.items() - } - ) - - -def munch_to_dictconfig(m): - return DictConfig( - {k: munch_to_dictconfig(v) if isinstance(v, Munch) else v for k, v in m.items()} - ) diff --git a/src/utils/optim/schedulers.py b/src/utils/optim/schedulers.py index dbe8ecf..35e6d87 100644 --- a/src/utils/optim/schedulers.py +++ b/src/utils/optim/schedulers.py @@ -1,9 +1,11 @@ -""" Deprecated optimizers. These have been superceded by various wrappers from torch and huggingface """ +"""Custom learning rate schedulers""" import math import warnings import torch +from timm.scheduler import CosineLRScheduler + # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): @@ -58,3 +60,28 @@ def lr_lambda(step): return 1. if step > warmup_step else (step + 1) / warmup_step return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + +class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): + """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. + It supports resuming as well. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._last_epoch = -1 + self.step(epoch=0) + + def step(self, epoch=None): + if epoch is None: + self._last_epoch += 1 + else: + self._last_epoch = epoch + # We call either step or step_update, depending on whether we're using the scheduler every + # epoch or every step. + # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set + # scheduler interval to "step", then the learning rate update will be wrong. + if self.t_in_epochs: + super().step(epoch=self._last_epoch) + else: + super().step_update(num_updates=self._last_epoch) diff --git a/src/utils/optim_groups.py b/src/utils/optim_groups.py new file mode 100644 index 0000000..b935a8f --- /dev/null +++ b/src/utils/optim_groups.py @@ -0,0 +1,144 @@ +"""Utilities for special optimizer hyperparameters. + +group_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused +add_optimizer_hooks is an improved version that uses this codebase's _optim dictionary +""" + +import inspect + +import torch.nn as nn + +import hydra + + +def add_optimizer_hooks( + model, + bias_weight_decay=False, + normalization_weight_decay=False, +): + """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with + attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for + normalization parameters if normalization_weight_decay==False + """ + + # Separate out all parameters to those that will and won't experience regularizing weight decay + blacklist_weight_modules = (nn.Embedding, ) + if not normalization_weight_decay: + blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + # Not compatible with Pytorch 1.8.1 + # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + if (not bias_weight_decay and pn.endswith('bias')) \ + or getattr(p, '_no_weight_decay', False) \ + or isinstance(m, blacklist_weight_modules): + setattr(p, "_optim", {"weight_decay": 0.0}) + + +def group_parameters_for_optimizer( + model, + optimizer_cfg, + bias_weight_decay=False, + normalization_weight_decay=False, +): + """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with + attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for + normalization parameters if normalization_weight_decay==False + """ + # Get the weight decay from the config, or from the default value of the optimizer constructor + # if it's not specified in the config. + if 'weight_decay' in optimizer_cfg: + weight_decay = optimizer_cfg.weight_decay + else: + # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value + signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) + if 'weight_decay' in signature.parameters: + weight_decay = signature.parameters['weight_decay'].default + if weight_decay is inspect.Parameter.empty: + weight_decay = 0.0 + else: + weight_decay = 0.0 + + # If none of the parameters have weight decay anyway, and there are no parameters with special + # optimization params + if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): + return model.parameters() + + skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() + skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') + else set()) + + # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + special = set() + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.Embedding, ) + if not normalization_weight_decay: + blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + # Not compatible with Pytorch 1.8.1 + # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + if not p.requires_grad: + continue # frozen weights + if hasattr(p, '_optim'): + special.add(fpn) + elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): + no_decay.add(fpn) + elif getattr(p, '_no_weight_decay', False): + no_decay.add(fpn) + elif not bias_weight_decay and pn.endswith('bias'): + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + # special case the position embedding parameter in the root GPT module as not decayed + if 'pos_emb' in param_dict: + no_decay.add('pos_emb') + + # In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys() + decay &= param_dict.keys() + decay |= (param_dict.keys() - no_decay - special) + # validate that we considered every parameter + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + if weight_decay == 0.0 or not no_decay: + param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], + "weight_decay": weight_decay}] + else: + param_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + # Add parameters with special hyperparameters + # Unique dicts + hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] + for hp in hps: + params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] + param_groups.append({"params": params, **hp}) + + return param_groups diff --git a/src/utils/permutations.py b/src/utils/permutations.py index f171271..b8f6a0d 100644 --- a/src/utils/permutations.py +++ b/src/utils/permutations.py @@ -178,11 +178,3 @@ def binary2gray(binary, axis=-1): gray = np.logical_xor(binary, shifted) return gray - -if __name__ == '__main__': - print(snake_permutation(4, 5)) - print(transpose_permutation(4, 5)) - - # perm = decode(list(range(16)), 2, 2) - # print(perm) - print(hilbert_permutation(4)) diff --git a/src/utils/registry.py b/src/utils/registry.py index d8d0628..f0b597e 100644 --- a/src/utils/registry.py +++ b/src/utils/registry.py @@ -19,60 +19,44 @@ } model = { + # Backbones from this repo "model": "src.models.sequence.SequenceModel", "unet": "src.models.sequence.SequenceUNet", - "lstm": "src.models.sequence.rnns.lstm.TorchLSTM", - "convnet": "src.models.sequence.convnet", - "ckconv": "src.models.baselines.ckconv.ClassificationCKCNN", + "sashimi": "src.models.sequence.sashimi.Sashimi", + # Baseline RNNs + "lstm": "src.models.baselines.lstm.TorchLSTM", + "gru": "src.models.baselines.gru.TorchGRU", "unicornn": "src.models.baselines.unicornn.UnICORNN", - "resnet": "src.models.baselines.resnet.ResnetSquare", - "wrn": "src.models.baselines.resnet.WideResNet", "odelstm": "src.models.baselines.odelstm.ODELSTM", "lipschitzrnn": "src.models.baselines.lipschitzrnn.RnnModels", - "wavegan": "src.models.baselines.wavegan.WaveGANDiscriminator", - "denseinception": "src.models.baselines.dense_inception.DenseInception", "stackedrnn": "src.models.baselines.samplernn.StackedRNN", "stackedrnn_baseline": "src.models.baselines.samplernn.StackedRNNBaseline", "samplernn": "src.models.baselines.samplernn.SampleRNN", - "dcgru": "src.models.baselines.dcgru.DCRNNModel_classification", - "dcgru_ss": "src.models.baselines.dcgru.DCRNNModel_nextTimePred", - "vit": "models.baselines.vit.ViT", - "snet": "src.models.sequence.snet.SequenceSNet", - "sashimi": "src.models.sequence.sashimi.Sashimi", + # Baseline CNNs + "ckconv": "src.models.baselines.ckconv.ClassificationCKCNN", + "wavegan": "src.models.baselines.wavegan.WaveGANDiscriminator", # DEPRECATED "wavenet": "src.models.baselines.wavenet.WaveNetModel", - # ViT Variants (note: small variant is taken from Tri, differs from original) - "vit_s_16": "src.models.baselines.vit_all.vit_small_patch16_224", - "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", - "convnext": "src.models.baselines.convnext.convnext", - "convnext_timm_base": "src.models.baselines.convnext_timm.convnext_base", - "convnext_timm_small": "src.models.baselines.convnext_timm.convnext_small", - "convnext_timm_tiny": "src.models.baselines.convnext_timm.convnext_tiny", - "convnext_timm_micro": "src.models.baselines.convnext_timm.convnext_micro", - "convnext_timm_orig": "src.models.baselines.convnext_timm_orig.convnext_base", - "convnext_timm_orig_v0": "src.models.baselines.convnext_timm_orig_v0.convnext_base", - "resnet50_timm": "src.models.baselines.resnet_timm.resnet50", - "s4nd_unet": "src.models.sequence.unet_nd.S4NDUNet", - "convnext_timm_tiny_3d": "src.models.baselines.convnext_timm.convnext3d_tiny", + "torch/resnet2d": "src.models.baselines.resnet.TorchVisionResnet", + # Nonaka 1D CNN baselines + "nonaka/resnet18": "src.models.baselines.nonaka.resnet.resnet1d18", + "nonaka/inception": "src.models.baselines.nonaka.inception.inception1d", + "nonaka/xresnet50": "src.models.baselines.nonaka.xresnet.xresnet1d50", } layer = { "id": "src.models.sequence.base.SequenceIdentity", "lstm": "src.models.sequence.rnns.lstm.TorchLSTM", - "sru": "src.models.sequence.rnns.sru.SRURNN", # TODO not updated + "sru": "src.models.sequence.rnns.sru.SRURNN", "lssl": "src.models.sequence.ss.lssl.LSSL", "s4": "src.models.sequence.ss.s4.S4", - "standalone": "src.models.sequence.ss.standalone.s4.S4", - "s4d": "src.models.sequence.ss.standalone.s4d.S4D", - "s4_2d": "src.models.sequence.ss.s4_2d.StateSpace2D", - "s4nd": "src.models.sequence.ss.s4_nd.S4ND", + "standalone": "src.models.s4.s4.S4", + "s4d": "src.models.s4.s4d.S4D", "ff": "src.models.sequence.ff.FF", "rnn": "src.models.sequence.rnns.rnn.RNN", "mha": "src.models.sequence.mha.MultiheadAttention", - "conv1d": "src.models.sequence.conv1d.Conv1d", - "attsimp": "src.models.sequence.mha.AttentionSimple", + "conv1d": "src.models.sequence.convs.conv1d.Conv1d", + "conv2d": "src.models.sequence.convs.conv2d.Conv2d", "performer": "src.models.sequence.attention.linear.Performer", - "s4_2dconv": "src.models.sequence.ss.s4_2dconv.S42DConv" - # 'packedrnn': 'models.sequence.rnns.packedrnn.PackedRNN', } callbacks = { @@ -84,14 +68,5 @@ "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", - "progressive_learning": "src.callbacks.progressive_learning.ProgressiveLearning", -} - -layer_decay = { - 'convnext_timm_tiny': 'src.models.baselines.convnext_timm.get_num_layer_for_convnext_tiny', -} - -model_state_hook = { - 'convnext_timm_tiny_2d_to_3d': 'src.models.baselines.convnext_timm.convnext_timm_tiny_2d_to_3d', - 'convnext_timm_tiny_s4nd_2d_to_3d': 'src.models.baselines.convnext_timm.convnext_timm_tiny_s4nd_2d_to_3d', + "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", } diff --git a/src/utils/train.py b/src/utils/train.py index 34c63e5..2085c93 100644 --- a/src/utils/train.py +++ b/src/utils/train.py @@ -14,7 +14,6 @@ # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging -# [21-09-17 AG] doesn't appear to be used class LoggingContext: def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger @@ -53,7 +52,6 @@ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: return logger -# def extras(config: DictConfig) -> None: def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place """A couple of optional utilities, controlled by main config file: - disabling warnings @@ -79,13 +77,11 @@ def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_k log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") - # set if if config.get("debug"): log.info("Running in debug mode! ") config.trainer.fast_dev_run = True - # # force debugger friendly configuration if - # if config.trainer.get("fast_dev_run"): + # force debugger friendly configuration log.info("Forcing debugger friendly configuration! ") # Debuggers don't like GPUs or multiprocessing if config.trainer.get("gpus"): @@ -103,16 +99,8 @@ def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_k @rank_zero_only def print_config( config: DictConfig, - # fields: Sequence[str] = ( - # "trainer", - # "model", - # "datamodule", - # "train", - # "callbacks", - # "logger", - # "seed", - # ), resolve: bool = True, + save_cfg=True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: @@ -138,8 +126,9 @@ def print_config( rich.print(tree) - with open("config_tree.txt", "w") as fp: - rich.print(tree, file=fp) + if save_cfg: + with open("config_tree.txt", "w") as fp: + rich.print(tree, file=fp) def log_optimizer(logger, optimizer, keys): """ Log values of particular keys from the optimizer's param groups """ @@ -150,64 +139,3 @@ def log_optimizer(logger, optimizer, keys): f"Optimizer group {i}", f"{len(g['params'])} tensors", ] + [f"{k} {v}" for k, v in group_hps.items()])) - # print(f"Optimizer group {i} | {len(g['params'])} tensors | lr {g['lr']} | wd {g.get('weight_decay', None)}") - - - -""" Old code """ - -def resume(config): - pl.seed_everything(config.train.seed, workers=True) # TODO what happens if None? - - trainer = create_trainer(config) - # Because we do model creation in setup(), we have to create model manually again - # model = SequenceLightningModule.load_from_checkpoint(path) - model = create_model(config, SequenceLightningModule) - - # [21-09-18] - # The order that PL calls its hooks is frustratingly opaque - # (1) If resuming from checkpoint, configure_optimizers() is not called - # So we need to manually create the model, move it to device, and call the hook - # (2) However, for some incredibly bizarre reason, it seems that if on_post_move_to_device is called, the model also calls configure_optimizers - # hopefully this doesn't mess with the optimizer checkpoint - # This currently doesn't seem to break anything, but is very annoying to reason about and who knows if it'll change in future versions - model.setup() - model = model.to('cuda') - model.on_post_move_to_device() - # My best guess to the order of hooks is something like: - # (1) .setup() - # (2) .to(device) / .configure_optimizers() - # (3) .load_state_dict (note that checkpoint tensors know their device) - # (4) .validate() or .train() - # Unfortunately, I can't find a hook in between .to(device) and .load_state_dict where we can call the submodule processing - # (since PL is not properly calling the post_move_to_device hook as of 1.4.7) - - trainer.fit(model) - -def resume_manual(config): - ### Alternatively to the Trainer(resume_from_checkpoint=) argument, we can explicitly restore trainer and model state - trainer = pl.Trainer(resume_from_checkpoint=path) - ### Model - import pathlib - path = Path(__file__).absolute().parent / config.train.resume - checkpoint = torch.load(path) - # Move to device explicitly so we can set up submodules (e.g. Krylov) and load the saved model - model = model.to('cuda') - model.setup() - for module in model.modules(): - if hasattr(module, 'setup'): module.setup() - - model.load_state_dict(checkpoint['state_dict']) - # delattr(model, 'setup') # Trick to prevent model from being set up multiple times, but runs into a Python bug LOL https://discuss.python.org/t/why-do-setattr-and-delattr-raise-an-attributeerror-in-this-case/7836/4 - - ### Optimizers - optimizers, lr_schedulers, _ = trainer.init_optimizers(model) # third arg is optimizer_frequencies https://github.com/PyTorchLightning/pytorch-lightning/blob/c66d30a4aa9615cf1b81e76e416c162bf9d2f0a3/pytorch_lightning/trainer/optimizers.py#L28 - for optimizer, optimizer_state in zip(optimizers, checkpoint['optimizer_states']): - optimizer.load_state_dict(optimizer_state) - - trainer.model = model - trainer.optimizers = optimizers - trainer.lr_schedulers = lr_schedulers - # trainer.restore_training_state(checkpoint) # Found in https://github.com/PyTorchLightning/pytorch-lightning/issues/2613 but doesn't work anymore - - trainer.test(trainer.model) diff --git a/train.py b/train.py index d1f3440..e1d85dd 100644 --- a/train.py +++ b/train.py @@ -1,28 +1,122 @@ +import copy import os -import sys -import traceback -from typing import List, Optional, Callable +import random +import time +from functools import partial, wraps +from typing import Callable, List, Optional + +import hydra import numpy as np -import pytorch_lightning.callbacks +import pytorch_lightning as pl import torch import torch.nn as nn -import pytorch_lightning as pl +import wandb +from hydra.utils import get_original_cwd +from omegaconf import DictConfig, OmegaConf from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.utilities import rank_zero_only -import hydra -from omegaconf import OmegaConf, DictConfig +from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn +from tqdm.auto import tqdm +import src.models.nn.utils as U import src.utils as utils import src.utils.train -from src.utils.optim.ema import build_ema_optimizer -from src.utils import registry -from src.tasks import encoders, decoders, tasks -import src.models.nn.utils as U from src.dataloaders import SequenceDataset # TODO make registry -from tqdm.auto import tqdm +from src.tasks import decoders, encoders, tasks +from src.utils import registry +from src.utils.optim.ema import build_ema_optimizer +from src.utils.optim_groups import add_optimizer_hooks log = src.utils.train.get_logger(__name__) +# Turn on TensorFloat32 (speeds up large model training substantially) +import torch.backends +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +# Lots of annoying hacks to get WandbLogger to continuously retry on failure +class DummyExperiment: + """Dummy experiment.""" + + def nop(self, *args, **kw): + pass + + def __getattr__(self, _): + return self.nop + + def __getitem__(self, idx) -> "DummyExperiment": + # enables self.logger.experiment[0].add_image(...) + return self + + def __setitem__(self, *args, **kwargs) -> None: + pass + + +def rank_zero_experiment(fn: Callable) -> Callable: + """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" + + @wraps(fn) + def experiment(self): + @rank_zero_only + def get_experiment(): + return fn(self) + + return get_experiment() or DummyExperiment() + + return experiment + + +class CustomWandbLogger(WandbLogger): + + def __init__(self, *args, **kwargs): + """Modified logger that insists on a wandb.init() call and catches wandb's error if thrown.""" + + super().__init__(*args, **kwargs) + + @property + @rank_zero_experiment + def experiment(self): + r""" + Actual wandb object. To use wandb features in your + :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. + Example:: + .. code-block:: python + self.logger.experiment.some_wandb_function() + """ + if self._experiment is None: + if self._offline: + os.environ["WANDB_MODE"] = "dryrun" + + attach_id = getattr(self, "_attach_id", None) + if wandb.run is not None: + # wandb process already created in this instance + rank_zero_warn( + "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" + " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." + ) + self._experiment = wandb.run + elif attach_id is not None and hasattr(wandb, "_attach"): + # attach to wandb process referenced + self._experiment = wandb._attach(attach_id) + else: + # create new wandb process + while True: + try: + self._experiment = wandb.init(**self._wandb_init) + break + except Exception as e: + print("wandb Exception:\n", e) + t = random.randint(30, 60) + print(f"Sleeping for {t} seconds") + time.sleep(t) + + # define default x-axis + if getattr(self._experiment, "define_metric", None): + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + + return self._experiment + class SequenceLightningModule(pl.LightningModule): def __init__(self, config): @@ -37,15 +131,9 @@ def __init__(self, config): # Passing in config expands it one level, so can access by self.hparams.train instead of self.hparams.config.train self.save_hyperparameters(config, logger=False) + # Dataset arguments self.dataset = SequenceDataset.registry[self.hparams.dataset._name_]( - **{ - # Arguments for configuring dataloader when using TBPTT - "tbptt": self.hparams.train.state.mode == 'tbptt', - "chunk_len": self.hparams.train.state.chunk_len, - "overlap_len": self.hparams.train.state.overlap_len, - # Dataset arguments - **self.hparams.dataset, - } + **self.hparams.dataset ) # Check hparams @@ -53,7 +141,8 @@ def __init__(self, config): # PL has some bugs, so add hooks and make sure they're only called once self._has_setup = False - self._has_on_post_move_to_device = False + + self.setup() ## Added by KS def setup(self, stage=None): if not self.hparams.train.disable_dataset: @@ -72,15 +161,21 @@ def setup(self, stage=None): encoder_cfg = utils.to_list(self.hparams.encoder) + utils.to_list( self.hparams.model.pop("encoder", None) ) - decoder_cfg = utils.to_list(self.hparams.model.pop("decoder", None)) + utils.to_list(self.hparams.decoder) + decoder_cfg = utils.to_list( + self.hparams.model.pop("decoder", None) + ) + utils.to_list(self.hparams.decoder) # Instantiate model self.model = utils.instantiate(registry.model, self.hparams.model) + if (name := self.hparams.train.post_init_hook['_name_']) is not None: + kwargs = self.hparams.train.post_init_hook.copy() + del kwargs['_name_'] + for module in self.modules(): + if hasattr(module, name): + getattr(module, name)(**kwargs) # Instantiate the task - if "task" not in self.hparams: # TODO maybe don't need this? - self.hparams.task = self.dataset.default_task - self.task = task = utils.instantiate( + self.task = utils.instantiate( tasks.registry, self.hparams.task, dataset=self.dataset, model=self.model ) @@ -89,18 +184,36 @@ def setup(self, stage=None): encoder_cfg, dataset=self.dataset, model=self.model ) decoder = decoders.instantiate( - self.hparams.decoder, model=self.model, dataset=self.dataset + decoder_cfg, model=self.model, dataset=self.dataset ) # Extract the modules so they show up in the top level parameter count - self.encoder = U.TupleSequential(task.encoder, encoder) - self.decoder = U.TupleSequential(decoder, task.decoder) - self.loss = task.loss - self.metrics = task.metrics + self.encoder = U.PassthroughSequential(self.task.encoder, encoder) + self.decoder = U.PassthroughSequential(decoder, self.task.decoder) + self.loss = self.task.loss + self.loss_val = self.task.loss + if hasattr(self.task, 'loss_val'): + self.loss_val = self.task.loss_val + self.metrics = self.task.metrics # Handle state logic self._initialize_state() + def load_state_dict(self, state_dict, strict=True): + if self.hparams.train.pretrained_model_state_hook['_name_'] is not None: + model_state_hook = utils.instantiate( + registry.model_state_hook, + self.hparams.train.pretrained_model_state_hook.copy(), + partial=True, + ) + # Modify the checkpoint['state_dict'] inside model_state_hook e.g. to inflate 2D convs to 3D convs + state_dict = model_state_hook(self.model, state_dict) + + print("Custom load_state_dict function is running.") + + # note, it needs to return something from the normal function we overrided + return super().load_state_dict(state_dict, strict=strict) + def _check_config(self): assert self.hparams.train.state.mode in [None, "none", "null", "reset", "bptt", "tbptt"] assert ( @@ -113,17 +226,14 @@ def _check_config(self): or isinstance(n, int) and n >= 0 ) - assert ( - not (self.hparams.train.state.mode == 'tbptt') or - (self.hparams.train.state.chunk_len is not None and - self.hparams.train.state.overlap_len is not None) - ), "If tbptt is True, chunk_len and overlap_len must be specified." def _initialize_state(self): + """Called at model setup and start of epoch to completely reset state""" self._state = None self._memory_chunks = [] def _reset_state(self, batch, device=None): + """Called to construct default_state when necessary, e.g. during BPTT""" device = device or batch[0].device self._state = self.model.default_state(*batch[0].shape[:1], device=device) @@ -142,14 +252,14 @@ def _detach_state(self, state): raise NotImplementedError def _process_state(self, batch, batch_idx, train=True): - """Handle logic for state context. This is unused for all current S3 experiments""" - + """Handle logic for state context.""" # Number of context steps key = "n_context" if train else "n_context_eval" n_context = self.hparams.train.state.get(key) - # Don't need to do anything if 0 context steps + # Don't need to do anything if 0 context steps. Make sure there is no state if n_context == 0 and self.hparams.train.state.mode not in ['tbptt']: + self._initialize_state() return # Reset state if needed @@ -168,8 +278,8 @@ def _process_state(self, batch, batch_idx, train=True): self._memory_chunks = self._memory_chunks[-n_context:] elif self.hparams.train.state.mode == 'tbptt': - _, _, *z = batch - reset = z[-1] # if tbptt, last element of z should be whether to reset state! + _, _, z = batch + reset = z["reset"] if reset: self._reset_state(batch) else: @@ -181,52 +291,43 @@ def on_epoch_start(self): def forward(self, batch): """Passes a batch through the encoder, backbone, and decoder""" # z holds arguments such as sequence length - x, y, *z = batch - # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs - x, *w = self.encoder(x, *z) - x, state = self.model(x, *w, state=self._state) - self._state = state - x, *w = self.decoder(x, state, *z) - return x, y, *w - - @torch.inference_mode() - def forward_recurrence(self, batch, k=1): - """This is a bit hacky; not part of the main train loop, only used to benchmark speed of recurrent view""" - x, y, *z = batch - T = x.shape[1] - - if k > 1: - x = torch.cat([x] * k, dim=0) - - self._state = self.model.default_state(*x.shape[:1], device="cuda") - - x_all = [] - w_all = [] - for t in tqdm(range(T)): - - x_t = x[:, t] - x_t = x_t.to("cuda") + x, y, *z = batch # z holds extra dataloader info such as resolution + if len(z) == 0: + z = {} + else: + assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" + z = z[0] - x_t, *w_t = self.encoder(x_t) - x_t, state = self.model.step(x_t, state=self._state) - self._state = state - x_t, *w_t = self.decoder(x_t, state) + x, w = self.encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs + x, state = self.model(x, **w, state=self._state) + self._state = state + x, w = self.decoder(x, state=state, **z) + return x, y, w - x_all.append(x_t) - w_all.append(w_t) - return torch.stack(x_all), y, *[torch.stack(w_) for w_ in zip(*w_all)] + def step(self, x_t): + x_t, *_ = self.encoder(x_t) # Potential edge case for encoders that expect (B, L, H)? + x_t, state = self.model.step(x_t, state=self._state) + self._state = state + # x_t = x_t[:, None, ...] # Dummy length + # x_t, *_ = self.decoder(x_t, state=state) + # x_t = x_t[:, 0, ...] + x_t, *_ = self.decoder.step(x_t, state=state) + return x_t def _shared_step(self, batch, batch_idx, prefix="train"): self._process_state(batch, batch_idx, train=(prefix == "train")) - x, y, *w = self.forward(batch) + x, y, w = self.forward(batch) # Loss - loss = self.loss(x, y, *w) + if prefix == 'train': + loss = self.loss(x, y, **w) + else: + loss = self.loss_val(x, y, **w) # Metrics - metrics = self.metrics(x, y) + metrics = self.metrics(x, y, **w) metrics["loss"] = loss metrics = {f"{prefix}/{k}": v for k, v in metrics.items()} @@ -351,10 +452,15 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): def configure_optimizers(self): + # Set zero weight decay for some params + if 'optimizer_param_grouping' in self.hparams.train: + add_optimizer_hooks(self.model, **self.hparams.train.optimizer_param_grouping) + # Normal parameters all_params = list(self.parameters()) params = [p for p in all_params if not hasattr(p, "_optim")] + # Construct optimizer, add EMA if necessary if self.hparams.train.ema > 0.0: optimizer = utils.instantiate( @@ -365,27 +471,62 @@ def configure_optimizers(self): polyak=self.hparams.train.ema, ) else: - optimizer = utils.instantiate( - registry.optimizer, self.hparams.optimizer, params - ) + optimizer = utils.instantiate(registry.optimizer, self.hparams.optimizer, params) del self.hparams.optimizer._name_ # Add parameters with special hyperparameters hps = [getattr(p, "_optim") for p in all_params if hasattr(p, "_optim")] hps = [ - dict(s) for s in set(frozenset(hp.items()) for hp in hps) + # dict(s) for s in set(frozenset(hp.items()) for hp in hps) + dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps))) + # dict(s) for s in dict.fromkeys(frozenset(hp.items()) for hp in hps) ] # Unique dicts + print("Hyperparameter groups", hps) for hp in hps: params = [p for p in all_params if getattr(p, "_optim", None) == hp] optimizer.add_param_group( {"params": params, **self.hparams.optimizer, **hp} ) + ### Layer Decay ### + + if self.hparams.train.layer_decay['_name_'] is not None: + get_num_layer = utils.instantiate( + registry.layer_decay, + self.hparams.train.layer_decay['_name_'], + partial=True, + ) + + # Go through all parameters and get num layer + layer_wise_groups = {} + num_max_layers = 0 + for name, p in self.named_parameters(): + # Get layer id for each parameter in the model + layer_id = get_num_layer(name) + + # Add to layer wise group + if layer_id not in layer_wise_groups: + layer_wise_groups[layer_id] = { + 'params': [], + 'lr': None, + 'weight_decay': self.hparams.optimizer.weight_decay + } + layer_wise_groups[layer_id]['params'].append(p) + + if layer_id > num_max_layers: num_max_layers = layer_id + + # Update lr for each layer + for layer_id, group in layer_wise_groups.items(): + group['lr'] = self.hparams.optimizer.lr * (self.hparams.train.layer_decay.decay ** (num_max_layers - layer_id)) + + # Reset the torch optimizer's param groups + optimizer.param_groups = [] + for layer_id, group in layer_wise_groups.items(): + optimizer.add_param_group(group) + # Print optimizer info for debugging - keys = set( - [k for hp in hps for k in hp.keys()] - ) # Get the set of special hparams + keys = set([k for hp in hps for k in hp.keys()]) # Special hparams utils.train.log_optimizer(log, optimizer, keys) # Configure scheduler @@ -434,7 +575,12 @@ def _eval_dataloaders(self): test_loader_names += [name + "/ema" for name in test_loader_names] test_loaders = test_loaders + test_loaders - return val_loader_names + test_loader_names, val_loaders + test_loaders + # adding option to only have val loader at eval (eg if test is duplicate) + if self.hparams.train.get("remove_test_loader_in_eval", None) is not None: + return val_loader_names, val_loaders + # default behavior is to add test loaders in eval + else: + return val_loader_names + test_loader_names, val_loaders + test_loaders def val_dataloader(self): val_loader_names, val_loaders = self._eval_dataloaders() @@ -447,8 +593,7 @@ def test_dataloader(self): return test_loaders -### pytorch-lightning utils and entrypoint - +### pytorch-lightning utils and entrypoint ### def create_trainer(config, **kwargs): callbacks: List[pl.Callback] = [] @@ -460,7 +605,7 @@ def create_trainer(config, **kwargs): # Can pass in config_exclude_keys='wandb' to remove certain groups import wandb - logger = WandbLogger( + logger = CustomWandbLogger( config=utils.to_dict(config, recursive=True), settings=wandb.Settings(start_method="fork"), **config.wandb, @@ -477,6 +622,7 @@ def create_trainer(config, **kwargs): # Configure ddp automatically if config.trainer.gpus > 1: + print("ddp automatically configured, more than 1 gpu used!") kwargs["plugins"] = [ pl.plugins.DDPPlugin( find_unused_parameters=True, @@ -485,6 +631,14 @@ def create_trainer(config, **kwargs): ] kwargs["accelerator"] = "ddp" + # Add ProgressiveResizing callback + if config.callbacks.get("progressive_resizing", None) is not None: + num_stages = len(config.callbacks.progressive_resizing.stage_params) + print(f"Progressive Resizing: {num_stages} stages") + for i, e in enumerate(config.callbacks.progressive_resizing.stage_params): + # Stage params are resolution and epochs, pretty print + print(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs") + kwargs.update(config.trainer) trainer = pl.Trainer( logger=logger, @@ -499,58 +653,20 @@ def train(config): pl.seed_everything(config.train.seed, workers=True) trainer = create_trainer(config) model = SequenceLightningModule(config) - trainer.fit(model) - if config.train.test: - trainer.test(model) - - if trainer.is_global_zero: - for cb in trainer.callbacks: - # print("cb type: ", str(type(cb))) - if isinstance(cb, pytorch_lightning.callbacks.ModelCheckpoint): - top = cb.best_model_score.cpu().item() - print("top score: ", top) - # TODO: dynamically determine which path is best to store the global summary file - dirname = os.path.dirname(os.path.abspath(__file__)) - with open(f"{dirname}/global_summary.txt", "a") as f: - f.write(f"{100*top:0.3f}: ") - cmd_line = " ".join(sys.argv[1:]) - f.write("liquid ") - f.write(f"python3 train.py {cmd_line}") - f.write("\n") - - -def benchmark_step(config): - """Utility function to benchmark speed of 'stepping', i.e. recurrent view. Unused for main train logic""" - pl.seed_everything(config.train.seed, workers=True) - - model = SequenceLightningModule(config) - model.setup() - model.to("cuda") - print("Num Parameters: ", sum(p.numel() for p in model.parameters())) - print( - "Num Trainable Parameters: ", - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) - model._on_post_move_to_device() - for module in model.modules(): - if hasattr(module, "setup_step"): - module.setup_step() - model.eval() + # Run initial validation epoch (useful for debugging, finetuning) + if config.train.validate_at_start: + print("Running validation before training") + trainer.validate(model) - val_dataloaders = model.val_dataloader() - dl = val_dataloaders[0] if utils.is_list(val_dataloaders) else val_dataloaders + if config.train.ckpt is not None: + trainer.fit(model, ckpt_path=config.train.ckpt) + else: + trainer.fit(model) + if config.train.test: + trainer.test(model) - import benchmark - for batch in dl: - benchmark.utils.benchmark( - model.forward_recurrence, - batch, - config.train.benchmark_step_k, - T=config.train.benchmark_step_T, - ) - break @hydra.main(config_path="configs", config_name="config.yaml") @@ -565,12 +681,8 @@ def main(config: OmegaConf): # Pretty print config using Rich library utils.train.print_config(config, resolve=True) - if config.train.benchmark_step: - benchmark_step(config) - exit() - train(config) if __name__ == "__main__": - main() \ No newline at end of file + main()