Git Product home page Git Product logo

sklearn2java's Introduction

SkLearn2Java

This project aims to used text exported ML models generated by sci-kit learn and make them usable in Java.

javadoc

Support

  • The tree.DecisionTreeClassifier is supported
    • Supports predict(),
    • Supports predict_proba() when export_text() configured with show_weights=True
  • The tree.RandomForestClassifier is supported
    • Supports predict(),
    • Supports predict_proba() when export_text() configured with show_weights=True

Installing

Importing Maven Dependency

<dependency>
  <groupId>rocks.vilaverde</groupId>
  <artifactId>scikit-learn-2-java</artifactId>
  <version>1.1.0</version>
</dependency>

DecisionTreeClassifier

As an example, a DecisionTreeClassifier model trained on the Iris dataset and exported using sklearn.tree export_text() as shown below:

>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'], show_weights=True, max_depth=sys.maxsize)
>>> print(r)

|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2

The exported text can then be executed in Java. Note that when calling export_text it is recommended that max_depth be set to sys.maxsize so that the tree isn't truncated.

Java Example

In this example the iris model exported using export_text is parsed, features are created as a Java Map and the decision tree is asked to predict the class.

    Reader tree = getTrainedModel("iris.model");
    final Classifier<Integer> decisionTree = DecisionTreeClassifier.parse(tree,
                PredictionFactory.INTEGER);

    Features features = Features.of("sepal length (cm)",
                "sepal width (cm)",
                "petal length (cm)",
                "petal width (cm)");
    FeatureVector fv = features.newSample();
    fv.add(0, 3.0).add(1, 5.0).add(2, 4.0).add(3, 2.0);
    
    Integer prediction = decisionTree.predict(fv);
    System.out.println(prediction.toString());

RandomForestClassifier

To use a RandomForestClassifier that has been trained on the Iris dataset, each of the estimators
in the classifiers need to be and exported using from sklearn.tree export export_text as shown below:

>>> from sklearn import datasets
>>> from sklearn import tree
>>> from sklearn.ensemble import RandomForestClassifier
>>> 
>>> import os
>>> 
>>> iris = datasets.load_iris()
>>> X = iris.data
>>> y = iris.target
>>> 
>>> clf = RandomForestClassifier(n_estimators = 50, n_jobs=8)
>>> model = clf.fit(X, y)
>>> 
>>> for i, t in enumerate(clf.estimators_):
>>>     with open(os.path.join('/tmp/estimators', "iris-" + str(i) + ".txt"), "w") as file1:
>>>         text_representation = tree.export_text(t, feature_names=iris.feature_names, show_weights=True, decimals=4, max_depth=sys.maxsize)
>>>         file1.write(text_representation)

Once all the estimators are exported into /tmp/estimators, you can create a TAR archive, for example:

cd /tmp/estimators
tar -czvf /tmp/iris.tgz .

Then you can use the RandomForestClassifier class to parse the TAR archive.

    import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
    ...
    
    TarArchiveInputStream tree = getArchive("iris.tgz");
    final Classifier<Double> decisionTree = RandomForestClassifier.parse(tree,
                PredictionFactory.DOUBLE);

Testing

Testing was done using models exported using sci-kit learn version 1.1.3, but should work with newer versions of sci-kit learn.

sklearn2java's People

Contributors

dependabot[bot] avatar dvilaverde avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

ericwittmann

sklearn2java's Issues

Add support for multiple samples when calling predict() and predict_proba()

Is your feature request related to a problem? Please describe.
It would be nice to support predictions on multiple samples, this would be especially useful when there are many samples and using uses multiple threads could speed up the execution time.

Describe the solution you'd like
A similar API to the sci-kit learn predict() and predict_proba() where an array of feature vectors are provided and an array response is produced. There is not requirement to use Java arrays, and List would work as well.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.