The Python Oracle

Decorating with @tf.function changes if condition output

--------------------------------------------------
Rise to the top 3% as a developer or hire one of them at Toptal: https://topt.al/25cXVn
--------------------------------------------------

Music by Eric Matyas
https://www.soundimage.org
Track title: Ancient Construction

--

Chapters
00:00 Decorating With @Tf.Function Changes If Condition Output
00:37 Accepted Answer Score 1
02:47 Thank you

--

Full question
https://stackoverflow.com/questions/5718...

--

Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...

--

Tags
#python #tensorflow #tensorflow20

#avk47



ACCEPTED ANSWER

Score 1


It's a little bit of a head-scratcher but, once we understand that tf.function is mapping python ops & control flow to a tf graph whereas the bare function is just executing eagerly, we can pick through it and it makes a lot more sense.

I have tweaked your example to illustrate what's going on. Consider test1 and test2 below:

@tf.function
def test1(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

def test2(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

these are identical to one another except for the @tf.function decorator.

Now executing test2(tf.Variable([[]])) gives us:

0
python print size: 0

which is the behaviour I assume you'd expect. Whereas test1(tf.Variable([[]])) gives:

python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0

There are a couple of things (beyond the fail) about this output that you might find surprising:

  • The print() statement prints out a (yet to be evaluated) tensor rather than a zero
  • The order of the print() and the tf.print() have been reversed

This is because by adding the @tf.function we no longer have a python function but instead have a tf graph mapped from the function code using autograph. This means that, at the point that the if condition is evaluated, we have not yet executed tf.math.not_equal(tf.size(a),0) and just have an object (an instance of a Tensor object) which, in python, is truthy:

class MyClass:
  pass
my_obj = MyClass()
if (my_obj):
  print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"

This means we get to the print('fail') statement in test1 before having evaluated tf.math.not_equal(tf.size(a),0).

So what's the fix?

Well, if we remove the call to the python-only print() function in the if block and replace it with an autograph-friendly tf.print() statement then autograph will seamlessly convert our if ... else ... logic to a graph friendly tf.cond statement that ensures everything happens in the correct order:

def test3(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        tf.print('fail')
    with tf.control_dependencies([print_op]):
        return None
test3(tf.Variable([[]]))
0
python print size: 0