mims-harvard / raincoat Goto Github PK
View Code? Open in Web Editor NEWDomain Adaptation for Time Series Under Feature and Label Shifts
Home Page: https://zitniklab.hms.harvard.edu/projects/Raincoat
License: MIT License
Domain Adaptation for Time Series Under Feature and Label Shifts
Home Page: https://zitniklab.hms.harvard.edu/projects/Raincoat
License: MIT License
Does the code implementation contains inference step, especially source prototype part and bimodel test to check those unknown samples?
It seems the RAINCOAT algorithm only implements alignment part and correction part.
https://github.com/mims-harvard/Raincoat/tree/main/trainers/trainer_uni.py
[Stage 3 Inference]
c_list = self.learn_t(dis2proto_a, dis2proto_c) print(c_list) self.trg_true_labels = tar_uni_label_test acc, f1, H = self.detect_private(dis2proto_a_test, dis2proto_c_test, tar_uni_label_test, c_list)
Num 24
Paper: for c in C^{s} do.
Github: for c in { all possible labels at WISDM} do
def detect_private(self, d1, d2, tar_uni_label, c_list): diff = np.abs(d2-d1) for i in range(6): # <= not C^{s}
Num 27: CLUSTER(d^{ac}|y^{hat} = c) => CLUSTER is based on the predicted value.
In Paper, However, in github code, the ground truth value is used. (i.e. the function, learn_t, shouldn't be available at inference)
def learn_t(self,d1,d2): diff = np.abs(d2-d1) c_list= [] for i in range(6): cat = np.where(self.trg_train_dl.dataset.y_data==i) cc = diff[cat]
While looking at the code, I found that the algorithm and the GitHub code in the paper do not match.
Therefore, we modified the mismatch based on the corresponding paper. However, we could not get the same H-Scoure in Paper Table 1.
How do I get the same experimental results of RAINCOAT in Universal DA?
I'd like to ask if there's any updated code.
I used the Sleep_EDF dataset (EEG) to run the code in the author's outline. Some bugs were found:
Best regards.
First of all, many thanks for considering CLUDA in your paper! And of course, all the best wishes for your submission!
Currently, in your GitHub repo the link of CLUDA incorrectly refers to the paper AdvSKM. May I kindly ask you to fix it? It will be published at ICLR 2023, but before that you can also refer to either openreview or arXiv link of the corresponding paper.
Best,
First let me thank you for your amazing contribution. As stated in the title I see some incoherences between your DIRT-T implementation and the original paper. Your implementation looks more like VADA, introduced in the same paper and which also serve as an initialization method for DIRT-T.
The authors describe DIRT-T as "a recursive extension of VADA, where the act of pseudo-labeling of the target distribution constructs a new 'source' domain". The problem is that I don't see neither the recursivity nor the pseudo-labelization. The classifier h_(n-1) is supposed to serve as teacher for h_n with h_0, the initial classifier, being the one produced by VADA.
I might have not understood your code well, if it is the case sorry for the inconvenience.
Best regards,
Hi, I have some questions about Section "C.2. Experimental Details", in which you provide the key hyperparameters of the RAINCOAT algorithm. I noticed that in Table 5, 6, 7, and 8, the epoch for RAINCOAT and all comparison algorithms is 50. I found that in your code, you set the epoch of the Alignment and Correction stages of RAINCOAT to 50, which corresponds to E1=50 and E2=50 in the paper. This seems unfair because all comparison algorithms only iterate 50 epochs, while RAINCOAT iterates a total of 50+50=100 epochs. To ensure fairness, should E1+E2=50 be met?
For the Sleep-EDF dataset, the following error is reported:
Traceback (most recent call last):
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/main.py", line 37, in
trainer.train()
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/trainers/trainer.py", line 113, in train
losses = algorithm.update(src_x, src_y, trg_x)
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/algorithms/RAINCOAT.py", line 161, in update
src_feat, out_s = self.feature_extractor(src_x)
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/algorithms/RAINCOAT.py", line 110, in forward
ef = F.relu(self.bn_freq(self.avg(ef).squeeze()))
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 182, in forward
self.eps,
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/functional.py", line 2451, in batch_norm
input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 100 elements not 600
For the HHAR dataset, the following error is reported:
Traceback (most recent call last):
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/main.py", line 37, in
trainer.train()
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/trainers/trainer.py", line 113, in train
losses = algorithm.update(src_x, src_y, trg_x)
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/algorithms/RAINCOAT.py", line 174, in update
src_pred = self.classifier(src_feat)
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/models/models.py", line 70, in forward
predictions = self.logits(x)/self.tmp
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x192 and 128x6)
For the WISDM dataset, the following error is reported:
Traceback (most recent call last):
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/main.py", line 37, in
trainer.train()
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/trainers/trainer.py", line 113, in train
losses = algorithm.update(src_x, src_y, trg_x)
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/algorithms/RAINCOAT.py", line 174, in update
src_pred = self.classifier(src_feat)
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/media/qiu/DataDisk/yue/Raincoat-main (1)/Raincoat-main/models/models.py", line 70, in forward
predictions = self.logits(x)/self.tmp
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/qiu/anaconda3/envs/sleepstage/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x256 and 192x6)
Could you please clarify whether there are any parameters that haven't been set, or if there might be some other reasons for this issue?
Great Work on the Time series enhanced by frequency information🤩! I noticed that you did not use ifft
to change back in time domain, which makes me curious about the reason, it is the first time I saw people directly use the amplitude and phrase.
As the paper mentioned:
"RAINCOAT extracts the polar coordinates of frequency coefficients to keep both low-level (ai) and high-level (pi) semantics. The frequency space features eF is a concatenation [ai; pi]."
Well, I am a student now beginning to focus on Time series OOD. I want to know the low-level (ai) and high-level (pi) semantics information you mentioned means what kind of information in time series. Thesedays, I did some experiments on data augmentations for OOD Domain generalization, and I noticed that low-frequency is more important than high-frequency. I want to explore whether there are some common invariants behind the frequency domain, so the amplitude and phrase are naturally to be considered.
I think there are must some connections between low/high frequency and [amplitude and phrase], but my knowledge of maths and experimental experience is limited, I might as well take the opportunity to ask you this question🤣
In the correct function in algorithms/RAINCOAT.py
, src_y
in the incoming arguments is not called, and the correct process is no different from the update training process.
def correct(self, src_x, src_y, trg_x):
self.coptimizer.zero_grad()
src_feat, out_s = self.feature_extractor(src_x)
trg_feat, out_t = self.feature_extractor(trg_x)
src_recon = self.decoder(src_feat, out_s)
trg_recon = self.decoder(trg_feat, out_t)
recons = 1e-4 * (self.recons(trg_recon, trg_x) + self.recons(src_recon, src_x))
recons.backward()
self.coptimizer.step()
return {'recon': recons.item()}
I'm not quite sure how this code achieves pulling close the same labeled samples and rejecting unknown samples in the target domain
class RAINCOAT(Algorithm):
def init(self, configs, hparams, device):
super(RAINCOAT, self).init(configs)
self.feature_extractor = tf_encoder(configs).to(device)
self.decoder = tf_decoder(configs).to(device)
self.classifier = classifier(configs).to(device)
self.optimizer = torch.optim.Adam(
list(self.feature_extractor.parameters()) + \
# list(self.decoder.parameters())+\
list(self.classifier.parameters()),
lr=hparams["learning_rate"],
weight_decay=hparams["weight_decay"]
)
self.coptimizer = torch.optim.Adam(
list(self.feature_extractor.parameters())+list(self.decoder.parameters()),
lr=0.5*hparams["learning_rate"],
weight_decay=hparams["weight_decay"]
)
self.hparams = hparams
self.recons = nn.L1Loss(reduction='sum').to(device)
self.pi = torch.acos(torch.zeros(1)).item() * 2
self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
self.sink = SinkhornDistance(eps=1e-3, max_iter=1000, reduction='sum')
def update(self, src_x, src_y, trg_x):
self.optimizer.zero_grad()
src_feat, out_s = self.feature_extractor(src_x)
trg_feat, out_t = self.feature_extractor(trg_x)
src_recon = self.decoder(src_feat, out_s)
trg_recon = self.decoder(trg_feat, out_t)
recons = 1e-4*(self.recons(src_recon, src_x)+self.recons(trg_recon, trg_x))
recons.backward(retain_graph=True)
dr, _, _ = self.sink(src_feat, trg_feat)
sink_loss = 1 *dr
sink_loss.backward(retain_graph=True)
lossinner = 1 * self.loss_func(src_feat, src_y)
lossinner.backward(retain_graph=True)
src_pred = self.classifier(src_feat)
loss_cls = 1 *self.cross_entropy(src_pred, src_y)
loss_cls.backward(retain_graph=True)
self.optimizer.step()
return {'Src_cls_loss': loss_cls.item(),'Sink': sink_loss.item(), 'inner': lossinner.item()}
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.