Comments (7)
Thanks a lot for your excellent code, it quit easy to understand compare with the code provided by official~~~~
and here I get a question, based on your blog about DAN(https://zhuanlan.zhihu.com/p/27657910), there is a weight parameter β for the multi-kernel need to be learned, but i could not recognize it in the DAN code you provided. is it 'fix_sigma'?
def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)
from transferlearning.
@Xavierxhq Labels are not used in the computing process since MMD is an unsupervised method. We just use MMD to get the domain distance.
from transferlearning.
@jindongwang Thanks a lot for the reply!
from transferlearning.
@Xavierxhq sorry, dude, I forgot it's your topic, and I just type the question, I thought this is a bbs. sorry again
from transferlearning.
@guizilaile23 Yes. The code is a little bit different from the ones in DAN. But the ideas are the same.
from transferlearning.
@guizilaile23 no big deal, in this way i can get more from your question, which i should thank for.
from transferlearning.
Thanks a lot for your excellent code, it quit easy to understand compare with the code provided by official~~~~
and here I get a question, based on your blog about DAN(https://zhuanlan.zhihu.com/p/27657910), there is a weight parameter β for the multi-kernel need to be learned, but i could not recognize it in the DAN code you provided. is it 'fix_sigma'?def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return lossdef guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)2).sum(2) if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)
The bandwidth indicates gamma(the denominator inside the exponential expression)according to the paper. This is how I understand. And there's no reflect of coefficients beta as in Gretton's paper. Just add these kernels up.
from transferlearning.
Related Issues (20)
- MMD距离 backward问题 HOT 3
- BDA代码中的数据集的问题 HOT 2
- 关于TransferNet中source_clf计算的疑问 HOT 1
- 关于复现 HOT 1
- 对比算法的复现缺少部分流程 HOT 1
- 程序显示没有loss_funcs模块 HOT 1
- Our new test-time adaptation algorithm for segmentation HOT 3
- 关于transferlearning/code /DeepDA的模型代码读取 HOT 2
- How to use DIFEX for single domain generalization? HOT 2
- code add HOT 2
- 有关于在使用DG中DANN方法中遇到的问题
- Office-31 webcam域上微调模型丢失 HOT 1
- 是否需要进行微调? HOT 2
- 作者您好!请问AdaRNN的对比实验中,MMD-RNN和DANN-RNN是如何实现的,MMD-RNN的源域和目标域是如何定义的,AdaRNN相比这两个对比模型进步在哪里? HOT 1
- Feature Request:添加ADDA的代码以及与其他方法的比较。 HOT 1
- 王老师您好,Domain Generalization for Activity Recognition via Adaptive Feature Fusion请问这篇论文的代码具体在哪个文件夹里?找了很久没找到 HOT 1
- Time series domain adaptation benchmark and datasets HOT 1
- DIVERSIFY: A General Framework for Time Series Out-of-distribution Detection and Generalization中的cross_dataset的文件 HOT 1
- BDA中A-distance问题
- 关于DANN中的域鉴别器的域分类准确率 HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transferlearning.