-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Tenzin Chan
committed
Apr 18, 2024
1 parent
8583be2
commit 58e19b5
Showing
3 changed files
with
234 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,133 @@ | ||
import Pkg | ||
Pkg.activate("..") | ||
import DrWatson | ||
DrWatson.@quickactivate | ||
import MEFK, NPZ, Flux, CUDA | ||
include("extract_blanche.jl") | ||
|
||
function uniquecount(data) | ||
unique_array = unique(data, dims=1) | ||
counts = Dict(unique_array[i, :] => 0 for i in 1:size(unique_array)[1]) | ||
for i in 1:size(data)[1] | ||
counts[data[i, :]] += 1 | ||
|
||
function model_to_cpu(model, params) | ||
if isa(model, MEFK.MEF3T) | ||
cpumodel = MEFK.MEF3T(model, Array) | ||
#cpumodel = MEFK.MEF3T(model.n, | ||
# model.W1 |> Array, | ||
# model.W2 |> Array, | ||
# model.W3 |> Array, | ||
# model.W2_mask |> Array, | ||
# model.W3_mask |> Array, | ||
# [g |> Array for g in model.gradients], | ||
# Array | ||
# ) | ||
elseif isa(model, MEFK.MEF2T) | ||
cpumodel = MEFK.MEF2T(model, Array) | ||
#cpumodel = MEFK.MEF2T(model.n, | ||
# model.W1 |> Array, | ||
# model.W2 |> Array, | ||
# model.W2_mask |> Array, | ||
# [g |> Array for g in model.gradients], | ||
# Array | ||
# ) | ||
else | ||
cpumodel = MEFK.MEFMPNK(model, Array) | ||
#cpumodel = MEFK.MEFMPNK(model.n, | ||
# model.K, | ||
# model.W .|> Array, | ||
# [g |> Array for g in model.grad], | ||
# [inds .|> Array for inds in model.indices], | ||
# [winds .|> Array for winds in model.windices], | ||
# Array | ||
# ) | ||
end | ||
reduce(hcat, keys(counts))' |> Array, values(counts) |> collect .|> Float32 | ||
cpumodel | ||
end | ||
|
||
function window(data, win_sz::Int) | ||
n_sample = size(data)[1] - win_sz + 1 | ||
out = [data[i:i+win_sz-1, :][:] for i in 1:n_sample] | ||
reduce(hcat, out)' |> Array | ||
end | ||
|
||
if abspath(PROGRAM_FILE) == @__FILE__ | ||
data = NPZ.npzread("../data/exp_raw/blanche_140000_area18.npz")["spikes_arr"][1, :, :]' .|> Int8 | ||
win_sz = 10 | ||
array_cast = CUDA.cu | ||
data = window(data, win_sz) | ||
data, counts = uniquecount(data) | ||
for i in 1:size(data)[1] | ||
if sum(data[i, :]) == 0 | ||
println(i) | ||
counts[i] = 1000 | ||
break | ||
end | ||
end | ||
function train_on_data(data, max_iter, winsz, batchsize, array_cast) | ||
println("windowing") | ||
data, counts = window(data, winsz) | ||
n = size(data)[2] | ||
# model = MEFK.MEFMPNK(n, 2) | ||
model = MEFK.MEF3T(n; array_cast=array_cast) | ||
println(size(data)) | ||
# Prep dataloader | ||
println("prep loader") | ||
counts = reshape(counts, (length(counts), 1)) | ||
loader = Flux.Data.DataLoader((data', counts'); batchsize=batchsize, partial=true) | ||
model = MEFK.MEF2T(n; array_cast=array_cast) | ||
optim = Flux.setup(Flux.Adam(0.1), model) | ||
max_iter = 100 | ||
|
||
for i in 1:max_iter | ||
# loss, grads = model(data, counts, i==1) | ||
loss, grads = model(data, counts) | ||
loss = 0 | ||
grads = [0, 0, 0] | ||
for (d, c) in loader | ||
d = d' |> Array | ||
l, _ = model(d, c[:]) | ||
loss += l | ||
end | ||
println(loss) | ||
grads = MEFK.retrieve_reset_gradients!(model; reset_grad=true) | ||
Flux.update!(optim, model, grads) | ||
end | ||
end | ||
model | ||
end | ||
|
||
|
||
function recording_split_train(data, num_split::Int) | ||
split_sz = size(data)[1] / num_split | ||
[data[ceilint(i*split_sz)+1:ceilint((i+1)*split_sz), :] for i in 0:num_split-1] | ||
end | ||
|
||
|
||
function converge_dynamics(model, data) | ||
out_ = MEFK.dynamics(model, data) | ||
out = MEFK.dynamics(model, out_) | ||
while !all(out .== out_) | ||
out_ .= out | ||
out = MEFK.dynamics(model, out) | ||
end | ||
out | ||
end | ||
|
||
|
||
if abspath(PROGRAM_FILE) == @__FILE__ | ||
dir = DrWatson.datadir("exp_raw", "pvc3", "crcns_pvc3_cat_recordings", "spont_activity", "spike_data_area18") | ||
params = Dict("binsz"=>parse(Int, ARGS[1]), "maxiter"=>parse(Int, ARGS[2])) | ||
winszs = [5i for i in 2:10] | ||
max_iter = params["maxiter"] | ||
binsz = params["binsz"] | ||
num_split = parse(Int, ARGS[3]) | ||
|
||
data = extract_bin_spikes(dir, binsz) | ||
data = trim_recording(data, 0.15)#094) | ||
println(size(data)) | ||
data_split = recording_split_train(data, num_split) | ||
println("splits $num_split total $(size(data))") | ||
|
||
dev = parse(Int, ARGS[4]) | ||
CUDA.device!(dev) | ||
array_cast = CUDA.cu | ||
batchsize = parse(Int, ARGS[5]) | ||
|
||
#k = 2 | ||
#params["k"] = k | ||
#model = MEFK.MEFMPNK(n, k; array_cast=array_cast) | ||
#model = MEFK.MEF3T(n; array_cast=array_cast) | ||
println("starting") | ||
for winsz in winszs | ||
params["winsz"] = winsz | ||
# Only doing matrix for now | ||
save_loc = DrWatson.datadir("exp_pro", "matrix", "split", "models", "$(DrWatson.savename(params)).jld2") | ||
if isfile(save_loc) | ||
println("$save_loc exists, skipping") | ||
continue | ||
end | ||
|
||
save_data = Dict() | ||
for i in 1:num_split | ||
model = train_on_data(data_split[i], max_iter, winsz, batchsize, array_cast) | ||
# TODO converge on data_split[i] | ||
input = ordered_window(data_split[i]) | ||
output = converge_dynamics(model, input |> array_cast) |> Array | ||
cpumodel = model_to_cpu(model, params) | ||
save_data["$i"] = Dict("input"=>input, "output"=>output, "net"=>cpumodel) | ||
end | ||
|
||
DrWatson.wsave(save_loc, save_data) | ||
end | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
include("blanche_expt.jl") | ||
import Plots | ||
|
||
|
||
function calc_entropy(dist) | ||
prob = dist ./ sum(dist) | ||
sum(-prob .* log2.(prob)) | ||
end | ||
|
||
|
||
function combine_counts(data, counts) | ||
patt_cnt = DataStructures.DefaultDict{Vector{UInt8}, Float32}(0) | ||
for i in 1:size(data)[1] | ||
patt_cnt[data[i, :]] += counts[i] | ||
end | ||
out = reduce(hcat, patt_cnt |> keys |> collect)' |> Array | ||
cnt = patt_cnt |> values |> collect | ||
out, cnt | ||
end | ||
|
||
|
||
function unique_ind(data) | ||
unique_patterns = unique(data, dims=1) | ||
patt2ids = Dict(unique_patterns[i, :]=>i for i in 1:size(unique_patterns)[1]) | ||
inds = [patt2ids[data[i, :]] for i in 1:size(data)[1]] | ||
patt2ids, inds | ||
end | ||
|
||
|
||
unique_counts(arr) = [count(==(e), arr) for e in unique(arr)] | ||
|
||
|
||
function raster_patterns(patterns) | ||
_, y = unique_ind(patterns) | ||
num_patterns = size(patterns)[1] | ||
x = [i for i in 1:num_patterns] | ||
x, y | ||
end | ||
|
||
|
||
function converge_dynamics(model, data) | ||
out_ = MEFK.dynamics(model, data) | ||
out = MEFK.dynamics(model, out_) | ||
while !all(out .== out_) | ||
out_ .= out | ||
out = MEFK.dynamics(model, out) | ||
end | ||
out | ||
end | ||
|
||
|
||
if abspath(PROGRAM_FILE) == @__FILE__ | ||
numwin = 10 | ||
numbin = 10 | ||
winszs = [5i for i in 2:numwin] | ||
binszs = [500i for i in 1:numbin] | ||
maxiter = 100 | ||
num_split = 10 | ||
base_dir = ["exp_pro", "matrix", "split", "complete"] | ||
ents = zeros(length(binszs), length(winszs), num_split) | ||
raw_ents = zeros(size(ents)) | ||
for (i, binsz) in enumerate(binszs) | ||
for (j, winsz) in enumerate(winszs) | ||
println("win $winsz bin $binsz") | ||
params = Dict("winsz"=>winsz, "binsz"=>binsz, "maxiter"=>maxiter) | ||
save_name = "$(DrWatson.savename(params)).jld2" | ||
data = DrWatson.wload(DrWatson.datadir(base_dir..., save_name)) | ||
|
||
for k in 1:num_split | ||
input = data["$k"]["input"] | ||
output = data["$k"]["output"] | ||
|
||
in_patt, in_ids = unique_ind(input) | ||
out_patt, out_ids = unique_ind(output) | ||
in_count = unique_counts(in_ids) | ||
out_count = unique_counts(out_ids) | ||
raw_ents[i, j, k] = calc_entropy(in_count) | ||
println("raw entropy: $(raw_ents[i, j, k])") | ||
ents[i, j, k] = calc_entropy(out_count) | ||
println("converged entropy: $(ents[i, j, k])") | ||
|
||
#x, y = raster_patterns(out) | ||
#y = log10.(y) | ||
#p = Plots.scatter(x, y) | ||
#Plots.savefig(DrWatson.plotsdir("converged_patterns_$(winsz)_order2.png")) | ||
end | ||
end | ||
end | ||
DrWatson.wsave(DrWatson.datadir("entropy_winbin.jld2"), Dict("model_entropy"=>ents, "raw_entropy"=>raw_ents)) | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import DrWatson | ||
DrWatson.@quickactivate | ||
import Plots | ||
|
||
|
||
function save_plot_hm(hm, loc, min_val, max_val) | ||
hm = sum(hm, dims=3) / size(hm)[3] | ||
hm = hm[:, :, 1] | ||
Plots.heatmap(10:5:50, 500:500:5000, hm, xlabel="window size / number of timesteps", ylabel="bin size / ms", clim=(min_val, max_val)) | ||
Plots.savefig(loc) | ||
end | ||
|
||
|
||
data = DrWatson.wload(DrWatson.datadir("entropy_winbin.jld2")) | ||
raw_ent = data["raw_entropy"] | ||
model_ent = data["model_entropy"] | ||
min_val = min(min(raw_ent...), min(model_ent...)) | ||
max_val = max(max(raw_ent...), max(model_ent...)) | ||
loc = DrWatson.plotsdir("raw_entropy.pdf") | ||
save_plot_hm(raw_ent, loc, min_val, max_val) | ||
loc = DrWatson.plotsdir("model_entropy.pdf") | ||
save_plot_hm(model_ent, loc, min_val, max_val) |