Simplification Example¶
- Author
Brandon T. Willard
- Date
2019-09-08
1 Introduction¶
In this example, we’ll illustrate the effect of algebraic graph simplifications using the log-likelihood of a hierarchical normal-normal model.
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.eager.context import graph_mode
from tensorflow.python.framework.ops import disable_tensor_equality
from symbolic_pymc.tensorflow.printing import tf_dprint
disable_tensor_equality()
We start by including the graph normalization/simplifications native to
TensorFlow via the grappler module. In
grappler-normalize-function-sp, we create a helper function that
applies grappler simplifications to a graph.
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework import importer
from tensorflow.python.framework import meta_graph
from tensorflow.python.grappler import cluster
from tensorflow.python.grappler import tf_optimizer
try:
gcluster = cluster.Cluster()
except tf.errors.UnavailableError:
pass
config = config_pb2.ConfigProto()
def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
"""Use grappler to normalize a graph.
Arguments
=========
graph_output: Tensor
A tensor we want to consider as "output" of a FuncGraph.
Returns
=======
The simplified graph.
"""
train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.clear()
train_op.extend([graph_output])
metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)
optimized_graphdef = tf_optimizer.OptimizeGraph(
config, metagraph, verbose=verbose, cluster=gcluster)
output_name = graph_output.name
if new_graph:
optimized_graph = ops.Graph()
else:
optimized_graph = ops.get_default_graph()
del graph_output
with optimized_graph.as_default():
importer.import_graph_def(optimized_graphdef, name="")
opt_graph_output = optimized_graph.get_tensor_by_name(output_name)
return opt_graph_output
hier-normal-graph creates our model and normalizes it.
def tfp_normal_log_prob(x, loc, scale):
log_unnormalized = -0.5 * tf.math.squared_difference(
x / scale, loc / scale)
log_normalization = 0.5 * np.log(2. * np.pi)
# log_normalization += tf.math.log(scale)
return log_unnormalized - log_normalization
with graph_mode(), tf.Graph().as_default() as demo_graph:
x_tf = tf.compat.v1.placeholder(tf.float32, name='value_x',
shape=tf.TensorShape([None]))
tau_tf = tf.compat.v1.placeholder(tf.float32, name='tau',
shape=tf.TensorShape([None]))
y_tf = tf.compat.v1.placeholder(tf.float32, name='value_y',
shape=tf.TensorShape([None]))
X_tfp = tfp.distributions.normal.Normal(0.0, 1.0, name='X')
z_tf = x_tf + tau_tf * y_tf
hier_norm_lik = tf.math.log(z_tf)
# Unscaled normal log-likelihood
log_unnormalized = -0.5 * tf.math.squared_difference(
z_tf / tau_tf, x_tf / tau_tf)
log_normalization = 0.5 * np.log(2. * np.pi)
hier_norm_lik += log_unnormalized - log_normalization
hier_norm_lik += X_tfp.log_prob(x_tf)
hier_norm_lik = normalize_tf_graph(hier_norm_lik)
In hier-normal-graph we used an unscaled version of the normal log-likelihood. This is because we’re emulating the effect of applying a substitution like \(Y \to x + \tau \epsilon \sim \operatorname{N}\left(x, \tau^2\right)\). This has the same effect as subtracting a \(\log(\tau)\) term; however, the result will produce equivalent–but not equal–graphs when we compare with the manually created fully transformed graph in manually-simplified-graph.
tf_dprint(hier_norm_lik)
Tensor(AddV2):0, dtype=float32, shape=[None], "add_2:0"
| Tensor(Sub):0, dtype=float32, shape=[None], "X_1/log_prob/sub:0"
| | Tensor(Mul):0, dtype=float32, shape=[None], "X_1/log_prob/mul:0"
| | | Tensor(SquaredDifference):0, dtype=float32, shape=[None], "X_1/log_prob/SquaredDifference:0"
| | | | Tensor(Mul):0, dtype=float32, shape=[None], "X_1/log_prob/truediv:0"
| | | | | Tensor(Const):0, dtype=float32, shape=[], "ConstantFolding/X_1/log_prob/truediv_recip:0"
| | | | | | 1.
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | | | Tensor(Const):0, dtype=float32, shape=[], "X_1/log_prob/truediv_1:0"
| | | | | 0.
| | | Tensor(Const):0, dtype=float32, shape=[], "mul_1/x:0"
| | | | -0.5
| | Tensor(Const):0, dtype=float32, shape=[], "sub/y:0"
| | | 0.9189385
| Tensor(AddV2):0, dtype=float32, shape=[None], "add_1:0"
| | Tensor(Log):0, dtype=float32, shape=[None], "Log:0"
| | | Tensor(AddV2):0, dtype=float32, shape=[None], "add:0"
| | | | Tensor(Mul):0, dtype=float32, shape=[None], "mul:0"
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "tau:0"
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_y:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | Tensor(Sub):0, dtype=float32, shape=[None], "sub:0"
| | | Tensor(Mul):0, dtype=float32, shape=[None], "mul_1:0"
| | | | Tensor(SquaredDifference):0, dtype=float32, shape=[None], "SquaredDifference:0"
| | | | | Tensor(RealDiv):0, dtype=float32, shape=[None], "truediv:0"
| | | | | | Tensor(AddV2):0, dtype=float32, shape=[None], "add:0"
| | | | | | | ...
| | | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "tau:0"
| | | | | Tensor(RealDiv):0, dtype=float32, shape=[None], "truediv_1:0"
| | | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "tau:0"
| | | | Tensor(Const):0, dtype=float32, shape=[], "mul_1/x:0"
| | | | | -0.5
| | | Tensor(Const):0, dtype=float32, shape=[], "sub/y:0"
| | | | 0.9189385
From hier-normal-graph-print we can see
that grappler is not applying enough algebraic
simplifications (e.g. it doesn’t remove multiplications with \(1\) or reduce the
\(\left(\mu + x - \mu \right)^2\) term
in SquaredDifference).
**Does missing this simplification amount to anything practical?**
manually-simplified-graph-eval demonstrates the difference between our model without the simplification and a manually constructed model with the simplification (i.e. manually-simplified-graph).
with graph_mode(), demo_graph.as_default():
Z_tfp = tfp.distributions.normal.Normal(0.0, 1.0, name='Y_trans')
hn_manually_simplified_lik = tf.math.log(z_tf)
hn_manually_simplified_lik += Z_tfp.log_prob(y_tf)
hn_manually_simplified_lik += X_tfp.log_prob(x_tf)
hn_manually_simplified_lik = normalize_tf_graph(hn_manually_simplified_lik)
tf_dprint(hn_manually_simplified_lik)
Tensor(AddV2):0, dtype=float32, shape=[None], "add_4:0"
| Tensor(Sub):0, dtype=float32, shape=[None], "X_2/log_prob/sub:0"
| | Tensor(Mul):0, dtype=float32, shape=[None], "X_2/log_prob/mul:0"
| | | Tensor(SquaredDifference):0, dtype=float32, shape=[None], "X_2/log_prob/SquaredDifference:0"
| | | | Tensor(Mul):0, dtype=float32, shape=[None], "X_2/log_prob/truediv:0"
| | | | | Tensor(Const):0, dtype=float32, shape=[], "ConstantFolding/Y_trans_1/log_prob/truediv_recip:0"
| | | | | | 1.
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | | | Tensor(Const):0, dtype=float32, shape=[], "Y_trans_1/log_prob/truediv_1:0"
| | | | | 0.
| | | Tensor(Const):0, dtype=float32, shape=[], "Y_trans_1/log_prob/mul/x:0"
| | | | -0.5
| | Tensor(Const):0, dtype=float32, shape=[], "Y_trans_1/log_prob/add:0"
| | | 0.9189385
| Tensor(AddV2):0, dtype=float32, shape=[None], "add_3:0"
| | Tensor(Log):0, dtype=float32, shape=[None], "Log_1:0"
| | | Tensor(AddV2):0, dtype=float32, shape=[None], "add:0"
| | | | Tensor(Mul):0, dtype=float32, shape=[None], "mul:0"
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "tau:0"
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_y:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | Tensor(Sub):0, dtype=float32, shape=[None], "Y_trans_1/log_prob/sub:0"
| | | Tensor(Mul):0, dtype=float32, shape=[None], "Y_trans_1/log_prob/mul:0"
| | | | Tensor(SquaredDifference):0, dtype=float32, shape=[None], "Y_trans_1/log_prob/SquaredDifference:0"
| | | | | Tensor(Mul):0, dtype=float32, shape=[None], "Y_trans_1/log_prob/truediv:0"
| | | | | | Tensor(Const):0, dtype=float32, shape=[], "ConstantFolding/Y_trans_1/log_prob/truediv_recip:0"
| | | | | | | 1.
| | | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_y:0"
| | | | | Tensor(Const):0, dtype=float32, shape=[], "Y_trans_1/log_prob/truediv_1:0"
| | | | | | 0.
| | | | Tensor(Const):0, dtype=float32, shape=[], "Y_trans_1/log_prob/mul/x:0"
| | | | | -0.5
| | | Tensor(Const):0, dtype=float32, shape=[], "Y_trans_1/log_prob/add:0"
| | | | 0.9189385
test_point = {x_tf.name: np.r_[1.0],
tau_tf.name: np.r_[1e-9],
y_tf.name: np.r_[1000.1]}
with tf.compat.v1.Session(graph=hn_manually_simplified_lik.graph).as_default():
hn_manually_simplified_val = hn_manually_simplified_lik.eval(test_point)
with tf.compat.v1.Session(graph=hier_norm_lik.graph).as_default():
hn_unsimplified_val = hier_norm_lik.eval(test_point)
_ = np.subtract(hn_unsimplified_val, hn_manually_simplified_val)
[39299.97]
The output of manually-simplified-graph-eval shows exactly how large
the discrepancy can be for carefully chosen parameter values. More
specifically, as tau_tf gets smaller and the magnitude
of the difference x_tf - y_tf gets larger, the
discrepancy can increase. Since such parameter values are likely to be visited
during sampling, we should address this missing simplification.
In further-simplify-test-graph we create a goal that performs that
aforementioned simplification for SquaredDifference.
from functools import partial
from collections import Sequence
from unification import var
from kanren import run, eq, lall, conde
from kanren.facts import fact
from kanren.assoccomm import eq_comm, commutative
from kanren.graph import walko
from etuples import etuple, etuplize
from etuples.core import ExpressionTuple
from symbolic_pymc.meta import enable_lvar_defaults
from symbolic_pymc.tensorflow.meta import mt, TFlowMetaOperator
fact(commutative, TFlowMetaOperator(mt.SquaredDifference.op_def, var()))
def recenter_sqrdiffo(in_g, out_g):
"""Create a goal that essentially reduces `(a / d - (a + d * c) / d)**2` to `d**2`"""
a_sqd_lv, b_sqd_lv, d_sqd_lv = var(), var(), var()
with enable_lvar_defaults('names'):
# Pattern: (a / d - b / d)**2
target_sqrdiff_lv = mt.SquaredDifference(
mt.Realdiv(a_sqd_lv, d_sqd_lv),
mt.Realdiv(b_sqd_lv, d_sqd_lv))
# Pattern: d * c + a
c_sqd_lv = var()
b_part_lv = mt.AddV2(mt.Mul(d_sqd_lv, c_sqd_lv), a_sqd_lv)
# Replacement: c**2
simplified_sqrdiff_lv = mt.SquaredDifference(
c_sqd_lv,
0.0
)
reshape_lv = var()
simplified_sqrdiff_reshaped_lv = mt.SquaredDifference(
mt.reshape(c_sqd_lv, reshape_lv),
0.0
)
with enable_lvar_defaults('names'):
b_sqd_reshape_lv = mt.Reshape(b_part_lv, reshape_lv)
res = lall(
# input == (a / d - b / d)**2 must be "true"
eq_comm(in_g, target_sqrdiff_lv),
# "and"
conde([
# "if" b == d * c + a is "true"
eq(b_sqd_lv, b_part_lv),
# "then" output == (c - 0)**2 is also "true"
eq(out_g, simplified_sqrdiff_lv)
# "or"
], [
# We have to use this to cover some variation also not
# sufficiently/consistently "normalized" by `grappler`.
# "if" b == reshape(d * c + a, ?) is "true"
eq_comm(b_sqd_lv, b_sqd_reshape_lv),
# "then" output == (reshape(c, ?) - 0)**2 is also "true"
eq(out_g, simplified_sqrdiff_reshaped_lv)
]))
return res
We apply the simplification in further-simplify-test-graph and print the results in further-simplify-test-graph-print.
from kanren.graph import reduceo
with graph_mode(), hier_norm_lik.graph.as_default():
q = var()
res = run(1, q,
reduceo(lambda x, y: walko(recenter_sqrdiffo, x, y),
hier_norm_lik, q))
with graph_mode(), tf.Graph().as_default() as result_graph:
hn_simplified_tf = res[0].eval_obj.reify()
hn_simplified_tf = normalize_tf_graph(hn_simplified_tf)
# tf_dprint(hier_norm_lik.graph.get_tensor_by_name('SquaredDifference:0'))
tf_dprint(hn_simplified_tf)
Tensor(AddV2):0, dtype=float32, shape=[None], "add_2_1:0"
| Tensor(Sub):0, dtype=float32, shape=[None], "X_1/log_prob/sub:0"
| | Tensor(Mul):0, dtype=float32, shape=[None], "X_1/log_prob/mul:0"
| | | Tensor(SquaredDifference):0, dtype=float32, shape=[None], "X_1/log_prob/SquaredDifference:0"
| | | | Tensor(Mul):0, dtype=float32, shape=[None], "X_1/log_prob/truediv:0"
| | | | | Tensor(Const):0, dtype=float32, shape=[], "ConstantFolding/X_1/log_prob/truediv_recip:0"
| | | | | | 1.
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | | | Tensor(Const):0, dtype=float32, shape=[], "X_1/log_prob/truediv_1:0"
| | | | | 0.
| | | Tensor(Const):0, dtype=float32, shape=[], "mul_1/x:0"
| | | | -0.5
| | Tensor(Const):0, dtype=float32, shape=[], "sub/y:0"
| | | 0.9189385
| Tensor(AddV2):0, dtype=float32, shape=[None], "add_1_1:0"
| | Tensor(Log):0, dtype=float32, shape=[None], "Log:0"
| | | Tensor(AddV2):0, dtype=float32, shape=[None], "add:0"
| | | | Tensor(Mul):0, dtype=float32, shape=[None], "mul:0"
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "tau:0"
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_y:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_x:0"
| | Tensor(Sub):0, dtype=float32, shape=[None], "sub_1:0"
| | | Tensor(Mul):0, dtype=float32, shape=[None], "mul_1_1:0"
| | | | Tensor(SquaredDifference):0, dtype=float32, shape=[None], "SquaredDifference_1:0"
| | | | | Tensor(Const):0, dtype=float32, shape=[], "X_1/log_prob/truediv_1:0"
| | | | | | 0.
| | | | | Tensor(Placeholder):0, dtype=float32, shape=[None], "value_y:0"
| | | | Tensor(Const):0, dtype=float32, shape=[], "mul_1/x:0"
| | | | | -0.5
| | | Tensor(Const):0, dtype=float32, shape=[], "sub/y:0"
| | | | 0.9189385
After applying our simplification, simplified-eval-print numerically demonstrates that the difference is gone and that our transform produces a graph equivalent to the manually simplified graph in manually-simplified-graph.
with tf.compat.v1.Session(graph=hn_simplified_tf.graph).as_default():
hn_simplified_val = hn_simplified_tf.eval(test_point)
_ = np.subtract(hn_manually_simplified_val, hn_simplified_val)
[0.]