Comments (20)
Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.Session(config=sess_config)
model = InpaintCAModel()
input_image_ph = tf.placeholder(
tf.float32, shape=(1, args.image_height, args.image_width*2, 3))
output = model.build_server_graph(input_image_ph)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
vname = var.name
from_name = vname
var_value = tf.contrib.framework.load_variable(
args.checkpoint_dir, from_name)
assign_ops.append(tf.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.')
with open(args.flist, 'r') as f:
lines = f.read().splitlines()
t = time.time()
for line in lines:
image, mask, out = line.split()
base = os.path.basename(mask)
image = cv2.imread(image)
mask = cv2.imread(mask)
image = cv2.resize(image, (args.image_width, args.image_height))
mask = cv2.resize(mask, (args.image_width, args.image_height))
# cv2.imwrite(out, image*(1-mask/255.) + mask)
# # continue
# image = np.zeros((128, 256, 3))
# mask = np.zeros((128, 256, 3))
assert image.shape == mask.shape
h, w, _ = image.shape
grid = 4
image = image[:h//grid*grid, :w//grid*grid, :]
mask = mask[:h//grid*grid, :w//grid*grid, :]
print('Shape of image: {}'.format(image.shape))
image = np.expand_dims(image, 0)
mask = np.expand_dims(mask, 0)
input_image = np.concatenate([image, mask], axis=2)
# load pretrained model
result = sess.run(output, feed_dict={input_image_ph: input_image})
print('Processed: {}'.format(out))
cv2.imwrite(out, result[0][:, :, ::-1])
print('Time total: {}'.format(time.time() - t))
from generative_inpainting.
"We have not found perceptual loss (reconstruction loss on VGG features), style loss (squared Frobenius norm of Gram matrix computed on the VGG features) [21] and total variation (TV) loss bring noticeable improvements for image inpainting in our framework, thus are not used."
You will need to implement VGG16 perceptual loss by yourself.
from generative_inpainting.
Oh thank you, I have found the answer: Just set the parameter reuse = tf.AUTO_REUSE
output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE)
The tensorflow will automatically understand and reuse the graph.
from generative_inpainting.
These codes should be added to the master branch π π π
from generative_inpainting.
Your usage is not correct actually. The build graph function should always be called once in all tensorflow-based code, unless you want to reuse the graph. I've modified it for your case. Please use the following code:
sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess = tf.Session(config=sess_config) model = InpaintCAModel() input_image_ph = tf.placeholder( tf.float32, shape=(1, args.image_height, args.image_width*2, 3)) output = model.build_server_graph(input_image_ph) output = (output + 1.) * 127.5 output = tf.reverse(output, [-1]) output = tf.saturate_cast(output, tf.uint8) vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) assign_ops = [] for var in vars_list: vname = var.name from_name = vname var_value = tf.contrib.framework.load_variable( args.checkpoint_dir, from_name) assign_ops.append(tf.assign(var, var_value)) sess.run(assign_ops) print('Model loaded.') with open(args.flist, 'r') as f: lines = f.read().splitlines() t = time.time() for line in lines: image, mask, out = line.split() base = os.path.basename(mask) image = cv2.imread(image) mask = cv2.imread(mask) image = cv2.resize(image, (args.image_width, args.image_height)) mask = cv2.resize(mask, (args.image_width, args.image_height)) # cv2.imwrite(out, image*(1-mask/255.) + mask) # # continue # image = np.zeros((128, 256, 3)) # mask = np.zeros((128, 256, 3)) assert image.shape == mask.shape h, w, _ = image.shape grid = 4 image = image[:h//grid*grid, :w//grid*grid, :] mask = mask[:h//grid*grid, :w//grid*grid, :] print('Shape of image: {}'.format(image.shape)) image = np.expand_dims(image, 0) mask = np.expand_dims(mask, 0) input_image = np.concatenate([image, mask], axis=2) # load pretrained model result = sess.run(output, feed_dict={input_image_ph: input_image}) print('Processed: {}'.format(out)) cv2.imwrite(out, result[0][:, :, ::-1]) print('Time total: {}'.format(time.time() - t))
Should be:
output = model.build_server_graph(FLAGS, input_image_ph)
from generative_inpainting.
It would be even more efficient if you can build graph ONCE with placeholder and feed your images with sess.run
. A related issue can be found #8.
from generative_inpainting.
Hello JiahuiYu,
Thank you for your quick response. Did you mean sess.run
?
I'm reading your source code to understand what you have done.
from generative_inpainting.
Sorry typo.
from generative_inpainting.
Hello JiahuiYu,
Thank you for your response. I'm building the graph.
In inpaint.yml file, at #loss legacy
line. I have found that VGG_MOEL_FILE you have configured, I have read your paper, it did not mention transfer learning. So, I wonder whether we can use VGG16 network for transfer learning?
Thank you for your concerns.
from generative_inpainting.
Thank you for your fast response.
I have used your pretrained model to apply transfer learning, it saved me a lot of time on a new training set.
I am reading your paper again, I think it's a great paper.
from generative_inpainting.
Hello Jiahuiyu,
Thank you for your awesome code, I have tried to modify and build the graph, but unfortunately I could not build it.
I have found that you have used build_server_graph
function, but I don't understand it much. Could you please add some code you have built the graph and feed image by image into it?
Thank you in advance.
from generative_inpainting.
Here is my code at the moment: use a for loop
# prepare folder path
input_folder = args.test_dir + "/input"
mask_folder = args.test_dir + "/mask"
output_folder = args.test_dir + "/output_" + args.checkpoint_dir.split("/")[1] + "_" +datetime.datetime.now().strftime("%Y%m%d%H%M%S")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# start sess configuration
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
dir_files = os.listdir(input_folder)
dir_files.sort()
for file_inter in dir_files:
sess = tf.Session(config=sess_config)
base_file_name = os.path.basename(file_inter)
image = cv2.imread(input_folder + "/" + base_file_name)
mask = cv2.imread(mask_folder + "/" + base_file_name)
assert image.shape == mask.shape
h, w, _ = image.shape
grid = 1
image = image[:h//grid*grid, :w//grid*grid, :]
mask = mask[:h//grid*grid, :w//grid*grid, :]
print('Shape of image: {}'.format(image.shape))
image = np.expand_dims(image, 0)
mask = np.expand_dims(mask, 0)
input_image = np.concatenate([image, mask], axis=2)
input_image = tf.constant(input_image, dtype=tf.float32)
output = model.build_server_graph(input_image, reuse=tf.AUTO_REUSE)
output = (output + 1.) * 127.5
output = tf.reverse(output, [-1])
output = tf.saturate_cast(output, tf.uint8)
# load pretrained model
vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = []
for var in vars_list:
vname = var.name
from_name = vname
var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, from_name)
assign_ops.append(tf.assign(var, var_value))
sess.run(assign_ops)
print('Model loaded.')
result = sess.run(output)
# write to output folder
cv2.imwrite(output_folder + "/" + base_file_name, result[0][:, :, ::-1])
sess.close()
from generative_inpainting.
Hi JiahuiYu ,
Thank you very much for your code and your contribution. I am so excited to check it out.
Thank you again π π π π π π π π π π
from generative_inpainting.
Hi JiahuiYu ,
wow, it worked. Thank you very much, you have saved me tons of time. π π π
from generative_inpainting.
No problem. :)
from generative_inpainting.
These codes should be added to the master branch π π π
@Bingmang Is the code added to the for loop of test.py? Thank you
from generative_inpainting.
I have made this thread open so others can have a reference.
from generative_inpainting.
@TrinhQuocNguyen Thank you very much for your discussions about training a new model!
And could you give me more instructions to pre-train a model with transfer learning? Thanks a lot !
from generative_inpainting.
great!
from generative_inpainting.
Hey I'm trying since days to customize some part, can you explain me how to access model and run model.summary() ???
from generative_inpainting.
Related Issues (20)
- required broadcastable shapes [Op:Mul]
- How to change the learned model
- Terrible result image
- NotImplementedError: Cannot convert a symbolic Tensor HOT 1
- Increase/decrease of input/output dimensions
- New easy to use symmetric face inpainting
- URGENT
- I don't have access to the Google download Pretrained models HOT 2
- Training on own dataset HOT 1
- form of flist HOT 1
- I write something easy to modify. HOT 2
- Why remove l1_loss in v2 ?
- Why split in 2 instead of 3 in gen_conv ?
- two same repo?
- AssertionError loss_value is NaN
- Implementation Discrepency Relative to Publication
- Trouble with quality of results using pretrained model
- NVIDIA Issue and Modifying Variables in Training
- In which work environment can I train the model?
- ε¦ζθ½ε cv2.inpaintδΈζ ·εΌε ₯ε°ε·₯η¨δΈε°±ε₯½δΊ
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from generative_inpainting.