Skip to content

Commit

Permalink
leabra updated to new looper and looper gui.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Nov 5, 2024
1 parent fafdc0c commit 0b1015d
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 212 deletions.
78 changes: 37 additions & 41 deletions examples/deep_fsa/deep_fsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"cogentcore.org/core/base/mpi"
"cogentcore.org/core/base/randx"
"cogentcore.org/core/core"
"cogentcore.org/core/enums"
"cogentcore.org/core/icons"
"cogentcore.org/core/math32/vecint"
"cogentcore.org/core/tensor/table"
Expand Down Expand Up @@ -242,7 +243,7 @@ type Sim struct {
Params emer.NetParams `display:"add-fields"`

// contains looper control loops for running sim
Loops *looper.Manager `new-window:"+" display:"no-inline"`
Loops *looper.Stacks `new-window:"+" display:"no-inline"`

// contains computed statistic values
Stats estats.Stats `new-window:"+"`
Expand Down Expand Up @@ -405,35 +406,37 @@ func (ss *Sim) InitRandSeed(run int) {

// ConfigLoops configures the control loops: Training, Testing
func (ss *Sim) ConfigLoops() {
man := looper.NewManager()
ls := looper.NewStacks()

trls := ss.Config.Run.NTrials

man.AddStack(etime.Train).
ls.AddStack(etime.Train).
AddTime(etime.Run, ss.Config.Run.NRuns).
AddTime(etime.Epoch, ss.Config.Run.NEpochs).
AddTime(etime.Trial, trls).
AddTime(etime.Cycle, 100)

man.AddStack(etime.Test).
ls.AddStack(etime.Test).
AddTime(etime.Epoch, 1).
AddTime(etime.Trial, trls).
AddTime(etime.Cycle, 100)

leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing
leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code
leabra.LooperStdPhases(ls, &ss.Context, ss.Net, 75, 99) // plus phase timing
leabra.LooperSimCycleAndLearn(ls, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code

for m, _ := range man.Stacks {
stack := man.Stacks[m]
ls.Stacks[etime.Train].OnInit.Add("Init", func() { ss.Init() })

for m, _ := range ls.Stacks {
stack := ls.Stacks[m]
stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() {
ss.ApplyInputs()
})
}

man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun)
ls.Loop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun)

// Train stop early condition
man.GetLoop(etime.Train, etime.Epoch).IsDone["NZeroStop"] = func() bool {
ls.Loop(etime.Train, etime.Epoch).IsDone.AddBool("NZeroStop", func() bool {
// This is calculated in TrialStats
stopNz := ss.Config.Run.NZero
if stopNz <= 0 {
Expand All @@ -442,10 +445,10 @@ func (ss *Sim) ConfigLoops() {
curNZero := ss.Stats.Int("NZero")
stop := curNZero >= stopNz
return stop
}
})

// Add Testing
trainEpoch := man.GetLoop(etime.Train, etime.Epoch)
trainEpoch := ls.Loop(etime.Train, etime.Epoch)
trainEpoch.OnStart.Add("TestAtInterval", func() {
if (ss.Config.Run.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.Run.TestInterval == 0) {
// Note the +1 so that it doesn't occur at the 0th timestep.
Expand All @@ -456,33 +459,35 @@ func (ss *Sim) ConfigLoops() {
/////////////////////////////////////////////
// Logging

man.GetLoop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() {
ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() {
leabra.LogTestErrors(&ss.Logs)
})
man.GetLoop(etime.Train, etime.Epoch).OnEnd.Add("PCAStats", func() {
trnEpc := man.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur
ls.Loop(etime.Train, etime.Epoch).OnEnd.Add("PCAStats", func() {
trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur
if ss.Config.Run.PCAInterval > 0 && trnEpc%ss.Config.Run.PCAInterval == 0 {
leabra.PCAStats(ss.Net, &ss.Logs, &ss.Stats)
ss.Logs.ResetLog(etime.Analyze, etime.Trial)
}
})

man.AddOnEndToAll("Log", ss.Log)
leabra.LooperResetLogBelow(man, &ss.Logs)
ls.AddOnEndToAll("Log", func(mode, time enums.Enum) {
ss.Log(mode.(etime.Modes), time.(etime.Times))
})
leabra.LooperResetLogBelow(ls, &ss.Logs)

man.GetLoop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() {
trnEpc := man.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur
ls.Loop(etime.Train, etime.Trial).OnEnd.Add("LogAnalyze", func() {
trnEpc := ls.Stacks[etime.Train].Loops[etime.Epoch].Counter.Cur
if (ss.Config.Run.PCAInterval > 0) && (trnEpc%ss.Config.Run.PCAInterval == 0) {
ss.Log(etime.Analyze, etime.Trial)
}
})

man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() {
ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunStats", func() {
ss.Logs.RunStats("PctCor", "FirstZero", "LastZero")
})

// Save weights to file, to look at later
man.GetLoop(etime.Train, etime.Run).OnEnd.Add("SaveWeights", func() {
ls.Loop(etime.Train, etime.Run).OnEnd.Add("SaveWeights", func() {
ctrString := ss.Stats.PrintValues([]string{"Run", "Epoch"}, []string{"%03d", "%05d"}, "_")
leabra.SaveWeightsIfConfigSet(ss.Net, ss.Config.Log.SaveWeights, ctrString, ss.Stats.String("RunName"))
})
Expand All @@ -492,19 +497,21 @@ func (ss *Sim) ConfigLoops() {

if !ss.Config.GUI {
if ss.Config.Log.NetData {
man.GetLoop(etime.Test, etime.Trial).Main.Add("NetDataRecord", func() {
ls.Loop(etime.Test, etime.Trial).OnEnd.Add("NetDataRecord", func() {
ss.GUI.NetDataRecord(ss.ViewUpdate.Text)
})
}
} else {
leabra.LooperUpdateNetView(man, &ss.ViewUpdate, ss.Net, ss.NetViewCounters)
leabra.LooperUpdatePlots(man, &ss.GUI)
leabra.LooperUpdateNetView(ls, &ss.ViewUpdate, ss.Net, ss.NetViewCounters)
leabra.LooperUpdatePlots(ls, &ss.GUI)
ls.Stacks[etime.Train].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() })
ls.Stacks[etime.Test].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() })
}

if ss.Config.Debug {
mpi.Println(man.DocString())
mpi.Println(ls.DocString())
}
ss.Loops = man
ss.Loops = ls
}

// ApplyInputs applies input patterns from given environment.
Expand Down Expand Up @@ -542,7 +549,7 @@ func (ss *Sim) ApplyInputs() {
// for the new run value
func (ss *Sim) NewRun() {
ctx := &ss.Context
ss.InitRandSeed(ss.Loops.GetLoop(etime.Train, etime.Run).Counter.Cur)
ss.InitRandSeed(ss.Loops.Loop(etime.Train, etime.Run).Counter.Cur)
ss.Envs.ByMode(etime.Train).Init(0)
ss.Envs.ByMode(etime.Test).Init(0)
ctx.Reset()
Expand Down Expand Up @@ -691,8 +698,7 @@ func (ss *Sim) Log(mode etime.Modes, time etime.Times) {
ss.Logs.LogRow(mode, time, row) // also logs to file, etc
}

////////////////////////////////////////////////////////////////////////////////////////////
// Gui
//////// GUI

// ConfigGUI configures the Cogent Core GUI interface for this simulation.
func (ss *Sim) ConfigGUI() {
Expand All @@ -715,18 +721,8 @@ func (ss *Sim) ConfigGUI() {
}

func (ss *Sim) MakeToolbar(p *tree.Plan) {
ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Init", Icon: icons.Update,
Tooltip: "Initialize everything including network weights, and start over. Also applies current params.",
Active: egui.ActiveStopped,
Func: func() {
ss.Init()
ss.GUI.UpdateWindow()
},
})
ss.GUI.AddLooperCtrl(p, ss.Loops)

ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test})

////////////////////////////////////////////////
tree.Add(p, func(w *core.Separator) {})
ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog",
Icon: icons.Reset,
Expand Down Expand Up @@ -789,7 +785,7 @@ func (ss *Sim) RunNoGUI() {
ss.Init()

mpi.Printf("Running %d Runs starting at %d\n", ss.Config.Run.NRuns, ss.Config.Run.Run)
ss.Loops.GetLoop(etime.Train, etime.Run).Counter.SetCurMaxPlusN(ss.Config.Run.Run, ss.Config.Run.NRuns)
ss.Loops.Loop(etime.Train, etime.Run).Counter.SetCurMaxPlusN(ss.Config.Run.Run, ss.Config.Run.NRuns)

if ss.Config.Run.StartWts != "" { // this is just for testing -- not usually needed
ss.Loops.Step(etime.Train, 1, etime.Trial) // get past NewRun
Expand Down
76 changes: 33 additions & 43 deletions examples/hip/hip.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"cogentcore.org/core/base/errors"
"cogentcore.org/core/base/randx"
"cogentcore.org/core/core"
"cogentcore.org/core/enums"
"cogentcore.org/core/icons"
"cogentcore.org/core/plot/plotcore"
"cogentcore.org/core/tensor/stats/split"
Expand Down Expand Up @@ -195,7 +196,7 @@ type Sim struct {
Params emer.NetParams `display:"add-fields"`

// contains looper control loops for running sim
Loops *looper.Manager `new-window:"+" display:"no-inline"`
Loops *looper.Stacks `new-window:"+" display:"no-inline"`

// contains computed statistic values
Stats estats.Stats `new-window:"+"`
Expand Down Expand Up @@ -395,7 +396,7 @@ func (ss *Sim) Init() {
}

func (ss *Sim) TestInit() {
ss.Loops.ResetCountersByMode(etime.Test)
ss.Loops.InitMode(etime.Test)
tst := ss.Envs.ByMode(etime.Test).(*env.FixedTable)
tst.Init(0)
}
Expand All @@ -410,29 +411,32 @@ func (ss *Sim) InitRandSeed(run int) {

// ConfigLoops configures the control loops: Training, Testing
func (ss *Sim) ConfigLoops() {
man := looper.NewManager()
ls := looper.NewStacks()

trls := ss.TrainAB.Rows
ttrls := ss.TestAll.Rows

man.AddStack(etime.Train).AddTime(etime.Run, ss.Config.NRuns).AddTime(etime.Epoch, ss.Config.NEpochs).AddTime(etime.Trial, trls).AddTime(etime.Cycle, 100)
ls.AddStack(etime.Train).AddTime(etime.Run, ss.Config.NRuns).AddTime(etime.Epoch, ss.Config.NEpochs).AddTime(etime.Trial, trls).AddTime(etime.Cycle, 100)

man.AddStack(etime.Test).AddTime(etime.Epoch, 1).AddTime(etime.Trial, ttrls).AddTime(etime.Cycle, 100)
ls.AddStack(etime.Test).AddTime(etime.Epoch, 1).AddTime(etime.Trial, ttrls).AddTime(etime.Cycle, 100)

leabra.LooperStdPhases(man, &ss.Context, ss.Net, 75, 99) // plus phase timing
leabra.LooperSimCycleAndLearn(man, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code
ss.Net.ConfigLoopsHip(&ss.Context, man)
leabra.LooperStdPhases(ls, &ss.Context, ss.Net, 75, 99) // plus phase timing
leabra.LooperSimCycleAndLearn(ls, ss.Net, &ss.Context, &ss.ViewUpdate) // std algo code
ss.Net.ConfigLoopsHip(&ss.Context, ls)

for m, _ := range man.Stacks {
stack := man.Stacks[m]
ls.Stacks[etime.Train].OnInit.Add("Init", func() { ss.Init() })
ls.Stacks[etime.Test].OnInit.Add("Init", func() { ss.TestInit() })

for m, _ := range ls.Stacks {
stack := ls.Stacks[m]
stack.Loops[etime.Trial].OnStart.Add("ApplyInputs", func() {
ss.ApplyInputs()
})
}

man.GetLoop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun)
ls.Loop(etime.Train, etime.Run).OnStart.Add("NewRun", ss.NewRun)

man.GetLoop(etime.Train, etime.Run).OnEnd.Add("RunDone", func() {
ls.Loop(etime.Train, etime.Run).OnEnd.Add("RunDone", func() {
if ss.Stats.Int("Run") >= ss.Config.NRuns-1 {
ss.RunStats()
expt := ss.Stats.Int("Expt")
Expand All @@ -441,7 +445,7 @@ func (ss *Sim) ConfigLoops() {
})

// Add Testing
trainEpoch := man.GetLoop(etime.Train, etime.Epoch)
trainEpoch := ls.Loop(etime.Train, etime.Epoch)
trainEpoch.OnEnd.Add("TestAtInterval", func() {
if (ss.Config.TestInterval > 0) && ((trainEpoch.Counter.Cur+1)%ss.Config.TestInterval == 0) {
// Note the +1 so that it doesn't occur at the 0th timestep.
Expand All @@ -461,28 +465,33 @@ func (ss *Sim) ConfigLoops() {
})

// early stop
man.GetLoop(etime.Train, etime.Epoch).IsDone["ACMemStop"] = func() bool {
ls.Loop(etime.Train, etime.Epoch).IsDone.AddBool("ACMemStop", func() bool {
// This is calculated in TrialStats
tstEpcLog := ss.Logs.Tables[etime.Scope(etime.Test, etime.Epoch)]
acMem := float32(tstEpcLog.Table.Float("ACMem", ss.Stats.Int("Epoch")))
stop := acMem >= ss.Config.StopMem
return stop
}
})

/////////////////////////////////////////////
// Logging

man.GetLoop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() {
ls.Loop(etime.Test, etime.Epoch).OnEnd.Add("LogTestErrors", func() {
leabra.LogTestErrors(&ss.Logs)
})

man.AddOnEndToAll("Log", ss.Log)
leabra.LooperResetLogBelow(man, &ss.Logs)
ls.AddOnEndToAll("Log", func(mode, time enums.Enum) {
ss.Log(mode.(etime.Modes), time.(etime.Times))
})
leabra.LooperResetLogBelow(ls, &ss.Logs)

leabra.LooperUpdateNetView(man, &ss.ViewUpdate, ss.Net, ss.NetViewCounters)
leabra.LooperUpdatePlots(man, &ss.GUI)
ss.Loops = man
leabra.LooperUpdateNetView(ls, &ss.ViewUpdate, ss.Net, ss.NetViewCounters)
leabra.LooperUpdatePlots(ls, &ss.GUI)

ls.Stacks[etime.Train].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() })
ls.Stacks[etime.Test].OnInit.Add("GUI-Init", func() { ss.GUI.UpdateWindow() })

ss.Loops = ls
// mpi.Println(man.DocString())
}

Expand Down Expand Up @@ -519,7 +528,7 @@ func (ss *Sim) ApplyInputs() {
// for the new run value
func (ss *Sim) NewRun() {
ctx := &ss.Context
ss.InitRandSeed(ss.Loops.GetLoop(etime.Train, etime.Run).Counter.Cur)
ss.InitRandSeed(ss.Loops.Loop(etime.Train, etime.Run).Counter.Cur)
// ss.ConfigPats()
ss.ConfigEnv()
ctx.Reset()
Expand Down Expand Up @@ -681,7 +690,7 @@ func (ss *Sim) NetViewCounters(tm etime.Times) {
// TrialStats computes the trial-level statistics.
// Aggregation is done directly from log data.
func (ss *Sim) TrialStats() {
ss.MemStats(ss.Loops.Mode)
ss.MemStats(ss.Loops.Mode.(etime.Modes))
}

// MemStats computes ActM vs. Target on ECout with binary counts
Expand Down Expand Up @@ -936,27 +945,8 @@ func (ss *Sim) ConfigGUI() {
}

func (ss *Sim) MakeToolbar(p *tree.Plan) {
ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Init", Icon: icons.Update,
Tooltip: "Initialize everything including network weights, and start over. Also applies current params.",
Active: egui.ActiveStopped,
Func: func() {
ss.Init()
ss.GUI.UpdateWindow()
},
})

ss.GUI.AddLooperCtrl(p, ss.Loops, []etime.Modes{etime.Train, etime.Test})
ss.GUI.AddLooperCtrl(p, ss.Loops)

ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Test Init", Icon: icons.Update,
Tooltip: "Initialize the testing process.",
Active: egui.ActiveStopped,
Func: func() {
ss.TestInit()
ss.GUI.UpdateWindow()
},
})

////////////////////////////////////////////////
tree.Add(p, func(w *core.Separator) {})
ss.GUI.AddToolbarItem(p, egui.ToolbarItem{Label: "Reset RunLog",
Icon: icons.Reset,
Expand Down
Loading

0 comments on commit 0b1015d

Please sign in to comment.