Git Product home page Git Product logo

contrib's Introduction

torch-contrib

This repository contains reviewed implementations of ideas from recent machine learning papers.

Installation

pip:

pip install torchcontrib

From source:

python setup.py install

contrib's People

Contributors

andrewgordonwilson avatar apaszke avatar balandat avatar crowsonkb avatar izmailovpavel avatar mdraw avatar soumith avatar ssnl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

contrib's Issues

SWA test broken

test_swa fails with the following:

test/test_swa.py::TestSWA::test_swa
  /Users/balandat/Code/contrib/torchcontrib/optim/swa.py:200: UserWarning: SWA wasn't applied to param tensor([[-1.3171e-01,  3.3898e-01,  3.9688e-01,  9.4115e-01, -1.0891e-01],
          [-6.1716e-01, -9.9478e-02, -3.5716e-01,  6.9384e-01, -2.1573e-01],
          [-9.2800e-01,  2.5287e-01, -9.5098e-01, -3.5055e-01,  6.5104e-01],
          [ 6.6169e-01,  7.8536e-01,  6.6425e-01, -1.6037e+00, -6.1215e-01],
          [ 5.0540e-01, -8.8359e-01,  9.3367e-01, -3.3240e-01,  3.8069e-01],
          [ 1.0775e-01,  3.0351e-01,  1.3170e+00, -5.3968e-01,  1.8907e-01],
          [-4.6874e-02,  9.1908e-01, -1.7464e+00, -7.5991e-01, -4.2205e-01],
          [ 1.0197e+00,  1.0610e-01,  8.9256e-01,  5.9761e-01, -1.9891e-03],
          [ 3.0417e+00, -8.8183e-01, -4.5786e-01, -6.0704e-01,  3.2223e-01],
          [-1.6168e+00,  9.2239e-01, -2.0307e-01, -1.2542e-01,  1.5070e+00]],
         requires_grad=True); skipping it
    warnings.warn(

test/test_swa.py::TestSWA::test_swa
  /Users/balandat/Code/contrib/torchcontrib/optim/swa.py:200: UserWarning: SWA wasn't applied to param tensor([ 0.5962, -0.0670,  0.4809, -1.0233, -1.6968,  0.0400,  0.7110, -0.4441,
          -0.3767,  0.0970], requires_grad=True); skipping it
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/warnings.html
================================================================================================================== short test summary info ===================================================================================================================
FAILED test/test_swa.py::TestSWA::test_swa - AssertionError: tensor(0.1980, grad_fn=<MaxBackward1>) not less than or equal to 1e-05 :

Wanted feature signaling / voting

An idea:

When someone releases a code package using PyTorch (to accompany a paper or whatever), they could put in the github root some file that highlights the features the package implements and the missing PyTorch features that the authors think are useful for PyTorch core / contrib.

Then one can write a simple script that crawls all of github under PyTorch tag and merges the signaling files (potentially weighted by stars / fork count of the repo). Then PyTorch team can have some easy analytics of features users miss (a natural github-native uservoice alternative) and prioritize them, if decided. Potentially, PyTorch may curate a list of feature names (so that no hand merging is needed).

Also, PyTorch may analyze numpy / scipy functions popular PyTorch-using code packages are using.

@soumith @apaszke

swa, type mismatch

I use swa.py in this way. Is this the proper way?

optimizer0 = optim.SGD(model.parameters(), 1e-4,momentum=0.9, weight_decay=weight_decay)
optimizer  =  SWA(optimizer0)

for epoch in range(num_epoch):
    train(train_loader, model, criterion, optimizer, epoch)
    ###########
    optimizer.swap_swa_sgd()
    avg_loss , avg_acc= validate(val_loader, model , criterion)
    optimizer.swap_swa_sgd()
    ###########
    if epoch == 0  : 
        optimizer.bn_update(train_loader, model, device='cuda')		
    if get_learning_rate(optimizer) <  5e-7 or is_lowest_loss :
        if epoch < 8:
            torch.save(state, './model/checkpoint' %file_name +'_%s.pth.tar' %epoch )
        else :
            optimizer.bn_update(train_loader, model, device='cuda')
            optimizer.update_swa()
            optimizer.swap_swa_sgd()
            torch.save(state, './model/checkpoint' %file_name +'_%s.pth.tar' %epoch )
            optimizer.swap_swa_sgd()

