diff --git a/tensorflow_graph_in_jupyter.py b/tensorflow_graph_in_jupyter.py new file mode 100644 index 0000000..0e1edfd --- /dev/null +++ b/tensorflow_graph_in_jupyter.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +# This module defines the show_graph() function to visualize a TensorFlow graph within Jupyter. + +# As far as I can tell, this code was originally written by Alex Mordvintsev at: +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/deepdream/deepdream.ipynb + +# The original code only worked on Chrome (because of the use of , but the version below +# uses Polyfill (copied from this StackOverflow answer: https://stackoverflow.com/a/41463991/38626) +# so that it can work on other browsers as well. + +import numpy as np +import tensorflow as tf +from IPython.display import clear_output, Image, display, HTML + +def strip_consts(graph_def, max_const_size=32): + """Strip large constant values from graph_def.""" + strip_def = tf.GraphDef() + for n0 in graph_def.node: + n = strip_def.node.add() + n.MergeFrom(n0) + if n.op == 'Const': + tensor = n.attr['value'].tensor + size = len(tensor.tensor_content) + if size > max_const_size: + tensor.tensor_content = ""%size + return strip_def + +def show_graph(graph_def, max_const_size=32): + """Visualize TensorFlow graph.""" + if hasattr(graph_def, 'as_graph_def'): + graph_def = graph_def.as_graph_def() + strip_def = strip_consts(graph_def, max_const_size=max_const_size) + code = """ + + + +
+ +
+ """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand())) + + iframe = """ + + """.format(code.replace('"', '"')) + display(HTML(iframe))