gorgonia / gorgonia

Gorgonia is a library that helps facilitate machine learning in Go.
https://gorgonia.org/
Apache License 2.0
5.4k stars 431 forks source link

OneHot op #559

Open JoshPattman opened 6 months ago

JoshPattman commented 6 months ago

I am writing some code that requires an op that takes a vector of indexes (ints), and returns a matrix of onehot vectors. I didn't make this a pull request as I am not 100% sure this has enough error-checking / correct implementations for some functions. I also might just be missing the fact that this is already implemented.

I think, if its not already there, this would be a good addition to make gorgonia a bit easier to use.

package main

import (
    "fmt"
    "hash"

    "github.com/chewxy/hm"
    "gorgonia.org/gorgonia"
    "gorgonia.org/tensor"
)

func OneHot(x *gorgonia.Node, numClasses int, dType tensor.Dtype) (*gorgonia.Node, error) {
    op := &oneHotOp{numClasses, dType}

    return gorgonia.ApplyOp(op, x)
}

var _ gorgonia.Op = &oneHotOp{}
var _ gorgonia.SDOp = &oneHotOp{}

type oneHotOp struct {
    numClasses int
    dType      tensor.Dtype
}

// DiffWRT implements gorgonia.SDOp.
func (*oneHotOp) DiffWRT(inputs int) []bool {
    // I'm pretty sure you cant, nor would ever want to, take the derivative of this op.
    return make([]bool, inputs)
}

// SymDiff implements gorgonia.SDOp.
func (*oneHotOp) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) {
    panic("unimplemented (tho tbf this should never be called)")
}

// Arity implements gorgonia.Op.
func (*oneHotOp) Arity() int {
    return 1 // we expect just a vector of indices
}

// CallsExtern implements gorgonia.Op.
func (*oneHotOp) CallsExtern() bool {
    return false
}

// Do implements gorgonia.Op.
func (op *oneHotOp) Do(inp ...gorgonia.Value) (gorgonia.Value, error) {
    batchSize := inp[0].Shape()[0]
    tens := tensor.New(tensor.WithShape(batchSize, op.numClasses), tensor.Of(op.dType))
    for i := 0; i < batchSize; i++ {
        index := inp[0].Data().([]int)[i]
        switch op.dType {
        case tensor.Int:
            tens.SetAt(int(1), i, index)
        case tensor.Float64:
            tens.SetAt(float64(1), i, index)
        case tensor.Float32:
            tens.SetAt(float32(1), i, index)
        case tensor.Bool:
            tens.SetAt(true, i, index)
        }
    }
    return tens, nil
}

// InferShape implements gorgonia.Op.
func (op *oneHotOp) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
    s := inputs[0].(tensor.Shape).Clone()
    s = append(s, op.numClasses)
    return s, nil
}

// OverwritesInput implements gorgonia.Op.
func (*oneHotOp) OverwritesInput() int {
    return -1
}

// ReturnsPtr implements gorgonia.Op.
func (*oneHotOp) ReturnsPtr() bool {
    return false
}

// String implements gorgonia.Op.
func (*oneHotOp) String() string {
    return "OneHotOp"
}

// Type implements gorgonia.Op.
func (*oneHotOp) Type() hm.Type {
    ohTypeInput := gorgonia.TensorType{
        Dims: 1,
        Of:   tensor.Int,
    }
    ohTypeOutput := gorgonia.TensorType{
        Dims: 2,
        Of:   tensor.Float64,
    }
    return hm.NewFnType(ohTypeInput, ohTypeOutput)
}

// I dont actually know what this is for (i just copied this code from another op)
func (op *oneHotOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, op.String()) }

// Hashcode implements gorgonia.Op.
func (*oneHotOp) Hashcode() uint32 {
    // I dont actually know what this is for
    panic("unimplementedb")
}