Git Product home page Git Product logo

Comments (2)

OmarJay1 avatar OmarJay1 commented on August 18, 2024

I looked at your code and can't seem to find anything wrong with it. Maybe there's something I'm not getting about the Matrix organization. I also tried the example given in README.md and can't get that to give accurate results.

The following is what I get when I examine the data. Am I missing something?

Matrix pred = net.predict(x);

for (int col = 0; col < 100; col++)
{
	std::cout << "Actual:\t" << y(0, col) << "\t" << y(1, col) << "\tPred:\t" << pred(0, col) << "\t" << pred(1, col) << "\n";
}

[Epoch 0, batch 0] Loss = 0.306261
[Epoch 1, batch 0] Loss = 0.305863
[Epoch 2, batch 0] Loss = 0.305571
[Epoch 3, batch 0] Loss = 0.305317
[Epoch 4, batch 0] Loss = 0.305056
[Epoch 5, batch 0] Loss = 0.304763
[Epoch 6, batch 0] Loss = 0.304438
[Epoch 7, batch 0] Loss = 0.304063
[Epoch 8, batch 0] Loss = 0.303648
[Epoch 9, batch 0] Loss = 0.303203
Actual: -0.594592 -0.559557 Pred: -0.0108888 -0.00331937
Actual: -0.601184 0.171911 Pred: -0.0151105 0.011997
Actual: 0.391766 -0.335063 Pred: -0.0120812 0.0071503
Actual: 0.397565 0.491379 Pred: -0.0176556 0.0154349
Actual: 0.490341 -0.813654 Pred: -0.0124328 0.000167552
Actual: -0.582873 -0.424421 Pred: -0.0160829 0.00661926
Actual: 0.124912 -0.328288 Pred: -0.016426 0.00983986
Actual: -0.358806 0.320353 Pred: -0.0151985 0.00863794
Actual: 0.269021 0.217261 Pred: -0.0171463 0.0139211
Actual: 0.731071 -0.109165 Pred: -0.00500451 0.00823758
Actual: -0.0208441 0.998169 Pred: -0.0116118 0.00654761
Actual: -0.87286 -0.931883 Pred: -0.0249642 0.00230164
Actual: -0.395856 -0.109287 Pred: -0.0135956 0.00817519
Actual: -0.56798 0.350383 Pred: -0.0184759 0.0143652
Actual: -0.267251 -0.120212 Pred: -0.020773 0.00400116
Actual: -0.562059 -0.880123 Pred: -0.0222947 0.0010379
Actual: 0.941649 0.0643025 Pred: -0.00497992 0.00557524
Actual: 0.361248 -0.810785 Pred: -0.0214937 0.00512891
Actual: 0.407331 0.0719932 Pred: -0.0205746 0.0059913
Actual: 0.846858 -0.96118 Pred: -0.0170448 0.00248724
Actual: 0.443281 0.697684 Pred: 0.000145709 0.00842001
Actual: 0.304117 0.058565 Pred: -0.0108235 0.0061853
Actual: -0.974303 0.576281 Pred: -0.0261767 0.00766655
Actual: 0.137913 -0.894711 Pred: -0.0161439 0.00761143
Actual: -0.630055 -0.296426 Pred: -0.00916252 0.00844913
Actual: 0.436079 0.689993 Pred: -0.0157206 0.0072009
Actual: 0.620045 -0.382855 Pred: -0.0111943 0.00363827
Actual: 0.676809 -0.894467 Pred: -0.0103445 0.00890045
Actual: 0.714225 -0.950133 Pred: -0.0166271 0.00632329
Actual: 0.0177923 -0.19071 Pred: -0.0173963 0.00688036
Actual: -0.511155 0.86877 Pred: -0.0104109 0.0135536
Actual: -0.26957 0.371136 Pred: -0.0183217 0.0110357
Actual: -0.917051 -0.889096 Pred: -0.0222241 0.000982962
Actual: -0.659658 0.301065 Pred: -0.0180802 0.0051002
Actual: -0.294229 -0.0690634 Pred: -0.0158001 0.00998637
Actual: -0.962096 0.0678426 Pred: -0.0166481 0.0105297
Actual: -0.169408 -0.730949 Pred: -0.0108766 0.00872367
Actual: -0.0792566 0.621998 Pred: -0.00995024 0.00982201
Actual: 0.876644 -0.943663 Pred: -0.0139434 -0.00218531
Actual: 0.630543 -0.265358 Pred: -0.00409277 0.0100282
Actual: -0.558458 0.766289 Pred: -0.0185546 0.0026624
Actual: -0.0528275 0.465682 Pred: -0.00901414 0.0119489
Actual: -0.852474 0.126865 Pred: -0.0300936 0.00649495
Actual: -0.0979339 0.0729698 Pred: -0.0149036 0.00541182
Actual: -0.836909 -0.194372 Pred: -0.0151815 0.00681906
Actual: 0.97412 0.676138 Pred: 0.000142749 0.0178851
Actual: -0.537095 -0.120273 Pred: -0.00977149 0.00998664
Actual: -0.0708335 -0.0143742 Pred: -0.0272016 0.00619372
Actual: -0.111423 0.903195 Pred: -0.0158854 0.00943956
Actual: -0.132115 -0.587878 Pred: -0.0225948 -0.00104554
Actual: -0.927183 -0.615162 Pred: -0.0242762 0.00347298
Actual: -0.958312 -0.134068 Pred: -0.0253801 0.00422208
Actual: 0.289529 0.489059 Pred: -0.00869754 0.0164616
Actual: 0.0019837 0.561327 Pred: -0.0213195 0.00962443
Actual: 0.291299 -0.229225 Pred: -0.00906035 0.0065133
Actual: -0.357036 -0.863399 Pred: -0.0232279 0.00697945
Actual: 0.199072 0.759392 Pred: -0.0131771 0.013273
Actual: -0.29783 0.403974 Pred: -0.0225625 0.00580345
Actual: 0.503708 0.201453 Pred: -0.0175854 0.00230672
Actual: 0.280129 0.0392773 Pred: -0.013257 0.0056122
Actual: -0.755547 -0.222449 Pred: -0.0208501 0.00849499
Actual: -0.0579546 0.52739 Pred: -0.0120977 0.0149936
Actual: -0.575182 0.102878 Pred: -0.0262276 0.00781519
Actual: 0.964415 -0.327189 Pred: -0.00355708 0.00567568
Actual: 0.730033 -0.173742 Pred: -0.00907445 0.0092031
Actual: -0.711966 0.989807 Pred: -0.0200257 0.00757274
Actual: -0.393353 -0.360576 Pred: -0.0186856 0.00786544
Actual: -0.258644 0.283242 Pred: -0.0190717 0.00299429
Actual: 0.769768 0.00845363 Pred: -0.0122089 0.0128042
Actual: 0.0950652 -0.265664 Pred: -0.00716194 0.00137264
Actual: 0.196142 0.440718 Pred: -0.0178735 0.0130729
Actual: 0.511582 -0.197974 Pred: -0.00541223 0.0165225
Actual: 0.6704 -0.900693 Pred: -0.015097 -0.00245243
Actual: -0.0483108 0.253395 Pred: -0.0101348 0.00760942
Actual: -0.0571612 0.577929 Pred: -0.0155326 0.00965104
Actual: 0.100436 0.635731 Pred: -0.0197192 0.00156631
Actual: 0.564135 0.647816 Pred: -0.0110585 0.0127332
Actual: 0.375286 -0.497726 Pred: -0.0159445 0.00922888
Actual: -0.542772 -0.20951 Pred: -0.0125907 0.00831498
Actual: 0.578661 0.704703 Pred: -0.0105018 0.0150356
Actual: 0.291299 0.327921 Pred: -0.013753 0.0107937
Actual: 0.757622 0.0637532 Pred: -0.00509624 0.0149238
Actual: 0.717887 0.962584 Pred: -0.0120862 0.0152628
Actual: 0.832881 0.732353 Pred: -0.011825 0.0101201
Actual: -0.466597 -0.459395 Pred: -0.0190337 0.00211734
Actual: -0.440046 0.561571 Pred: -0.0155961 0.0112913
Actual: -0.170202 -0.150243 Pred: -0.022732 0.00669575
Actual: -0.614673 -0.216041 Pred: -0.0207481 0.00626955
Actual: -0.804193 -0.734794 Pred: -0.0260024 0.00704391
Actual: 0.75866 0.324992 Pred: -0.0101267 0.0109374
Actual: -0.968139 -0.944578 Pred: -0.0216883 0.00387969
Actual: 0.465072 -0.943785 Pred: -0.0101294 0.00464512
Actual: 0.0405591 0.160924 Pred: -0.00623769 0.00949334
Actual: -0.460067 0.302591 Pred: -0.0256378 0.00666636
Actual: -0.0449538 0.257546 Pred: -0.0227666 0.00891267
Actual: -0.67272 -0.0843226 Pred: -0.0209044 0.00459229
Actual: 0.569079 0.918394 Pred: -0.00881892 0.0153782
Actual: -0.754204 0.226356 Pred: -0.0154799 0.00604572
Actual: 0.129612 0.684988 Pred: -0.020261 0.00833592
Actual: 0.824335 0.18833 Pred: -0.00903887 0.00990284

