Skip to content

Commit

Permalink
Updated blanche experiment files
Browse files Browse the repository at this point in the history
  • Loading branch information
Tenzin Chan committed Apr 18, 2024
1 parent 8583be2 commit 58e19b5
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 33 deletions.
154 changes: 121 additions & 33 deletions scripts/blanche_expt.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,133 @@
import Pkg
Pkg.activate("..")
import DrWatson
DrWatson.@quickactivate
import MEFK, NPZ, Flux, CUDA
include("extract_blanche.jl")

function uniquecount(data)
unique_array = unique(data, dims=1)
counts = Dict(unique_array[i, :] => 0 for i in 1:size(unique_array)[1])
for i in 1:size(data)[1]
counts[data[i, :]] += 1

function model_to_cpu(model, params)
if isa(model, MEFK.MEF3T)
cpumodel = MEFK.MEF3T(model, Array)
#cpumodel = MEFK.MEF3T(model.n,
# model.W1 |> Array,
# model.W2 |> Array,
# model.W3 |> Array,
# model.W2_mask |> Array,
# model.W3_mask |> Array,
# [g |> Array for g in model.gradients],
# Array
# )
elseif isa(model, MEFK.MEF2T)
cpumodel = MEFK.MEF2T(model, Array)
#cpumodel = MEFK.MEF2T(model.n,
# model.W1 |> Array,
# model.W2 |> Array,
# model.W2_mask |> Array,
# [g |> Array for g in model.gradients],
# Array
# )
else
cpumodel = MEFK.MEFMPNK(model, Array)
#cpumodel = MEFK.MEFMPNK(model.n,
# model.K,
# model.W .|> Array,
# [g |> Array for g in model.grad],
# [inds .|> Array for inds in model.indices],
# [winds .|> Array for winds in model.windices],
# Array
# )
end
reduce(hcat, keys(counts))' |> Array, values(counts) |> collect .|> Float32
cpumodel
end

function window(data, win_sz::Int)
n_sample = size(data)[1] - win_sz + 1
out = [data[i:i+win_sz-1, :][:] for i in 1:n_sample]
reduce(hcat, out)' |> Array
end

if abspath(PROGRAM_FILE) == @__FILE__
data = NPZ.npzread("../data/exp_raw/blanche_140000_area18.npz")["spikes_arr"][1, :, :]' .|> Int8
win_sz = 10
array_cast = CUDA.cu
data = window(data, win_sz)
data, counts = uniquecount(data)
for i in 1:size(data)[1]
if sum(data[i, :]) == 0
println(i)
counts[i] = 1000
break
end
end
function train_on_data(data, max_iter, winsz, batchsize, array_cast)
println("windowing")
data, counts = window(data, winsz)
n = size(data)[2]
# model = MEFK.MEFMPNK(n, 2)
model = MEFK.MEF3T(n; array_cast=array_cast)
println(size(data))
# Prep dataloader
println("prep loader")
counts = reshape(counts, (length(counts), 1))
loader = Flux.Data.DataLoader((data', counts'); batchsize=batchsize, partial=true)
model = MEFK.MEF2T(n; array_cast=array_cast)
optim = Flux.setup(Flux.Adam(0.1), model)
max_iter = 100

for i in 1:max_iter
# loss, grads = model(data, counts, i==1)
loss, grads = model(data, counts)
loss = 0
grads = [0, 0, 0]
for (d, c) in loader
d = d' |> Array
l, _ = model(d, c[:])
loss += l
end
println(loss)
grads = MEFK.retrieve_reset_gradients!(model; reset_grad=true)
Flux.update!(optim, model, grads)
end
end
model
end


function recording_split_train(data, num_split::Int)
split_sz = size(data)[1] / num_split
[data[ceilint(i*split_sz)+1:ceilint((i+1)*split_sz), :] for i in 0:num_split-1]
end


function converge_dynamics(model, data)
out_ = MEFK.dynamics(model, data)
out = MEFK.dynamics(model, out_)
while !all(out .== out_)
out_ .= out
out = MEFK.dynamics(model, out)
end
out
end


if abspath(PROGRAM_FILE) == @__FILE__
dir = DrWatson.datadir("exp_raw", "pvc3", "crcns_pvc3_cat_recordings", "spont_activity", "spike_data_area18")
params = Dict("binsz"=>parse(Int, ARGS[1]), "maxiter"=>parse(Int, ARGS[2]))
winszs = [5i for i in 2:10]
max_iter = params["maxiter"]
binsz = params["binsz"]
num_split = parse(Int, ARGS[3])

data = extract_bin_spikes(dir, binsz)
data = trim_recording(data, 0.15)#094)
println(size(data))
data_split = recording_split_train(data, num_split)
println("splits $num_split total $(size(data))")

dev = parse(Int, ARGS[4])
CUDA.device!(dev)
array_cast = CUDA.cu
batchsize = parse(Int, ARGS[5])

