Decorating with @tf.function changes if condition output
Rise to the top 3% as a developer or hire one of them at Toptal: https://topt.al/25cXVn
--------------------------------------------------
Music by Eric Matyas
https://www.soundimage.org
Track title: Ancient Construction
--
Chapters
00:00 Decorating With @Tf.Function Changes If Condition Output
00:37 Accepted Answer Score 1
02:47 Thank you
--
Full question
https://stackoverflow.com/questions/5718...
--
Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...
--
Tags
#python #tensorflow #tensorflow20
#avk47
ACCEPTED ANSWER
Score 1
It's a little bit of a head-scratcher but, once we understand that tf.function is mapping python ops & control flow to a tf graph whereas the bare function is just executing eagerly, we can pick through it and it makes a lot more sense.
I have tweaked your example to illustrate what's going on. Consider test1 and test2 below:
@tf.function
def test1(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
def test2(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
these are identical to one another except for the @tf.function decorator.
Now executing test2(tf.Variable([[]])) gives us:
0
python print size: 0
which is the behaviour I assume you'd expect. Whereas test1(tf.Variable([[]])) gives:
python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0
There are a couple of things (beyond the fail) about this output that you might find surprising:
- The
print()statement prints out a (yet to be evaluated) tensor rather than a zero - The order of the
print()and thetf.print()have been reversed
This is because by adding the @tf.function we no longer have a python function but instead have a tf graph mapped from the function code using autograph. This means that, at the point that the if condition is evaluated, we have not yet executed tf.math.not_equal(tf.size(a),0) and just have an object (an instance of a Tensor object) which, in python, is truthy:
class MyClass:
pass
my_obj = MyClass()
if (my_obj):
print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"
This means we get to the print('fail') statement in test1 before having evaluated tf.math.not_equal(tf.size(a),0).
So what's the fix?
Well, if we remove the call to the python-only print() function in the if block and replace it with an autograph-friendly tf.print() statement then autograph will seamlessly convert our if ... else ... logic to a graph friendly tf.cond statement that ensures everything happens in the correct order:
def test3(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
tf.print('fail')
with tf.control_dependencies([print_op]):
return None
test3(tf.Variable([[]]))
0
python print size: 0