diff --git a/docs/neural_modeling/plot_04_v1_cells.py b/docs/neural_modeling/plot_04_v1_cells.py index d09b58f2..1a86a11a 100644 --- a/docs/neural_modeling/plot_04_v1_cells.py +++ b/docs/neural_modeling/plot_04_v1_cells.py @@ -10,21 +10,25 @@ To execute this notebook locally, ensure you download the necessary [utility functions](https://github.com/flatironinstitute/nemos/tree/main/docs/neural_modeling/examples_utils) into the same directory as this notebook. """ - +import jax import matplotlib.pyplot as plt import numpy as np import pynapple as nap from examples_utils import data import nemos as nmo +import jax.numpy as jnp + +# suppress jax to numpy conversion warning in pynapple +nap.nap_config.suppress_conversion_warnings = True -# configure plots some +# plot style configuration plt.style.use("examples_utils/nemos.mplstyle") # %% # ## Data Streaming -# +# Downloading data from a remote server and storing it locally for analysis. path = data.download_data("m691l1.nwb", "https://osf.io/xesdm/download", '../data') @@ -55,23 +59,30 @@ # %% # There are 73 neurons recorded together in V1. To fit the GLM faster, we will focus on one neuron. + print(units) + # this returns TsGroup with one neuron only -spikes = units[[34]] +spikes = units[units.rate >= 5.0] # %% # How could we predict neuron's response to white noise stimulus? # # - we could fit the instantaneous spatial response. that is, just predict -# neuron's response to a given frame of white noise. this will give an x by y +# neuron's response to a given frame of white noise. this will give an x-pixel by y-pixel # filter. implicitly assumes that there's no temporal info: only matters what # we've just seen # # - could fit spatiotemporal filter. instead of an x by y that we use -# independently on each frame, fit (x, y, t) over, say 100 msecs. and then +# independently on each frame, fit (x, y, t) over, say 130 msecs. and then # fit each of these independently (like in head direction example) # -# - that's a lot of parameters! can simplify by assumping that the response is +# - could reduce the dimensionality by using a bank of k two-dimensional basis +# functions of shape (x-pixel, y-pixel, k). We can "project" each (x-pixel, y-pixel) stimulus +# image over the basis by computing the dot product of the two, leaving us with a k-dimensional +# vector, where k can be much smaller than the original pixel size. +# +# - that's still a lot of parameters! can simplify by assuming that the response is # separable: fit a single (x, y) filter and then modulate it over time. this # wouldn't catch e.g., direction-selectivity because it assumes that phase # preference is constant over time @@ -79,11 +90,7 @@ # - could make use of our knowledge of V1 and try to fit a more complex # functional form, e.g., a Gabor. # -# That last one is very non-linear and thus non-convex. we'll do the third one. -# -# in this example, we'll fit the spatial filter outside of the GLM framework, -# using spike-triggered average, and then we'll use the GLM to fit the temporal -# timecourse. +# That last one is very non-linear and thus non-convex. We'll do the third one. # # ## Spike-triggered average # @@ -97,7 +104,7 @@ # V1). Pynapple makes this easy: -sta = nap.compute_event_trigger_average(spikes, stimulus, binsize=0.025, +sta = nap.compute_event_trigger_average(spikes[[20, 34]], stimulus, binsize=0.025, windowsize=(-0.15, 0.0)) # %% # @@ -112,146 +119,211 @@ sta[1, 0] # %% -# we can easily plot this +# We can easily plot this. fig, axes = plt.subplots(1, len(sta), figsize=(3*len(sta),3)) for i, t in enumerate(sta.t): - axes[i].imshow(sta[i,0], vmin = np.min(sta), vmax = np.max(sta), + axes[i].imshow(sta[i, 1], vmin=np.min(sta), vmax=np.max(sta), cmap='Greys_r') axes[i].set_title(str(t)+" s") # %% # -# that looks pretty reasonable for a V1 simple cell: localized in space, -# orientation, and spatial frequency. that is, looks Gabor-ish -# -# To convert this to the spatial filter we'll use for the GLM, let's take the -# average across the bins that look informative: -.125 to -.05 - -# mkdocs_gallery_thumbnail_number = 3 -receptive_field = np.mean(sta.get(-0.125, -0.05), axis=0)[0] - -fig, ax = plt.subplots(1, 1, figsize=(4,4)) -ax.imshow(receptive_field, cmap='Greys_r') - -# %% -# +# That looks pretty reasonable for a V1 simple cell: localized in space, +# orientation, and spatial frequency. That is, looks Gabor-ish. # This receptive field gives us the spatial part of the linear response: it # gives a map of weights that we use for a weighted sum on an image. There are # multiple ways of performing this operation: - -# element-wise multiplication and sum -print((receptive_field * stimulus[0]).sum()) -# dot product of flattened versions -print(np.dot(receptive_field.flatten(), stimulus[0].flatten())) +# +# ## Firing rate model +# What we want is to model the log-firing rate as a linear combination of the past +# stimuli $\bm{x}\_t$ over a fixed window, here $\bm{x}\_t$ is an array representing the +# flattened image of shape `(nm, )`, where n and m are the pixel of the x and y axes +# of the noise stimuli. +# Mathematically, this can be expressed as, +# $$ +# \log \mu\_t = \sum \beta\_{i} \bm{x}\_{t-i} +# $$ +# Where beta is a vector of coefficients, also of shape `(nm, )`. This is quite a lot of coefficients. +# For example, if you want to use a window of 130ms at 10 ms resolution on a 51x51 image, +# you'll end up with 51^2 x 13 = 33813 coefficients. +# We can use a basis set to reduce the dimensionality: first we create a bank of basis with 51x51 of +# elements 15 elements, reducing the problem to 15^2 x 13 = 2925 parameters. + +# define a two-dimensional basis as a product of two "RaisedCosineBasisLinear" basis. +n_bas = 15 +basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_bas) ** 2 + +# evaluate the basis on a (51, 51) grid of points +X, Y, basis_eval = basis.evaluate_on_grid(51, 51) + +print(basis_eval.shape) + + + +# plot the basis set +fig, axs = plt.subplots(n_bas, n_bas, figsize=(10, 8)) +for i in range(n_bas): + for j in range(n_bas): + axs[i, j].contourf(X, Y, basis_eval[..., i*n_bas + j]) + axs[i, j].set_xticks([]) + axs[i, j].set_yticks([]) + +# % +# Now we can project the stimulus onto the bases. + +# project stimulus into the basis +projected_stim = nap.TsdFrame( + t=stimulus.t, + d=jnp.einsum("tnm, nmk -> tk", stimulus.d[:], basis_eval), # read the HDF5 (needed for jax to work) + time_support=stimulus.time_support +) # %% +# And additionally, we could jointly model the all-to-all functional connectivity of the V1 population. # -# When performing this operation on multiple stimuli, things become slightly -# more complicated. For loops on the above methods would work, but would be -# slow. Reshaping and using the dot product is one common method, as are -# methods like `np.tensordot`. -# -# We'll use einsum to do this, which is a convenient way of representing many -# different matrix operations: +# !!! note +# See the [head direction](#plot_02_head_direction.py) tutorial for a detailed +# overview on how to infer the functional connectivity with a GLM. -filtered_stimulus = np.einsum('t h w, h w -> t', stimulus, receptive_field) +# Define a basis for the coupling filters +basis_coupling = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=20) # %% -# -# This notation says: take these arrays with dimensions `(t,h,w)` and `(h,w)` -# and multiply and sum to get an array of shape `(t,)`. This performs the same -# operations as above. -# -# And this remains a pynapple object, so we can easily visualize it! +# Since the number of parameters and samples is still quite large, we could try a stochastic +# optimization approach, in which we update the parameters using a few random samples from +# the time series of predictors and counts at each iteration. Let's define the parameters +# that we are going to use to sample the features and spike counts. -fig, ax = plt.subplots(1, 1, figsize=(12,4)) -ax.plot(filtered_stimulus) +# the sampling rate for counts and features +bin_size = 0.01 # 10 ms -# %% -# -# But what is this? It's how much each frame in the video should drive our -# neuron, based on the receptive field we fit using the spike-triggered -# average. -# -# This, then, is the spatial component of our input, as described above. -# -# ## Preparing data for NeMoS -# -# We'll now use the GLM to fit the temporal component. To do that, let's get -# this and our spike counts into the proper format for NeMoS: +# the window size used for prediction (130 ms of video) +prediction_window = 0.13 # duration of the window in sec -# grab spikes from when we were showing our stimulus, and bin at 1 msec -# resolution -bin_size = .001 -counts = spikes[34].restrict(filtered_stimulus.time_support).count(bin_size) -print(counts.rate) -print(filtered_stimulus.rate) +# number of past frames used for predicting the current firing rate +lags = int(np.ceil(prediction_window / bin_size)) -# %% -# -# Hold on, our stimulus is at a much lower rate than what we want for our rates -# -- in previous neural_modeling, our input has been at a higher rate than our spikes, -# and so we used `bin_average` to down-sample to the appropriate rate. When the -# input is at a lower rate, we need to think a little more carefully about how -# to up-sample. +# the duration of the data chunk that we will use at each iteration +batch_size = 10 # seconds -print(counts[:5]) -print(filtered_stimulus[:5]) -# %% -# -# What was the visual input to the neuron at time 0.005? It was the same input -# as time 0. At time 0.0015? Same thing, up until we pass time 0.025017. Thus, -# we want to "fill forward" the values of our input, and we have pynapple -# convenience function to do so: -filtered_stimulus = data.fill_forward(counts, filtered_stimulus) -filtered_stimulus +# define a function that returns the chunk of data from "time" to "time + batch_size" +def batcher(time: float): + # get the stimulus in a 10 sec interval plus a window + ep = nap.IntervalSet(start=time, end=time + batch_size + prediction_window) -# %% -# -# We can see that the time points are now aligned, and we've filled forward the -# values the way we'd like. -# -# Now, similar to the [head direction tutorial](../02_head_direction), we'll -# use the log-stretched raised cosine basis to create the predictor for our -# GLM: + # count the spikes of the neuron that we are fitting + y = spikes[34].count(bin_size, ep) -window_size = 100 -basis = nmo.basis.RaisedCosineBasisLog(8, mode="conv", window_size=window_size) + # up-sample the projected stimulus to 0.1sec + x = data.fill_forward(y, projected_stim.restrict(ep)) -convolved_input = basis.compute_features(filtered_stimulus) + # function that shifts tha stimulus of a lag and crops + def roll_and_crop(x, lag): + return jnp.roll(x, lag, axis=0)[:-lags] -# %% -# -# convolved_input has shape (n_time_pts, n_features * n_basis_funcs), because -# n_features is the singleton dimension from filtered_stimulus. -# -# ## Fitting the GLM -# -# Now we're ready to fit the model! Let's do it, same as before: + # vectorize the function over the lags + roll = jax.vmap(roll_and_crop, in_axes=(None, 0), out_axes=1) + + # roll and reshape to get the predictors + features = roll(x.d, -jnp.arange(lags)).reshape(x.shape[0] - lags, -1) + # convolve the counts with the basis to get the coupling features + coupling = basis_coupling.compute_features(spikes.count(bin_size, ep))[lags:] -model = nmo.glm.GLM(regularizer=nmo.regularizer.UnRegularized(solver_name="LBFGS")) -model.fit(convolved_input, counts) + # concatenate the features and return features and counts + return np.hstack((coupling, features)), y[lags:] # %% -# -# We have our coefficients for each of our 8 basis functions, let's combine -# them to get the temporal time course of our input: +# We are now ready to run learn the model parameters. + + +# instantiate two models: one that will estimate the functional connectivity and one that will not. +model_coupled = nmo.glm.GLM( + regularizer=nmo.regularizer.Lasso( + regularizer_strength=0.01, + solver_kwargs={"stepsize": 0.001, "acceleration": False} + ) +) -time, basis_kernels = basis.evaluate_on_grid(window_size) -time *= bin_size * window_size -temp_weights = np.einsum('b, t b -> t', model.coef_, basis_kernels) -plt.plot(time, temp_weights) -plt.xlabel("time[sec]") -plt.ylabel("amplitude") +model_uncoupled = nmo.glm.GLM( + regularizer=nmo.regularizer.Lasso( + regularizer_strength=0.01, + solver_kwargs={"stepsize": 0.001, "acceleration": False} + ) +) + +# initialize the solver +X, Y = batcher(0) + +# initialize params coupled +params_cp, state_cp = model_coupled.initialize_solver(X, Y) + +# initialize uncoupled (remove the column corresponding to coupling parameters) +n_coupling_coef = len(spikes) * basis_coupling.n_basis_funcs +params_uncp, state_uncp = model_uncoupled.initialize_solver(X[:, n_coupling_coef:], Y) # %% -# -# When taken together, the results of the GLM and the spike-triggered average -# give us the linear component of our LNP model: the separable spatio-temporal -# filter. +# Finally, we can run a loop that grabs a chunk of data and updates the model parameters. + +# run the stochastic gradient descent for a few iterations +np.random.seed(123) + +for k in range(500): + if k % 50 == 0: + print(f"iter {k}") + + # select a random time point in the recording + time = np.random.uniform(0, 2400) + # grab a 30sec batch starting from time. + X, Y = batcher(time) + # update the parameters of the coupled model + params_cp, state_cp = model_coupled.update(params_cp, state_cp, X, Y) + + # update the uncoupled model dropping the column of the features that corresponds to the coupling + # filters + params_uncp, state_uncp = model_uncoupled.update(params_uncp, state_uncp, X[:, n_coupling_coef:], Y) + + +# %% +# We can now plot the receptive fields estimated by the models. + +# get the coefficient for the spatiotemporal filters +coeff_coupled = model_coupled.coef_[n_coupling_coef:] +coeff_uncoupled = model_uncoupled.coef_ + +# weight the basis by the coefficients to get the estimated receptive fields. +rf_coupled = np.einsum("lk,ijk->lij", coeff_coupled.reshape(lags, -1), basis_eval) +rf_uncoupled = np.einsum("lk,ijk->lij", coeff_uncoupled.reshape(lags, -1), basis_eval) + +# compare the receptive fields +mn1, mx1 = rf_uncoupled.min(), rf_uncoupled.max() +mn2, mx2 = rf_coupled.min(), rf_coupled.max() + +fig1, axs1 = plt.subplots(1, lags, figsize=(10, 1.5)) +fig2, axs2 = plt.subplots(1, lags, figsize=(10, 1.5)) +fig1.suptitle("uncoupled model RF") +fig2.suptitle("coupled model RF") +for i in range(lags): + axs1[i].set_title(f"{prediction_window - i * bin_size:.2} s") + axs1[i].imshow(rf_uncoupled[i], vmin=mn1, vmax=mx1, cmap="Greys") + axs1[i].set_xticks([]) + axs1[i].set_yticks([]) + + axs2[i].set_title(f"{prediction_window - i * bin_size:.2} s") + axs2[i].imshow(rf_coupled[i], vmin=mn2, vmax=mx2, cmap="Greys") + axs2[i].set_xticks([]) + axs2[i].set_yticks([]) +fig1.tight_layout() +fig2.tight_layout() + +# %% +# Using this batched approach allows for the estimation of a neuron's receptive field with high temporal resolution, +# even with long recordings and high-dimensional stimuli such as images. Additionally, if the stimulus data is stored +# in large HDF5 or Zarr files, you can leverage pynapple's "lazy-loading" capabilities (details available +# [here](https://pynapple-org.github.io/pynapple/generated/api_guide/tutorial_pynapple_nwb/)), +# to directly read each data chunk from the disk.