Git Product home page Git Product logo

treebomination's Introduction

logo

CI (License MIT 1.0)

Disclaimer: This is just a fun experiment, I conducted for my own curiosity and entertainment. It's not intended to be useful for anything else.

treebomination

Treebomination is a way to convert a sklearn.tree.DecisionTreeRegressor into a (roughly) equivalent tf.keras.Model.

When is this helpful?

  • You irrationally dislike decision trees (e.g., for their stepwise behavior) and feel neural networks (with their smoothness) are much cooler. ๐Ÿคช
  • You want to prove a point about neural networks. ๐Ÿ‘จโ€๐Ÿซ
  • You think that converting the tree to a NN and then fine-tuning it might decrease the value of your less metric on your test set. ๐Ÿช„
  • You have a well-working decision tree but want to only use TensorFlow or frugally-deep in production. ๐Ÿ’พ
  • You want to back up the claims of your marketing department about your team using "AI". ๐Ÿ‘จโ€๐Ÿ’ผ

When is not useful?

  • You care about the performance of your predictions. ๐ŸŒ
  • You care about the precision of your results. ๐Ÿน
  • You care about the size of your final application. ๐Ÿฆ

So, it is highly recommended to not actually use this for anything serious.

If you seriously consider replacing trees with NNs, I recommend having a look at the following projects instead:

Usage

from sklearn.tree import DecisionTreeRegressor
from treebomination import treebominate

my_decision_tree_regressor = DecisionTreeRegressor()
# ... training ...
model = treebominate(my_decision_tree_regressor)

Origin story

From some unbridled thoughts:

  • A Decision tree is a fancy way of having nested if statements.
  • A simple logistic regression on a one-dimensional input acts like a fuzzy threshold (or an if statement).
  • A neuron in an artificial neural network acts can act as a single logistic regression node.
  • A sigmoid (activation function) with a steeper slope ("edginess") acts like a less-fuzzy threshold:

edginess = 1:

smooth_sigmoid

edginess = 10:

steep_sigmoid

So the following idea arose: There should be a morphism from binary decision trees to neural networks, it shouldโ„ข๏ธ be possible to emulate every decision tree with a neural network, i.e., derive the network architecture from the tree and initialize the weights and biases such that the output of the network is similar to the output of the tree.

Structure of the generated neural networks

There might be much more intelligent ways to "encode" a decision tree as a neural network, but treebomination uses the following approach.

Each decision node from the tree is simulated by two neurons (each one represented as a dense layer with a singleton shape). The threshold-ish behavior results from the neuron having a very high input weight (steep and "sudden" sigmoid) and the bias chosen such that the "middle" of the sigmoid falls into the (scaled) threshold value. The output values (booleans, encoded as fuzzy 0 and 1) signal if this path of the tree is taken. For subsequent neurons, this incoming signal is multiplied onto their output value, such that not-taken paths are silenced for further output. The final "leaf" neurons use linear activation, have a bias of 0, and their initial weight set to the intended output value. Since only one of these final neurons gets an input signal, their outputs can be combined by summing them up.

Initializing the neural network this way makes it output (almost) the exact same predictions as the tree does.

So a simple three like this

simple_tree

results in the following NN:

simple_nn

A DecisionTreeRegressor with a higher max_depth (3 in the case below) like the following:

|--- feature_3 <= 7.50
|   |--- feature_3 <= 6.50
|   |   |--- feature_15 <= 1131.50
|   |   |   |--- value: [115593.60]
|   |   |--- feature_15 >  1131.50
|   |   |   |--- value: [149818.43]
|   |--- feature_3 >  6.50
|   |   |--- feature_15 <= 2093.50
|   |   |   |--- value: [197758.96]
|   |   |--- feature_15 >  2093.50
|   |   |   |--- value: [284680.23]
|--- feature_3 >  7.50
|   |--- feature_3 <= 8.50
|   |   |--- feature_15 <= 1928.00
|   |   |   |--- value: [250284.08]
|   |   |--- feature_15 >  1928.00
|   |   |   |--- value: [314964.80]
|   |--- feature_3 >  8.50
|   |   |--- feature_31 <= 517.50
|   |   |   |--- value: [372716.17]
|   |   |--- feature_31 >  517.50
|   |   |   |--- value: [745000.00]

results in a ridiculously complex abomination of a neural-network architecture.

model

In reality, trees are often much deeper than that, which not only results in a very large (and slow) model, but also the precision of the results suffers.

But hey, at least in this toy example (trained on the numerical features from the Kaggle competition "House Prices - Advanced Regression Techniques", see tests) the R2 score of the NN (0.766), is slightly higher than the one of the tree (0.765). With a quick re-training on the same data, it even improves a bit more (to 0.770). ๐ŸŽ‰

treebomination's People

Contributors

dobiasd avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.