emer / axon

Axon is a spiking, biologically-based neural model driven by predictive error-driven learning, for systems-level models of the brain
BSD 3-Clause "New" or "Revised" License
19 stars 7 forks source link

add a one-to-many test case in examples #298

Open rcoreilly opened 1 year ago

rcoreilly commented 1 year ago

useful for testing noise params etc. test case was in obelisk -- very out of date -- easier to maintain within axon when doing mass updates. start with ra25 and add these specializations:

// ConfigPats used to configure patterns
func (ss *Sim) ConfigPats() {
    dt := ss.Pats
    dt.SetMetaData("name", "TrainPats")
    dt.SetMetaData("desc", "Training patterns")
    sch := etable.Schema{
        {Name: "Name", Type: etensor.STRING, CellShape: nil, DimNames: nil},
        {Name: "Input", Type: etensor.FLOAT32, CellShape: []int{5, 5}, DimNames: []string{"Y", "X"}},
        {Name: "Output", Type: etensor.FLOAT32, CellShape: []int{5, 5}, DimNames: []string{"Y", "X"}},
    }
    dt.SetFromSchema(sch, ss.NInputs*ss.NOutputs)

    patgen.PermutedBinaryMinDiff(dt.Cols[1].(*etensor.Float32), 6, 1, 0, 3)
    patgen.PermutedBinaryMinDiff(dt.Cols[2].(*etensor.Float32), 6, 1, 0, 3)
    for i := 0; i < ss.NInputs; i++ {
        for j := 0; j < ss.NOutputs; j++ {
            dt.SetCellTensor("Input", i*ss.NOutputs+j, dt.CellTensor("Input", i*ss.NOutputs))
            dt.SetCellString("Name", i*ss.NOutputs+j, fmt.Sprintf("%d", i))
        }
    }
    dt.SaveCSV("random_5x5_25_gen.tsv", etable.Tab, etable.Headers)
}
// TrialStats computes the trial-level statistics.
// Aggregation is done directly from log data.
func (ss *Sim) TrialStats() {
    out := ss.Net.LayerByName("Output").(axon.AxonLayer).AsAxon()

    ss.Stats.SetFloat("TrlCorSim", float64(out.Vals[0].CorSim.Cor))
    ss.Stats.SetFloat("TrlUnitErr", out.PctUnitErr(&ss.Context)[0])

    _, cor, cnm := ss.Stats.ClosestPat(ss.Net, "Output", "ActM", 0, ss.Pats, "Output", "Name")

    //For each name, record map of closest rows that are predicted
    //For each name, record rows associated with
    ss.Stats.SetString("TrlClosest", cnm)
    ss.Stats.SetFloat("TrlCorrel", float64(cor))
    ev := ss.Envs[ss.Context.Mode.String()].(*env.FixedTable)
    tnm := ev.TrialName.Cur
    if cnm == tnm {
        ss.Stats.SetFloat("TrlErr", 0)
    } else {
        ss.Stats.SetFloat("TrlErr", 1)
    }
}