Git Product home page Git Product logo

murenn's Introduction

MuReNN: Multiresolution Neural Networks

@inproceedings{lostanlen2023fitting,
  title={Fitting Auditory Filterbanks with Multiresolution Neural Networks},
  author={Lostanlen, Vincent and Haider, Daniel and Han, Han and Lagrange, Mathieu and Balazs, Peter and Ehler, Martin},
  booktitle={Proceedings of the IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)},
  pages={1--5},
  year={2023},
  organization={IEEE}
}

murenn's People

Contributors

danedane-haider avatar lostanlen avatar xir4n avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Forkers

banalasaritha

murenn's Issues

`plot_receptive_field`: all questions answered

Some questions came up after the review of #49 (which closed #48):

  • should the theoretical bound be T * (2**J) or half of that? (i think it's half)
  • should the magnitude of the gradient be squared? (maybe not ...)
  • should we aim for a ~ 1/T power law? or ~ 1/sqrt(T)? or a constant? this has implications for the way we define the Conv1D operator in between DTCWT Direct and Inverse

Let's discuss and then make the appropriate edits to the script

Input data shape for `DTCWT`

Now that DTCWT requires the input data to be a tensor of shape (B, C, 2**J), where J is the number of wavelet scales, it can be inconvenient to apply to real data. However, the dtcwt.np package offers a less strict requirement for the input shape. Instead of (B, C, 2**J), it only requires the input length to be even. To accommodate this, I made some modifications to achieve the same results.

depend on `dtcwt` rather than `pytorch_wavelets`?

Since #3, we use dtcwt in unit tests to check for approximate equality of DTCWT's between NumPy and PyTorch.

How about we take advantage of this to drop our dependency on pytorch_wavelets and depend on dtcwt instead?

Energy issue in `MuReNNDirect`

  1. to make the output signal's energy approximately at the same level crossing scales
  2. a huge amount of negative valued signal appeared after down_j

see code below

J = 8
Q = 1
T = 10
murenn_layer = MuReNNDirect(
    J=J, 
    Q=Q, 
    T=T, 
    in_channels=1,
)
# shape (B, C, Q, J, T)
Y = murenn_layer(X).detach()

# permute to (C, Q, J, T, B)
Y = Y.permute((1, 2, 3, 4, 0))

# reshape to (C*Q*J, T*B)
Y  = Y.reshape((Q*J, -1))
plt.boxplot(Y)

result
image

Padding modes

Problem related to the DTCWTForward output length
Description of the problem:
When decomposing tensors using DTCWTForward, the length of the input tensor doesn't always divide evenly into the length of the output tensor. Refer to the code snippet below:

T = 2**10
x = torch.zeros((1, 1, T), dtype=torch.float32)
x[0, 0, T//2] = 1
x_phis = []
x_psis = []
tfm = DTCWTForward(J=7, alternate_gh=True, include_scale=False)
x_phis, x_psis = tfm.forward(x)

print([x_psi.resolve_conj().numpy().shape[2] for x_psi in x_psis])
[1024, 1026, 515, 259, 131, 67, 35]

Solution :
Adjusted the padding parameter, still using zero-padding, but modified the amount of padding applied to the input to ensure the output length is "correct."

`MuReNNDirect`

with parameters J, Q, T, padding_mode
and any other parameter you deem necessary from either DTCWT or Conv1D

good luck! 🍀

Test gradient of Direct composed with Forward

Example

N = 2**10
x = torch.zeros(1, 1, N, requires_grad=True)
J = 8

dtcwt = murenn.DTCWT(J=J)
idtcwt = murenn.DTCWTInverse(J=J)

x_phi, x_psis = dtcwt(x)
y = idtcwt((x_phi, x_psis))
y0 = y[0, 0, N//2]
y0.backward()

Check that x.grad is a Dirac ☺️

Args for DTCWTInverse

Now DTCWTDirect has 2 outputs while DTCWTInverse.forward() has only one parameter, which is a tuple, for example:

J = 8
dtcwt = murenn.DTCWT(J=J)
idtcwt = murenn.DTCWTInverse(J=J)
x_phi, x_psis = dtcwt(x)
y = idtcwt((x_phi, x_psis))

It's better if we have 2 parameters for DTCWTInverse:

J = 8
dtcwt = murenn.DTCWT(J=J)
idtcwt = murenn.DTCWTInverse(J=J)
x_phi, x_psis = dtcwt(x)
y = idtcwt(x_phi, x_psis)

`DTCWTInverse`

Now that we have a compliant DTCWTForward, it would be good to write the inverse operator in pytorch. This is necessary for:

  • analyzing the effective receptive field of a MuReNN layer
  • analyzing various other quantities like phase distortion
  • making generative 1D models

We should try to follow the semantics of pytorch_wavelets to the extent that this is possible

`MuReNNDirect.to_conv1d()`

I suggest we have a method in MuReNNDirect that would compute the single-resolution equivalent impulse response of the MuReNN layer. This would be helpful for visualization in Fourier domain, for receptive fields, and for comparing computational costs.

The to_conv1d method should return a complex-valued PyTorch tensor of shape (Q*J, (2**J)*T). It would involve IDTCWT. This can be done efficiently by introducing an object @property which would be computed on demand.

   DTCWT        conv1d        IDTCWT
δ -------> ψ_j --------> w_j -------> y_j

I volunteer to do this.

`MuReNNDirect`: Conv1D should be before abs2

compare with waspaa paper

    def forward(self, x):
        x = x.reshape(x.shape[0], 1, x.shape[-1])
        _, x_levels = self.tfm.forward(x)
        Ux = []
        
        for j_psi in range(1+self.J_psi):
            x_level = x_levels[j_psi].type(torch.complex64) / (2**j_psi)
            Wx_real = self.psis[j_psi](x_level.real)
            Wx_imag = self.psis[j_psi](x_level.imag)
            Ux_j = Wx_real * Wx_real + Wx_imag * Wx_imag
            Ux_j = torch.real(Ux_j)
            if j_psi == 0:
                N_j = Ux_j.shape[-1]
            else:
                Ux_j = Ux_j[:, :, :N_j]
            Ux.append(Ux_j)

        Ux = torch.cat(Ux, axis=1)

https://github.com/lostanlen/lostanlen2023waspaa/blob/main/student.py

`murenn.dtcwt.Downsampling`

In the spirit of designing MuReNN for audio classification, it would be good to have a time downsampling operator which applies the low-pass filter recursively J times and returns the low-pass filtered signal x_phi.

down = Downsampling(J=J)
x_phi = down(x)

This x_phi is be the same as in x_phi, x_psis = dtcwt(x) yet the computation is more efficient since we don't also compute the band-pass coefficients x_psis.

This operator is not invertible but should remain differentiable. There should be Downsampling.forward and Downsampling.backward. For the __init__, we can probably inherit from the DTCWT parent class.

@xir4n let me know what you think! thank you

Need for custom `backward` in `FWD_J1` et al.

padding_mode = 'reflect'
N = 1024
x = torch.zeros(1, 1, N, requires_grad=True)

Phi = murenn.DTCWTForward(J=8, padding_mode=padding_mode)
Phi_inv = murenn.DTCWTInverse(J=8, padding_mode=padding_mode)
x_phi, x_psi = Phi(x)

y = Phi_inv((x_phi, x_psi))
y0 = y[0, 0, 1 + (N//2)]
y0.backward()

NotImplementedError: You must implement either the backward or vjp method for your custom autograd.Function to use it with backward mode AD.

`DTCWT.hz_to_octs` and `subbands`

Following a discussion we had yesterday. This would be useful for teacher–student training (WASPAA 2023) and more generally for understanding which frequencies are covered by each level.

This could be achieved with bisection search (https://docs.python.org/3/library/bisect.html)

def hz_to_octs(self, frequencies, sr=1.0):
    js = []
    for frequency in frequencies:
        # use bisect to find j
        js.append(j)
    return js

where subbands is a property of the DTCWT

@property
def subbands(self):
   # ...
   return subbands

It should be a decreasing list starting with 0.5 (Nyquist) and ending with zero. The interval covered by wavelet level j would be [fmin, fmax] where fmin is subbands[j+1] and fmax is subbands[j]. Your previous work on displaying wavelets in the Fourier domain can potentially help for this.

Receptive field example

It would be good to show that MuReNN amplifies the effective receptive field of a Conv1D layer, by way of an example Python script

`RuntimeError` in multichannel IDTCWT

I think that our implementation of INV_J1.forward has a bug when passing a multichannel input.

Here's how i found the problem: (note the ch=3 instead of more customary ch=1)

import murenn
import torch

N = 2**10
J = 8
ch = 3

x = torch.zeros(1, ch, N)
dtcwt = murenn.DTCWT(J=J)
idtcwt = murenn.DTCWTInverse(J=J)
x_phi, x_psis = dtcwt(x)
y = idtcwt((x_phi, x_psis))

Error message:

File ~/MuReNN/murenn/murenn/murenn/dtcwt/transform_funcs.py:178, in INV_J1.forward(ctx, lo, hi_r, hi_i, g0, g1, padding_mode)
    175 ctx.mode = mode_to_int(padding_mode)
    177 # Apply dual low-pass filtering
--> 178 x0 = torch.nn.functional.conv1d(pad_(lo, g0, padding_mode), g0_rep)
    180 # Apply dual high-pass filtering
    181 hi = torch.stack((hi_r, hi_i), dim=-1).view(b, ch, T)
Given groups=1, weight of size [7, 1, 7], expected input[1, 7, 8198] to have 1 channels, but got 7 channels instead

normalization issue for `DTCWTDirect`

the coefficients don't seem to be normalized correctly when setting alternate_gh=True:

N = 44100
x = torch.zeros(1, 1, N)
x[0, 0, N//2] = 1
J = 8

dtcwt = murenn.DTCWT(J=J, alternate_gh=False)
idtcwt = murenn.DTCWTInverse(J=J, alternate_gh=False)

x_phi, x_psis = dtcwt(x)
y_j = []

for j in range(J):
    y_phi = x_phi * 0
    y_psis = [x_psis[k] * (j==k) for k in range(J)]
    y_j.append(idtcwt(y_phi, y_psis).squeeze())

lp_psis = [torch.zeros(x_psis[k].shape) + 1j*torch.zeros(x_psis[k].shape) for k in range(J)]
y_lp = idtcwt(x_phi, lp_psis).squeeze()

plt.figure(figsize=(10, 3))
for j in range(J):
    y_jhat = torch.fft.fft(y_j[j])
    plt.semilogx(range(N), torch.abs(y_jhat), label=f'j={j+1}')
y_lphat = torch.fft.fft(y_lp)
plt.semilogx(range(N), torch.abs(y_lphat), label=f'lp')
plt.grid(linestyle='--', alpha=0.5)
plt.xlim(0.5, N//2)
plt.title('Frequency responses')
plt.legend()

output:
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.