Git Product home page Git Product logo

Comments (4)

JingWang18 avatar JingWang18 commented on July 29, 2024 3

Hi Yang,
Thank you for the explanation. Based on the paper, you mentioned the feature-norm enlargement should be based on their feature norms calculated by the past model parameters in the last iteration. However, in your code, it seems like they all come from the current iteration. I am still a little bit confused about the constant loss term during the training. I modified your code based on your paper. Could you please take a look and point out for which part I misunderstand the paper?
The code starts here:
#######################################################################
def get_L2norm_loss_self_driven(x, radius):
assert radius.requires_grad == False
l = ((x.norm(p=2, dim=1)[:radius.size()[0]] - radius[:x.size()[0]]) ** 2).mean()
return args.weight_L2norm * l

def get_entropy_loss(p_softmax):
mask = p_softmax.ge(0.000001)
mask_out = torch.masked_select(p_softmax, mask)
entropy = -(torch.sum(mask_out * torch.log(mask_out)))
return args.weight_entropy * (entropy / float(p_softmax.size(0)))

def train(num_epoch):
radius_s=torch.ones(1).cuda()
radius_t=torch.ones(1).cuda()
for ep in range(num_epoch):
G.train()
C.train()

    print(len(dataset.data_loader_A.dataset))
    print(len(dataset.data_loader_B.dataset))

    for batch_idx, data in enumerate(dataset): 
        if batch_idx * batch_size >= 2816:
            break
		# print(type(data))
		# print(data['S'].size())
        if args.cuda:
            data1 = data['S']
            target1 = data['S_label']
            data2  = data['T']
            target2 = data['T_label']
            data1, target1 = data1.cuda(), target1.cuda()
            data2, target2 = data2.cuda(), target2.cuda()
        
		# when pretraining network source only
        data1 = Variable(data1)
        data2 = Variable(data2)
        target1 = Variable(target1)     
        target2 = Variable(target2)
    
        opt_g.zero_grad()
        opt_f.zero_grad()
        
        s_bottleneck = G(data1)            
        s_fc2_emb, s_logit = C(s_bottleneck)
        
        s_cls_loss = F.nll_loss(F.log_softmax(s_logit), target1)
        s_fc2_L2norm_loss = get_L2norm_loss_self_driven(s_fc2_emb, radius_s)
      
        loss = s_cls_loss + s_fc2_L2norm_loss 
        loss.backward()
        
        opt_g.step()
        opt_f.step()
        
        
        opt_g.zero_grad()
        opt_f.zero_grad()

        s_bottleneck = G(data1)
        t_bottleneck = G(data2)
        
        s_fc2_emb, s_logit = C(s_bottleneck)
        t_fc2_emb, t_logit = C(t_bottleneck)
        
        s_cls_loss = F.nll_loss(F.log_softmax(s_logit), target1)
        s_fc2_L2norm_loss = get_L2norm_loss_self_driven(s_fc2_emb, radius_s)
        t_fc2_L2norm_loss = get_L2norm_loss_self_driven(t_fc2_emb, radius_t)
        
        t_prob = F.softmax(t_logit)
        t_entropy_loss = get_entropy_loss(t_prob)

        loss = s_cls_loss + s_fc2_L2norm_loss + t_fc2_L2norm_loss + t_entropy_loss
        loss.backward()

        opt_g.step()
        opt_f.step()
        
        radius_s = s_fc2_emb.norm(p=2, dim=1).detach()
        radius_t = t_fc2_emb.norm(p=2, dim=1).detach()
              
        if batch_idx % args.log_interval == 0: # number of print out: 3264/50 = 64
                print('Train Ep: {} [{}/{} ({:.0f}%)]\tLoss_cls: {:.6f}\t Loss_t: {:.6f}\t '.format(
                        ep, batch_idx * len(data)/2, len(dataset.data_loader_A.dataset),
                        batch_idx * len(data)/ 2 / len(dataset.data_loader_A.dataset), s_cls_loss.item(),
                        t_fc2_L2norm_loss.item()))
        
    torch.save(G.state_dict(), os.path.join("Office31_IAFN_" +"_netG_" +'.' + '_' + str(ep) + ".pth"))
    torch.save(C.state_dict(), os.path.join("Office31_IAFN_" + "_netF_" + '.' + '_'  + str(ep) + ".pth"))
    test(ep)

#########################################################################

Thank you so much for your patience and help.

from afn.

jihanyang avatar jihanyang commented on July 29, 2024

Hello,

Thanks for interesting to our work. In fact, according to the equation in our paper, the value of this loss would only relate to the hyper-parameter \Delta r. As a result, since we use 1.0 for \Delta r, the loss value will actually be 1 and not change through training procedure. Though the loss value will not change, the gradient has been backward to update our model. And that is actually what we want. Please refer to the section 3.5 in our paper for more details. If you still have any questions, please feel free to contact us.

from afn.

Turlan avatar Turlan commented on July 29, 2024

@JingWang18 , actually, in the paper, the updated and upadating L2 feature norm is for the same training examples in the current iteration. So, the author's implemtentation fits the paperโ€™s description.

from afn.

tarun005 avatar tarun005 commented on July 29, 2024

The loss doesn't make sense to me. In the paper, you state that the difference is computed between (feature norm computed using last parameters) and (radius computed using current parameters). But your code computed them differently which makes the loss constant, and at the end of all, we are only adding a constant value to the loss. So what is the reason your model improves accuracy?

from afn.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.