Git Product home page Git Product logo

kfda's Introduction

Kernel FDA

PyPI version

This repository implements Kernel Fisher Discriminant Analysis (Kernel FDA) as described in https://arxiv.org/abs/1906.09436. FDA, equivalent to Linear Discriminant Analysis (LDA), is a classification method that projects vectors onto a smaller subspace. This subspace is optimized to maximize between-class scatter and minimize within class scatter, making it an effective classification method. Kernel FDA improves on regular FDA by enabling nonlinear subspaces using the kernel trick. The MNIST example and the Colab Notebook demonstrate 97% accuracy by only training on one-seventh of the MNIST dataset.

FDA and Kernel FDA classify vectors by comparing their projection in the fisher subspace to class centroids, adding a new class is just a matter of adding a new centroid. Thus, this model is implemented here with the hope of using Kernel FDA as a oneshot learning algorithm.

Installation

kfda is available on PyPI:

pip install kfda

Usage

Kfda uses scikit-learn's interface.

  • Initializing: cls = Kfda(n_components=2, kernel='linear') for a classifier that a linear kernel with 2 components. Use Kfda(n_components=2, kernel='poly', degree=2) for a polynomial kernel of degree 2. See https://scikit-learn.org/stable/modules/metrics.html#polynomial-kernel for a list of kernels and their parameters, or the source code docstrings for a complete description of the parameters.

  • Fitting: cls.fit(X, y)

  • Prediction: cls.predict(X)

  • Scoring: cls.score(X, y)

  • Introducing new classes without retraining (fewshot learning): cls.fit_additional(X, y)

Oneshot Learning

Oneshot learning means that an algorithm can learn a new class with as little as one sample. This is possible for Kernel FDA because it finds a subspace that purposefully spreads out distinct classes. Introducing a new label involves simply adding another centroid in this subspace for use in prediction. See the Colab Notebook or the example for examples.

Examples

See examples for examples on MNIST, faces, and oneshot learning.

After running them, you can plug corresponding pairs of generated *embeddings.tsv and *labels.tsv into Tensorflow's Embedding Projector to visualize the embeddings. For example, running mnist.py and then loading mnist_test_embeddings.tsv and mnist_test_labels.tsv shows the following using the UMAP visualizer where each color is a different digit:

MNIST Kernel FDA embeddings

The effectiveness of the classifier is shown by the clear class separation shown.

Notebook

Another place to see example usage is the Colab Notebook. The notebook walks through training, evaluation, and oneshot usage.

Caveats

Similar to SVM, the most glaring constraint of KFDA is the memory limit in training. Training a Kernel FDA classifier requires creating matrices that are n_samples by n_samples large, meaning the memory requirement grows with respect to O(n_samples^2).

The accuracy, while high (0.97 on MNIST), seems to be limited by the training set size. With a training size of 10000 and a testing size of 60000, performance on MNIST averages around 0.97 accuracy using 9 fisher directions and the RBF kernel:

cls = Kfda(kernel='rbf', n_components=9)

Accuracy can be improved without increasing training size by implementing invariant kernels that would implicitly handle scale and rotation without requiring an extended dataset.

kfda's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.