Git Product home page Git Product logo

Comments (7)

schnecki avatar schnecki commented on May 20, 2024

Btw shouldn't the weights be initialized by random numbers close to 0 and in the range [-1;1]? For me when I run randomNetwork in the IO Monad the weights are initialized by arbritary numbers, yielding huge outputs for e.g. FullyConnected. This causes issues with some activation layers like tanh.

I could fix it by adding with following snippet for the fully connected layer. But how is it supposed to be initialized?

...
s1   <- getRandomR (-1,1)
s2    <- getRandomR (-1,1)
let wB = 1/5000 * randomVector  s1 Uniform
    wN = 1/5000 * uniformSample s2 (-1) 1
...

from grenade.

HuwCampbell avatar HuwCampbell commented on May 20, 2024

They are actually (edit: within the [1,1] range).
It's currently

randomFullyConnected = do
    s1    <- getRandom
    s2    <- getRandom
    let wB = randomVector  s1 Uniform * 2 - 1
        wN = uniformSample s2 (-1) 1
        bm = konst 0
        mm = konst 0
return $ FullyConnected (FullyConnected' wB wN) (FullyConnected' bm mm)

The s1 and s2 are integer seeds which are fed to uniformSample from hmatrix. The (-1) 1 args specify the range. Uniform gives between 0, and 1, hence the scaling I have done.

from grenade.

HuwCampbell avatar HuwCampbell commented on May 20, 2024

But yes, they could be scaled down a bit to not be uniform, and there could also be some more normalisation done. This is one of the reasons I am interested in being a bit smarter about the initialisation.

from grenade.

schnecki avatar schnecki commented on May 20, 2024

True, you're right. I investigated the observed problem a little more. It happens to me when I use multiple FullyConnected layers with Relu activations sequentially ending with a FullyConnected layer and the tanh activation. The wBs seem to lift the values out of the range [-1;1] and as ReLu ignores this, the input values to tanh are almost all >>1 or <<-1. This results in tanh returning 1 (-1) for most values. Learning from there fails (or at least would take forever).
edit: typo

from grenade.

schnecki avatar schnecki commented on May 20, 2024

So, I'm quite busy for another 2-3 weeks. After that I could do the ticket if that's ok for everyone. But I will have to look into literature on how to properly initialize the weights first.

from grenade.

HuwCampbell avatar HuwCampbell commented on May 20, 2024

I think that should be fine.

from grenade.

schnecki avatar schnecki commented on May 20, 2024

Hi @HuwCampbell
first of all I'd like to say sorry for taking that long. I have finally found time to work on the initialization.

You can find the current version here: https://github.com/schnecki/grenade

Before going into intialization: I implemented a GNum (Grenade Num) class which can be used to scale nodes (after initialization that might make sense) and add networks (e.g. for slow adaption in reinforcement learning where you use a training network and a target network), see Grenade.Core.Network. Additionally it can be used on Gradients to process batches in parallel (parMap) and then adding up the gradients before training the network (which was faster for me). Furthermore, I had to add NFData instances to prevent memory leaks.

Coming back to Weight Initialization, see Grenade.Core.WeightInitialization and Grenade.Core.Layer. All layers now ask the Grenade.Core.WeightInitialization module for a random vector/matrix when initializing. Therefore, the actual generation of random data is encapsulated in that module and thus adding a new weight initialization method just requires changes in that module. Btw a simple experiment showed that weight initialization makes a huge difference, see the feedforwardweightinit.hs example and test different settings.

My goal so far was to have backward compatibility, which worked out quite nicely. I moved the RandomMonad m constraint to specific IO which shouldn't be the problem for most people. Otherwise the new randomNetworkWith function can be called.

The class Variate as proposed by you can only provide uniformly distributed values, thus this does not work.

P.S.: Btw what training algorithm is implemented, Adam?

from grenade.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.