I got some conceptual questions regarding temp_pointwise implementation. I marked 3 steps in the following source code for questions. The comments are my understanding and there are 4 lines below extracted from your source code.
def temp_pointwise(...):
...
# temp_skip(batch_size, ts_feature_value_dim, ts_feature_conv_dim+1, n_measure_of_patient)
# temp_skip is combination of temporal convolution and skip connection. Each ts_feature_conv_dim(12 values in 1 layer) values are
# concatenated with a feature value from skip connection.
# step 1
temp_skip = cat((point_skip.unsqueeze(2), # B * (F + Zt) * 1 * T
X_temp.view(B, point_skip.shape[1], temp_kernels, T)), # B * (F + Zt) * temp_kernels * T
dim=2) # B * (F + Zt) * (1 + temp_kernels) * T
# point_output(batch_size * n_measure_of_patient, point_size)
# -> view(batch_size, n_measure_of_patient, point_size, 1)
# -> permute(batch_size, point_size, 1, n_measure_of_patient)
# -> X_point_rep(batch_size, point_size, ts_feature_pattern_dim+1, n_measure_of_patient)
# X_point_rep contains representation of each measure in low-dimensional space
# step 2
X_point_rep = point_output.view(B, T, point_size, 1).permute(0, 2, 3, 1).repeat(1, 1, (1 + temp_kernels), 1) # B * point_size * (1 + temp_kernels) * T
# X_combined(batch_size, ts_feature_value_dim + point_size, ts_feature_conv_dim+1, n_measure_of_patient)
# temp_skip and X_point_rep are concatenated along ts_feature_value_dim axis.
# step 3
X_combined = self.relu(cat((temp_skip, X_point_rep), dim=1)) # B * (F + Zt) * (1 + temp_kernels) * T
next_X = X_combined.contiguous().view(B, (point_skip.shape[1] + point_size) * (1 + temp_kernels), T) # B * ((F + Zt + point_size) * (1 + temp_kernels)) * T
...
X_combined = self.relu(cat(
(temp_skip.view(B, point_skip.shape[1] * (temp_kernels+1), T),
point_output.view(B, T, point_size).permute(0, 2, 1) # B * point_size * T
),
dim=1
)
So flatten temp_skip so that it can be concatenated with point_output at ts_feature_value_dim level.
I actually have difficulty in understanding the reasoning to repeat each point_size value (1+temp_kernals) times at step 2 X_point_rep
. The only reason I can think of is to match the dimension with temp_skip. But with the repeation, will next_X
contain (1+temp_kernals) repeated value at dim=1, which will not add information for network?
Asking source code in text is a bit difficult. I am not sure if I state my question clearly.