And for def bn_update(),

   swa.py", line 302
    model(input)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 141, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ts/code05_bin_shuff/snetv27.py", line 118, in forward
    x = self.conv1(x)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same

I do not use pytorch for long time. It is the first time for me to see torch.cuda.DoubleTensor.
model(input.type(torch.cuda.FloatTensor)) could fix this error. But how does it happen?

PS: There are lots of warnings in model.eval() as

SWA wasn't applied to param {}; skipping it".format(p)

Should I fix it, or not?

entropy of a tensor

can we have feature enhancement that we can compute the entropy of a tensor, the similar way that we can do it for mean, std, etc?
in more details:
It would be super useful to have a function that compute the entropy of a tensor. well, saying that, it is good if that function can compute the entropy in different ways depending on our desire; meaning that we can define a dimension and compute the entropy based on that, e.g. computiong the entropy of a tensor channel wise, or etc

This request was first made Here and they mentioned that it will be more useful to open it here

swa with CyclicLR will get error

if I use swa with CyclicLR, will get error

 File "/usr/local/lib/python3.5/dist-packages/torch/optim/lr_scheduler.py", line 586, in __init__
    if 'momentum' not in optimizer.defaults:
AttributeError: 'SWA' object has no attribute 'defaults'

How can we contribute?

Thank you for creating the repository, I think it will be very helpful to many of us. I got a few questions:

  1. How will the process look like to contribute to this repository?
  2. Will there be a list of ideas, which should be implemented?

Best,
Max

The position of bn_update?

If model has batch normalization layers, where should I use the bn_update()?

for _ in range(100):
     opt.zero_grad()
     loss_fn(model(input), target).backward()
     opt.step()
opt.swap_swa_sgd()
opt.bn_update(train_loader, model)

Is this setting correct?

Or the following case is right?

