brain-research / self-attention-gan Goto Github PK
View Code? Open in Web Editor NEWLicense: Apache License 2.0
License: Apache License 2.0
When I try to run the code energy = torch.bmm(proj_query, proj_key)
, the program runs into the RuntimeError: CUDA out of memory. My Graphics card's memory is 12GB and I am looking for a way to reduce the size of intermediate variables.i.e.energy
which in my case is 1 x 65536 x 65536. I've already used torch.no_grad()
and split the intermediate matrixes into smaller sub-matrix, then use del
to release the memory. But it doesn't seem to work, would you please show me some subtle tips to help me with this kind of problem? (My batch size is 1, the input size is 256 x 256)
It seems that you create the variable 'reuse_vars' in the build_model_single_gpu function. However, I do not find you use this variable to reuse variables among your multiple GPUs. Could you please check that? Thank you so much!
Hello,
Will the use of Conditional Batch Normalization in the generator cancel Spectral Normalization?
Hello,
Which parameters do I need to change to make this train and evaluate on one GPU?
I am currently getting an OOM Resource Exhausted Error when i try to train on one GTX 1080.
I tried setting num_towers=1 in train_imagenet.py but this did not help.
I can not find any discrible codes in this package about "self-attention". But in this code, I find much more codes about DCGAN, it makes me very puzzeld. It's so depressed!
Has anybody coded to visualize attention maps for a query point as shown in the paper?
like 4 machines with 4 GPUs each.
`def Nonlocalblock(x):
batch_size, height, width, in_channels = x.get_shape().as_list()
print("height",height)
print("width",width)
print("in_channels",in_channels)
#print("out_channels",out_channels)
print( "shape", x.get_shape())
g1 = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding = 'same')(x)
g1 = tf.math.multiply(g1,x)
print("g1",g1.shape)
g = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding = 'same')(g1)
print("phi",g.shape) #x, tf.stack( [ -1, nb_maps, nb_feats ]
hw = height * width
g_x = tf.reshape(g, [ batch_size, hw, in_channels])
g_x = tf.squeeze(g_x ,axis= 0)
phi = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding='same')(g1)
print("phi",phi.shape)
theta = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding='same')(g1)
print("theta",theta.shape)
print("g_x",g_x.shape) #64,16384
theta_x = tf.reshape(theta, [ batch_size, hw, in_channels]) #64,16384
print( "theta_x",theta_x.shape)
phi_x = tf.reshape(phi, ([ batch_size, hw, in_channels]))
phi_x1 = tf.squeeze(phi_x ,axis= 0)
print( "phi_x",phi_x.shape)
#theta_x1 = tf.transpose(theta_x, [0,2,1])
#theta_x1 = tf.squeeze(theta_x1 ,axis= 0) #16384,64
#print( "theta_x1",theta_x1.shape)
print( "theta_x",theta_x.shape)
f = tf.matmul( theta_x,g_x,transpose_b=True ) #64,64
print("f",f.shape)
f = tf.nn.softmax(f, -1)
y = tf.matmul(phi_x1,f )
print("y",y.shape)
y1 = tf.nn.softmax(y)
print("y1",y1.shape) #64,16384
y1 = tf.reshape(y1, [ batch_size, height, width,in_channels])
print("y1",y1.shape)
print("in_channels",in_channels)
print( "y2" , y1.shape )
return y1
`
I have implemented this non-local attention block as shown in the above code, but the problem is that when I am using it in a network the batch-size is always None, so while using it for multiplication and reshaping is giving me error
Hello and thank you for the repository!
The model training on Imagenet dataset will take a lot of time. Could someone upload a pre-trained model?
Hi,
I couldn't find the implementation of the attention layer inside the network models. In the SAGAN paper it is mentioned that they have added the self-attention mechanism at different stages and compared them with each other. Would you please let me know where you have considered that?
Bests,
Samaneh
Hello,
Is the GAN trained with a fixed learning rate?
The discriminator LR: 0.0004
The generator LR: 0.0001
Are these learning rates decayed? If so, where may I find the implementation?
There are a couple of max_pooling2d()
layers inside the attention layer sn_non_local_block_sim() which reduce the number of local features by 4 as such downsampled_num = location_num // 4. However, no downsampling step is reported in the original paper.
Also, the first two sn_conv1x1()
layers, which stand for Wg
and Wf
in the paper, have equal sizes C/8 x C
, but the third one standing for Wh
has C/2 x C
shape, while should be also C/8 x C
. Similarly the last conv layer.
Is there a reason for such discrepancies?
Related #8
All generator and discriminator types implemented here are made of either block()
or block_no_sn()
modules, which either way have internally a residual connection x_0 + x
by default. However, in the associated paper residual vs. attentional blocks are compared as if both architectures were exclusive, one or the other. So, does the attentional architecture reported in the paper includes also residual blocks or this implementation does not fully follow the reported architectures?
Thanks.
I have trained the model using your code and I have the checkpoints saved (I ran it for 300 epochs). How can I evaluate on image files (.png) that I have in ./data folder?
Thanks in advance.
Hello,
Should the attention map be transposed? I can't see that in the papers! Also, I think you should use dim=0 in softmax.
self-attention-gan/non_local.py
Line 77 in ad9612e
Is there a mistake in the reshape operation linked above? Shouldn't it be
attn_g = tf.reshape(attn_g, [batch_size, h // 2, w // 2, num_channels // 2])
attn_g = tf.depth_to_space(attn_g, [batch_size, h, w, num_channels // 8])
attn_g = sn_conv1x1(attn_g, num_channels, update_collection, init, 'sn_conv_attn')
instead of attn_g = tf.reshape(attn_g, [batch_size, h, w, num_channels // 2])
?
Hi, this self-attention block is different from original paper !!!
in paper Figure2 there is not maxpool !!!
why are you use it ?? and it is good for performance or bad ??
There are two types for generator and discriminator, test and baseline.
What do they do?
The following functions _extract_image_and_label and _extract_image_and_label are there twice in utils_ori.py. Is there are any specific reason for this or is it just there by mistake ?
The evaluation script seems to be stuck at the line "Creating CudaSolver handles for stream 0x1f1ffea0" for more than a day.
I trained a model on my own dataset and save the weights.
The eval script hasn't started generating any samples yet.
Hi,
I am having trouble understanding the right way to visualize the attention maps. Lets day the attention block is in the last layer and the image has w=128, h=128, that means the attention map as dimensions N=w*h.
if I want to visualize the attention map for the midpoint for example. which part of the attention map should I access?
The only idea I got was to obtain either the row or the column n:
attention_map[:,n] or attention_map[n,:]
Could you explain how to correctly access the attention map for a specific point?
Thanks in advance
Hi,
I was wondering if you could upload the trained weights of the generator that I can directly use for running the test code.
Thanks in advance.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.