Git Product home page Git Product logo

Comments (6)

shendiaomo avatar shendiaomo commented on August 23, 2024

Monadic Go+Torch API

In a monadic pattern, we have to define a new type that encapsulates the libtorch functions. The type Torch has two fields err and tensors to reduce boilerplate code about memory management and error handling.

type Torch struct {
    err error
    tensors []Tensor
}

func (torch *Torch) Error() error {
    return torch.err
}

func (torch *Torch) Reclaim()  {
    for t := torch.tensors {
        Delete(t)  // `Delete` is a imaginary cgo wrapper function that calls C++ delete
    }
    torch.tensors = []Tensor{}
}

func (torch *Torch) Relu(tensor Tensor) Tensor {
    if torch.err != nil {
        return Tensor{}
    }
    t, err := C.relu(tensor) // `C.relu` is the imaginary cgo wrapper function
    if err != nil {
        torch.err = err // Record line number etc.
        return Tensor{}
    }
    append([]Tensor{t}, torch.tensors...) // Register `Tensor`s in reverse order for later reclaiming
    return t
}

// Define `MaxPool2d`, `Dropout`, and other functions with a similar pattern as `Relu`

With the monadic API, a user-define Module struct should embed a Torch object. As a result, users can write much neater code for their own Modules:

Monadic Model Definition

struct Net {
    torch.Torch
    // Define conv1, conv2, fc1, fc2 etc.
}

func (net *Net) Forward(x torch.Tensor) torch.Tensor {
    x = net.Relu(net.MaxPool2d(net.conv1.Forward(x), 2))
    x = net.Relu(
        net.MaxPool2d(net.conv2_drop.Forward(net.conv2.Forward(x)), 2))
    x = x.View([]int{-1, 320})
    x = net.Relu(net.fc1.Forward(x))
    x = net.Dropout(x, /*p=*/0.5, /*training=*/is_training())
    x = net.fc2.Forward(x)
    return net.LogSoftmax(x, /*dim=*/1)

As we can see, the code is much shorter than the same example in the original post.
NOTE: the users also have to pass the tensors slice of Torch to submodules such as net.conv1, net.conv2 etc. in some way to register Tensors for later reclaiming.

Monadic Train Loop

// We need this nested function to make `defer` works as expected.
func step(batch *Batch) {
    // `model` is an instance of  the `Net` struct above
    defer model.Reclaim() // `delete` all tensors at the end of scope as the C++ version
    data := model.ToDevice(batch.Data, device)
    target := model.ToDevice(batch.Target, device)
    optimizer.zero_grad()
    output := model.Forward(data)
    loss = model.NllLoss(output, targets)
    loss.Backward()
    optimizer.Step()
    // ...
}

for batch := range data_loader {
    step(batch)
    if model.err != nil {  // Use the monad API to handle errors
        return ...
    }
}

As we can see, the code is also much shorter than the same example in the original post and contains a complete error handling.

Another option is to put the defer model.Reclaim() in the optimizer.Step() function (because user have to pass parameters of model to the optimizer after all). The advantage of this is that users don't have to write a step function, the drawback to this is that the users cannot access any tensors after the call to optimizer.Step()

from gotorch.

shendiaomo avatar shendiaomo commented on August 23, 2024

Another direction is to define a general monad,like:

struct Monad {
    err error
    tensors []Tensor
}

func (m *Monad) Get(func f() (Tensor, error)) Tensor {
    if m.err != nil {
        return Tensor{}
    }
    t, e := f()
    if e != nil {
        m.err = e
        return Tensor{}
    } 
    append([]Tensor{t}, tensors...)
    return t
}

func (m *Monad) Do(func f() error) {
    if m.err != nil {
        return Tensor{}
    }
    t, e := f()
    if e != nil {
        m.err = e
    } 
}

func (m *Monad)Error() error {
    return m.err
}

func (m *Monad) Reclaim() {
    for t := range tensors {
        C.Delete(t)
    }
}

A user can write the forward function and the train loop above in Go+(with the new lambda syntax) like:

for batch := range data_loader {
    var m torch.Monad
    defer m.Reclaim() // `delete` all tensors at the end of scope as the C++ version
    // `model` is an instance of  the `Net` struct above
    data := m.Get(=>model.ToDevice(batch.Data, device))
    target := m.Get(=>model.ToDevice(batch.Target, device))
    m.Do(=>optimizer.zero_grad())
    output := m.Get(=>model.Forward(data))
    loss = m.Get(=>model.NllLoss(output, targets))
    m.Do(=>loss.Backward())
    m.Do(=>optimizer.Step())
    if m.err != nil {  // Use the monad API to handle errors
        return ...
    }
}

The benefit is that it's much easier to design the API, otherwise, we have to pass a pointer to the monad everywhere to collect Tensors and errors. The pitfall is that it forces some mind overhead on users.

from gotorch.

xushiwei avatar xushiwei commented on August 23, 2024

Maybe we can introduce smart pointer in Go+?

from gotorch.

wangkuiyi avatar wangkuiyi commented on August 23, 2024

We finally chose the SetTensorFinalizer solution, other than the Monad pattern.

from gotorch.

xushiwei avatar xushiwei commented on August 23, 2024

We finally chose the SetTensorFinalizer solution, other than the Monad pattern.

Where can I refer the detail about SetTensorFinalizer solution?

from gotorch.

wangkuiyi avatar wangkuiyi commented on August 23, 2024

@xushiwei It is in https://github.com/wangkuiyi/gotorch/blob/develop/tensor_gc.go

from gotorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.