How do you use TensorFlow Graphkeys to get all weights?
Become part of the top 3% of the developers by applying to Toptal https://topt.al/25cXVn
--
Music by Eric Matyas
https://www.soundimage.org
Track title: Puzzle Game Looping
--
Chapters
00:00 Question
01:14 Accepted answer (Score 6)
01:54 Answer 2 (Score 3)
02:12 Thank you
--
Full question
https://stackoverflow.com/questions/4525...
Question links:
https://www.tensorflow.org/versions/r0.1...
Answer 1 links:
[docs]: https://www.tensorflow.org/api_docs/pyth...
--
Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...
--
Tags
#python #machinelearning #tensorflow
#avk47
--
Music by Eric Matyas
https://www.soundimage.org
Track title: Puzzle Game Looping
--
Chapters
00:00 Question
01:14 Accepted answer (Score 6)
01:54 Answer 2 (Score 3)
02:12 Thank you
--
Full question
https://stackoverflow.com/questions/4525...
Question links:
https://www.tensorflow.org/versions/r0.1...
Answer 1 links:
[docs]: https://www.tensorflow.org/api_docs/pyth...
--
Content licensed under CC BY-SA
https://meta.stackexchange.com/help/lice...
--
Tags
#python #machinelearning #tensorflow
#avk47
ACCEPTED ANSWER
Score 6
By default all variables are bounded to the tf.GraphKeys.GLOBAL_VARIABLES collection. The convenient method is to set each weight to the collection tf.GraphKeys.WEIGHTS like this:
In [2]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [3]: w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
Then you can fetch them by:
tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
And this is the weights:
[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
ANSWER 2
Score 3
From the docs:
The following standard keys are defined, but their collections are not automatically populated as many of the others are:
WEIGHTSBIASESACTIVATIONS