How can I change the shape of a variable in TensorFlow?
tf.Variable class is the recommended way to create variables, but it restricts your ability to change the shape of the variable once it has been created.
If you need to change the shape of a variable, you can do the following (e.g. for a 32-bit floating point tensor):
var = tf.Variable(tf.placeholder(tf.float32)) # ... new_value = ... # Tensor or numpy array. change_shape_op = tf.assign(var, new_value, validate_shape=False) # ... sess.run(change_shape_op) # Changes the shape of `var` to new_value's shape.
Note that this feature is not in the documented public API, so it is subject to change. If you do find yourself needing to use this feature, let us know, and we can investigate a way to support it moving forward.
Take a look at shapes-and-shaping from TensorFlow documentation. It describes different shape transformations available.
The most common function is probably tf.reshape, which is similar to its numpy equivalent. It allows you to specify any shape that you want as long as the number of elements stays the same. There are some examples available in the documentation.
Documentation shows methods for reshaping. They are:
- squeeze (removes dimensions of size 1 from the shape of a tensor)
- expand_dims (adds dimensions of size 1)
as well as bunch of methods to get
rank of your tensor. Probably the most used is
reshape and here is a code example with a couple of edge cases (-1):
import tensorflow as tf v1 = tf.Variable([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12] ]) v2 = tf.reshape(v1, [2, 6]) v3 = tf.reshape(v1, [2, 2, -1]) v4 = tf.reshape(v1, [-1]) # v5 = tf.reshape(v1, [2, 4, -1]) will fail, because you can not find such an integer for -1 v6 = tf.reshape(v1, [1, 4, 1, 3, 1]) v6_shape = tf.shape(v6) v6_squeezed = tf.squeeze(v6) v6_squeezed_shape = tf.shape(v6_squeezed) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) a, b, c, d, e, f, g = sess.run([v2, v3, v4, v6, v6_shape, v6_squeezed, v6_squeezed_shape]) # print all variables to see what is there print e # shape of v6 print g # shape of v6_squeezed