A Tour of Symbolic PyMC¶
- Author
Brandon T. Willard
- Date
2019-08-03
Introduction¶
In this document we’ll cover the basics of the Symbolic PyMC package while
implementing a symbolic “search-and-replace” that changes TensorFlow graphs
like tf.matmul(A, x + y)
into tf.matmul(A, x) + tf.matmul(A, y)
. In other words, we’ll
demonstrate how to implement the distributive property of matrix multiplication
so that it can be applied to arbitrary TensorFlow graphs.
Symbolic PyMC allows one to implement rewrite rules like the distributive property–and many other sophisticated manipulations of graphs–by providing flexible, pure Python versions of core operations in symbolic computation. These operations are then combined and orchestrated through the relational programming DSL miniKanren.
More specifically, we’ll introduce the basic unification and reification operations and explicitly show how they relate to graph manipulation and the modeling of high-level mathematical relations. Along the way, we’ll cover some of the necessary details behind TensorFlow graphs and how they’re modeled by meta graph objects in Symbolic PyMC.
We start by creating a graph of our target
expressions–i.e. tf.matmul(A, x + y)
–in TensorFlow.
We need to do this in order to determine exactly what we’re searching for
and–later–what to put in its place.
import numpy as np
import tensorflow as tf
from IPython.lib.pretty import pprint
from tensorflow.python.eager.context import graph_mode
with graph_mode():
# Matrix
A_tf = tf.compat.v1.placeholder(tf.float32, name='A',
shape=tf.TensorShape([None, None]))
# Column vectors
x_tf = tf.compat.v1.placeholder(tf.float32, name='x',
shape=tf.TensorShape([None, 1]))
y_tf = tf.compat.v1.placeholder(tf.float32, name='y',
shape=tf.TensorShape([None, 1]))
# The multiplication
z_tf = tf.matmul(A_tf, x_tf + y_tf)
A text print-out of the full TensorFlow graph is provided by the debug print
function tf_dprint
.
from symbolic_pymc.tensorflow.printing import tf_dprint
tf_dprint(z_tf)
Tensor(MatMul):0, dtype=float32, shape=[None, 1], "MatMul:0"
| Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| Tensor(AddV2):0, dtype=float32, shape=[None, 1], "add:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
The output of tf-print-graph shows us the underlying operators (e.g. MatMul
,
Placeholder
, AddV2
) and their arguments.
To “match/search for” combinations of TensorFlow operations–or, in other words, graphs–like tf-print-graph, we use **unification**; to “replace” parts of a graph (well, to produce copies with replaced parts), we use **reification**. Symbolic PyMC provides support for these using TensorFlow (and Theano) graphs via meta objects and expression-tuples.
Meta Objects¶
Meta objects model the essential components of TensorFlow graphs, while allowing one to use input that isn’t normally valid. More specifically, we can construct meta graphs that contain logic variables. Later, those logic variables can be replaced with other objects that allow the meta graph to be converted into a real TensorFlow graph.
Existing TensorFlow graphs can be converted to their meta graph equivalents with
the mt
helper object.
from symbolic_pymc.tensorflow.meta import mt
z_mt = mt(z_tf)
tf_dprint(z_mt)
Tensor(MatMul):0, dtype=float32, shape=[None, 1], "MatMul:0"
| Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| Tensor(AddV2):0, dtype=float32, shape=[None, 1], "add:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
A meta graph can be converted to a TensorFlow graph using its reify
method.
tf_dprint(z_mt.reify())
Tensor(MatMul):0, dtype=float32, shape=[None, 1], "MatMul:0"
| Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| Tensor(AddV2):0, dtype=float32, shape=[None, 1], "add:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
The mt
object also makes it easier to construct meta graphs by hand.
from unification import unify, reify, var
with graph_mode():
add_mt = mt.add(1, var('a'))
pprint(add_mt)
TFlowMetaTensor(
op=TFlowMetaOp(
op_def=TFlowMetaOpDef(Add),
node_def=TFlowMetaNodeDef(op='Add', name='Add', attr={'T': ~_6}),
inputs=(TFlowMetaTensor(
op=TFlowMetaOp(
op_def=TFlowMetaOpDef(Const),
node_def=TFlowMetaNodeDef(
op='Const',
name='Const',
attr={'value': HashableNDArray(1, dtype=int32), 'dtype': 'int32'}),
inputs=()),
value_index=0,
dtype=tf.int32),
~a)),
value_index=0,
dtype=tf.int32)
In create-meta-graph, we created a graph of 1
plus
a unification
logic variable with the name 'a'
. This
wouldn’t be possible with a standard TensorFlow graph.
Also, because one of the elements in the graph is a logic variable, it cannot be
converted into a TensorFlow graph. Instead, if we attempt to use the meta
graph’s reify
method, we are simply given the meta graph back.
pprint(add_mt.reify())
TFlowMetaTensor(
op=TFlowMetaOp(
op_def=TFlowMetaOpDef(Add),
node_def=TFlowMetaNodeDef(op='Add', name='Add', attr={'T': ~_6}),
inputs=(TFlowMetaTensor(
op=TFlowMetaOp(
op_def=TFlowMetaOpDef(Const),
node_def=TFlowMetaNodeDef(
op='Const',
name='Const',
attr={'value': HashableNDArray(1, dtype=int32), 'dtype': 'int32'}),
inputs=()),
value_index=0,
dtype=tf.int32),
~a)),
value_index=0,
dtype=tf.int32)
S-expressions¶
As an alternative approach to full meta graph conversion, we can also convert TensorFlow graphs into an S-expression-like form using ``etuples` <https://github.com/pythological/etuples>`_.
from etuples import etuple, etuplize
z_sexp = etuplize(z_tf)
pprint(z_sexp)
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(MatMul),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'MatMul',
'MatMul',
{'T': 'float32', 'transpose_a': False, 'transpose_b': False})),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'A',
{'dtype': 'float32',
'shape': TFlowMetaTensorShape(dims=(None, None))}))),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(AddV2),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'AddV2',
'add',
{'T': 'float32'})),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'x',
{'dtype': 'float32',
'shape': TFlowMetaTensorShape(dims=(None, 1))}))),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'y',
{'shape': TFlowMetaTensorShape(dims=(None, 1)),
'dtype': 'float32'})))))
An etuple
is like a
normal tuple
, except that its first element is
a Callable
and the remaining elements are
the Callable
’s arguments.
As above, a pretty-printed etuple
looks like
a tuple
prefixed by an e
.
By working with etuple
s, we can use arbitrary Python functions in
conjunction with meta graphs and logic variable arguments. Basically,
an etuple
can be manipulated until all of its constituent logic
variables and meta objects are eventually replaced with valid arguments to the
function/operator. At that point, the etuple
can be evaluated.
For example, in etuple-eval-example, we create an etuple
that uses the TensorFlow function tf.add
with a logic variable argument.
x_lv, y_lv = var('x'), var('y')
add_tf_pat = etuple(tf.add, x_lv, y_lv)
Normally, it wouldn’t be possible to call this function with these argument types, as demonstrated in etuple-bad-usage-example.
try:
tf.add(x_lv, 1)
except ValueError as e:
print(str(e))
2019-11-17 20:48:04.437195: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-11-17 20:48:04.461487: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2112000000 Hz
2019-11-17 20:48:04.462162: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x558d5e551fc0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2019-11-17 20:48:04.462183: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
Attempt to convert a value (~x) with an unsupported type (<class 'unification.variable.Var'>) to a Tensor.
We’ll get the same error if we attempt to evaluate
the etuple
by accessing
its ExpressionTuple.eval_obj
property. However, after
performing a simple manipulation that replaces the logic variable with valid
inputs to tf.add
, we are able to evaluate
the etuple
and obtain a TF Tensor result, as
demonstrated in etuple-reify-example and
etuple-reify-eval-print-example.
add_pat_new = reify(add_tf_pat, {x_lv: 1, y_lv: 1})
pprint(add_pat_new)
e(<function tensorflow.python.ops.gen_math_ops.add(x, y, name=None)>, 1, 1)
pprint(add_pat_new.eval_obj)
<tf.Tensor: shape=(), dtype=int32, numpy=2>
Working with S-expressions is much like manipulating a subset of Python AST, so,
when using etuple
s, one is–in effect–meta programming (e.g. by automating the production and evaluation of
TensorFlow-using Python code).
As a matter of fact, etuple
s could be recast
as ast.Expr
and ast.Call
objects that, through the use of eval
, could achieve
the same results–albeit without the more convenient tuple-like structuring.
Meta Operators and their Parameters¶
In etuplize-graph-print, the etuple
form of
our matrix multiplication graph, z_sexp
, produced
symbolic_pymc.tensorflow.meta.TFlowMetaOperator
in the function/operator position. print-etuple-operator prints
only the function part of the etuple
.
pprint(z_sexp[0])
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(MatMul),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'MatMul',
'MatMul',
{'T': 'float32', 'transpose_a': False, 'transpose_b': False}))
A TFlowMetaOperator
is an abstraction that combines the
TF OpDef
and NodeDef
that,
when paired with operator arguments, comprises a valid
TF Operation
.
When we call mt.add
we’re imitating the TF user-level API function
tf.add
. Behind the scenes, tf.add
obtains
the OpDef
, creates the NodeDef
and
produces an Operation
. Since we can’t directly use helper functions like
tf.add
with our logic variables, the meta objects have to recreate
the same process and that’s what TFlowMetaOperator
does.
More importantly, it does so in a way that allows for some intercession so that logic variables
can be used. For instance, TF Operation
s are necessarily assigned unique
names, so, if we wanted to match graphs produced by tf.add
, we would
either need to know the explicit names of its Operation
s,
or use logic variables in their place. The NodeDef
holds the
name value, so we could set that property–or the
entire NodeDef
–to a logic variable and match any .
The same goes for extra options associated with an
Operation
’s OpDef
. Notice that the
NodeDef
in the meta operator for tf.matmul
has a dict
containing transpose_*
entries.
These are the default values for the TF function tf.matmul
(see
print-tf-matmul).
pprint(tf.matmul)
<function tensorflow.python.ops.math_ops.matmul(a, b, transpose_a=False, transpose_b=False, adjoint_a=False, adjoint_b=False, a_is_sparse=False, b_is_sparse=False, name=None)>
Meta operators make it easier to set an entire NodeDef
to a logic variable so that one can find graphs based only on the high-level
operations they perform (e.g. multiplication). Furthermore, it separates the
high-level operator’s arguments from its parameters. Take the matrix
multiplication above; at the mathematical level, matrix multiplication only
takes the objects it’s multiplying as arguments, and not any “transpose”
parameters.
When we want to make general statements about the properties of a mathematical operator, this confusion of arguments and parameters only requires more work to separate them. Let’s say we wanted to programmatically state that addition is commutative, so that our matching process could consider any order of arguments. If we followed TensorFlow’s convention, we would–at minimum–need to include special logic to determine which arguments are applicable.
We’ll see examples of TFlowMetaOperator
’s use in the
sections that follow.
Unification and Reification¶
With the ability to use logic variables and TensorFlow graphs together, we can now “search” or “match” arbitrary graphs using unification and produce new graphs by replacing logic variables using reification.
We start by making “patterns” or templates for the subgraphs we would like to match. Patterns, in this case, take the form of meta graphs or S-expr graphs with the desired structure and logic variables in place of “unknown” or arbitrary terms that we might like to reference elsewhere.
matmul-pattern represents an S-expr that evaluates to a graph in which two terms are matrix-multiplied.
from symbolic_pymc.tensorflow.meta import TFlowMetaOperator
A_lv, B_lv = var('A'), var('B')
node_def_lv = var('node_def')
matmul_op_mt = TFlowMetaOperator('matmul', node_def_lv)
matmul_pat_mt = matmul_op_mt(A_lv, B_lv)
matmul_pat = etuplize(matmul_pat_mt)
In matmul-pattern we created a meta
graph, matmul_pat_mt
, from a meta
TF MatMul
operator and a
variable NodeDef
, then we applied that meta operator to
two logic variable arguments.
The logic variable node_def_lv
is there to match the parameters
to tf.matmul
:
e.g.
transpose_a
, transpose_b
, and
the name parameter.
Again, by setting the NodeDef
in our meta operator to a
to logic variable, we are allowing unification with any matrix multiplication
(e.g. not just ones named "blah"
, or ones with
transposed second arguments).
“Matching” a graph against our pattern is actually called unification.
Unification of two graphs implies unification of all sub-graphs and elements
between them. When unification is successful, it returns a map of logic
variables and their unified values. If there are no logic variables in the
graphs–it simply returns an empty map. If unification fails, it
returns False
–at least in the implementation we use, but not
necessarily in general.
Unification¶
We can perform the unification using the function unify
. The result
is a dict
mapping logic variables to their unified values.
s = unify(matmul_pat, z_sexp, {})
pprint(s)
{~node_def: e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'MatMul',
'MatMul',
{'T': 'float32', 'transpose_a': False, 'transpose_b': False}),
~A: e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'A',
{'dtype': 'float32',
'shape': TFlowMetaTensorShape(dims=(None, None))}))),
~B: e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(AddV2),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'AddV2',
'add',
{'T': 'float32'})),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'x',
{'dtype': 'float32',
'shape': TFlowMetaTensorShape(dims=(None, 1))}))),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'y',
{'shape': TFlowMetaTensorShape(dims=(None, 1)),
'dtype': 'float32'}))))}
Reification¶
Using reify
, we can “fill-in”–or replace–the logic variables of
our “pattern” with the matches obtained by unify
that are held
within the variable s
, or we could specify our own substitutions
based on that information.
In matmul-pattern-reify, we simply change the 'name'
value in the
and create a new graph with that value. The end result is a version of the original
graph, z_sexp
, with a new name.
s[var('node_def')] = s[var('node_def')][:2] + ("a_new_name",) + s[var('node_def')][3:]
z_sexp_re = reify(matmul_pat, s)
pprint(z_sexp_re)
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(MatMul),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'MatMul',
'a_new_name',
{'T': 'float32', 'transpose_a': False, 'transpose_b': False})),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'A',
{'dtype': 'float32',
'shape': TFlowMetaTensorShape(dims=(None, None))}))),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(AddV2),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'AddV2',
'add',
{'T': 'float32'})),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'x',
{'dtype': 'float32',
'shape': TFlowMetaTensorShape(dims=(None, 1))}))),
e(
e(
symbolic_pymc.tensorflow.meta.TFlowMetaOperator,
TFlowMetaOpDef(Placeholder),
e(
symbolic_pymc.tensorflow.meta.TFlowMetaNodeDef,
'Placeholder',
'y',
{'shape': TFlowMetaTensorShape(dims=(None, 1)),
'dtype': 'float32'})))))
Finishing our Implementation¶
We can also reify an entirely different graph using the values extracted from
the graph z_sexp
. In this case, we create an “output”
pattern graph, to complement our “input” pattern
graph, matmul_pat
.
If we combine our matrix multiplication and
addition etuple
patterns, we can extract all the
arguments needed as input to a distributed multiplication pattern.
add_op_mt = TFlowMetaOperator('addv2', var('add_node_def'))
output_pat_mt = add_op_mt(matmul_op_mt(A_lv, x_lv), matmul_op_mt(A_lv, y_lv))
output_pat = etuplize(output_pat_mt)
With logic
variables A_lv
, x_lv
and y_lv
mapped to their template-corresponding objects
in another graph, we can reify output_pat
and obtain a
“transformed” version of said graph.
Using our earlier unification results in matmul-pattern-unify, we only
need to reify our output pattern, output_pat
, with
those mappings. However, since our output pattern refers to logic variables
x_lv
and y_lv
, we’ll need
to unify those logic variables with the appropriate terms in the graph.
dist-add-unify, unifies the remaining terms by simply extracting the
B
argument in the matrix multiply and unifying
that with a pattern for tensor addition.
add_pat = etuple(etuplize(add_op_mt), x_lv, y_lv)
s_add = unify(s[B_lv], add_pat, s)
z_new = reify(output_pat, s_add)
tf_dprint(z_new.eval_obj)
Tensor(AddV2):0, dtype=float32, shape=~_11, "add:0"
| Tensor(MatMul):0, dtype=float32, shape=~_12, "a_new_name:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| Tensor(MatMul):0, dtype=float32, shape=~_13, "a_new_name:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
As we’ve seen, using only the basics of unification and reification provided by Symbolic PyMC, one can extract specific elements from TensorFlow graphs and use them to implement mathematical identities/relations. Through clever use of multiple mathematical relations, one can–for example–construct graph optimizations that turn large classes of user-defined statistical models into computational tractable reformulations. Similarly, one can construct “normal forms” for models, making it possible to determine whether or not a user-defined model is suitable for a specific sampler.
Next, we’ll introduce another major element of Symbolic PyMC that orchestrates and simplifies sequences of unifications like we used earlier, provides control-flow-like capabilities, produces fully reified results of arbitrary form, and does so within a genuinely declarative formalism that carries much of the same power as logic programming: miniKanren!
Relational Programming in miniKanren¶
As mentioned at the end of the last section, Symbolic PyMC uses a Python
implementation of the embedded domain-specific language miniKanren–provided by
the kanren
package–to orchestrate more sophisticated uses of
unification and reification. For a quick intro, see the basic introduction
provided by the kanren
package. We’ll cover most of the same
basic material here, but not all.
To start, miniKanren uses goals (in the same sense as logic programming) to
assert relations, and the run
function evaluates those goals and
allows one to specify the exact amount and type of reified output desired from
the states that satisfy the goals.
In their most basic form, miniKanren states are simply the substitution maps returned by unification, which–in the normal course of operation–aren’t dealt with directly.
The Basic Goals¶
Normally, a user will only need to construct compound goals from a basic set of
primitives. Arguably, the most primitive goal is the equivalence relation
under unification denoted by eq
in Python.
In mk-basics-eq, we ask for all successful results/reifications (signified
by the 0
argument) of the logic variable var('q')
for the goal
eq(var('q'), 1)
–i.e. unify var('q')
with 1
.
from kanren import run, eq
q_lv = var('q')
mk_res = run(0, q_lv, eq(q_lv, 1))
pprint(mk_res)
(1,)
Since miniKanren’s run
always returns a stream of results, we obtain
a tuple containing the reified value of q_lv
under the one
possible state for which our stated goal successfully evaluates.
The other basic primitives represent conjunction and disjunction of miniKanren
goals: lall
and lany
, respectively.
from kanren import lall, lany
mk_res = run(0, q_lv, lall(eq(q_lv, 1), eq(q_lv, 2)))
pprint(mk_res)
()
In mk-basics-lall, we used lall
to obtain the conjunction of two unification goals.
Since we requested that the same logic variable be unified
with both 1
and 2
simultaneously (i.e. in the same
state), which isn’t possible, we got back an empty stream of results–indicating failure.
Goal disjunction, lany
, will split a state stream across goals,
producing new distinct states for each.
mk_res = run(0, q_lv, lany(eq(q_lv, 1), eq(q_lv, 2)))
pprint(mk_res)
(1, 2)
The goal disjunction results in mk-basics-lany-print show that the logic variable
q_lv
can be unified with either 1
or 2
under the
two unification goals.
A common pattern of disjunction and conjunction is called conde
, and
it mirrors the Lisp function cond
, which is effectively a type of
compound if ... elif ... elif ...
. Specifically,
conde([x_1, ...], ..., [y_1, ...])
is the same as
lany(lall(x_1, ...), ..., lall(y_1, ...))
–i.e. a disjunction of goal conjunctions.
from kanren import conde
r_lv = var('r')
mk_res = run(0, [q_lv, r_lv],
conde(
[eq(q_lv, 1), eq(r_lv, 10)],
[eq(q_lv, 2), eq(r_lv, 20)],
))
pprint(mk_res)
([1, 10], [2, 20])
In mk-basics-conde, we introduced another logic
variable, r_lv
, and requested the reified values of a list
containing both logic variables. The output resembles the idea that
if q_lv
is “equal” to 1
, then r_lv
is “equal”
to 10
, etc. Unlike normal conditionals, each clause/branch isn’t
exclusive, instead each is realized when the goals in a branch can be successful.
mk-basics-conde-exclusive, demonstrates when conde
can behave more
like a traditional conditional statement.
mk_res = run(0, [q_lv, r_lv],
lall(eq(q_lv, 1),
conde(
[eq(q_lv, 1), eq(r_lv, 10)],
[eq(q_lv, 2), eq(r_lv, 20)],
)))
pprint(mk_res)
([1, 10],)
A Better Implementation¶
Since miniKanren uses unification and reification, we can apply its basic goals to TensorFlow graphs, as we did earlier, and reproduce the entire implementation in a much more concise manner.
mk_res = run(1, output_pat,
eq(matmul_pat, z_sexp),
eq(add_pat, B_lv))
tf_dprint(mk_res[0].eval_obj)
Tensor(AddV2):0, dtype=float32, shape=~_14, "add:0"
| Tensor(MatMul):0, dtype=float32, shape=~_15, "MatMul:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| Tensor(MatMul):0, dtype=float32, shape=~_16, "MatMul:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
We didn’t need to use the goal conjunction operator lall
explicitly
in mk-distribute, because all remaining goal arguments
to run
are automatically applied in conjunction.
When combinations of miniKanren goals comprise logical units, we can wrap their construction in a functions which we call goal constructors.
Goal Constructors¶
Using our distributive law example, we can create a goal constructor that
creates our combined pattern and applies it in one go. In this case, we’ll
construct goals that operate on meta graphs instead
of etuple
s.
def distributeo(in_g, out_g):
"""Create a goal that represents commuted matrix multiplication and addition.
Specifically, A * (x + y) == A * x + A * y
"""
matmul_op_mt = TFlowMetaOperator('matmul', var())
add_op_mt = TFlowMetaOperator('addv2', var())
A_lv, x_lv, y_lv = var(), var(), var()
mul_pat_mt = matmul_op_mt(A_lv, add_op_mt(x_lv, y_lv))
dist_pat_mt = mt.addv2(mt.matmul(A_lv, x_lv), mt.matmul(A_lv, y_lv))
return lall(eq(in_g, mul_pat_mt),
eq(out_g, dist_pat_mt))
Our goal constructor represents the relation for distribution of matrix multiplication and addition. In this sense, it can be run both ways: i.e. it can “expand” a multiplication by distributing it through addition, and it can “contract” by doing the opposite.
In mk-dist-goal-expand-distribute we “expand” the distribution.
q_lv = var()
mk_res = run(1, q_lv, distributeo(z_mt, q_lv))
z_expanded_mt = mk_res[0]
tf_dprint(z_expanded_mt)
Tensor(AddV2):0, dtype=~_27, shape=~_28, "AddV2:0"
| Tensor(MatMul):0, dtype=~_25, shape=~_29, "MatMul:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| Tensor(MatMul):0, dtype=~_26, shape=~_30, "MatMul:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
Now, in mk-dist-goal-contract-distribute we “contract” the graph using the previously “expanded” results.
q_lv = var()
mk_res = run(1, q_lv, distributeo(q_lv, z_expanded_mt))
z_contracted_mt = mk_res[0]
tf_dprint(z_contracted_mt)
Tensor(MatMul):0, dtype=~_38, shape=~_42, "~_44"
| Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| Tensor(AddV2):0, dtype=~_37, shape=~_45, "~_47"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
Graph-based Goals¶
In most situations, one won’t be operating on the exact graph they want to match. Instead, the desired graphs will be subgraphs of much larger ones.
Symbolic PyMC introduces some miniKanren goals that apply other goals throughout graphs until a fixed-point is reached. This sequence of operations is generally necessary for graph simplification and rewriting.
In mk-dist-goal-gapply-distribute we create a new graph that
contains tf.matmul(A, x + y)
as a subgraph.
Using graph_applyo
,
our distributeo
relation is applied all throughout the
graph until the applicable subgraph is found (and replaced).
from kanren.graph import walko
with graph_mode():
z_graph_mt = (np.array(2.0, dtype='float32') *
mt.matmul(mt(A_tf), mt(x_tf) + mt(y_tf)) +
np.array(1.0, dtype='float32'))
tf_dprint(z_graph_mt)
Tensor(AddV2):0, dtype=float32, shape=[None, 1], "add_2:0"
| Tensor(Mul):0, dtype=float32, shape=[None, 1], "mul:0"
| | Tensor(Const):0, dtype=float32, shape=[], "mul/x:0"
| | | 2.
| | Tensor(MatMul):0, dtype=float32, shape=[None, 1], "MatMul_1:0"
| | | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | | Tensor(AddV2):0, dtype=float32, shape=[None, 1], "add_1:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
| Tensor(Const):0, dtype=float32, shape=[], "add_2/y:0"
| | 1.
with graph_mode():
q_lv = var()
mk_res = run(1, q_lv, graph_applyo(distributeo, z_graph_mt, q_lv))
z_graph_expanded_mt = mk_res[0].eval_obj
tf_dprint(z_graph_expanded_mt)
Tensor(AddV2):0, dtype=float32, shape=~_197, "add_2:0"
| Tensor(Mul):0, dtype=float32, shape=~_198, "mul:0"
| | Tensor(Const):0, dtype=float32, shape=[], "mul/x:0"
| | | 2.
| | Tensor(AddV2):0, dtype=~_156, shape=~_199, "AddV2:0"
| | | Tensor(MatMul):0, dtype=~_154, shape=~_200, "MatMul:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "x:0"
| | | Tensor(MatMul):0, dtype=~_155, shape=~_201, "MatMul:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None, None], "A:0"
| | | | Tensor(Placeholder):0, dtype=float32, shape=[None, 1], "y:0"
| Tensor(Const):0, dtype=float32, shape=[], "add_2/y:0"
| | 1.
The first result from graph_applyo
is the graph with
all applications of distributeo
applied. The other
goal results are all the successful applications leading up to the first one.
In other words, we’re given the entire sequence of all possible applications of
distributeo
throughout the graph.
Since run
computes results lazily, we don’t have to
compute all those graphs unless we actually request them.
Discussion¶
As the development of Symbolic PyMC goes on, the process of using the
above elements will become easier and computationally more efficient.
Much of the boilerplate work can be removed without affecting the extensibility
of Symbolic PyMC and kanren
.
For instance, the need to manually replace NodeDef
s
with logic variables can be handled by context managers
like enable_lvar_defaults
, or by updates to the
defaults of meta object creation.
Likewise, there are tools available in Symbolic PyMC that make it easier to
determine which components are unequal between two meta objects
(e.g. symbolic_pymc.utils.meta_parts_unequal
).
Symbolic PyMC’s library of relevant mathematical and statistical relations is intended to evolve over time. These relations will reflect useful properties for the reformulation of statistical models into computationally more efficient equivalent forms–and conditional on, or used to determine, explicit estimation procedures in PyMC.