#k = 2
#params["k"] = k
#model = MEFK.MEFMPNK(n, k; array_cast=array_cast)
#model = MEFK.MEF3T(n; array_cast=array_cast)
println("starting")
for winsz in winszs
params["winsz"] = winsz
# Only doing matrix for now
save_loc = DrWatson.datadir("exp_pro", "matrix", "split", "models", "$(DrWatson.savename(params)).jld2")
if isfile(save_loc)
println("$save_loc exists, skipping")
continue
end

save_data = Dict()
for i in 1:num_split
model = train_on_data(data_split[i], max_iter, winsz, batchsize, array_cast)
# TODO converge on data_split[i]
input = ordered_window(data_split[i])
output = converge_dynamics(model, input |> array_cast) |> Array
cpumodel = model_to_cpu(model, params)
save_data["$i"] = Dict("input"=>input, "output"=>output, "net"=>cpumodel)
end

DrWatson.wsave(save_loc, save_data)
end
end

91 changes: 91 additions & 0 deletions scripts/calculate_entropy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
include("blanche_expt.jl")
import Plots


function calc_entropy(dist)
prob = dist ./ sum(dist)
sum(-prob .* log2.(prob))
end


function combine_counts(data, counts)
patt_cnt = DataStructures.DefaultDict{Vector{UInt8}, Float32}(0)
for i in 1:size(data)[1]
patt_cnt[data[i, :]] += counts[i]
end
out = reduce(hcat, patt_cnt |> keys |> collect)' |> Array
cnt = patt_cnt |> values |> collect
out, cnt
end


function unique_ind(data)
unique_patterns = unique(data, dims=1)
patt2ids = Dict(unique_patterns[i, :]=>i for i in 1:size(unique_patterns)[1])
inds = [patt2ids[data[i, :]] for i in 1:size(data)[1]]
patt2ids, inds
end


unique_counts(arr) = [count(==(e), arr) for e in unique(arr)]


function raster_patterns(patterns)
_, y = unique_ind(patterns)
num_patterns = size(patterns)[1]
x = [i for i in 1:num_patterns]
x, y
end


function converge_dynamics(model, data)
out_ = MEFK.dynamics(model, data)
out = MEFK.dynamics(model, out_)
while !all(out .== out_)
out_ .= out
out = MEFK.dynamics(model, out)
end
out
end


if abspath(PROGRAM_FILE) == @__FILE__
numwin = 10
numbin = 10
winszs = [5i for i in 2:numwin]
binszs = [500i for i in 1:numbin]
maxiter = 100
num_split = 10
base_dir = ["exp_pro", "matrix", "split", "complete"]
ents = zeros(length(binszs), length(winszs), num_split)
raw_ents = zeros(size(ents))
for (i, binsz) in enumerate(binszs)
for (j, winsz) in enumerate(winszs)
println("win $winsz bin $binsz")
params = Dict("winsz"=>winsz, "binsz"=>binsz, "maxiter"=>maxiter)
save_name = "$(DrWatson.savename(params)).jld2"
data = DrWatson.wload(DrWatson.datadir(base_dir..., save_name))

for k in 1:num_split
input = data["$k"]["input"]
output = data["$k"]["output"]

in_patt, in_ids = unique_ind(input)
out_patt, out_ids = unique_ind(output)
in_count = unique_counts(in_ids)
out_count = unique_counts(out_ids)
raw_ents[i, j, k] = calc_entropy(in_count)
println("raw entropy: $(raw_ents[i, j, k])")
ents[i, j, k] = calc_entropy(out_count)
println("converged entropy: $(ents[i, j, k])")

#x, y = raster_patterns(out)
#y = log10.(y)
#p = Plots.scatter(x, y)
#Plots.savefig(DrWatson.plotsdir("converged_patterns_$(winsz)_order2.png"))
end
end
end
DrWatson.wsave(DrWatson.datadir("entropy_winbin.jld2"), Dict("model_entropy"=>ents, "raw_entropy"=>raw_ents))
end

22 changes: 22 additions & 0 deletions scripts/plot_ent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import DrWatson
DrWatson.@quickactivate
import Plots


function save_plot_hm(hm, loc, min_val, max_val)
hm = sum(hm, dims=3) / size(hm)[3]
hm = hm[:, :, 1]
Plots.heatmap(10:5:50, 500:500:5000, hm, xlabel="window size / number of timesteps", ylabel="bin size / ms", clim=(min_val, max_val))
Plots.savefig(loc)
end


data = DrWatson.wload(DrWatson.datadir("entropy_winbin.jld2"))
raw_ent = data["raw_entropy"]
model_ent = data["model_entropy"]
min_val = min(min(raw_ent...), min(model_ent...))
max_val = max(max(raw_ent...), max(model_ent...))
loc = DrWatson.plotsdir("raw_entropy.pdf")
save_plot_hm(raw_ent, loc, min_val, max_val)
loc = DrWatson.plotsdir("model_entropy.pdf")
save_plot_hm(model_ent, loc, min_val, max_val)

0 comments on commit 58e19b5

Please sign in to comment.