|
def call(self, inputs): |
|
if not self.built: |
|
raise ValueError('This model has not yet been built.') |
|
|
|
#FFT2D is calculated over last two dimensions! |
|
if not self.isChannelFirst: |
|
inputs = tf.einsum("bhwc->bchw",inputs) |
|
|
|
outputs_F = np.ndarray(shape=self.out_shape) |
|
|
|
# Pad the kernel to the shape of the input to enable element-wise multiplication |
|
signal_shape = tf.shape(inputs) |
|
kernel_shape = tf.shape(self.kernel) |
|
x_pad = signal_shape[2] - kernel_shape[2] |
|
y_pad = signal_shape[3] - kernel_shape[3] |
|
# paddings shape is [2,4] because rank of inputs is 4, and 2 for height and width |
|
paddings = [[0,0], |
|
[0,0], |
|
[0,x_pad], |
|
[0,y_pad] |
|
] |
|
kernels_padded = tf.pad(self.kernel, paddings) # [out_channels, inp_channel, height,width,out_channe] |
|
#print("Shape: inputs {}".format(np.shape(inputs))) |
|
#print("Shape: kernels_padded {}".format(np.shape(kernels_padded))) |
|
|
|
# Compute DFFTs for both inputs and kernel weights |
|
inputs_F = tf.signal.rfft2d(inputs) #[batch,height,width,channel] |
|
kernels_F = tf.signal.rfft2d(kernels_padded) |
|
#kernels_F = tf.math.conj(kernels_F) #calculate conjugate to be mathematically correct with the cross-corelation implementation. Not important, since filter is learned! |
|
|
|
#print("Shape: inputs_F {}".format(np.shape(inputs_F))) |
|
#print("Shape: kernels_F {}".format(np.shape(kernels_F))) |
|
|
|
# Apply filters by element wise multiplications |
|
for filter in range(self.filters): |
|
#print("Shape: kernels_F[filter,:,:,:] {}".format(np.shape(kernels_F[filter,:,:,:]))) |
|
outputs_F = tf.concat( |
|
[outputs_F, |
|
tf.reduce_sum( |
|
inputs_F * kernels_F[filter,:,:,:], #inputs:(batch, inp_filter, height, width ), fourier_filter:(...,out_filter,inp_filter,height, width) |
|
axis = -3, # sum over all applied filters |
|
keepdims = True |
|
)], |
|
axis = -3 # is the new filter count, since channel first |
|
) |
|
|
|
#print("Shape: outputs_F {}".format(np.shape(outputs_F))) |
|
# Inverse rDFFT |
|
output = tf.signal.irfft2d(outputs_F) |
|
#output = tf.math.real(output) |
|
|
|
if self.use_bias: |
|
output += self.bias |
|
|
|
#reverse the channel configuration to its initial config |
|
if not self.isChannelFirst: |
|
output = tf.einsum("bchw->bhwc",output) |
|
|
|
return output |