Simplification Example

Author

Brandon T. Willard

Date

2019-09-08

1 Introduction

In this example, we’ll illustrate the effect of algebraic graph simplifications using the log-likelihood of a hierarchical normal-normal model.

import numpy as np
import tensorflow as tf

import tensorflow_probability as tfp

from tensorflow.python.eager.context import graph_mode
from tensorflow.python.framework.ops import disable_tensor_equality

from symbolic_pymc.tensorflow.printing import tf_dprint


disable_tensor_equality()

We start by including the graph normalization/simplifications native to TensorFlow via the grappler module. In grappler-normalize-function-sp, we create a helper function that applies grappler simplifications to a graph.

from tensorflow.core.protobuf import config_pb2

from tensorflow.python.framework import ops
from tensorflow.python.framework import importer
from tensorflow.python.framework import meta_graph

from tensorflow.python.grappler import cluster
from tensorflow.python.grappler import tf_optimizer


try:
    gcluster = cluster.Cluster()
except tf.errors.UnavailableError:
    pass

config = config_pb2.ConfigProto()


def normalize_tf_graph(graph_output, new_graph=True, verbose=False):
    """Use grappler to normalize a graph.

    Arguments
    =========
    graph_output: Tensor
      A tensor we want to consider as "output" of a FuncGraph.

    Returns
    =======
    The simplified graph.
    """
    train_op = graph_output.graph.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.clear()
    train_op.extend([graph_output])

    metagraph = meta_graph.create_meta_graph_def(graph=graph_output.graph)

    optimized_graphdef = tf_optimizer.OptimizeGraph(
        config, metagraph, verbose=verbose, cluster=gcluster)

    output_name = graph_output.name

    if new_graph:
        optimized_graph = ops.Graph()
    else:
        optimized_graph = ops.get_default_graph()
        del graph_output

    with optimized_graph.as_default():
        importer.import_graph_def(optimized_graphdef, name="")

    opt_graph_output = optimized_graph.get_tensor_by_name(output_name)

    return opt_graph_output

hier-normal-graph creates our model and normalizes it.

def tfp_normal_log_prob(x, loc, scale):
    log_unnormalized = -0.5 * tf.math.squared_difference(
        x / scale, loc / scale)
    log_normalization = 0.5 * np.log(2. * np.pi)
    # log_normalization += tf.math.log(scale)
    return log_unnormalized - log_normalization


with graph_mode(), tf.Graph().as_default() as demo_graph:

    x_tf = tf.compat.v1.placeholder(tf.float32, name='value_x',
                                    shape=tf.TensorShape([None]))
    tau_tf = tf.compat.v1.placeholder(tf.float32, name='tau',
                                      shape=tf.TensorShape([None]))
    y_tf = tf.compat.v1.placeholder(tf.float32, name='value_y',
                                    shape=tf.TensorShape([None]))

    X_tfp = tfp.distributions.normal.Normal(0.0, 1.0, name='X')

    z_tf = x_tf + tau_tf * y_tf

    hier_norm_lik = tf.math.log(z_tf)

    # Unscaled normal log-likelihood
    log_unnormalized = -0.5 * tf.math.squared_difference(
        z_tf / tau_tf, x_tf / tau_tf)
    log_normalization = 0.5 * np.log(2. * np.pi)
    hier_norm_lik += log_unnormalized - log_normalization

    hier_norm_lik += X_tfp.log_prob(x_tf)

    hier_norm_lik = normalize_tf_graph(hier_norm_lik)

In hier-normal-graph we used an unscaled version of the normal log-likelihood. This is because we’re emulating the effect of applying a substitution like \(Y \to x + \tau \epsilon \sim \operatorname{N}\left(x, \tau^2\right)\). This has the same effect as subtracting a \(\log(\tau)\) term; however, the result will produce equivalent–but not equal–graphs when we compare with the manually created fully transformed graph in manually-simplified-graph.

