wangkuiyi / gotorch

A Go idiomatic binding to the C++ core of PyTorch
MIT License
312 stars 35 forks source link

torch.nn.Module in Go #28

Closed wangkuiyi closed 4 years ago

wangkuiyi commented 4 years ago

PyTorch APi has a key concept -- torch.nn.Module. Many builtin and user-defined models are classes derived from torch.nn.Module. The only method to override is forward(x).

Usually, a torch.nn.Module-derived class has data members representing the model parameters. For example, nn.Linear, the PyTorch implementation of the fully-connected layer has W and B -- the weights and the bias respectively.

In Go/Go+, the concept corresponds to a base class in Python is an interface. So, we provide type Module interface to mimic torch.nn.Module.

Then, we need a solution to free up tensors when a model's life is over.

wangkuiyi commented 4 years ago

Failed Proposal 1: Using Reflection

I tried an idea that uses Go's reflection mechanism and goes over fields in a Module-derived type recursively to free up all fields of type Tensor. Here is the example program that silhouettes the idea.

However, it doesn't work well because it requires that submodule fields in a Module-derived struct must be exported. In the following example, if we change the name of Ln2 into ln2, the function FreeModel would panic.

package main

import (
    "fmt"
    "reflect"
)

//
// Basic data type Tensor.
//

type Tensor struct{}

func (t Tensor) Close() {
    fmt.Println("Free up tensor memory...")
}

//
// Basic Tensor operations like those in ATen.
//

// torch_RandN mimics torch.RandN.
func torch_RandN(size []int) Tensor {
    return Tensor{}
}

// torch_MM mimics torch.MM.
func torch_MM(x, y Tensor) Tensor {
    return Tensor{}
}

// torch_Relu mimics torch.Relu.
func torch_Relu(x Tensor) Tensor {
    return Tensor{}
}

//
// Module is like torch.nn.Module.
//

type Module interface {
    Forward(x Tensor) Tensor
}

//
// Linear implements Module, like torch.nn.Linear implements torch.nn.Module.
//

type Linear struct {
    P Tensor
}

func NewLinear(in, out int) *Linear {
    return &Linear{
        P: torch_RandN([]int{in, out}), // torch_RandN mimics torch.RandN
    }
}

func (l *Linear) Forward(x Tensor) Tensor {
    return torch_MM(x, l.P) // torch_MM mimics torch.MM
}

//
// MyModel implements Module too.  It is a user-defined model.
//

type MyModel struct {
    Ln1, Ln2 *Linear
}

func NewMyModel() *MyModel {
    return &MyModel{
        Ln1: NewLinear(50, 100),
        Ln2: NewLinear(100, 10),
    }
}

func (m *MyModel) Forward(x Tensor) Tensor {
    x = torch_Relu(m.Ln1.Forward(x))
    return torch_Relu(m.Ln2.Forward(x))
}

//
// FreeModel uses Go's inspection mechanisms to free parameters of a module
//    and submodules recursively.
//

func FreeModel(m Module) {
    modelType := reflect.TypeOf((*Module)(nil)).Elem()
    tensorType := reflect.TypeOf(Tensor{})

    e := reflect.ValueOf(m).Elem()

    for i := 0; i < e.NumField(); i++ {
        varName := e.Type().Field(i).Name
        varType := e.Type().Field(i).Type
        varValue := e.Field(i).Interface()

        if varType.Implements(modelType) {
            fmt.Printf("Free up %s.%s\n", reflect.TypeOf(m), varName)
            FreeModel(varValue.(Module))
        }

        if varType == tensorType {
            fmt.Printf("Close Tensor %s\n", varName)
            varValue.(Tensor).Close()
        }
    }
}

func main() {
    FreeModel(NewMyModel())
}

You can run this program using Go Playground at https://play.golang.org/p/gtSyNhhGNbF

wangkuiyi commented 4 years ago

Proposal 2: Module.Close()

Alternatively, we can make this less magic but more Go idiomatic by requiring users to define a method Close for each Module-derived type.

package main

import (
    "fmt"
)

//
// Basic data type Tensor.
//

type Tensor struct{}

func (t Tensor) Close() {
    fmt.Println("Free up tensor memory...")
}

//
// Basic Tensor operations like those in ATen.
//

// torch_RandN mimics torch.RandN.
func torch_RandN(size []int) Tensor {
    return Tensor{}
}

// torch_MM mimics torch.MM.
func torch_MM(x, y Tensor) Tensor {
    return Tensor{}
}

// torch_Relu mimics torch.Relu.
func torch_Relu(x Tensor) Tensor {
    return Tensor{}
}

//
// Module is like torch.nn.Module.
//

type Module interface {
    Forward(x Tensor) Tensor
    Close()
}

//
// Linear implements Module, like torch.nn.Linear implements torch.nn.Module.
//

type linear struct {
    p Tensor
}

func Linear(in, out int) Module {
    return &linear{
        p: torch_RandN([]int{in, out}), // torch_RandN mimics torch.RandN
    }
}

func (l *linear) Forward(x Tensor) Tensor {
    return torch_MM(x, l.p) // torch_MM mimics torch.MM
}

func (l *linear) Close() {
    l.p.Close()
}

//
// mymodel implements Module too.  It is a user-defined model.
//

type mymodel struct {
    ln1, ln2 Module
}

func MyModel() Module {
    return &mymodel{
        ln1: Linear(50, 100),
        ln2: Linear(100, 10),
    }
}

func (m *mymodel) Forward(x Tensor) Tensor {
    x = torch_Relu(m.ln1.Forward(x))
    return torch_Relu(m.ln2.Forward(x))
}

func (m *mymodel) Close() {
    m.ln1.Close()
    m.ln2.Close()
}

func main() {
    m := MyModel()
    m.Close()
}

You can run it at https://play.golang.org/p/OBs7UHwWPMQ

QiJune commented 4 years ago

Module is the base class for all neural network modules. Use models should subclass this class.

import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight, 0.0, 0.02)

model = MyModel()
model.apply(weights_init)

The derived class MyModel inherits data member and method from the base class, such as _parameters data member, and apply method. And the derived class also need to implement the virtual methods defined in the base class, such as forward method.

Go does not support inheritance. It seems that we have to introduce other concepts to represent nn.Module in Python.

In GoTorch library:

type Tensor struct{}

// Module defines the interface
type Module interface {
    Forward(x Tensor) Tensor
}

// Model contains data members
type Model struct {
    parameters map[string]Tensor
        buffers map[string]Tenosr
        modules map[string]Module
}

func (m *Model) RegisterBuffer(string, Tenosr) {
}

func (m *Model) RegisterParameter(string, Tensor) {
}

func (m *Model) RegisterModule(string, Module) {
}

// Apply provides global functions
func Apply(m Module, fn func (Module)) {
    fmt.Println("apply something to a module")
}

In user's codes:

type MyModel struct {
    Model
}

func (m *MyModel) Forward(x Tensor) Tensor {
    fmt.Println("my model forward")
    return Tensor{}
}

func NewMyModel() *MyModel {
    return &MyModel{}
}

func main() {
    model := NewMyModel()
    Apply(model)
    fmt.Println(model.parameters)
    x := Tensor{}
    model.Forward(x)
}
QiJune commented 4 years ago

Things need to study further:

QiJune commented 4 years ago

Do we need to introduce Parameter type in Go?

No, in C++, there is no Parameter class also.

shendiaomo commented 4 years ago

Fixed by #95