Git Product home page Git Product logo

hierarchicalsoftmax's Introduction

hierarchicalsoftmax

testing badge coverage badge docs badge black badge

A Hierarchical Softmax Framework for PyTorch.

Installation

hierarchicalsoftmax can be installed using pip from the git repository:

pip install git+https://github.com/rbturnbull/hierarchicalsoftmax.git

Usage

Build up a hierarchy tree for your categories using the SoftmaxNode instances:

from hierarchicalsoftmax import SoftmaxNode

root = SoftmaxNode("root")
a = SoftmaxNode("a", parent=root)
aa = SoftmaxNode("aa", parent=a)
ab = SoftmaxNode("ab", parent=a)
b = SoftmaxNode("b", parent=root)
ba = SoftmaxNode("ba", parent=b)
bb = SoftmaxNode("bb", parent=b)

The SoftmaxNode class inherits from the anytree Node class which means that you can use methods from that library to build and interact with your hierarchy tree.

The tree can be rendered as a string with the render method:

root.render(print=True)

This results in a text representation of the tree:

root
├── a
│   ├── aa
│   └── ab
└── b
    ├── ba
    └── bb

The tree can also be rendered to a file using graphviz if it is installed:

root.render(filepath="tree.svg")

An example tree rendering.

Then you can add a final layer to your network that has the right size of outputs for the softmax layers. You can do that manually by setting the output number of features to root.layer_size. Alternatively you can use the HierarchicalSoftmaxLinear or HierarchicalSoftmaxLazyLinear classes:

from torch import nn
from hierarchicalsoftmax import HierarchicalSoftmaxLinear

model = nn.Sequential(
    nn.Linear(in_features=20, out_features=100),
    nn.ReLU(),
    HierarchicalSoftmaxLinear(in_features=100, root=root)
)

Once you have the hierarchy tree, then you can use the HierarchicalSoftmaxLoss module:

from hierarchicalsoftmax import HierarchicalSoftmaxLoss

loss = HierarchicalSoftmaxLoss(root=root)

Metric functions are provided to show accuracy and the F1 score:

from hierarchicalsoftmax import greedy_accuracy, greedy_f1_score

accuracy = greedy_accuracy(predictions, targets, root=root)
f1 = greedy_f1_score(predictions, targets, root=root)

The nodes predicted from the final layer of the model can be inferred using the greedy_predictions function which provides a list of the predicted nodes:

from hierarchicalsoftmax import greedy_predictions

outputs = model(inputs)
inferred_nodes = greedy_predictions(outputs)

Credits

hierarchicalsoftmax's People

Contributors

rbturnbull avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.