tf_dprint(hier_norm_lik)
Tensor(AddV2):0,    dtype=float32,  shape=[None],   "add_2:0"
|  Tensor(Sub):0,   dtype=float32,  shape=[None],   "X_1/log_prob/sub:0"
|  |  Tensor(Mul):0,        dtype=float32,  shape=[None],   "X_1/log_prob/mul:0"
|  |  |  Tensor(SquaredDifference):0,       dtype=float32,  shape=[None],   "X_1/log_prob/SquaredDifference:0"
|  |  |  |  Tensor(Mul):0,  dtype=float32,  shape=[None],   "X_1/log_prob/truediv:0"
|  |  |  |  |  Tensor(Const):0,     dtype=float32,  shape=[],       "ConstantFolding/X_1/log_prob/truediv_recip:0"
|  |  |  |  |  |  1.
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_x:0"
|  |  |  |  Tensor(Const):0,        dtype=float32,  shape=[],       "X_1/log_prob/truediv_1:0"
|  |  |  |  |  0.
|  |  |  Tensor(Const):0,   dtype=float32,  shape=[],       "mul_1/x:0"
|  |  |  |  -0.5
|  |  Tensor(Const):0,      dtype=float32,  shape=[],       "sub/y:0"
|  |  |  0.9189385
|  Tensor(AddV2):0, dtype=float32,  shape=[None],   "add_1:0"
|  |  Tensor(Log):0,        dtype=float32,  shape=[None],   "Log:0"
|  |  |  Tensor(AddV2):0,   dtype=float32,  shape=[None],   "add:0"
|  |  |  |  Tensor(Mul):0,  dtype=float32,  shape=[None],   "mul:0"
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "tau:0"
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_y:0"
|  |  |  |  Tensor(Placeholder):0,  dtype=float32,  shape=[None],   "value_x:0"
|  |  Tensor(Sub):0,        dtype=float32,  shape=[None],   "sub:0"
|  |  |  Tensor(Mul):0,     dtype=float32,  shape=[None],   "mul_1:0"
|  |  |  |  Tensor(SquaredDifference):0,    dtype=float32,  shape=[None],   "SquaredDifference:0"
|  |  |  |  |  Tensor(RealDiv):0,   dtype=float32,  shape=[None],   "truediv:0"
|  |  |  |  |  |  Tensor(AddV2):0,  dtype=float32,  shape=[None],   "add:0"
|  |  |  |  |  |  |  ...
|  |  |  |  |  |  Tensor(Placeholder):0,    dtype=float32,  shape=[None],   "tau:0"
|  |  |  |  |  Tensor(RealDiv):0,   dtype=float32,  shape=[None],   "truediv_1:0"
|  |  |  |  |  |  Tensor(Placeholder):0,    dtype=float32,  shape=[None],   "value_x:0"
|  |  |  |  |  |  Tensor(Placeholder):0,    dtype=float32,  shape=[None],   "tau:0"
|  |  |  |  Tensor(Const):0,        dtype=float32,  shape=[],       "mul_1/x:0"
|  |  |  |  |  -0.5
|  |  |  Tensor(Const):0,   dtype=float32,  shape=[],       "sub/y:0"
|  |  |  |  0.9189385

From hier-normal-graph-print we can see that grappler is not applying enough algebraic simplifications (e.g. it doesn’t remove multiplications with \(1\) or reduce the \(\left(\mu + x - \mu \right)^2\) term in SquaredDifference).

**Does missing this simplification amount to anything practical?**

manually-simplified-graph-eval demonstrates the difference between our model without the simplification and a manually constructed model with the simplification (i.e. manually-simplified-graph).

with graph_mode(), demo_graph.as_default():

    Z_tfp = tfp.distributions.normal.Normal(0.0, 1.0, name='Y_trans')

    hn_manually_simplified_lik = tf.math.log(z_tf)
    hn_manually_simplified_lik += Z_tfp.log_prob(y_tf)
    hn_manually_simplified_lik += X_tfp.log_prob(x_tf)

    hn_manually_simplified_lik = normalize_tf_graph(hn_manually_simplified_lik)
