From f6805381b14e16571b5c578280344e01994ee8cc Mon Sep 17 00:00:00 2001 From: Albert Gu Date: Wed, 27 Apr 2022 07:54:02 -0700 Subject: [PATCH] add standalone version of S4 with some options stripped out, and diagonal version --- configs/model/layer/s4_simple.yaml | 19 + .../sequence/ss/standalone/s4_simple.py | 1001 +++++++++++++++++ src/utils/registry.py | 1 + 3 files changed, 1021 insertions(+) create mode 100644 configs/model/layer/s4_simple.yaml create mode 100644 src/models/sequence/ss/standalone/s4_simple.py diff --git a/configs/model/layer/s4_simple.yaml b/configs/model/layer/s4_simple.yaml new file mode 100644 index 0000000..b92075f --- /dev/null +++ b/configs/model/layer/s4_simple.yaml @@ -0,0 +1,19 @@ +_name_: simple +d_state: 64 +channels: 1 +bidirectional: false +activation: gelu +postact: null +dropout: ${..dropout} # Same as null +measure: legs +dt_min: 0.001 +dt_max: 0.1 +trainable: + dt: true + A: true + P: true + B: true +lr: 0.001 +mode: nplr +l_max: ${oc.select:dataset.__l_max,1} # Grab dataset length if exists, otherwise set to 1 and kernel will automatically resize +verbose: false diff --git a/src/models/sequence/ss/standalone/s4_simple.py b/src/models/sequence/ss/standalone/s4_simple.py new file mode 100644 index 0000000..712866d --- /dev/null +++ b/src/models/sequence/ss/standalone/s4_simple.py @@ -0,0 +1,1001 @@ +""" Standalone version of Structured (Sequence) State Space (S4) model, with additional features removed. """ + + +import logging +from functools import partial +import math +import numpy as np +from scipy import special as ss +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning.utilities import rank_zero_only +from einops import rearrange, repeat +import opt_einsum as oe + +contract = oe.contract +contract_expression = oe.contract_expression + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger +log = get_logger(__name__) + +""" Cauchy kernel """ + +def cauchy_slow(v, z, w): + """ + v, w: (..., N) + z: (..., L) + returns: (..., L) + """ + cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) + +def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] + return tensors + +_c2r = torch.view_as_real +_r2c = torch.view_as_complex +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + + +""" simple nn.Module components """ + +def Activation(activation=None, dim=-1): + if activation in [ None, 'id', 'identity', 'linear' ]: + return nn.Identity() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'relu': + return nn.ReLU() + elif activation == 'gelu': + return nn.GELU() + elif activation in ['swish', 'silu']: + return nn.SiLU() + elif activation == 'glu': + return nn.GLU(dim=dim) + elif activation == 'sigmoid': + return nn.Sigmoid() + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + +def get_initializer(name, activation=None): + if activation in [ None, 'id', 'identity', 'linear', 'modrelu' ]: + nonlinearity = 'linear' + elif activation in ['relu', 'tanh', 'sigmoid']: + nonlinearity = activation + elif activation in ['gelu', 'swish', 'silu']: + nonlinearity = 'relu' # Close to ReLU so approximate with ReLU's gain + else: + raise NotImplementedError(f"get_initializer: activation {activation} not supported") + + if name == 'uniform': + initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) + elif name == 'normal': + initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) + elif name == 'xavier': + initializer = torch.nn.init.xavier_normal_ + elif name == 'zero': + initializer = partial(torch.nn.init.constant_, val=0) + elif name == 'one': + initializer = partial(torch.nn.init.constant_, val=1) + else: + raise NotImplementedError(f"get_initializer: initializer type {name} not supported") + + return initializer + + +def LinearActivation( + d_input, d_output, bias=True, + zero_bias_init=False, + transposed=False, + initializer=None, + activation=None, + activate=False, # Apply activation as part of this module + weight_norm=False, + **kwargs, + ): + """ Returns a linear nn.Module with control over axes order, initialization, and activation """ + + # Construct core module + linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + if activation == 'glu': d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + # Initialize weight + if initializer is not None: + get_initializer(initializer, activation)(linear.weight) + + # Initialize bias + if bias and zero_bias_init: + nn.init.zeros_(linear.bias) + + # Weight norm + if weight_norm: + linear = nn.utils.weight_norm(linear) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + +""" Misc functional utilities """ + +def krylov(L, A, b, c=None, return_power=False): + """ + Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. + + If return_power=True, return A^{L-1} as well + """ + # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises + + x = b.unsqueeze(-1) # (..., N, 1) + A_ = A + + AL = None + if return_power: + AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) + _L = L-1 + + done = L == 1 + # loop invariant: _L represents how many indices left to compute + while not done: + if return_power: + if _L % 2 == 1: AL = A_ @ AL + _L //= 2 + + # Save memory on last iteration + l = x.shape[-1] + if L - l <= l: + done = True + _x = x[..., :L-l] + else: _x = x + + _x = A_ @ _x + x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes + if not done: A_ = A_ @ A_ + + assert x.shape[-1] == L + + if c is not None: + x = torch.einsum('...nl, ...n -> ...l', x, c) + x = x.contiguous() # WOW!! + if return_power: + return x, AL + else: + return x + +def power(L, A, v=None): + """ Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: I = powers[-1] @ I + L //= 2 + if L == 0: break + l *= 2 + powers.append(powers[-1] @ powers[-1]) + + if v is None: return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, '... (z l) -> ... z l', z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + + +""" HiPPO utilities """ + +def embed_c2r(A): + A = rearrange(A, '... m n -> ... m () n ()') + A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + \ + np.pad(A, ((0, 0), (1, 0), (0, 0), (1,0))) + return rearrange(A, 'm x n y -> (m x) (n y)') + +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == 'lagt': + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + elif measure == 'glagt': + alpha = measure_args.get('alpha', 0.0) + beta = measure_args.get('beta', 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1))) + A = (1./L[:, None]) * A * L[None, :] + B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2) + # Legendre (translated) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == 'fourier': + freqs = np.arange(N//2) + d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] + A = 2*np.pi*(np.diag(d, 1) - np.diag(d, -1)) + A = A - embed_c2r(np.ones((N//2, N//2))) + B = embed_c2r(np.ones((N//2, 1)))[..., :1] + elif measure == 'random': + A = np.random.randn(N, N) / N + B = np.random.randn(N, 1) + elif measure == 'diagonal': + A = -np.diag(np.exp(np.random.randn(N))) + B = np.random.randn(N, 1) + else: + raise NotImplementedError + + return A, B + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """ Return low-rank matrix L such that A + L is normal """ + + if measure == 'legs': + assert rank >= 1 + P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == 'legt': + assert rank >= 2 + P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + elif measure == 'lagt': + assert rank >= 1 + P = .5**.5 * torch.ones(1, N, dtype=dtype) + elif measure == 'fourier': + P = torch.ones(N, dtype=dtype) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + else: raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N) + return P + +def nplr(measure, N, rank=1, dtype=torch.float): + """ Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or torch.cfloat + if measure == 'random': + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + # w = torch.randn(N//2, dtype=dtype) + w = -torch.exp(torch.randn(N//2)) + 1j*torch.randn(N//2) + P = torch.randn(rank, N//2, dtype=dtype) + B = torch.randn(N//2, dtype=dtype) + V = torch.eye(N, dtype=dtype)[..., :N//2] # Only used in testing + return w, P, B, V + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) + AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) + w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + # V w V^{-1} = A + + # Only keep one of the conjugate pairs + w = w[..., 0::2].contiguous() + V = V[..., 0::2].contiguous() + + V_inv = V.conj().transpose(-1, -2) + + B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B + P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P + + + return w, P, B, V + + + +class SSKernelNPLR(nn.Module): + """Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) + + The class name stands for 'State-Space SSKernel for Normal Plus Low-Rank'. + The parameters of this function are as follows. + + A: (... N N) the state matrix + B: (... N) input matrix + C: (... N) output matrix + dt: (...) timescales / discretization step size + p, q: (... P N) low-rank correction to A, such that Ap=A+pq^T is a normal matrix + + The forward pass of this Module returns: + (... L) that represents represents FFT SSKernel_L(A^dt, B^dt, C) + + """ + + @torch.no_grad() + def _setup_C(self, double_length=False): + """ Construct C~ from C + + double_length: current C is for length L, convert it to length 2L + """ + C = _r2c(self.C) + self._setup_state() + dA_L = power(self.L, self.dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., :self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + if double_length: + self.L *= 2 + self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) + + def _omega(self, L, dtype, device, cache=True): + """ Calculate (and cache) FFT nodes and their "unprocessed" them with the bilinear transform + This should be called everytime the internal length self.L changes """ + omega = torch.tensor( + np.exp(-2j * np.pi / (L)), dtype=dtype, device=device + ) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + if cache: + self.register_buffer("omega", _c2r(omega)) + self.register_buffer("z", _c2r(z)) + return omega, z + + def __init__( + self, + L, w, P, B, C, log_dt, + trainable=None, + lr=None, + verbose=False, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + w: (N) + p: (r, N) low-rank correction to A + q: (r, N) + A represented by diag(w) - pq^* + + B: (N) + dt: (H) timescale per feature + C: (H, C, N) system is 1-D to c-D (channels) + + trainable: toggle which of the parameters is trainable + lr: add hook to set lr of hippo parameters specially (everything besides C) + length_correction: multiply C by (I - dA^L) - can be turned off when L is large for slight speedup at initialization (only relevant when N large as well) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + self.verbose = verbose + + # Rank of low-rank correction + self.rank = P.shape[-2] + assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) + H = 1 + B = repeat(B, 'n -> 1 h n', h=H) + P = repeat(P, 'r n -> r h n', h=H) + w = repeat(w, 'n -> h n', h=H) + + # Cache Fourier nodes every time we set up a desired length + self.L = L + if self.L is not None: + self._omega(self.L, dtype=C.dtype, device=C.device, cache=True) + + # Register parameters + # C is a regular parameter, not state + # self.C = nn.Parameter(_c2r(C.conj().resolve_conj())) + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + train = False + if trainable is None: trainable = {} + if trainable == False: trainable = {} + if trainable == True: trainable, train = {}, True + self.register("log_dt", log_dt, trainable.get('dt', train), lr, 0.0) + self.register("B", _c2r(B), trainable.get('B', train), lr, 0.0) + self.register("P", _c2r(P), trainable.get('P', train), lr, 0.0) + log_w_real = torch.log(-w.real + 1e-3) # Some of the HiPPO methods have real part 0 + w_imag = w.imag + self.register("log_w_real", log_w_real, trainable.get('A', 0), lr, 0.0) + self.register("w_imag", w_imag, trainable.get('A', train), lr, 0.0) + self.Q = None + + self._setup_C() + + def _w(self): + # Get the internal w (diagonal) parameter + w_real = -torch.exp(self.log_w_real) + w_imag = self.w_imag + w = w_real + 1j * w_imag + return w + + def forward(self, L=None): + """ + returns: (..., c, L) + """ + if L is None: + L = self.L + + # Increase the internal length if needed + while L > self.L: + self.double_length() + + dt = torch.exp(self.log_dt) + B = _r2c(self.B) + C = _r2c(self.C) + P = _r2c(self.P) + Q = P.conj() if self.Q is None else _r2c(self.Q) + w = self._w() + + omega, z = _r2c(self.omega), _r2c(self.z) # (..., L) + + B = repeat(B, '... 1 n -> ... h n', h=self.H) + P = repeat(P, '... 1 n -> ... h n', h=self.H) + Q = repeat(Q, '... 1 n -> ... h n', h=self.H) + + + # Incorporate dt into A + w = w * dt.unsqueeze(-1) # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (s+1+r, H, N) + C = torch.cat([C, Q], dim=-3) # (c+r, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (s+1+r, c+r, H, N) + # w = w[None, None, ...] # (1, 1, H, N) + # z = z[None, None, None, ...] # (1, 1, 1, L) + + # Calculate resolvent at omega + r = cauchy_slow(v, z, w) + r = r * dt[None, None, :, None] # (S+1+R, C+R, H, L) + + # Low-rank Woodbury correction + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / (1 + r[-1:, -1:, :, :]) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f) # (S+1, C, H, L) + + # Truncate to target length + k = k[..., :L] + k_B = k[-1, :, :, :] # (C H L) + return k_B + + @torch.no_grad() + def double_length(self): + if self.verbose: log.info(f"S4: Doubling length from L = {self.L} to {2*self.L}") + self._setup_C(double_length=True) + + def _setup_linear(self): + """ Create parameters that allow fast linear stepping of state """ + w = self._w() + B = _r2c(self.B) # (H N) + P = _r2c(self.P) + Q = P.conj() if self.Q is None else _r2c(self.Q) + + # Prepare Linear stepping + dt = torch.exp(self.log_dt) + D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) + R = (torch.eye(self.rank, dtype=w.dtype, device=w.device) + 2*contract('r h n, h n, s h n -> h r s', Q, D, P).real) # (H r r) + Q_D = rearrange(Q*D, 'r h n -> h r n') + R = torch.linalg.solve(R.to(Q_D), Q_D) # (H r N) + R = rearrange(R, 'h r n -> r h n') + + self.step_params = { + "D": D, # (H N) + "R": R, # (r H N) + "P": P, # (r H N) + "Q": Q, # (r H N) + "B": B, # (1 H N) + "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster + + u: (H) input + state: (H, N/2) state with conjugate pairs + Optionally, the state can have last dimension N + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if state.size(-1) == self.N: # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', _conj(p), _conj(x), _conj(y))[..., :self.N] # inner outer product + else: + assert state.size(-1) == 2*self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping + contract_fn = lambda p, x, y: contract('r h n, r h m, ... h m -> ... h n', p, x, y) # inner outer product + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (r H N) + P = step_params["P"] # (r H N) + Q = step_params["Q"] # (r H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """ Construct dA and dB for discretized state equation """ + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c(self.C) # Just returns a view that we use for finding dtype/device + + state = torch.eye(2*self.N, dtype=C.dtype, device=C.device).unsqueeze(-2) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + self.dA = dA # (H N N) + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + self.dB = rearrange(dB, '1 h n -> h n') # (H N) + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: + optim["lr"] = lr + if trainable and wd is not None: + optim["weight_decay"] = wd + if len(optim) > 0: + setattr(getattr(self, name), "_optim", optim) + + +class SSKernelDiag(nn.Module): + """ Version using (complex) diagonal state matrix. Main difference is this uses the ZOH instead of Bilinear transform. Note that it is slower and less memory efficient than the NPLR kernel because of lack of kernel support. + + """ + + def __init__( + self, + w, C, log_dt, + L=None, + disc='bilinear', + trainable=None, + lr=None, + # tie_state=False, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + w: (N) + p: (r, N) low-rank correction to A + q: (r, N) + A represented by diag(w) - pq^* + + B: (N) + dt: (H) timescale per feature + C: (H, C, N) system is 1-D to c-D (channels) + + hurwitz: tie pq and ensure w has negative real part + trainable: toggle which of the parameters is trainable + lr: add hook to set lr of hippo parameters specially (everything besides C) + tie_state: tie all state parameters across the H hidden features + length_correction: multiply C by (I - dA^L) - can be turned off when L is large for slight speedup at initialization (only relevant when N large as well) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + # self.tie_state = tie_state + self.L = L + self.disc = disc + + # Rank of low-rank correction + assert w.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) + + # Register parameters + # C is a regular parameter, not state + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + train = False + if trainable is None: trainable = {} + if trainable == False: trainable = {} + if trainable == True: trainable, train = {}, True + self.register("log_dt", log_dt, trainable.get('dt', train), lr, 0.0) + + if self.disc in ['bilinear', 'zoh']: + log_w_real = torch.log(-w.real + 1e-3) # Some of the HiPPO methods have real part 0 + w_imag = w.imag + self.register("log_w_real", log_w_real, trainable.get('A', 0), lr, 0.0) + self.register("w_imag", w_imag, trainable.get('A', train), lr, 0.0) + elif self.disc == 'dss': + self.register("w", _c2r(w), trainable.get('A', train), lr, 0.0) + else: raise NotImplementedError + + + def _w(self): + # Get the internal w (diagonal) parameter + if self.disc != 'dss': + w_real = -torch.exp(self.log_w_real) + w_imag = self.w_imag + w = w_real + 1j * w_imag + else: + w = _r2c(self.w) # (..., N) + return w + + def forward(self, state=None, L=None): + """ + state: (..., s, N) extra tensor that augments B + rate: sampling rate factor + + returns: (..., c+s, L) + """ + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) sampling rate rate + # If either are not passed in, assume we're not asked to change the scale of our kernel + if L is None: + L = self.L + + dt = torch.exp(self.log_dt) # (H) + C = _r2c(self.C) # (C H N) + w = self._w() # (1 N) or (H N) + + + # Incorporate dt into A + dtA = w * dt.unsqueeze(-1) # (H N) + + if self.disc == 'zoh': + # Power up + K = dtA.unsqueeze(-1) * torch.arange(L, device=w.device) # (H N L) + C = C * (torch.exp(dtA)-1.) / w + K = contract('chn, hnl -> chl', C, torch.exp(K)) + K = 2*K.real + elif self.disc == 'bilinear': + dA = (1. + dtA/2) / (1. - dtA/2) + K = dA.unsqueeze(-1) ** torch.arange(L, device=w.device) # (H N L) + C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / w + K = contract('chn, hnl -> chl', C, K) + K = 2*K.real + else: + # Implementation from DSS meant for case when real eigenvalues can be positive + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + w_gt_0 = w.real > 0 # [N] + if w_gt_0.any(): + with torch.no_grad(): + P_max = dtA * (w_gt_0 * (L-1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] + + dtA_neg = dtA * (1 - 2*w_gt_0) # [H N] + # S.sum(-1) == den / num + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] + C = C * num * type(self).reciprocal(den * w) # [C H N] + + K = contract('chn,hnl->chl', C, S).float() + return K + + @staticmethod + def reciprocal(x, epsilon=1e-7, clamp=False): + """ returns 1 / x, with bounded norm """ + x_conj = x.conj() + norm_sq = (x*x_conj).real.clamp(epsilon) if clamp else (x*x_conj + epsilon) + return x_conj / norm_sq + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: + optim["lr"] = lr + if trainable and wd is not None: + optim["weight_decay"] = wd + if len(optim) > 0: + setattr(getattr(self, name), "_optim", optim) + + +class HippoSSKernel(nn.Module): + """Wrapper around SSKernel that generates A, B, C, dt according to HiPPO arguments. + + The SSKernel is expected to support the interface + forward() + default_state() + setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=1, + measure="legs", + rank=1, + channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" + dt_min=0.001, + dt_max=0.1, + trainable=None, # Dictionary of options to train various HiPPO parameters + lr=None, # Hook to set LR of hippo parameters differently + precision=1, # 1 (single) or 2 (double) for the kernel + mode='nplr', + verbose=False, + **kernel_args, + ): + super().__init__() + self.N = N + self.H = H + L = L or 1 + self.precision = precision + dtype = torch.double if self.precision == 2 else torch.float + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + self.channels = channels + + # Generate dt + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + w, p, B, _ = nplr(measure, self.N, rank, dtype=dtype) + C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + if mode == "nplr": + self.kernel = SSKernelNPLR( + L, w, p, B, C, + log_dt, + trainable=trainable, + lr=lr, + verbose=verbose, + **kernel_args, + ) + elif mode == "diag": + C = C * B + self.kernel = SSKernelDiag( + w, C, log_dt, L=L, + trainable=trainable, + lr=lr, + **kernel_args, + ) + + def forward(self, L=None): + k = self.kernel(L=L) + return k.float() + + def step(self, u, state, **kwargs): + u, state = self.kernel.step(u, state, **kwargs) + return u.float(), state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) + +class S4(nn.Module): + + def __init__( + self, + d_model, + d_state=64, + l_max=1, # Maximum length of sequence. Fine if not provided: the kernel will keep doubling in length until longer than sequence. However, this can be marginally slower if the true length is not a power of 2 + channels=1, # maps 1-dim to C-dim + bidirectional=False, + # Arguments for FF + activation='gelu', # activation in between SS and FF + postact=None, # activation after FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + hyper_act=None, # Use a "hypernetwork" multiplication + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + verbose=False, + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum sequence length, also denoted by L + if this is not known at model creation, set l_max=1 + channels: can be interpreted as a number of "heads" + bidirectional: bidirectional + dropout: standard dropout argument + transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension] + + Other options are all experimental and should not need to be configured + """ + + super().__init__() + if verbose: + import src.utils.train + log = src.utils.train.get_logger(__name__) + log.info(f"Constructing S4 (H, N, L) = ({d_model}, {d_state}, {l_max})") + + self.h = d_model + self.n = d_state + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + + # optional multiplicative modulation GLU-style + # https://arxiv.org/abs/2002.05202 + self.hyper = hyper_act is not None + if self.hyper: + channels *= 2 + self.hyper_activation = Activation(hyper_act) + + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if self.bidirectional: + channels *= 2 + + + # SSM Kernel + self.kernel = HippoSSKernel(self.h, N=self.n, L=l_max, channels=channels, verbose=verbose, **kernel_args) + + # Pointwise + self.activation = Activation(activation) + dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + self.output_linear = LinearActivation( + self.h*self.channels, + self.h, + transposed=self.transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + + + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + + # Compute SS Kernel + k = self.kernel(L=L) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) \ + + k_f = torch.fft.rfft(k, n=2*L) # (C H L) + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) + y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) + + + # Compute D term in state space equation - essentially a skip connection + y = y + contract('bhl,ch->bchl', u, self.D) # u.unsqueeze(-3) * self.D.unsqueeze(-1) + + # Optional hyper-network multiplication + if self.hyper: + y, yh = rearrange(y, 'b (s c) h l -> s b c h l', s=2) + y = self.hyper_activation(yh) * y + + # Reshape to flatten channels + y = rearrange(y, '... c h l -> ... (c h) l') + + y = self.dropout(self.activation(y)) + + if not self.transposed: y = y.transpose(-1, -2) + + y = self.output_linear(y) + + return y, None + + @property + def d_output(self): + return self.h + + @property + def state_to_tensor(self): + return lambda state: rearrange('... h n -> ... (h n)', state) + diff --git a/src/utils/registry.py b/src/utils/registry.py index 1d3ffcf..c2d6b7b 100644 --- a/src/utils/registry.py +++ b/src/utils/registry.py @@ -41,6 +41,7 @@ "lssl": "src.models.sequence.ss.lssl.LSSL", "s4": "src.models.sequence.ss.s4.S4", "standalone": "src.models.sequence.ss.standalone.s4.S4", + "simple": "src.models.sequence.ss.standalone.s4_simple.S4", "ff": "src.models.sequence.ff.FF", "rnn": "src.models.sequence.rnns.rnn.RNN", "mha": "src.models.sequence.mha.MultiheadAttention",