Source code for pymc_experimental.bart.pgbart

#   Copyright 2020 The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import logging

from copy import copy

import aesara
import numpy as np

from aesara import function as aesara_function

from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
from pymc_experimental.bart.bart import BARTRV
from pymc_experimental.bart.tree import LeafNode, SplitNode, Tree
from pymc.model import modelcontext
from pymc.step_methods.arraystep import ArrayStepShared, Competence

_log = logging.getLogger("pymc")


[docs]class PGBART(ArrayStepShared): """ Particle Gibss BART sampling step. Parameters ---------- vars: list List of value variables for sampler num_particles : int Number of particles for the conditional SMC sampler. Defaults to 40 max_stages : int Maximum number of iterations of the conditional SMC sampler. Defaults to 100. batch : int or tuple Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees during tuning and after tuning. If a tuple is passed the first element is the batch size during tuning and the second the batch size after tuning. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). """ name = "pgbart" default_blocked = False generates_stats = True stats_dtypes = [{"variable_inclusion": object, "bart_trees": object}] def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", model=None): model = modelcontext(model) initial_values = model.compute_initial_point() if vars is None: vars = model.value_vars else: vars = [model.rvs_to_values.get(var, var) for var in vars] vars = inputvars(vars) value_bart = vars[0] self.bart = model.values_to_rvs[value_bart].owner.op self.X = self.bart.X self.Y = self.bart.Y self.missing_data = np.any(np.isnan(self.X)) self.m = self.bart.m self.alpha = self.bart.alpha self.k = self.bart.k self.alpha_vec = self.bart.split_prior if self.alpha_vec is None: self.alpha_vec = np.ones(self.X.shape[1]) self.init_mean = self.Y.mean() # if data is binary Y_unique = np.unique(self.Y) if Y_unique.size == 2 and np.all(Y_unique == [0, 1]): self.mu_std = 6 / (self.k * self.m**0.5) # maybe we need to check for count data else: self.mu_std = (2 * self.Y.std()) / (self.k * self.m**0.5) self.num_observations = self.X.shape[0] self.num_variates = self.X.shape[1] self.available_predictors = list(range(self.num_variates)) self.sum_trees = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX) self.a_tree = Tree.init_tree( leaf_node_value=self.init_mean / self.m, idx_data_points=np.arange(self.num_observations, dtype="int32"), ) self.mean = fast_mean() self.normal = NormalSampler() self.prior_prob_leaf_node = compute_prior_probability(self.alpha) self.ssv = SampleSplittingVariable(self.alpha_vec) self.tune = True if batch == "auto": batch = max(1, int(self.m * 0.1)) self.batch = (batch, batch) else: if isinstance(batch, (tuple, list)): self.batch = batch else: self.batch = (batch, batch) self.log_num_particles = np.log(num_particles) self.indices = list(range(2, num_particles)) self.len_indices = len(self.indices) self.max_stages = max_stages shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared) self.all_particles = [] for i in range(self.m): self.a_tree.leaf_node_value = self.init_mean / self.m p = ParticleTree(self.a_tree) self.all_particles.append(p) self.all_trees = np.array([p.tree for p in self.all_particles]) super().__init__(vars, shared) def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune]) for tree_id in tree_ids: # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree particles = self.init_particles(tree_id) # Compute the sum of trees without the old tree, that we are attempting to replace self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output() # Resample leaf values for particle 1 which is a copy of the old tree particles[1].sample_leafs( self.sum_trees, self.X, self.mean, self.m, self.normal, self.mu_std, ) # The old tree and the one with new leafs do not grow so we update the weights only once self.update_weight(particles[0], old=True) self.update_weight(particles[1], old=True) for _ in range(self.max_stages): # Sample each particle (try to grow each tree), except for the first two stop_growing = True for p in particles[2:]: tree_grew = p.sample_tree( self.ssv, self.available_predictors, self.prior_prob_leaf_node, self.X, self.missing_data, self.sum_trees, self.mean, self.m, self.normal, self.mu_std, ) if tree_grew: self.update_weight(p) if p.expansion_nodes: stop_growing = False if stop_growing: break # Normalize weights W_t, normalized_weights = self.normalize(particles[2:]) # Resample all but first two particles new_indices = np.random.choice( self.indices, size=self.len_indices, p=normalized_weights ) particles[2:] = particles[new_indices] # Set the new weights for p in particles[2:]: p.log_weight = W_t for p in particles[2:]: p.log_weight = p.old_likelihood_logp _, normalized_weights = self.normalize(particles) # Get the new tree and update new_particle = np.random.choice(particles, p=normalized_weights) new_tree = new_particle.tree self.all_trees[tree_id] = new_tree new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles self.all_particles[tree_id] = new_particle self.sum_trees = self.sum_trees_noi + new_tree.predict_output() if self.tune: self.ssv = SampleSplittingVariable(self.alpha_vec) for index in new_particle.used_variates: self.alpha_vec[index] += 1 else: for index in new_particle.used_variates: variable_inclusion[index] += 1 stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} return self.sum_trees, [stats]
[docs] def normalize(self, particles): """Use logsumexp trick to get W_t and softmax to get normalized_weights.""" log_w = np.array([p.log_weight for p in particles]) log_w_max = log_w.max() log_w_ = log_w - log_w_max w_ = np.exp(log_w_) w_sum = w_.sum() W_t = log_w_max + np.log(w_sum) - self.log_num_particles normalized_weights = w_ / w_sum # stabilize weights to avoid assigning exactly zero probability to a particle normalized_weights += 1e-12 return W_t, normalized_weights
[docs] def init_particles(self, tree_id: int) -> np.ndarray: """Initialize particles.""" p = self.all_particles[tree_id] particles = [p] particles.append(copy(p)) for _ in self.indices: particles.append(ParticleTree(self.a_tree)) return np.array(particles)
[docs] def update_weight(self, particle, old=False): """ Update the weight of a particle. Since the prior is used as the proposal,the weights are updated additively as the ratio of the new and old log-likelihoods. """ new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree.predict_output()) if old: particle.log_weight = new_likelihood particle.old_likelihood_logp = new_likelihood else: particle.log_weight += new_likelihood - particle.old_likelihood_logp particle.old_likelihood_logp = new_likelihood
[docs] @staticmethod def competence(var, has_grad): """PGBART is only suitable for BART distributions.""" dist = getattr(var.owner, "op", None) if isinstance(dist, BARTRV): return Competence.IDEAL return Competence.INCOMPATIBLE
class ParticleTree: """Particle tree.""" def __init__(self, tree): self.tree = tree.copy() # keeps the tree that we care at the moment self.expansion_nodes = [0] self.log_weight = 0 self.old_likelihood_logp = 0 self.used_variates = [] def sample_tree( self, ssv, available_predictors, prior_prob_leaf_node, X, missing_data, sum_trees, mean, m, normal, mu_std, ): tree_grew = False if self.expansion_nodes: index_leaf_node = self.expansion_nodes.pop(0) # Probability that this node will remain a leaf node prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] if prob_leaf < np.random.random(): index_selected_predictor = grow_tree( self.tree, index_leaf_node, ssv, available_predictors, X, missing_data, sum_trees, mean, m, normal, mu_std, ) if index_selected_predictor is not None: new_indexes = self.tree.idx_leaf_nodes[-2:] self.expansion_nodes.extend(new_indexes) self.used_variates.append(index_selected_predictor) tree_grew = True return tree_grew def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std): sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std) class SampleSplittingVariable: def __init__(self, alpha_vec): """ Sample splitting variables proportional to `alpha_vec`. This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model. This enforce sparsity. """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) def rvs(self): r = np.random.random() for i, v in self.enu: if r <= v: return i def compute_prior_probability(alpha): """ Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). Taken from equation 19 in [Rockova2018]. Parameters ---------- alpha : float Returns ------- list with probabilities for leaf nodes References ---------- .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. arXiv, `link <https://arxiv.org/abs/1810.00787>`__ """ prior_leaf_prob = [0] depth = 1 while prior_leaf_prob[-1] < 1: prior_leaf_prob.append(1 - alpha**depth) depth += 1 return prior_leaf_prob def grow_tree( tree, index_leaf_node, ssv, available_predictors, X, missing_data, sum_trees, mean, m, normal, mu_std, ): current_node = tree.get_node(index_leaf_node) idx_data_points = current_node.idx_data_points index_selected_predictor = ssv.rvs() selected_predictor = available_predictors[index_selected_predictor] available_splitting_values = X[idx_data_points, selected_predictor] if missing_data: idx_data_points = idx_data_points[~np.isnan(available_splitting_values)] available_splitting_values = available_splitting_values[ ~np.isnan(available_splitting_values) ] if available_splitting_values.size > 0: idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) split_value = available_splitting_values[idx_selected_splitting_values] new_idx_data_points = get_new_idx_data_points( split_value, idx_data_points, selected_predictor, X ) current_node_children = ( current_node.get_idx_left_child(), current_node.get_idx_right_child(), ) new_nodes = [] for idx in range(2): idx_data_point = new_idx_data_points[idx] node_value = draw_leaf_value( sum_trees[idx_data_point], X[idx_data_point, selected_predictor], mean, m, normal, mu_std, ) new_node = LeafNode( index=current_node_children[idx], value=node_value, idx_data_points=idx_data_point, ) new_nodes.append(new_node) new_split_node = SplitNode( index=index_leaf_node, idx_split_variable=selected_predictor, split_value=split_value, ) # update tree nodes and indexes tree.delete_node(index_leaf_node) tree.set_node(index_leaf_node, new_split_node) tree.set_node(new_nodes[0].index, new_nodes[0]) tree.set_node(new_nodes[1].index, new_nodes[1]) return index_selected_predictor def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std): for idx in tree.idx_leaf_nodes: if idx > 0: leaf = tree[idx] idx_data_points = leaf.idx_data_points parent_node = tree[leaf.get_idx_parent_node()] selected_predictor = parent_node.idx_split_variable node_value = draw_leaf_value( sum_trees[idx_data_points], X[idx_data_points, selected_predictor], mean, m, normal, mu_std, ) leaf.value = node_value def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X): left_idx = X[idx_data_points, selected_predictor] <= split_value left_node_idx_data_points = idx_data_points[left_idx] right_node_idx_data_points = idx_data_points[~left_idx] return left_node_idx_data_points, right_node_idx_data_points def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std): """Draw Gaussian distributed leaf values.""" if Y_mu_pred.size == 0: return 0 else: norm = normal.random() * mu_std if Y_mu_pred.size == 1: mu_mean = Y_mu_pred.item() / m else: mu_mean = mean(Y_mu_pred) / m draw = norm + mu_mean return draw def fast_mean(): """If available use Numba to speed up the computation of the mean.""" try: from numba import jit except ImportError: return np.mean @jit def mean(a): count = a.shape[0] suma = 0 for i in range(count): suma += a[i] return suma / count return mean def discrete_uniform_sampler(upper_value): """Draw from the uniform distribution with bounds [0, upper_value). This is the same and np.random.randit(upper_value) but faster. """ return int(np.random.random() * upper_value) class NormalSampler: """Cache samples from a standard normal distribution.""" def __init__(self): self.size = 1000 self.cache = [] def random(self): if not self.cache: self.update() return self.cache.pop() def update(self): self.cache = np.random.normal(loc=0.0, scale=1, size=self.size).tolist() def logp(point, out_vars, vars, shared): """Compile Aesara function of the model and the input and output variables. Parameters ---------- out_vars: List containing :class:`pymc.Distribution` for the output variables vars: List containing :class:`pymc.Distribution` for the input variables shared: List containing :class:`aesara.tensor.Tensor` for depended shared data """ out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared) f = aesara_function([inarray0], out_list[0]) f.trust_input = True return f