tf_dprint(hn_manually_simplified_lik)
Tensor(AddV2):0,    dtype=float32,  shape=[None],   "add_4:0"
|  Tensor(Sub):0,   dtype=float32,  shape=[None],   "X_2/log_prob/sub:0"
|  |  Tensor(Mul):0,        dtype=float32,  shape=[None],   "X_2/log_prob/mul:0"
|  |  |  Tensor(SquaredDifference):0,       dtype=float32,  shape=[None],   "X_2/log_prob/SquaredDifference:0"
|  |  |  |  Tensor(Mul):0,  dtype=float32,  shape=[None],   "X_2/log_prob/truediv:0"
|  |  |  |  |  Tensor(Const):0,     dtype=float32,  shape=[],       "ConstantFolding/Y_trans_1/log_prob/truediv_recip:0"
|  |  |  |  |  |  1.
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_x:0"
|  |  |  |  Tensor(Const):0,        dtype=float32,  shape=[],       "Y_trans_1/log_prob/truediv_1:0"
|  |  |  |  |  0.
|  |  |  Tensor(Const):0,   dtype=float32,  shape=[],       "Y_trans_1/log_prob/mul/x:0"
|  |  |  |  -0.5
|  |  Tensor(Const):0,      dtype=float32,  shape=[],       "Y_trans_1/log_prob/add:0"
|  |  |  0.9189385
|  Tensor(AddV2):0, dtype=float32,  shape=[None],   "add_3:0"
|  |  Tensor(Log):0,        dtype=float32,  shape=[None],   "Log_1:0"
|  |  |  Tensor(AddV2):0,   dtype=float32,  shape=[None],   "add:0"
|  |  |  |  Tensor(Mul):0,  dtype=float32,  shape=[None],   "mul:0"
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "tau:0"
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_y:0"
|  |  |  |  Tensor(Placeholder):0,  dtype=float32,  shape=[None],   "value_x:0"
|  |  Tensor(Sub):0,        dtype=float32,  shape=[None],   "Y_trans_1/log_prob/sub:0"
|  |  |  Tensor(Mul):0,     dtype=float32,  shape=[None],   "Y_trans_1/log_prob/mul:0"
|  |  |  |  Tensor(SquaredDifference):0,    dtype=float32,  shape=[None],   "Y_trans_1/log_prob/SquaredDifference:0"
|  |  |  |  |  Tensor(Mul):0,       dtype=float32,  shape=[None],   "Y_trans_1/log_prob/truediv:0"
|  |  |  |  |  |  Tensor(Const):0,  dtype=float32,  shape=[],       "ConstantFolding/Y_trans_1/log_prob/truediv_recip:0"
|  |  |  |  |  |  |  1.
|  |  |  |  |  |  Tensor(Placeholder):0,    dtype=float32,  shape=[None],   "value_y:0"
|  |  |  |  |  Tensor(Const):0,     dtype=float32,  shape=[],       "Y_trans_1/log_prob/truediv_1:0"
|  |  |  |  |  |  0.
|  |  |  |  Tensor(Const):0,        dtype=float32,  shape=[],       "Y_trans_1/log_prob/mul/x:0"
|  |  |  |  |  -0.5
|  |  |  Tensor(Const):0,   dtype=float32,  shape=[],       "Y_trans_1/log_prob/add:0"
|  |  |  |  0.9189385
test_point = {x_tf.name: np.r_[1.0],
              tau_tf.name: np.r_[1e-9],
              y_tf.name: np.r_[1000.1]}

with tf.compat.v1.Session(graph=hn_manually_simplified_lik.graph).as_default():
    hn_manually_simplified_val = hn_manually_simplified_lik.eval(test_point)

with tf.compat.v1.Session(graph=hier_norm_lik.graph).as_default():
    hn_unsimplified_val = hier_norm_lik.eval(test_point)

_ = np.subtract(hn_unsimplified_val, hn_manually_simplified_val)
[39299.97]

The output of manually-simplified-graph-eval shows exactly how large the discrepancy can be for carefully chosen parameter values. More specifically, as tau_tf gets smaller and the magnitude of the difference x_tf - y_tf gets larger, the discrepancy can increase. Since such parameter values are likely to be visited during sampling, we should address this missing simplification.

In further-simplify-test-graph we create a goal that performs that aforementioned simplification for SquaredDifference.

from functools import partial
from collections import Sequence

from unification import var

from kanren import run, eq, lall, conde
from kanren.facts import fact
from kanren.assoccomm import eq_comm, commutative
from kanren.graph import walko

from etuples import etuple, etuplize
from etuples.core import ExpressionTuple

from symbolic_pymc.meta import enable_lvar_defaults
from symbolic_pymc.tensorflow.meta import mt, TFlowMetaOperator


fact(commutative, TFlowMetaOperator(mt.SquaredDifference.op_def, var()))


def recenter_sqrdiffo(in_g, out_g):
    """Create a goal that essentially reduces `(a / d - (a + d * c) / d)**2` to `d**2`"""
    a_sqd_lv, b_sqd_lv, d_sqd_lv = var(), var(), var()

    with enable_lvar_defaults('names'):
        # Pattern: (a / d - b / d)**2
        target_sqrdiff_lv = mt.SquaredDifference(
            mt.Realdiv(a_sqd_lv, d_sqd_lv),
            mt.Realdiv(b_sqd_lv, d_sqd_lv))

        # Pattern: d * c + a
        c_sqd_lv = var()
        b_part_lv = mt.AddV2(mt.Mul(d_sqd_lv, c_sqd_lv), a_sqd_lv)

    # Replacement: c**2
    simplified_sqrdiff_lv = mt.SquaredDifference(
        c_sqd_lv,
        0.0
    )

    reshape_lv = var()
    simplified_sqrdiff_reshaped_lv = mt.SquaredDifference(
        mt.reshape(c_sqd_lv, reshape_lv),
        0.0
    )

    with enable_lvar_defaults('names'):
        b_sqd_reshape_lv = mt.Reshape(b_part_lv, reshape_lv)

    res = lall(
        # input == (a / d - b / d)**2 must be "true"
        eq_comm(in_g, target_sqrdiff_lv),
        # "and"
        conde([
            # "if" b == d * c + a is "true"
            eq(b_sqd_lv, b_part_lv),
            # "then" output ==  (c - 0)**2 is also "true"
            eq(out_g, simplified_sqrdiff_lv)

            # "or"
        ], [
            # We have to use this to cover some variation also not
            # sufficiently/consistently "normalized" by `grappler`.

            # "if" b == reshape(d * c + a, ?) is "true"
            eq_comm(b_sqd_lv, b_sqd_reshape_lv),
            # "then" output == (reshape(c, ?) - 0)**2 is also "true"
            eq(out_g, simplified_sqrdiff_reshaped_lv)
        ]))
    return res

