Git Product home page Git Product logo

Comments (4)

zbw0329 avatar zbw0329 commented on June 25, 2024

what is your setting of (g_alpha, g_num_mix, g_prob, r_beta, r_prob, r_num_mix, r_decay)?

from un-mix.

szq0214 avatar szq0214 commented on June 25, 2024

Hi @zbw0329, thanks very much for your interest in this work! It has been accepted in AAAI2022 and we just updated the camera ready version on arXiv. For the settings on SimCLR: we use 1.0 for the beta distribution (lam) of both g and r, we use 0.5 as the prob to choose g or r in each iteration, num mix of images for g and r is 2.

I'm preparing the code and will push it to this repo soon within a few days.

You can also insert the following code in the loop of each epoch:

      r = np.random.rand(1)
      # generate mixed sample
      cfg.beta = 1.0
      lam = np.random.beta(cfg.beta, cfg.beta)
      images_reverse = torch.flip(samples[0], (0,))
      if r < cfg.prob:
          mixed_images = lam * samples[0] + (1 - lam) * images_reverse
          mixed_images_flip = torch.flip(mixed_images, (0,))
      else:
          mixed_images = samples[0].clone()
          bbx1, bby1, bbx2, bby2 =rand_bbox(samples[0].size(), lam)
          mixed_images[:, :, bbx1:bbx2, bby1:bby2] = images_reverse[:, :, bbx1:bbx2, bby1:bby2]
          mixed_images_flip = torch.flip(mixed_images, (0,))
          # adjust lambda to exactly match pixel ratio
          lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (samples[0].size()[-1] * samples[0].size()[-2]))

      optimizer.zero_grad()
      loss_o = model(samples)
      loss_m1 = model([samples[1], mixed_images])
      loss_m2 = model([samples[1], mixed_images_flip])
      loss = loss_o + lam*loss_m1 + (1-lam)*loss_m2
      loss.backward()

      # function of rand_bbox for r
      def rand_bbox(size, lam):
          W = size[2]
          H = size[3]
          cut_rat = np.sqrt(1. - lam)
          cut_w = np.int(W * cut_rat)
          cut_h = np.int(H * cut_rat)
      
          # uniform
          cx = np.random.randint(W)
          cy = np.random.randint(H)
      
          bbx1 = np.clip(cx - cut_w // 2, 0, W)
          bby1 = np.clip(cy - cut_h // 2, 0, H)
          bbx2 = np.clip(cx + cut_w // 2, 0, W)
          bby2 = np.clip(cy + cut_h // 2, 0, H)
      
          return bbx1, bby1, bbx2, bby2

from un-mix.

szq0214 avatar szq0214 commented on June 25, 2024

@zbw0329, this is a simple but somewhat costly implementation, basically, the forward of samples[1] can be reused and mixed_images_flip can be obtained from the output of mixed_images.

from un-mix.

zbw0329 avatar zbw0329 commented on June 25, 2024

OK,thanks a lot!

from un-mix.

Related Issues (8)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.