Comments (6)
Monadic Go+Torch API
In a monadic pattern, we have to define a new type that encapsulates the libtorch functions. The type Torch
has two fields err
and tensors
to reduce boilerplate code about memory management and error handling.
type Torch struct {
err error
tensors []Tensor
}
func (torch *Torch) Error() error {
return torch.err
}
func (torch *Torch) Reclaim() {
for t := torch.tensors {
Delete(t) // `Delete` is a imaginary cgo wrapper function that calls C++ delete
}
torch.tensors = []Tensor{}
}
func (torch *Torch) Relu(tensor Tensor) Tensor {
if torch.err != nil {
return Tensor{}
}
t, err := C.relu(tensor) // `C.relu` is the imaginary cgo wrapper function
if err != nil {
torch.err = err // Record line number etc.
return Tensor{}
}
append([]Tensor{t}, torch.tensors...) // Register `Tensor`s in reverse order for later reclaiming
return t
}
// Define `MaxPool2d`, `Dropout`, and other functions with a similar pattern as `Relu`
With the monadic API, a user-define Module
struct should embed a Torch
object. As a result, users can write much neater code for their own Module
s:
Monadic Model Definition
struct Net {
torch.Torch
// Define conv1, conv2, fc1, fc2 etc.
}
func (net *Net) Forward(x torch.Tensor) torch.Tensor {
x = net.Relu(net.MaxPool2d(net.conv1.Forward(x), 2))
x = net.Relu(
net.MaxPool2d(net.conv2_drop.Forward(net.conv2.Forward(x)), 2))
x = x.View([]int{-1, 320})
x = net.Relu(net.fc1.Forward(x))
x = net.Dropout(x, /*p=*/0.5, /*training=*/is_training())
x = net.fc2.Forward(x)
return net.LogSoftmax(x, /*dim=*/1)
As we can see, the code is much shorter than the same example in the original post.
NOTE: the users also have to pass the tensors
slice of Torch
to submodules such as net.conv1
, net.conv2
etc. in some way to register Tensor
s for later reclaiming.
Monadic Train Loop
// We need this nested function to make `defer` works as expected.
func step(batch *Batch) {
// `model` is an instance of the `Net` struct above
defer model.Reclaim() // `delete` all tensors at the end of scope as the C++ version
data := model.ToDevice(batch.Data, device)
target := model.ToDevice(batch.Target, device)
optimizer.zero_grad()
output := model.Forward(data)
loss = model.NllLoss(output, targets)
loss.Backward()
optimizer.Step()
// ...
}
for batch := range data_loader {
step(batch)
if model.err != nil { // Use the monad API to handle errors
return ...
}
}
As we can see, the code is also much shorter than the same example in the original post and contains a complete error handling.
Another option is to put the defer model.Reclaim()
in the optimizer.Step()
function (because user have to pass parameters of model
to the optimizer
after all). The advantage of this is that users don't have to write a step
function, the drawback to this is that the users cannot access any tensors after the call to optimizer.Step()
from gotorch.
Another direction is to define a general monad,like:
struct Monad {
err error
tensors []Tensor
}
func (m *Monad) Get(func f() (Tensor, error)) Tensor {
if m.err != nil {
return Tensor{}
}
t, e := f()
if e != nil {
m.err = e
return Tensor{}
}
append([]Tensor{t}, tensors...)
return t
}
func (m *Monad) Do(func f() error) {
if m.err != nil {
return Tensor{}
}
t, e := f()
if e != nil {
m.err = e
}
}
func (m *Monad)Error() error {
return m.err
}
func (m *Monad) Reclaim() {
for t := range tensors {
C.Delete(t)
}
}
A user can write the forward
function and the train loop above in Go+(with the new lambda syntax) like:
for batch := range data_loader {
var m torch.Monad
defer m.Reclaim() // `delete` all tensors at the end of scope as the C++ version
// `model` is an instance of the `Net` struct above
data := m.Get(=>model.ToDevice(batch.Data, device))
target := m.Get(=>model.ToDevice(batch.Target, device))
m.Do(=>optimizer.zero_grad())
output := m.Get(=>model.Forward(data))
loss = m.Get(=>model.NllLoss(output, targets))
m.Do(=>loss.Backward())
m.Do(=>optimizer.Step())
if m.err != nil { // Use the monad API to handle errors
return ...
}
}
The benefit is that it's much easier to design the API, otherwise, we have to pass a pointer to the monad everywhere to collect Tensor
s and error
s. The pitfall is that it forces some mind overhead on users.
from gotorch.
Maybe we can introduce smart pointer in Go+?
from gotorch.
We finally chose the SetTensorFinalizer
solution, other than the Monad pattern.
from gotorch.
We finally chose the
SetTensorFinalizer
solution, other than the Monad pattern.
Where can I refer the detail about SetTensorFinalizer
solution?
from gotorch.
@xushiwei It is in https://github.com/wangkuiyi/gotorch/blob/develop/tensor_gc.go
from gotorch.
Related Issues (20)
- Build for NVIDIA Drive PX2 HOT 1
- Decoding jpg diff between Go image library and Python PIL library HOT 3
- A magic number HOT 1
- runtime.LockOSThread() reduce throughput heavily
- Image preprocessing pipeline HOT 2
- Implement PreFetcher to Increase ImageLoader Throughput HOT 2
- A minimal C++ example to reproduce the problem in #273 HOT 2
- Use the Homebrew version of libtorch on macOS
- torch.GC may hang with a data loader cache
- Wrap BuildIndexVocabulary into a command-line tool
- GoCV decoder skipped some JPEG image in imageloader HOT 2
- ImageLoader shuffle samples at the begging of every each epoch
- Support data parallelism with a GPU cluster HOT 4
- gotorch can load pytorch models? HOT 11
- Unable to install go torch, including windows and ubuntu HOT 2
- Random errors in mnist
- 两年木有更新还能用吗 HOT 2
- Do we support Embedding LSTM GRU Transformer Attention Layers for NLP?
- Cannot build cgotorch on Macbook M1 (arm64) system HOT 1
- Suitable for deployment?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from gotorch.