We apply the simplification in further-simplify-test-graph and print the results in further-simplify-test-graph-print.

from kanren.graph import reduceo


with graph_mode(), hier_norm_lik.graph.as_default():
    q = var()
    res = run(1, q,
              reduceo(lambda x, y: walko(recenter_sqrdiffo, x, y),
                      hier_norm_lik, q))

with graph_mode(), tf.Graph().as_default() as result_graph:
    hn_simplified_tf = res[0].eval_obj.reify()
    hn_simplified_tf = normalize_tf_graph(hn_simplified_tf)
# tf_dprint(hier_norm_lik.graph.get_tensor_by_name('SquaredDifference:0'))
tf_dprint(hn_simplified_tf)
Tensor(AddV2):0,    dtype=float32,  shape=[None],   "add_2_1:0"
|  Tensor(Sub):0,   dtype=float32,  shape=[None],   "X_1/log_prob/sub:0"
|  |  Tensor(Mul):0,        dtype=float32,  shape=[None],   "X_1/log_prob/mul:0"
|  |  |  Tensor(SquaredDifference):0,       dtype=float32,  shape=[None],   "X_1/log_prob/SquaredDifference:0"
|  |  |  |  Tensor(Mul):0,  dtype=float32,  shape=[None],   "X_1/log_prob/truediv:0"
|  |  |  |  |  Tensor(Const):0,     dtype=float32,  shape=[],       "ConstantFolding/X_1/log_prob/truediv_recip:0"
|  |  |  |  |  |  1.
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_x:0"
|  |  |  |  Tensor(Const):0,        dtype=float32,  shape=[],       "X_1/log_prob/truediv_1:0"
|  |  |  |  |  0.
|  |  |  Tensor(Const):0,   dtype=float32,  shape=[],       "mul_1/x:0"
|  |  |  |  -0.5
|  |  Tensor(Const):0,      dtype=float32,  shape=[],       "sub/y:0"
|  |  |  0.9189385
|  Tensor(AddV2):0, dtype=float32,  shape=[None],   "add_1_1:0"
|  |  Tensor(Log):0,        dtype=float32,  shape=[None],   "Log:0"
|  |  |  Tensor(AddV2):0,   dtype=float32,  shape=[None],   "add:0"
|  |  |  |  Tensor(Mul):0,  dtype=float32,  shape=[None],   "mul:0"
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "tau:0"
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_y:0"
|  |  |  |  Tensor(Placeholder):0,  dtype=float32,  shape=[None],   "value_x:0"
|  |  Tensor(Sub):0,        dtype=float32,  shape=[None],   "sub_1:0"
|  |  |  Tensor(Mul):0,     dtype=float32,  shape=[None],   "mul_1_1:0"
|  |  |  |  Tensor(SquaredDifference):0,    dtype=float32,  shape=[None],   "SquaredDifference_1:0"
|  |  |  |  |  Tensor(Const):0,     dtype=float32,  shape=[],       "X_1/log_prob/truediv_1:0"
|  |  |  |  |  |  0.
|  |  |  |  |  Tensor(Placeholder):0,       dtype=float32,  shape=[None],   "value_y:0"
|  |  |  |  Tensor(Const):0,        dtype=float32,  shape=[],       "mul_1/x:0"
|  |  |  |  |  -0.5
|  |  |  Tensor(Const):0,   dtype=float32,  shape=[],       "sub/y:0"
|  |  |  |  0.9189385

After applying our simplification, simplified-eval-print numerically demonstrates that the difference is gone and that our transform produces a graph equivalent to the manually simplified graph in manually-simplified-graph.

with tf.compat.v1.Session(graph=hn_simplified_tf.graph).as_default():
    hn_simplified_val = hn_simplified_tf.eval(test_point)

_ = np.subtract(hn_manually_simplified_val, hn_simplified_val)
[0.]