Skip to content

Commit

Permalink
updated apply ext and fsa now applying.
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Oct 20, 2024
1 parent 00fbdba commit 3b46a51
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 85 deletions.
79 changes: 60 additions & 19 deletions examples/deep_fsa/deep_fsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package main
//go:generate core generate -add-types

import (
"log"
"os"

"cogentcore.org/core/base/mpi"
Expand Down Expand Up @@ -63,19 +64,19 @@ var ParamSets = params.Sets{
"Layer.Inhib.ActAvg.Fixed": "true", // simpler to have everything fixed, for replicability
"Layer.Act.Init.Decay": "0", // essential to have all layers no decay
}},
{Sel: ".Hidden", Desc: "fix avg act",
{Sel: ".SuperLayer", Desc: "fix avg act",
Params: params.Params{
"Layer.Inhib.ActAvg.Fixed": "true",
}},
{Sel: ".Back", Desc: "top-down back-pathways MUST have lower relative weight scale, otherwise network hallucinates",
{Sel: ".BackPath", Desc: "top-down back-pathways MUST have lower relative weight scale, otherwise network hallucinates",
Params: params.Params{
"Path.WtScale.Rel": "0.2",
}},
{Sel: "TRCLayer", Desc: "standard weight is .3 here for larger distributed reps. no learn",
{Sel: ".PulvinarLayer", Desc: "standard weight is .3 here for larger distributed reps. no learn",
Params: params.Params{
"Layer.TRC.DriveScale": "0.8", // using .8 for localist layer
"Layer.Pulvinar.DriveScale": "0.8", // using .8 for localist layer
}},
{Sel: "CTCtxtPath", Desc: "no weight balance on CT context paths -- makes a diff!",
{Sel: ".CTCtxtPath", Desc: "no weight balance on CT context paths -- makes a diff!",
Params: params.Params{
"Path.Learn.WtBal.On": "false", // this should be true for larger DeepLeabra models -- e.g., sg..
}},
Expand Down Expand Up @@ -152,7 +153,7 @@ type RunConfig struct {
NZero int `default:"2"`

// total number of trials per epoch. Should be an even multiple of NData.
NTrials int `default:"32"`
NTrials int `default:"100"`

// how often to run through all the test patterns, in terms of training epochs.
// can use 0 or -1 for no testing.
Expand Down Expand Up @@ -207,6 +208,12 @@ type Config struct {
// log debugging information
Debug bool

// InputNames are names of input letters
InputNames []string

// InputNameMap has indexes of InputNames
InputNameMap map[string]int

// parameter related configuration options
Params ParamConfig `display:"add-fields"`

Expand Down Expand Up @@ -266,6 +273,7 @@ type Sim struct {
// New creates new blank elements and initializes defaults
func (ss *Sim) New() {
econfig.Config(&ss.Config, "config.toml")
ss.Config.InputNames = []string{"B", "T", "S", "X", "V", "P", "E"}
ss.Net = leabra.NewNetwork("RA25")
ss.Params.Config(ParamSets, ss.Config.Params.Sheet, ss.Config.Params.Tag, ss.Net)
ss.Stats.Init()
Expand Down Expand Up @@ -302,6 +310,13 @@ func (ss *Sim) ConfigEnv() {
tst = ss.Envs.ByMode(etime.Test).(*FSAEnv)
}

if ss.Config.InputNameMap == nil {
ss.Config.InputNameMap = make(map[string]int, len(ss.Config.InputNames))
for i, nm := range ss.Config.InputNames {
ss.Config.InputNameMap[nm] = i
}
}

// note: names must be standard here!
trn.Name = etime.Train.String()
trn.Seq.Max = 25 // 25 sequences per epoch training
Expand Down Expand Up @@ -500,17 +515,27 @@ func (ss *Sim) ConfigLoops() {
func (ss *Sim) ApplyInputs() {
ctx := &ss.Context
net := ss.Net
net.InitExt()

ev := ss.Envs.ByMode(ctx.Mode).(*FSAEnv)
ev.Step()
lays := net.LayersByType(leabra.InputLayer, leabra.TargetLayer)
net.InitExt()
ss.Stats.SetString("TrialName", ev.String())
for _, lnm := range lays {
ly := ss.Net.LayerByName(lnm)
pats := ev.State(ly.Name)
if pats != nil {
ly.ApplyExt(pats)

in := ss.Net.LayerByName("Input")
trg := ss.Net.LayerByName("Targets")
clrmsk, setmsk, _ := in.ApplyExtFlags()
ns := ev.NNext.Values[0]
for i := 0; i < ns; i++ {
lbl := ev.NextLabels.Values[i]
li, ok := ss.Config.InputNameMap[lbl]
if !ok {
log.Printf("Input label: %v not found in InputNames list of labels\n", lbl)
continue
}
if i == 0 {
in.ApplyExtValue(li, 1, clrmsk, setmsk, false)
}
trg.ApplyExtValue(li, 1, clrmsk, setmsk, false)
}
}

Expand Down Expand Up @@ -577,13 +602,29 @@ func (ss *Sim) NetViewCounters(tm etime.Times) {
// TrialStats computes the trial-level statistics.
// Aggregation is done directly from log data.
func (ss *Sim) TrialStats() {
out := ss.Net.LayerByName("Output")

ss.Stats.SetFloat("CorSim", float64(out.CosDiff.Cos))

sse, avgsse := out.MSE(0.5) // 0.5 = per-unit tolerance -- right side of .5
inp := ss.Net.LayerByName("HiddenP")
trg := ss.Net.LayerByName("Targets")
ss.Stats.SetFloat("CorSim", float64(inp.CosDiff.Cos))
sse := 0.0
gotOne := false
for ni := range inp.Neurons {
inn := &inp.Neurons[ni]
tgn := &trg.Neurons[ni]
if tgn.Act > 0.5 {
if inn.ActM > 0.4 {
gotOne = true
}
} else {
if inn.ActM > 0.5 {
sse += float64(inn.ActM)
}
}
}
if !gotOne {
sse += 1
}
ss.Stats.SetFloat("SSE", sse)
ss.Stats.SetFloat("AvgSSE", avgsse)
ss.Stats.SetFloat("AvgSSE", sse)
if sse > 0 {
ss.Stats.SetFloat("TrlErr", 1)
} else {
Expand Down
89 changes: 23 additions & 66 deletions leabra/layer.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,23 @@ func (ly *Layer) ApplyExt(ext tensor.Tensor) {
}
}

// ApplyExtVal applies given external value to given neuron
// using clearMask, setMask, and toTarg from ApplyExtFlags.
// Also saves Val in Exts for potential use by GPU.
func (ly *Layer) ApplyExtValue(lni int, val float32, clear, set []enums.BitFlag, toTarg bool) {
nrn := &ly.Neurons[lni]
if nrn.IsOff() {
return
}
if toTarg {
nrn.Targ = val
} else {
nrn.Ext = val
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
}

// ApplyExt2D applies 2D tensor external input
func (ly *Layer) ApplyExt2D(ext tensor.Tensor) {
clear, set, toTarg := ly.ApplyExtFlags()
Expand All @@ -145,17 +162,7 @@ func (ly *Layer) ApplyExt2D(ext tensor.Tensor) {
idx := []int{y, x}
vl := float32(ext.Float(idx))
i := ly.Shape.Offset(idx)
nrn := &ly.Neurons[i]
if nrn.IsOff() {
continue
}
if toTarg {
nrn.Targ = vl
} else {
nrn.Ext = vl
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
ly.ApplyExtValue(i, vl, clear, set, toTarg)
}
}
}
Expand All @@ -172,17 +179,7 @@ func (ly *Layer) ApplyExt2Dto4D(ext tensor.Tensor) {
idx := []int{y, x}
vl := float32(ext.Float(idx))
ui := tensor.Projection2DIndex(&ly.Shape, false, y, x)
nrn := &ly.Neurons[ui]
if nrn.IsOff() {
continue
}
if toTarg {
nrn.Targ = vl
} else {
nrn.Ext = vl
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
ly.ApplyExtValue(ui, vl, clear, set, toTarg)
}
}
}
Expand All @@ -201,17 +198,7 @@ func (ly *Layer) ApplyExt4D(ext tensor.Tensor) {
idx := []int{yp, xp, yn, xn}
vl := float32(ext.Float(idx))
i := ly.Shape.Offset(idx)
nrn := &ly.Neurons[i]
if nrn.IsOff() {
continue
}
if toTarg {
nrn.Targ = vl
} else {
nrn.Ext = vl
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
ly.ApplyExtValue(i, vl, clear, set, toTarg)
}
}
}
Expand All @@ -225,18 +212,8 @@ func (ly *Layer) ApplyExt1DTsr(ext tensor.Tensor) {
clear, set, toTarg := ly.ApplyExtFlags()
mx := min(ext.Len(), len(ly.Neurons))
for i := 0; i < mx; i++ {
nrn := &ly.Neurons[i]
if nrn.IsOff() {
continue
}
vl := float32(ext.Float1D(i))
if toTarg {
nrn.Targ = vl
} else {
nrn.Ext = vl
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
ly.ApplyExtValue(i, vl, clear, set, toTarg)
}
}

Expand All @@ -247,18 +224,8 @@ func (ly *Layer) ApplyExt1D(ext []float64) {
clear, set, toTarg := ly.ApplyExtFlags()
mx := min(len(ext), len(ly.Neurons))
for i := 0; i < mx; i++ {
nrn := &ly.Neurons[i]
if nrn.IsOff() {
continue
}
vl := float32(ext[i])
if toTarg {
nrn.Targ = vl
} else {
nrn.Ext = vl
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
ly.ApplyExtValue(i, vl, clear, set, toTarg)
}
}

Expand All @@ -272,18 +239,8 @@ func (ly *Layer) ApplyExt1D32(ext []float32) {
clear, set, toTarg := ly.ApplyExtFlags()
mx := min(len(ext), len(ly.Neurons))
for i := 0; i < mx; i++ {
nrn := &ly.Neurons[i]
if nrn.IsOff() {
continue
}
vl := ext[i]
if toTarg {
nrn.Targ = vl
} else {
nrn.Ext = vl
}
nrn.SetFlag(false, clear...)
nrn.SetFlag(true, set...)
ly.ApplyExtValue(i, vl, clear, set, toTarg)
}
}

Expand Down

0 comments on commit 3b46a51

Please sign in to comment.