When trying to train the CVAE on data I created on my own, I get the error posted below.
Note that I'm trying to run the code in Python 3.7, which is officially not supported. However, I got the same error when I was trying to run the code locally in Python 3.6. On this local machine on the other hand I do not have access to a graphics card and was running it on the CPU, which is not supported for training if I remember correctly.
As it is a re-shaping error I guess it is due to the data-format I'm using and that the training data isn't quite in the correct shape. I deduced, that your code expects the training data to be of shape (number training samples, number detectors, number samples per timeseries)
. My custom training-data contains only the keys ['rand_pars', 'snrs', 'x_data', 'y_data_noisefree', 'y_data_noisy', 'y_normscale']
with respective shapes [(9,), (1000,3), (1000,1,9), (1000, 3, 256), (1000, 3, 256), ()]
For the test data I deduced that the code expects single samples and thus the data to be of shape (number detectors, number samples per timeseries)
. My test set therefore contains the same keys as the training set, which have the respective shapes [(9,), (3,), (1,9), (3, 256), (3, 256), ()]
.
WARNING:tensorflow:AutoGraph could not transform <function train.<locals>.truncnorm at 0x14c61ddb7c80> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Bad argument number for Name: 4, expecting 3
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
... Training Inference Model
Traceback (most recent call last):
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1365, in _do_call
return fn(*args)
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1350, in _run_fn
target_list, run_metadata)
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1443, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 128 values, but the requested shape has 64
[[{{node Reshape}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/work/marlin.schaefer/projects/collab_glasgow/vitamin_b/vitamin_b/run_vitamin.py", line 1471, in <module>
train(params,bounds,fixed_vals)
File "/work/marlin.schaefer/projects/collab_glasgow/vitamin_b/vitamin_b/run_vitamin.py", line 888, in train
XS_all,snrs_test)
File "/work/marlin.schaefer/projects/collab_glasgow/vitamin_b/vitamin_b/models/CVAE_model.py", line 734, in train
session.run(minimize, feed_dict={bs_ph:bs, x_ph:next_x_data, y_ph:next_y_data, ramp:rmp})
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 958, in run
run_metadata_ptr)
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1181, in _run
feed_dict_tensor, options, run_metadata)
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1359, in _do_run
run_metadata)
File "/work/marlin.schaefer/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1384, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 128 values, but the requested shape has 64
[[node Reshape (defined at /projects/collab_glasgow/vitamin_b/vitamin_b/models/CVAE_model.py:618) ]]
Original stack trace for 'Reshape':
File "/projects/collab_glasgow/vitamin_b/vitamin_b/run_vitamin.py", line 1471, in <module>
train(params,bounds,fixed_vals)
File "/projects/collab_glasgow/vitamin_b/vitamin_b/run_vitamin.py", line 888, in train
XS_all,snrs_test)
File "/projects/collab_glasgow/vitamin_b/vitamin_b/models/CVAE_model.py", line 618, in train
con = tf.reshape(tf.math.reciprocal(temp_var_r2_sky),[bs_ph]) # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky
File "/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py", line 193, in reshape
result = gen_array_ops.reshape(tensor, shape, name)
File "/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 8087, in reshape
"Reshape", tensor=tensor, shape=shape, name=name)
File "/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 744, in _apply_op_helper
attrs=attr_protos, op_def=op_def)
File "/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3327, in _create_op_internal
op_def=op_def)
File "/envs/vitamin3/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1791, in __init__
self._traceback = tf_stack.extract_stack()