Comments (4)
Sounds like an issue with the batching process:
Here:
Line 210 in ce60739
In the working example, batches
is zero which makes it work.
In the failed example, batches is one, which has an impact on the Xs
, Policies
and Values
tensors.
The error is raised from this function in the dualnet
package:
Line 16 in ce60739
I will try to setup a test on this function to reproduce the error so we can investigate more.
from agogo.
Here is a test case to ease the debugging process:
func TestTrain(t *testing.T) {
features := 2
height := 3
width := 3
actionSpace := 10
batchSize := 100
batchesZero := 0
batchesOne := 1
conf := DefaultConf(3, 3, 10)
conf.BatchSize = batchSize
conf.Features = features
conf.K = 3
conf.SharedLayers = 3
type args struct {
d *Dual
Xs *tensor.Dense
policies *tensor.Dense
values *tensor.Dense
batches int
iterations int
}
d := &Dual{Config: conf}
if err := d.Init(); err != nil {
t.Fatalf("%+v", err)
}
tests := []struct {
name string
args args
wantErr bool
}{
{
"issue #2 batch one (not working)",
args{
d: d,
Xs: tensor.New(tensor.WithBacking(make([]float32, batchSize*batchesOne*features*height*width)), tensor.WithShape(batchSize*batchesOne, features, height, width)),
policies: tensor.New(tensor.WithBacking(make([]float32, batchSize*batchesOne*actionSpace)), tensor.WithShape(batchSize*batchesOne, actionSpace)),
values: tensor.New(tensor.WithBacking(make([]float32, batchSize*batchesOne)), tensor.WithShape(batchSize*batchesOne)),
batches: batchesOne,
iterations: 100,
},
false,
},
{
"issue #2 batches zero",
args{
d: d,
Xs: tensor.New(tensor.WithBacking(make([]float32, batchSize*batchesZero*features*height*width)), tensor.WithShape(batchSize*batchesZero, features, height, width)),
policies: tensor.New(tensor.WithBacking(make([]float32, batchSize*batchesZero*actionSpace)), tensor.WithShape(batchSize*batchesZero, actionSpace)),
values: tensor.New(tensor.WithBacking(make([]float32, batchSize*batchesZero)), tensor.WithShape(batchSize*batchesZero)),
batches: batchesZero,
iterations: 100,
},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := Train(tt.args.d, tt.args.Xs, tt.args.policies, tt.args.values, tt.args.batches, tt.args.iterations); (err != nil) != tt.wantErr {
t.Errorf("Train() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
Train() error = PC: 246: PC 246. Failed to execute instruction Aแต{0, 2, 3, 1} [CPU144] CPU144 false true false: Failed to carry op.Do(): Dimension mismatch. Expected 2, got 4, wantErr false
FAIL
from agogo.
It looks like the Reset
method of the TapeMachine is leaking something.
Actually, this code makes the previous test turn green:
// Train is a basic trainer.
func Train(d *Dual, Xs, policies, values *tensor.Dense, batches, iterations int) error {
- m := G.NewTapeMachine(d.g, G.BindDualValues(d.Model()...))
- model := G.NodesToValueGrads(d.Model())
- solver := G.NewVanillaSolver(G.WithLearnRate(0.1))
var s slicer
for i := 0; i < iterations; i++ {
// var cost float32
for bat := 0; bat < batches; bat++ {
+ m := G.NewTapeMachine(d.g, G.BindDualValues(d.Model()...))
+ model := G.NodesToValueGrads(d.Model())
+ solver := G.NewVanillaSolver(G.WithLearnRate(0.1))
batchStart := bat * d.Config.BatchSize
batchEnd := batchStart + d.Config.BatchSize
@@ -38,7 +38,7 @@ func Train(d *Dual, Xs, policies, values *tensor.Dense, batches, iterations int)
if err := solver.Step(model); err != nil {
return err
}
- m.Reset()
+ //m.Reset()
tensor.ReturnTensor(Xs2)
tensor.ReturnTensor(ฯ)
tensor.ReturnTensor(v)
I guess that I should raise an issue in gorgonia.
cc @chewxy
from agogo.
You're right. It appears Reset
is leaking something. Also looking back, wow the Gorgonia library has changed quite a bit.
from agogo.
Related Issues (9)
- when it's ready HOT 4
- Train fail: shuffle batch failed - matX: Not yet implemented: native matrix for colmajor or unpacked matrices HOT 4
- How would the configuration for training agogo for go look like?
- How to run wq?
- Wrong model architecture in residual network? HOT 5
- 9x9 board
- Mancala / kalah
- Can't run tic-tac-toc HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from agogo.