Comments (3)
Can you share your code? A bit hard to debug without seeing your code
from infogan.
the only part I changed is train_all in infoGAN.py, where I commented out tqdm and plt, as I was not running in a notebook and the remote server does not allow display with plt.
the code now becomes:
#############################
def train_all(self, train_loader):
nll = nn.NLLLoss().cuda()
mse = nn.MSELoss().cuda()
bce = nn.BCELoss().cuda()
# plt.figure(0, figsize = (32, 32))
for epoch in range(100):
# pb = tqdm_notebook(total = train_loader.dataset.data_tensor.size()[0])
# pb.set_description('Epoch ' + str(epoch + 1))
for i, (data, targets) in enumerate(train_loader):
ones = Variable(torch.ones(data.size()[0], 1)).cuda()
zeros = Variable(torch.zeros(data.size()[0], 1)).cuda()
z_dict = self.get_z(data.size()[0])
z = torch.cat([z_dict[k] for k in z_dict.keys()], dim=1)
data = Variable(data.float().cuda(async=True)) / 255.0
targets = Variable(targets.float().cuda(async=True))
# Forward pass on real MNIST
out_dis, hid = self.dis(data)
c1 = F.log_softmax(self.Q_cat(hid))
loss_dis = mse(out_dis, ones) - torch.sum(targets * c1) / (torch.sum(targets) + 1e-3) # Loss for real MNIST
# Forward pass on generated MNIST
out_gen = self.gen(z)
out_dis, hid = self.dis(out_gen)
# Loss for generated MNIST
loss_dis = loss_dis + mse(out_dis, zeros)
# loss_dis = loss_dis
# Zero gradient buffers for gen and Q_cat and backward pass
self.dis.zero_grad()
self.Q_cat.zero_grad()
loss_dis.backward(retain_graph=True) # We need PyTorch to retain the graph buffers so we can run backward again later
self.d_optim.step() # Apply the discriminator's update now since we have to delete its gradients later
# And backward pass and loss for generator and update
self.gen.zero_grad()
loss_gen = mse(out_dis, ones)
loss_gen.backward(retain_graph=True)
self.dis.zero_grad()
# Don't want the gradients of the generator's objective in the discriminator
# Forward pass and loss for latent codes
# loss_q = 0
c1 = F.log_softmax(self.Q_cat(hid))
loss_q = nll(c1, torch.max(z_dict['cat'], dim=1)[1])
if self.c2_len:
c2 = self.Q_con(hid)
loss_q += 0.5 * mse(c2, z_dict[
'con']) # Multiply by 0.5 as we treat targets as Gaussian (and there's a coefficient of 0.5 when we take logs)
self.Q_con.zero_grad() # Zero gradient buffers before the backward pass
if self.c3_len:
c3 = F.sigmoid(self.Q_bin(hid))
loss_q += bce(c3, z_dict['bin'])
self.Q_bin.zero_grad() # Zero gradient buffers before the backward pass
# Backward pass for latent code objective
loss_q.backward()
# Do the updates for everything
self.d_optim.step()
self.g_optim.step()
self.qcat_optim.step()
if self.c2_len:
self.qcon_optim.step()
if self.c3_len:
self.qbin_optim.step()
# pb.update(data.size()[0])
# pb.set_postfix(loss_dis = loss_dis.cpu().data.numpy()[0], loss_gen = loss_gen.cpu().data.numpy()[0], loss_q = loss_q.cpu().data.numpy()[0])
print("epoch %d discriminator loss for generated MNIST %f") % (epoch, loss_dis.cpu().data.numpy()[0])
print("epoch %d generation loss %f") % (epoch, loss_gen.cpu().data.numpy()[0])
print("epoch %d latent loss %f") % (epoch, loss_q.cpu().data.numpy()[0])
# pb.close()
# plt.subplot(10, 10, epoch + 1)
# bp()
x = np.squeeze(np.transpose(out_gen.cpu().data.numpy(), (0, 2, 3, 1)))
true_x = np.squeeze(np.transpose(data.cpu().data.numpy(), (0, 2, 3, 1)))
merged_x = merge(x[:64], [8, 8])
merged_true = merge(true_x[:64], [8, 8])
# bp()
scipy.misc.imsave("./samples/gn_" + str(epoch) + ".png", merged_x)
scipy.misc.imsave("./samples/gr_" + str(epoch) + ".png", merged_true)
#############################
I am using torch 0.2.0_4 and python2.7. thanks for your reply.
from infogan.
Sorry I'm taking a while, I've been busy. I wrote all of my code for python 3, so it may be an incompatibility issue, but otherwise I'm unable to help. InfoGAN is fairly complex so I suggest you read up more on PyTorch first
from infogan.
Related Issues (3)
- can't find variable "Q_bin" HOT 2
- Label in CIFAR-10 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from infogan.