TensorFlow Estimator extra outputs Smuggling tensors out
Estimator Output Smuggling
While the TensorFlowEstimator framework has a lot of appeal, since it can hide a lot
of the training / evaluation / prediction mechanics, the price of this kind of convenience is often
paid in flexibility in how one can work with the models dynamically (i.e. in research mode). In particular,
it would be very convenient to be able to look at the values of several different output tensors created
by a model (other than just the ones designated ‘label’, etc).
The Estimator framework has now bifurcated, base on using either
the ‘OLD style’ x, y, batch_size input parameters,
or feeding information to the model using the ‘NEW style’ input_fn method
(which is more flexible, and doesn’t complain about DEPRECATION).
This post shows how to ‘smuggle’ out tensor results from a model that has been integrated into the Estimator framework. The
key parts are the SMUGGLE TENSORS OUT HERE sections in the model, and the subsequent .evaluate or .predict call
(depending on which style you’re using).
NB: It seems that once you’ve run specific batch_size ‘NEW style’, the model becomes specialized
w.r.t. batch_size and so no longer accepts ‘OLD style’ batches. This issue probably warrants further exploration -
except that ‘NEW style’ is clearly the better, more modern and more flexible way to go.
OLD STYLE runs (uses features and integer_labels PLAIN)
The following also illustrates the logical process required to find the magic incantation
tf.contrib.metrics.streaming_concat that pulls all the right stuff together.
The predictions['input_grad'] becomes the value of labels that gets concatenated into
the mnist_classifier.evaluate() results :
NEW STYLE runs (uses features dictionary and integer_labels PLAIN)
This is considerably easier, since the features dictionary allows one to smuggle more values IN,
and the (non-DEPRECATED) new style also allows one to use the outputs parameter of
Estimator.predict(), which means that only the tensors specified get calculated…
And here is the function that does a .predict() to get the extra tensor value out. Because
the outputs is defined, no superfluous computations are done :