Git Product home page Git Product logo

gtn's Introduction

logo

Deprecated

This repository has been deprecated.

Future development will be in the GTN org repository.

GTN: Automatic Differentiation with WFSTs

Quickstart | Installation | Documentation

facebookresearch Documentation Status

What is GTN?

GTN is a framework for automatic differentiation with weighted finite-state transducers. The framework is written in C++ and has bindings to Python.

The goal of GTN is to make adding and experimenting with structure in learning algorithms much simpler. This structure is encoded as weighted automata, either acceptors (WFSAs) or transducers (WFSTs). With gtn you can dynamically construct complex graphs from operations on simpler graphs. Automatic differentiation gives gradients with respect to any input or intermediate graph with a single call to gtn.backward.

Also checkout the repository gtn_applications which consists of GTN applications to Handwriting Recognition (HWR), Automatic Speech Recognition (ASR) etc.

Quickstart

First install the python bindings.

The following is a minimal example of building two WFSAs with gtn, constructing a simple function on the graphs, and computing gradients. Open In Colab

import gtn

# Make some graphs:
g1 = gtn.Graph()
g1.add_node(True)  # Add a start node
g1.add_node()  # Add an internal node
g1.add_node(False, True)  # Add an accepting node

# Add arcs with (src node, dst node, label):
g1.add_arc(0, 1, 1)
g1.add_arc(0, 1, 2)
g1.add_arc(1, 2, 1)
g1.add_arc(1, 2, 0)

g2 = gtn.Graph()
g2.add_node(True, True)
g2.add_arc(0, 0, 1)
g2.add_arc(0, 0, 0)

# Compute a function of the graphs:
intersection = gtn.intersect(g1, g2)
score = gtn.forward_score(intersection)

# Visualize the intersected graph:
gtn.draw(intersection, "intersection.pdf")

# Backprop:
gtn.backward(score)

# Print gradients of arc weights 
print(g1.grad().weights_to_list()) # [1.0, 0.0, 1.0, 0.0]

Installation

Requirements

  • A C++ compiler with good C++14 support (e.g. g++ >= 5)
  • cmake >= 3.5.1, and make

Python

Install the Python bindings with

pip install gtn

Building C++ from source

First, clone the project:

git clone [email protected]:facebookresearch/gtn.git && cd gtn

Create a build directory and run CMake and make:

mkdir -p build && cd build
cmake ..
make -j $(nproc)

Run tests with:

make test

Run make install to install.

Python bindings from source

Setting up your environment:

conda create -n gtn_env
conda activate gtn_env

Required dependencies:

cd bindings/python
conda install setuptools

Use one of the following commands for installation:

python setup.py install

or, to install in editable mode (for dev):

python setup.py develop

Python binding tests can be run with make test, or with

python -m unittest discover bindings/python/test

Run a simple example:

python bindings/python/examples/simple_graph.py

License

GTN is licensed under a MIT license. See LICENSE.

gtn's People

Contributors

awni avatar csukuangfj avatar facebook-github-bot avatar jacobkahn avatar shubho avatar vineelpratap avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gtn's Issues

Full end to end example for RNN-CTC-WFST

Do you have any plan to open source a complete example of a RNN-CTC-WFST recognizer? Currently most existing examples only cover several parts rather than a full recognizer pipeline.

Thank you!

set_weights has some numerical instability around zero values

Hi @awni!

I followed a similar implementation to the TransducerLossFunction in gtn_applications. I'm wondering whether you also experienced some instability when using these lines:

cpu_data = transition_params.cpu().contiguous()
transitions.set_weights(cpu_data.data_ptr())

I experienced this problem where a zero value in a tensor could become a very large, very small, or even nan value after passing through the cycle of gpu_tensor-->cpu_tensor-->gtn-->numpy.

I reproduced this behavior here: https://colab.research.google.com/drive/1AuZJlukSEwadrH6cokrqEjABBWuHJ8wQ?usp=sharing

For example, two of my zero values in a GPU tensor shifted while the remaining did not. And this does not happen if the tensor is only on CPU.

Do you have any suggestions on how to prevent this?

Copying @siddalmia @sw005320 as well.

GTN name

So what does the name mean? My only guess is FSM, plus one letter. (?)

LogAdd and -Ofast

if (a == kNegInf) {

Just a warning for the future (you guys probably know this already, but just in case), that any comparisons with infinities will fail if you ever try to compile with -Ofast. I seem to remember situations where -Ofast is implicit- possibly at high optimization levels like -O3, but I'm not sure.

We encountered this in Kaldi, and it's very hard to work around because the compiler is very clever at detecting implicit checks for infinity and optimizing them out. Will close the issue after creating it because it's an FYI more than an issue.

Differentiable beam search with GTN-WFST

Hi,

I noticed that in 2019, the authors published a paper on differentiable beam search decoder. With GTN, is it possible to implement a RNN-CTC-WFST-Beam search decode speech recognizer or handwriting recognizer using PyTorch?

As previously (both in Kaldi and the 2019 paper), extra code in C++/Shell/Python has to be done to implement such an end to end recognizer, it would be great if everything can be done with PyTorch. If that's possible, is there any chance that a full end to end demo/toy example can be provided in the future?

Thanks!

Including GTN graphs as trainable parameters in a model

Hi!

I would like to keep a gtn.graph object in my model that is updated with each optimizer.step(). I may be wrong, but I doubt that the pytorch optimizer would treat the weights in this graph as parameters and update them with that call.

For example, the code in gtn_application runs some model to get emission probabilities and computes CTC loss from gtn.intersect(g_emissions, g_criterion). This works because g_emissions is being created each time from the tensor representation of emissions.

Now, if I want to insert a new_graph so that my CTC loss comes from gtn.intersect(gtn.intersect(g_emissions, new_graph), g_criterion), is there a way to keep the parameters of new_graph in a gtn.graph and have them be updated with pytorch's optimizer.step? Or does new_graph need to be created each time from tensors and the set_weights call?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.