Git Product home page Git Product logo

Comments (7)

guizilaile23 avatar guizilaile23 commented on July 24, 2024 1

Thanks a lot for your excellent code, it quit easy to understand compare with the code provided by official~~~~
and here I get a question, based on your blog about DAN(https://zhuanlan.zhihu.com/p/27657910), there is a weight parameter β for the multi-kernel need to be learned, but i could not recognize it in the DAN code you provided. is it 'fix_sigma'?

def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples
2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)

from transferlearning.

jindongwang avatar jindongwang commented on July 24, 2024

@Xavierxhq Labels are not used in the computing process since MMD is an unsupervised method. We just use MMD to get the domain distance.

from transferlearning.

Xavierxhq avatar Xavierxhq commented on July 24, 2024

@jindongwang Thanks a lot for the reply!

from transferlearning.

guizilaile23 avatar guizilaile23 commented on July 24, 2024

@Xavierxhq sorry, dude, I forgot it's your topic, and I just type the question, I thought this is a bbs. sorry again

from transferlearning.

jindongwang avatar jindongwang commented on July 24, 2024

@guizilaile23 Yes. The code is a little bit different from the ones in DAN. But the ideas are the same.

from transferlearning.

Xavierxhq avatar Xavierxhq commented on July 24, 2024

@guizilaile23 no big deal, in this way i can get more from your question, which i should thank for.

from transferlearning.

ChengYeung1222 avatar ChengYeung1222 commented on July 24, 2024

Thanks a lot for your excellent code, it quit easy to understand compare with the code provided by official~~~~
and here I get a question, based on your blog about DAN(https://zhuanlan.zhihu.com/p/27657910), there is a weight parameter β for the multi-kernel need to be learned, but i could not recognize it in the DAN code you provided. is it 'fix_sigma'?

def mmd_linear(f_of_X, f_of_Y):
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss

def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)2).sum(2) if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)

The bandwidth indicates gamma(the denominator inside the exponential expression)according to the paper. This is how I understand. And there's no reflect of coefficients beta as in Gretton's paper. Just add these kernels up.

from transferlearning.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.