Tensorflow: How does tf.get_variable work?

If you define a variable with a name that has been defined before, then TensorFlow throws an exception. Hence, it is convenient to use the tf.get_variable() function instead of tf.Variable(). The function tf.get_variable() returns the existing variable with the same name if it exists, and creates the variable with the specified shape and initializer if it does not exist.


tf.get_variable(name) creates a new variable called name (or add _ if name already exists in the current scope) in the tensorflow graph.

In your example, you're creating a python variable called var1.

The name of that variable in the tensorflow graph is not ** var1, but is Variable:0.

Every node you define has its own name that you can specify or let tensorflow give a default (and always different) one. You can see the name value accessing the name property of the python variable. (ie print(var1.name)).

On your second line, you're defining a Python variable var2 whose name in the tensorflow graph is var1.

The script

import tensorflow as tf

var1 = tf.Variable(3.,dtype=tf.float64)
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)

In fact prints:

Variable:0
var1:0

If you, instead, want to define a variable (node) called var1 in the tensorflow graph and then getting a reference to that node, you cannot simply use tf.get_variable("var1"), because it will create a new different variable valled var1_1.

This script

var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)

prints:

var1:0
var1_1:0

If you want to create a reference to the node var1, you first:

  1. Have to replace tf.Variable with tf.get_variable. The variables created with tf.Variable can't be shared, while the latter can.

  2. Know what the scope of the var1 is and allow the reuse of that scope when declaring the reference.

Looking at the code is the better way for understanding

import tensorflow as tf

#var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
var1 = tf.get_variable(initializer=tf.constant_initializer(3.), dtype=tf.float64, name="var1", shape=())
current_scope = tf.contrib.framework.get_name_scope()
print(var1.name)
with tf.variable_scope(current_scope, reuse=True):
    var2 = tf.get_variable("var1",[],dtype=tf.float64)
    print(var2.name)

outputs:

var1:0
var1:0

Tags:

Tensorflow