Compute Symbolic Closed-form Posteriors¶
- Author
Brandon T. Willard
- Date
2019-11-24
import numpy as np
import theano
import theano.tensor as tt
import pymc3 as pm
from functools import partial
from unification import var
from kanren import run
from kanren.graph import reduceo, walko
from symbolic_pymc.theano.printing import tt_pprint
from symbolic_pymc.theano.pymc3 import model_graph
from symbolic_pymc.relations.theano.conjugates import conjugate
theano.config.cxx = ''
theano.config.compute_test_value = 'ignore'
a_tt = tt.vector('a')
R_tt = tt.matrix('R')
F_t_tt = tt.matrix('F')
V_tt = tt.matrix('V')
a_tt.tag.test_value = np.r_[1., 0.]
R_tt.tag.test_value = np.diag([10., 10.])
F_t_tt.tag.test_value = np.c_[-2., 1.]
V_tt.tag.test_value = np.diag([0.5])
y_tt = tt.as_tensor_variable(np.r_[-3.])
y_tt.name = 'y'
with pm.Model() as model:
# A normal prior
beta_rv = pm.MvNormal('beta', a_tt, R_tt, shape=(2,))
# An observed random variable using the prior as a regression parameter
E_y_rv = F_t_tt.dot(beta_rv)
Y_rv = pm.MvNormal('Y', E_y_rv, V_tt, observed=y_tt)
# Create a graph for the model
fgraph = model_graph(model, output_vars=[Y_rv])
def conjugate_graph(graph):
"""Apply conjugate relations throughout a graph."""
def fixedp_conjugate_walko(x, y):
return reduceo(partial(walko, conjugate), x, y)
expr_graph, = run(1, var('q'),
fixedp_conjugate_walko(graph, var('q')))
fgraph_opt = expr_graph.eval_obj
fgraph_opt_tt = fgraph_opt.reify()
return fgraph_opt_tt
fgraph_conj = conjugate_graph(fgraph.outputs[0])
Before¶
>>> print(tt_pprint(fgraph))
F in R**(N^F_0 x N^F_1), a in R**(N^a_0), R in R**(N^R_0 x N^R_1)
V in R**(N^V_0 x N^V_1)
beta ~ N(a, R) in R**(N^beta_0), Y ~ N((F * beta), V) in R**(N^Y_0)
Y = [-3.]
After¶
>>> print(tt_pprint(fgraph_conj))
a in R**(N^a_0), R in R**(N^R_0 x N^R_1), F in R**(N^F_0 x N^F_1)
c in R**(N^c_0 x N^c_1), d in R**(N^d_0 x N^d_1)
V in R**(N^V_0 x N^V_1), e in R**(N^e_0 x N^e_1)
b ~ N((a + (((R * F.T) * c) * ([-3.] - (F * a)))), (R - ((((R * F.T) * d) * (V + (F * (R * F.T)))) * ((R * F.T) * e).T))) in R**(N^b_0)
b