Git Product home page Git Product logo

decisiontreehugger's Introduction

decisiontreehugger

๐Ÿ™ˆ what if we ๐Ÿ˜ณ learn about how ML algorithms work, not just scikit-learn about instantiate, fit, and predict ๐Ÿ˜˜

Examples use a preprocessed dataset of the Titanic passenger manifest. Recreate the process or roll your own for other data by following the Decision Tree approach outlined in this Kaggle companion notebook.

usage

Start by establishing training and test data:

>>> import pandas
>>>
>>> train = pandas.read_csv('./example_data/titanic_train_preprocessed.csv')
>>> test  = pandas.read_csv('./example_data/titanic_test_preprocessed.csv')
expand to view example train output

>>> train
           Age     Fare  Embarked_C  Embarked_Q  Embarked_S  Cabin_A  Cabin_B  ...  Master  Miss  Mr  Mrs  Officer  Royalty  Survived
0    22.000000   7.2500           0           0           1        0        0  ...       0     0   1    0        0        0       0.0
1    38.000000  71.2833           1           0           0        0        0  ...       0     0   0    1        0        0       1.0
2    26.000000   7.9250           0           0           1        0        0  ...       0     1   0    0        0        0       1.0
3    35.000000  53.1000           0           0           1        0        0  ...       0     0   0    1        0        0       1.0
4    35.000000   8.0500           0           0           1        0        0  ...       0     0   1    0        0        0       0.0
..         ...      ...         ...         ...         ...      ...      ...  ...     ...   ...  ..  ...      ...      ...       ...
886  27.000000  13.0000           0           0           1        0        0  ...       0     0   0    0        1        0       0.0
887  19.000000  30.0000           0           0           1        0        1  ...       0     1   0    0        0        0       1.0
888  29.881138  23.4500           0           0           1        0        0  ...       0     1   0    0        0        0       0.0
889  26.000000  30.0000           1           0           0        0        0  ...       0     0   1    0        0        0       1.0
890  32.000000   7.7500           0           1           0        0        0  ...       0     0   1    0        0        0       0.0

[891 rows x 29 columns]

expand to view example test output

>>> test
           Age      Fare  Embarked_C  Embarked_Q  Embarked_S  Cabin_A  Cabin_B  ...  Pclass_3  Master  Miss  Mr  Mrs  Officer  Royalty
0    34.500000    7.8292           0           1           0        0        0  ...         1       0     0   1    0        0        0
1    47.000000    7.0000           0           0           1        0        0  ...         1       0     0   0    1        0        0
2    62.000000    9.6875           0           1           0        0        0  ...         0       0     0   1    0        0        0
3    27.000000    8.6625           0           0           1        0        0  ...         1       0     0   1    0        0        0
4    22.000000   12.2875           0           0           1        0        0  ...         1       0     0   0    1        0        0
..         ...       ...         ...         ...         ...      ...      ...  ...       ...     ...   ...  ..  ...      ...      ...
413  29.881138    8.0500           0           0           1        0        0  ...         1       0     0   1    0        0        0
414  39.000000  108.9000           1           0           0        0        0  ...         0       0     0   0    0        0        1
415  38.500000    7.2500           0           0           1        0        0  ...         1       0     0   1    0        0        0
416  29.881138    8.0500           0           0           1        0        0  ...         1       0     0   1    0        0        0
417  29.881138   22.3583           1           0           0        0        0  ...         1       1     0   0    0        0        0

[418 rows x 28 columns]

To get a baseline sense of expected behaviour, let's look at boolean-case survival statistics for our training set:

>>> passengers = train.Survived
>>> survived   = sum(p for p in passengers if p == 1.0)
>>>
>>> survived / len(passengers)
0.3838383838383838

proto-model

The ProtoTree is the rudimentary proto-model for the DecisionTree. We aren't creating and evaluating a decision tree here, but rather verifying the foundational integrity of our process.

>>> from models.base import ProtoTree
>>>
>>> dt = ProtoTree()
>>> dt.fit(data=train, target='Survived')
>>>
>>> # we expect the aforementioned survival rate, consistent to all rows:
>>> predictions = dt.predict(test)
>>> predictions[:3]
array([[0.61616162, 0.38383838],
       [0.61616162, 0.38383838],
       [0.61616162, 0.38383838]])

We receive the reassuring but useless projection of a 0.38383838 survival rate for our test data, indicating that the training data probabilities have been processed correctly.

decision tree

Expand this to read context on the design parametrisation of the decision tree.

properties:

  • contains a root node

  • each node may have a left and right branch

  • bottom-layer nodes do not have branches

considerations:

  • prioritise the most 'efficient' conditions at the top of the tree

  • branches can be recursively implemented decision trees rather than separately articulated

  • how do we create the tree to have optimal splits?

    • ideal split: an ideal split in binary classification produces homogenous branches

    • impurity: homogeneity is unrealistic - how can we decrease impurity in child node w.r.t. parent node?

      • Gini impurity

      • cross-entropy / information gain (logarithmic calculation)

Expand this to read how splitting behaviour is decided by calculation methods for impurity/misclassifaction rate.

Gini impurity:

Randomly pick a "row" / data point; it has k value cases, each with a number of occurrences in the dataset. Based on the proportional distribution of a case and the k cases it could be randomly classified as, Gini impurity is the probability sum of all incorrect classification events.

For example, assume we have received the following dataset of Yes, No, and Maybe votes for some proposition. Since we're working with a binary classification problem, we'll use an example attribute/feature with k=2 cases.

case count proportion
Yes 6 60%
No 4 40%
total (k=2) 10 100%

Here, proportion also represents probability; i.e. we have a 10% chance of randomly selecting a Maybe case.

Given the case distribution in our training data, we would e.g. randomly select Yes 60% of the time, and we would also randomly classify it as Yes 60% of the time. Therefore, the event probability of classifying Yes as Yes is 0.6 * 0.6 = 0.36. The full series of classification event probabilities would be as follows:

randomly selected selection probability randomly classified as classification event probability correctness
Yes 0.6 Yes 0.6 * 0.6 = 0.36 โœ…
Yes 0.6 No 0.6 * 0.4 = 0.24 โŒ
No 0.4 Yes 0.4 * 0.6 = 0.24 โŒ
No 0.4 No 0.4 * 0.4 = 0.16 โœ…

The Gini impurity would be the sum of all incorrect classification event probabilities:

0.24 + 0.24 = 0.48

Canonically, it is expressed by the following equation:

where k is the number of cases in the target attribute, and p is the probability of a given case existing at that node

So, let's test whether our event table and the Gini impurity equation yield the same result:

G = p(1)*(1 - p(1)) + p(2)*(1 - p(2))
  = (0.6)*(1 - 0.6) + (0.4)*(1 - 0.4)  = 0.48

information gain:

where k is the number of cases in the target attribute, r is the row count in the node, and R is the row count for the entire dataset

references

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.