Git Product home page Git Product logo

Comments (3)

HenryJia avatar HenryJia commented on August 30, 2024

Can you share your code? A bit hard to debug without seeing your code

from infogan.

aubreychen9012 avatar aubreychen9012 commented on August 30, 2024

the only part I changed is train_all in infoGAN.py, where I commented out tqdm and plt, as I was not running in a notebook and the remote server does not allow display with plt.

the code now becomes:
#############################

def train_all(self, train_loader):

nll = nn.NLLLoss().cuda()
mse = nn.MSELoss().cuda()
bce = nn.BCELoss().cuda()

# plt.figure(0, figsize = (32, 32))
for epoch in range(100):
    # pb = tqdm_notebook(total = train_loader.dataset.data_tensor.size()[0])
    # pb.set_description('Epoch ' + str(epoch + 1))
    for i, (data, targets) in enumerate(train_loader):
        ones = Variable(torch.ones(data.size()[0], 1)).cuda()
        zeros = Variable(torch.zeros(data.size()[0], 1)).cuda()

        z_dict = self.get_z(data.size()[0])
        z = torch.cat([z_dict[k] for k in z_dict.keys()], dim=1)

        data = Variable(data.float().cuda(async=True)) / 255.0
        targets = Variable(targets.float().cuda(async=True))

        # Forward pass on real MNIST
        out_dis, hid = self.dis(data)
        c1 = F.log_softmax(self.Q_cat(hid))
        loss_dis = mse(out_dis, ones) - torch.sum(targets * c1) / (torch.sum(targets) + 1e-3)  # Loss for real MNIST

        # Forward pass on generated MNIST
        out_gen = self.gen(z)
        out_dis, hid = self.dis(out_gen)

        # Loss for generated MNIST
        loss_dis = loss_dis + mse(out_dis, zeros)
        # loss_dis = loss_dis


        # Zero gradient buffers for gen and Q_cat and backward pass
        self.dis.zero_grad()
        self.Q_cat.zero_grad()
        loss_dis.backward(retain_graph=True)  # We need PyTorch to retain the graph buffers so we can run backward again later
        self.d_optim.step()  # Apply the discriminator's update now since we have to delete its gradients later

        # And backward pass and loss for generator and update
        self.gen.zero_grad()
        loss_gen = mse(out_dis, ones)
        loss_gen.backward(retain_graph=True)
        self.dis.zero_grad()
        # Don't want the gradients of the generator's objective in the discriminator

        # Forward pass and loss for latent codes
        # loss_q = 0

        c1 = F.log_softmax(self.Q_cat(hid))
        loss_q = nll(c1, torch.max(z_dict['cat'], dim=1)[1])

        if self.c2_len:
            c2 = self.Q_con(hid)
            loss_q += 0.5 * mse(c2, z_dict[
                'con'])  # Multiply by 0.5 as we treat targets as Gaussian (and there's a coefficient of 0.5 when we take logs)
            self.Q_con.zero_grad()  # Zero gradient buffers before the backward pass
        if self.c3_len:
            c3 = F.sigmoid(self.Q_bin(hid))
            loss_q += bce(c3, z_dict['bin'])
            self.Q_bin.zero_grad()  # Zero gradient buffers before the backward pass
        # Backward pass for latent code objective
        loss_q.backward()

        # Do the updates for everything
        self.d_optim.step()
        self.g_optim.step()
        self.qcat_optim.step()

        if self.c2_len:
            self.qcon_optim.step()
        if self.c3_len:
            self.qbin_optim.step()

            # pb.update(data.size()[0])
            # pb.set_postfix(loss_dis = loss_dis.cpu().data.numpy()[0], loss_gen = loss_gen.cpu().data.numpy()[0], loss_q = loss_q.cpu().data.numpy()[0])
    print("epoch %d discriminator loss for generated MNIST %f") % (epoch, loss_dis.cpu().data.numpy()[0])
    print("epoch %d generation loss %f") % (epoch, loss_gen.cpu().data.numpy()[0])
    print("epoch %d latent loss %f") % (epoch, loss_q.cpu().data.numpy()[0])
    # pb.close()
    # plt.subplot(10, 10, epoch + 1)
    # bp()
    x = np.squeeze(np.transpose(out_gen.cpu().data.numpy(), (0, 2, 3, 1)))
    true_x = np.squeeze(np.transpose(data.cpu().data.numpy(), (0, 2, 3, 1)))
    merged_x = merge(x[:64], [8, 8])
    merged_true = merge(true_x[:64], [8, 8])
    # bp()
    scipy.misc.imsave("./samples/gn_" + str(epoch) + ".png", merged_x)
    scipy.misc.imsave("./samples/gr_" + str(epoch) + ".png", merged_true)

#############################

I am using torch 0.2.0_4 and python2.7. thanks for your reply.

from infogan.

HenryJia avatar HenryJia commented on August 30, 2024

Sorry I'm taking a while, I've been busy. I wrote all of my code for python 3, so it may be an incompatibility issue, but otherwise I'm unable to help. InfoGAN is fairly complex so I suggest you read up more on PyTorch first

from infogan.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.