# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from copy import copy
import aesara
import numpy as np
from aesara import function as aesara_function
from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc_experimental.bart.bart import BARTRV
from pymc_experimental.bart.tree import LeafNode, SplitNode, Tree
from pymc.model import modelcontext
from pymc.step_methods.arraystep import ArrayStepShared, Competence
_log = logging.getLogger("pymc")
[docs]class PGBART(ArrayStepShared):
"""
Particle Gibss BART sampling step.
Parameters
----------
vars: list
List of value variables for sampler
num_particles : int
Number of particles for the conditional SMC sampler. Defaults to 40
max_stages : int
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
batch : int or tuple
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and after tuning. If a tuple is passed the first element is the batch size
during tuning and the second the batch size after tuning.
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context).
"""
name = "pgbart"
default_blocked = False
generates_stats = True
stats_dtypes = [{"variable_inclusion": object, "bart_trees": object}]
def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", model=None):
model = modelcontext(model)
initial_values = model.compute_initial_point()
if vars is None:
vars = model.value_vars
else:
vars = [model.rvs_to_values.get(var, var) for var in vars]
vars = inputvars(vars)
value_bart = vars[0]
self.bart = model.values_to_rvs[value_bart].owner.op
self.X = self.bart.X
self.Y = self.bart.Y
self.missing_data = np.any(np.isnan(self.X))
self.m = self.bart.m
self.alpha = self.bart.alpha
self.k = self.bart.k
self.alpha_vec = self.bart.split_prior
if self.alpha_vec is None:
self.alpha_vec = np.ones(self.X.shape[1])
self.init_mean = self.Y.mean()
# if data is binary
Y_unique = np.unique(self.Y)
if Y_unique.size == 2 and np.all(Y_unique == [0, 1]):
self.mu_std = 6 / (self.k * self.m**0.5)
# maybe we need to check for count data
else:
self.mu_std = (2 * self.Y.std()) / (self.k * self.m**0.5)
self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.sum_trees = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX)
self.a_tree = Tree.init_tree(
leaf_node_value=self.init_mean / self.m,
idx_data_points=np.arange(self.num_observations, dtype="int32"),
)
self.mean = fast_mean()
self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.ssv = SampleSplittingVariable(self.alpha_vec)
self.tune = True
if batch == "auto":
batch = max(1, int(self.m * 0.1))
self.batch = (batch, batch)
else:
if isinstance(batch, (tuple, list)):
self.batch = batch
else:
self.batch = (batch, batch)
self.log_num_particles = np.log(num_particles)
self.indices = list(range(2, num_particles))
self.len_indices = len(self.indices)
self.max_stages = max_stages
shared = make_shared_replacements(initial_values, vars, model)
self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared)
self.all_particles = []
for i in range(self.m):
self.a_tree.leaf_node_value = self.init_mean / self.m
p = ParticleTree(self.a_tree)
self.all_particles.append(p)
self.all_trees = np.array([p.tree for p in self.all_particles])
super().__init__(vars, shared)
def astep(self, _):
variable_inclusion = np.zeros(self.num_variates, dtype="int")
tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune])
for tree_id in tree_ids:
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree_id)
# Compute the sum of trees without the old tree, that we are attempting to replace
self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output()
# Resample leaf values for particle 1 which is a copy of the old tree
particles[1].sample_leafs(
self.sum_trees,
self.X,
self.mean,
self.m,
self.normal,
self.mu_std,
)
# The old tree and the one with new leafs do not grow so we update the weights only once
self.update_weight(particles[0], old=True)
self.update_weight(particles[1], old=True)
for _ in range(self.max_stages):
# Sample each particle (try to grow each tree), except for the first two
stop_growing = True
for p in particles[2:]:
tree_grew = p.sample_tree(
self.ssv,
self.available_predictors,
self.prior_prob_leaf_node,
self.X,
self.missing_data,
self.sum_trees,
self.mean,
self.m,
self.normal,
self.mu_std,
)
if tree_grew:
self.update_weight(p)
if p.expansion_nodes:
stop_growing = False
if stop_growing:
break
# Normalize weights
W_t, normalized_weights = self.normalize(particles[2:])
# Resample all but first two particles
new_indices = np.random.choice(
self.indices, size=self.len_indices, p=normalized_weights
)
particles[2:] = particles[new_indices]
# Set the new weights
for p in particles[2:]:
p.log_weight = W_t
for p in particles[2:]:
p.log_weight = p.old_likelihood_logp
_, normalized_weights = self.normalize(particles)
# Get the new tree and update
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
self.all_trees[tree_id] = new_tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree_id] = new_particle
self.sum_trees = self.sum_trees_noi + new_tree.predict_output()
if self.tune:
self.ssv = SampleSplittingVariable(self.alpha_vec)
for index in new_particle.used_variates:
self.alpha_vec[index] += 1
else:
for index in new_particle.used_variates:
variable_inclusion[index] += 1
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
return self.sum_trees, [stats]
[docs] def normalize(self, particles):
"""Use logsumexp trick to get W_t and softmax to get normalized_weights."""
log_w = np.array([p.log_weight for p in particles])
log_w_max = log_w.max()
log_w_ = log_w - log_w_max
w_ = np.exp(log_w_)
w_sum = w_.sum()
W_t = log_w_max + np.log(w_sum) - self.log_num_particles
normalized_weights = w_ / w_sum
# stabilize weights to avoid assigning exactly zero probability to a particle
normalized_weights += 1e-12
return W_t, normalized_weights
[docs] def init_particles(self, tree_id: int) -> np.ndarray:
"""Initialize particles."""
p = self.all_particles[tree_id]
particles = [p]
particles.append(copy(p))
for _ in self.indices:
particles.append(ParticleTree(self.a_tree))
return np.array(particles)
[docs] def update_weight(self, particle, old=False):
"""
Update the weight of a particle.
Since the prior is used as the proposal,the weights are updated additively as the ratio of
the new and old log-likelihoods.
"""
new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree.predict_output())
if old:
particle.log_weight = new_likelihood
particle.old_likelihood_logp = new_likelihood
else:
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood
[docs] @staticmethod
def competence(var, has_grad):
"""PGBART is only suitable for BART distributions."""
dist = getattr(var.owner, "op", None)
if isinstance(dist, BARTRV):
return Competence.IDEAL
return Competence.INCOMPATIBLE
class ParticleTree:
"""Particle tree."""
def __init__(self, tree):
self.tree = tree.copy() # keeps the tree that we care at the moment
self.expansion_nodes = [0]
self.log_weight = 0
self.old_likelihood_logp = 0
self.used_variates = []
def sample_tree(
self,
ssv,
available_predictors,
prior_prob_leaf_node,
X,
missing_data,
sum_trees,
mean,
m,
normal,
mu_std,
):
tree_grew = False
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
if prob_leaf < np.random.random():
index_selected_predictor = grow_tree(
self.tree,
index_leaf_node,
ssv,
available_predictors,
X,
missing_data,
sum_trees,
mean,
m,
normal,
mu_std,
)
if index_selected_predictor is not None:
new_indexes = self.tree.idx_leaf_nodes[-2:]
self.expansion_nodes.extend(new_indexes)
self.used_variates.append(index_selected_predictor)
tree_grew = True
return tree_grew
def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std):
sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std)
class SampleSplittingVariable:
def __init__(self, alpha_vec):
"""
Sample splitting variables proportional to `alpha_vec`.
This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
This enforce sparsity.
"""
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))
def rvs(self):
r = np.random.random()
for i, v in self.enu:
if r <= v:
return i
def compute_prior_probability(alpha):
"""
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 19 in [Rockova2018].
Parameters
----------
alpha : float
Returns
-------
list with probabilities for leaf nodes
References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
"""
prior_leaf_prob = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha**depth)
depth += 1
return prior_leaf_prob
def grow_tree(
tree,
index_leaf_node,
ssv,
available_predictors,
X,
missing_data,
sum_trees,
mean,
m,
normal,
mu_std,
):
current_node = tree.get_node(index_leaf_node)
idx_data_points = current_node.idx_data_points
index_selected_predictor = ssv.rvs()
selected_predictor = available_predictors[index_selected_predictor]
available_splitting_values = X[idx_data_points, selected_predictor]
if missing_data:
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
available_splitting_values = available_splitting_values[
~np.isnan(available_splitting_values)
]
if available_splitting_values.size > 0:
idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
split_value = available_splitting_values[idx_selected_splitting_values]
new_idx_data_points = get_new_idx_data_points(
split_value, idx_data_points, selected_predictor, X
)
current_node_children = (
current_node.get_idx_left_child(),
current_node.get_idx_right_child(),
)
new_nodes = []
for idx in range(2):
idx_data_point = new_idx_data_points[idx]
node_value = draw_leaf_value(
sum_trees[idx_data_point],
X[idx_data_point, selected_predictor],
mean,
m,
normal,
mu_std,
)
new_node = LeafNode(
index=current_node_children[idx],
value=node_value,
idx_data_points=idx_data_point,
)
new_nodes.append(new_node)
new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=split_value,
)
# update tree nodes and indexes
tree.delete_node(index_leaf_node)
tree.set_node(index_leaf_node, new_split_node)
tree.set_node(new_nodes[0].index, new_nodes[0])
tree.set_node(new_nodes[1].index, new_nodes[1])
return index_selected_predictor
def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std):
for idx in tree.idx_leaf_nodes:
if idx > 0:
leaf = tree[idx]
idx_data_points = leaf.idx_data_points
parent_node = tree[leaf.get_idx_parent_node()]
selected_predictor = parent_node.idx_split_variable
node_value = draw_leaf_value(
sum_trees[idx_data_points],
X[idx_data_points, selected_predictor],
mean,
m,
normal,
mu_std,
)
leaf.value = node_value
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):
left_idx = X[idx_data_points, selected_predictor] <= split_value
left_node_idx_data_points = idx_data_points[left_idx]
right_node_idx_data_points = idx_data_points[~left_idx]
return left_node_idx_data_points, right_node_idx_data_points
def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std):
"""Draw Gaussian distributed leaf values."""
if Y_mu_pred.size == 0:
return 0
else:
norm = normal.random() * mu_std
if Y_mu_pred.size == 1:
mu_mean = Y_mu_pred.item() / m
else:
mu_mean = mean(Y_mu_pred) / m
draw = norm + mu_mean
return draw
def fast_mean():
"""If available use Numba to speed up the computation of the mean."""
try:
from numba import jit
except ImportError:
return np.mean
@jit
def mean(a):
count = a.shape[0]
suma = 0
for i in range(count):
suma += a[i]
return suma / count
return mean
def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value).
This is the same and np.random.randit(upper_value) but faster.
"""
return int(np.random.random() * upper_value)
class NormalSampler:
"""Cache samples from a standard normal distribution."""
def __init__(self):
self.size = 1000
self.cache = []
def random(self):
if not self.cache:
self.update()
return self.cache.pop()
def update(self):
self.cache = np.random.normal(loc=0.0, scale=1, size=self.size).tolist()
def logp(point, out_vars, vars, shared):
"""Compile Aesara function of the model and the input and output variables.
Parameters
----------
out_vars: List
containing :class:`pymc.Distribution` for the output variables
vars: List
containing :class:`pymc.Distribution` for the input variables
shared: List
containing :class:`aesara.tensor.Tensor` for depended shared data
"""
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
f = aesara_function([inarray0], out_list[0])
f.trust_input = True
return f