for _ in range(100):
      opt.zero_grad() 
      loss_fn(model(input), target).backward()`
      opt.step()
opt.bn_update(train_loader, model)
opt.swap_swa_sgd()

Contribution Ideas/Wishlist

Most core functions are already in PyTorch. Here are some which may be better suited for this repository.

Shake-Shake Regularization

This regularizer greatly reduces error rates on CIFAR-10 and CIFAR-100. Here is a PyTorch implementation.

Non-Correlating Multiplicative Noise

NCMN is like Shake-Shake regularization but requires far fewer epochs. They have a PyTorch implementation.

Random Erasure Data Augmentation

This data augmentation technique appears to require less tuning than Cutout and is used in some computer vision papers.
An implementation is here.

GELU

This paper introduced the GELU and Swish (by a different name). This is used in some NLP work such as BERT and NLP libraries.

class GELU(Module):
    def __init__(self, fast=True):
        super(GELU, self).__init__()
        self.fast = fast

    def forward(self, input):
        if self.fast:
            return x * torch.sigmoid(1.702 * x)
        else:
            return 0.5 * x * (1 + torch.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))

Swish

This work reinvented x * sigmoid(x) over a year later and greatly popularized the nonlinearity. Later versions add a learnable parameter due to novelty concerns, but its value is less clear. Hence the implementation might look like

class Swish(Module):
    def forward(self, input):
        return x * torch.sigmoid(x)  

Convolutional Block Attention Module

This iterates on the Squeeze-and-Excitation module by adding a spatial component.
A PyTorch implementation is here. According to the authors, the CBAM usually works better than the BAM.

To One-Hot

There is no simple, built-in function to convert a list of n labels into a n x num_classes one-hot matrix.

SSIM

SSIM is a differentiable loss function used in some computer vision work. A PyTorch implementation is here.

Since this includes several disparate ideas, feel free to close.

[Bug] Unit test fails on multi-GPU setup

running test_swa.py on a device with multiple GPUs results in the following:

Test output:
> test_swa (test.test_swa.TestSWA) ... ERROR
>
> ======================================================================
> ERROR: test_swa (test.test_swa.TestSWA)
> ----------------------------------------------------------------------
> Traceback (most recent call last):
>   File "/data/users/balandat/fbsource/fbcode/buck-out/opt/gen/pytorch/contrib/test_torchcontrib#binary,link-tree/test/test_swa.py", line 313, in test_swa
>     lambda weight, bias: constructor([weight, bias]))
>   File "/data/users/balandat/fbsource/fbcode/buck-out/opt/gen/pytorch/contrib/test_torchcontrib#binary,link-tree/test/test_swa.py", line 238, in _test_basic_cases
>     constructor
>   File "/data/users/balandat/fbsource/fbcode/buck-out/opt/gen/pytorch/contrib/test_torchcontrib#binary,link-tree/test/test_swa.py", line 131, in _test_basic_cases_template
>     optimizer.step(fn)
>   File "/data/users/balandat/fbsource/fbcode/buck-out/opt/gen/pytorch/contrib/test_torchcontrib#binary,link-tree/torchcontrib/optim/swa.py", line 206, in step
>     loss = self.optimizer.step(closure)
>   File "/data/users/balandat/fbsource/fbcode/buck-out/opt/gen/pytorch/contrib/test_torchcontrib#binary,link-tree/torch/optim/lbfgs.py", line 427, in step
>     self._add_grad(t, d)
>   File "/data/users/balandat/fbsource/fbcode/buck-out/opt/gen/pytorch/contrib/test_torchcontrib#binary,link-tree/torch/optim/lbfgs.py", line 264, in _add_grad
>     p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
> RuntimeError: expected device cuda:1 and dtype Double but got device cuda:0 and dtype Double

Memory problem

In distributed training, the memory of the first GPU is twice that of the other.But before the swa is applied, the GPU memory is the same.

SWA wasn't applied to param {}; skipping it".format(p))

hello, I get this error have no clue what I am doing wrong. Here's my code

                for _ in range(self.K_epochs):

                    # Evaluating old actions and values :
                    logprobs, values, dist_entropy = self.policy.evaluate(old_states, old_actions)
                    # Finding the ratio (pi_theta / pi_theta__old):
                    advantages = self.calculate_advantages(reward_batch, values.detach())
                    ratios = torch.exp(logprobs - old_logprobs)

                    # Finding Surrogate Loss:
                    surr1 = ratios * advantages
                    surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
                    actor_loss = -torch.min(surr1, surr2)
                    critic_loss = 0.5*self.MseLoss(values, returns)
                    loss = actor_loss + critic_loss + 0.5*self.MseLoss(values, returns) - 0.01*dist_entropy

                    # take gradient step
                    self.optimizer.zero_grad()
                    loss.mean().backward()
                    self.SWAoptim.step()
                self.SWAoptim.swap_swa_sgd()         
                # Copy new weights into old policy:
                self.policy_old.load_state_dict(self.policy.state_dict())
                self.loSS.append(loss.mean().item())

I get the following error,

/home/murtaza/.local/lib/python2.7/site-packages/torchcontrib/optim/swa.py:191: UserWarning: SWA wasn't applied to param Parameter containing:
tensor([[-0.0556,  0.1067,  0.0519, -0.1137,  0.0632, -0.0402,  0.0576, -0.0704,
         -0.0888, -0.1129, -0.0102,  0.0503,  0.0469, -0.0822, -0.1028, -0.0354,
         -0.0007,  0.0863, -0.0221, -0.1036,  0.0431,  0.0164,  0.0004, -0.1106,
          0.0466, -0.0283, -0.0954, -0.1001, -0.0113,  0.0089,  0.0471, -0.0335,
          0.0501,  0.0773,  0.1195, -0.0987,  0.0455, -0.0468, -0.0520, -0.1011,
         -0.0373, -0.0642,  0.0105,  0.0455,  0.0452, -0.0569, -0.0551, -0.1137,
         -0.0057,  0.0203,  0.0088,  0.0077,  0.0917, -0.1203,  0.0266,  0.0904,
         -0.0180,  0.0097, -0.0717, -0.0547, -0.0954,  0.1197,  0.0836,  0.0938]],
       requires_grad=True); skipping it
  "SWA wasn't applied to param {}; skipping it".format(p))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.