Git Product home page Git Product logo

Comments (4)

GLmontanari avatar GLmontanari commented on June 10, 2024

Update.
I tried to calculate the normalized confusion matrix using this function

void print_norm_confmat(arma::Row<size_t> &predictedLabels, arma::Row<size_t> &testlabel)
{
	arma::mat conf_mat;
	mlpack::data::ConfusionMatrix(predictedLabels, testlabel, conf_mat, 2);

	arma::colvec row_sums = arma::sum(conf_mat, 1);

	conf_mat.row(0) /= row_sums[0];
	conf_mat.row(1) /= row_sums[1];
	conf_mat.print();
}

And I noticed that for some values of max_depth the second row returns Nan:
image

from mlpack.

rcurtin avatar rcurtin commented on June 10, 2024

What max_depth parameters did you try? The extra trees algorithm uses a completely random binary split, so with a depth of only 4, there are at most 2^4 = 16 possible splits. The whole idea of the extra trees algorithm is that the trees used are very large (but random, and thus quick to train). I suspect you will see better performance with a maximum depth of more like 10 (and likely higher will be necessary to achieve very good performance, depending on the dataset).

If you want to save disk space, also, try saving as .bin instead of .xml (the XML representation is really huge). But, at some level, the extra trees algorithm is just a large model.

For your confusion matrix, I am guessing you are seeing NaNs because row_sums[1] is 0.

Hope that's helpful---let me konw if I can clarify anything.

from mlpack.

GLmontanari avatar GLmontanari commented on June 10, 2024

Thank you for your reply! Yes, I did some benchmarks with values of max_depth from 0 to 10 and I get reasonable values as soon as max_depth is 5 or more. I tried it also with the RandomForest algorithm and I have the same situation.
I don't understand why it works at max_detph = 0 which I read from the code is the default value and works if I set it by hand.

from mlpack.

GLmontanari avatar GLmontanari commented on June 10, 2024

I have another question: the number of trees in the forest you get with NumTrees() in most cases is not the number of trees you set when you train the classifier, like this
clf.Train(traindata, trainlabel, numClasses, num_trees, mleaf, mgain, max_depth, false);

Is it because max_depth != 0 ?

from mlpack.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.