from minidnn.

yixuan avatar yixuan commented on August 18, 2024

Sorry for my late reply.

The question both of you were asking was mostly about the DNN theory, not the program itself.

For @OmarJay1's observation, in fact this is expected. In my example x and y were simulated independently, so they shouldn't have any real correlation. It uses random errors to predict random errors, so the bad accuracy is real.

For @katzb123's XOR function, I can comment a little bit more. It is true that the XOR function is simple, mathematically. But unfortunately, optimizing the nonconvex loss function using gradient-based methods is hard, even if the true function is simple.

One common trick in training DNN is overparameterization. That is, you use (far) more parameters than needed to fit the function. I modified your code a little bit and it produced the desired results.

#include <Eigen/Core>
#include <MiniDNN.h>

using namespace MiniDNN;

typedef Eigen::MatrixXd Matrix;
typedef Eigen::VectorXd Vector;

int main()
{
    std::srand(123);

    Matrix inputs(2, 4);
    inputs << 0, 0, 1, 1,
              0, 1, 0, 1;

    Matrix outputs(1, 4);
    outputs << 0, 1, 1, 0;

    std::cout << "input =\n" << inputs << std::endl;
    std::cout << "output = " << outputs << std::endl;

    // Construct a network object
    Network net;

    // Create layers
    Layer* layer1 = new FullyConnected<ReLU>(2, 100);     // 2 input, 100 hidden
    Layer* layer2 = new FullyConnected<Sigmoid>(100, 1); // 1 output

    // Add layers to the network object
    net.add_layer(layer1);
    net.add_layer(layer2);

    // Set output layer
    net.set_output(new RegressionMSE());

    // Optimizer
    RMSProp opt;
    opt.m_lrate = 0.01;

    VerboseCallback callback;
    net.set_callback(callback);

    // Initialize parameters with N(0, 0.01^2) using random seed 123
    net.init(0, 0.01, 123);

    // Fit the model with a batch size of 4, running 1000 epochs with random seed 123
    net.fit(opt, inputs, outputs, 4, 500, 123);

    Matrix pred = net.predict(inputs);
    std::cout << pred << std::endl;

    std::cin.get();
    return 0;
}

Output:

input =
0 0 1 1
0 1 0 1
output = 0 1 1 0
[Epoch 0, batch 0] Loss = 0.125015
[Epoch 1, batch 0] Loss = 0.124969
[Epoch 2, batch 0] Loss = 0.124943
[Epoch 3, batch 0] Loss = 0.12491
[Epoch 4, batch 0] Loss = 0.124871
[Epoch 5, batch 0] Loss = 0.124817
[Epoch 6, batch 0] Loss = 0.124751
[Epoch 7, batch 0] Loss = 0.124656
[Epoch 8, batch 0] Loss = 0.124524
[Epoch 9, batch 0] Loss = 0.124371
[Epoch 10, batch 0] Loss = 0.124134
[Epoch 11, batch 0] Loss = 0.12382
[Epoch 12, batch 0] Loss = 0.123449
[Epoch 13, batch 0] Loss = 0.122944
[Epoch 14, batch 0] Loss = 0.122254
[Epoch 15, batch 0] Loss = 0.121505
[Epoch 16, batch 0] Loss = 0.120556
[Epoch 17, batch 0] Loss = 0.119285
[Epoch 18, batch 0] Loss = 0.117998
[Epoch 19, batch 0] Loss = 0.116465
[Epoch 20, batch 0] Loss = 0.114548
[Epoch 21, batch 0] Loss = 0.112635
[Epoch 22, batch 0] Loss = 0.110376
[Epoch 23, batch 0] Loss = 0.107858
[Epoch 24, batch 0] Loss = 0.105298
[Epoch 25, batch 0] Loss = 0.102734
...
[Epoch 490, batch 0] Loss = 4.97365e-05
[Epoch 491, batch 0] Loss = 4.95598e-05
[Epoch 492, batch 0] Loss = 4.94494e-05
[Epoch 493, batch 0] Loss = 4.92948e-05
[Epoch 494, batch 0] Loss = 4.9163e-05
[Epoch 495, batch 0] Loss = 4.90414e-05
[Epoch 496, batch 0] Loss = 4.88854e-05
[Epoch 497, batch 0] Loss = 4.87567e-05
[Epoch 498, batch 0] Loss = 4.86316e-05
[Epoch 499, batch 0] Loss = 4.84858e-05
0.0108134  0.990748  0.991276 0.0103999

The differences I made were:

  1. Changed the first activation function to ReLU
  2. Used more hidden units
  3. Ran for more epochs

from minidnn.

Related Issues (16)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.