# Packages (apologies, this is a dependency heavy script!)
::p_load("mlbench","grf","cfcausal","caret",
pacman"reshape2","glmnet","tidyverse", "ggthemes",
"ggdist","ggpubr")
This post briefly demonstrates how to generate conformal uncertainty bands/intervals in R
employing the cfcausal
package wrapped around the local linear forest estimator from the grf
package. I don’t spend much time here explaining how conformal intervals are constructed, for that see (Lei et al. 2018) and (Samii 2019).
For the purpose of this demo, we’ll use a modified version of the data generating process (DGP) from (Friedman 1991), namely:
\[y=10 \sin (\pi x 1 x 2)+20(x 3-0.5)^2+10 x 4+5 x 5+e\]where \(e \sim N(0,1)\). In addition to the five covariates that are related to the outcome, there are also 25 noise covariates unrelated to the outcome (all \(\sim N(0,1)\)). This particular DGP has proved challenging for the Gaussian confidence intervals constructed using the standard local linear forest method (see the Monte Carlo simulations presented in (Friedberg et al. 2020), equation 7).
DGP simulated below:
# Set seed for reproducibility
set.seed(1995)
# Simulate the DGP
<- 20 #20 additional predictor vars
p <- 500 #sample size
n <- matrix(rnorm(n * p), n, p) #junk predictors
junk <- mlbench::mlbench.friedman1(n) #friedman MARs DGP...
data <- data.frame(y = data$y, x = data$x,junk) #...adding more junk
data #
# Split train/test set
<- caret::createDataPartition(data$y, p=0.7, list = FALSE) #70/30 split
trainIndex <- data[trainIndex,]
train <- data[-trainIndex,]
test #
# Convenience objects
<- train$y
Y <- test$y
Y.test <- train[,-1]
X <- test[,-1]
X.test #
Now we can move on to training the local linear forest, we’ll enable the local linear split feature and use cross-validated lasso to select the correction features.
# local linear regression forest with LL splits enabled
<- grf::ll_regression_forest(X = as.matrix(X), Y = Y,
c.forest.ll #tune.parameters = "all", #can't have this with ll splits, IRL you'd want to do a custom tuning loop
enable.ll.split = TRUE, ll.split.weight.penalty = TRUE,
num.trees = 4000, #upping from default for stable variance estimates
seed = 1995)
#
# Select covariates for local linear correction
<- glmnet::cv.glmnet(as.matrix(X), Y, alpha = 1, nfolds = 20) #cross-validated lasso
lasso.mod <- predict(lasso.mod, type = "nonzero")
lasso.coef <- lasso.coef[,1]
selected #
# out-of-sample preds, could also look at oob preds (leave out the testing set and use X instead)
<- predict(c.forest.ll,X.test,estimate.variance = TRUE,linear.correction.variables = selected)
preds.ll # data frame for plots later, adding built-in grf uncertainty intervals
<- data.frame(ll.preds = preds.ll$predictions,
plot.df ll.upper = preds.ll$predictions + 1.96*sqrt(preds.ll$variance.estimates), #grf 95% confidence intervals
ll.lower = preds.ll$predictions - 1.96*sqrt(preds.ll$variance.estimates),
Y = Y.test)
# column to indicate whether the uncertainty band contains the true value (for plotting)
$ll.cover <- as.factor(ifelse(Y.test >= plot.df$ll.lower & Y.test <= plot.df$ll.upper,1,0)) plot.df
In order to generate the conformal intervals, we first need to setup a function that will estimate the local linear forest model.
# Setup the llf function to plugin to cfcausal, same settings and seed as above
<- function(Y, X, Xtest, ...){
llRF <- grf::ll_regression_forest(X, Y, enable.ll.split = TRUE,
fit ll.split.weight.penalty = TRUE,num.trees = 4000,seed = 1995,...)
# Same selection procedure
<- glmnet::cv.glmnet(as.matrix(X), Y, alpha = 1, nfolds = 20)
lasso.mod <- predict(lasso.mod, type = "nonzero")
lasso.coef <- lasso.coef[,1]
selected #
# out-of sample preds
<- predict(c.forest.ll,Xtest,estimate.variance = FALSE,linear.correction.variables = selected) #turn off grf variance estimates
res #
<- as.numeric(res$predictions)
res return(res)
}
We can then feed that function into the cfcausal::conformal
function to generate unweighted standard conformal intervals.
# Setup the conformal prediction function plugging in our llf estimator
<- cfcausal::conformal(X = X,Y = Y, type = "mean", side = "two",
c.test wtfun = NULL, #unweighted
outfun = llRF, #our custom output function
useCV = FALSE) # Note: we're using split conformal here, you could alternatively use CV+ by setting useCV = FALSE.
<- predict(c.test,X.test,alpha = .025) #generate the uncertainty bands, here we're .025
ll.preds.conformal # Save out the results
$ll.upper.c <- ll.preds.conformal$upper
plot.df$ll.lower.c <- ll.preds.conformal$lower
plot.df# Column for whether truth is covered or not in a given instance, for plotting
$ll.cover.c <- as.factor(ifelse(Y.test >= plot.df$ll.lower.c & Y.test <= plot.df$ll.upper.c,1,0)) plot.df
Now we can plot a comparison between the local linear forest model predictions wrapped in the standard Gaussian and conformal intervals (both aiming for 95% coverage). The y-axis here is model predictions and associated uncertainty bands, the X axis are the true values for y in the testing set.
# llf plot with grf uncertainty bands
<- plot.df %>%
ll.plot ggplot(aes(y = ll.preds, x = Y.test, ymin = ll.lower, ymax = ll.upper,color = ll.cover)) +
geom_pointinterval(alpha = .5,shape = 1) +
scale_x_continuous(limits = c(0,30), expand = c(0, 0)) +
scale_y_continuous(limits = c(0,30), expand = c(0, 0)) +
geom_abline(linewidth = .75, intercept = 0,slope = 1) +
ylab("LLF Predictions (Y), 95% Confidence Intervals") +
xlab("Y (Real, Test Set)") +
ggtitle("Local Linear Forests") +
::theme_few()+
ggthemesscale_color_manual(values = c("firebrick1","dodgerblue")) +
guides(color=guide_legend(title="Cover Truth? (Blue = Yes)")) +
annotate("text", x = 20, y = 5, label = "Coverage == .28",
parse = TRUE)
# llf plot with conformal bands
<- plot.df %>%
ll.plot.c ggplot(aes(y = ll.preds, x = Y.test, ymin = ll.lower.c, ymax = ll.upper.c,color = ll.cover.c)) +
geom_pointinterval(alpha = .5,shape = 1) +
scale_x_continuous(limits = c(0,30), expand = c(0, 0)) +
scale_y_continuous(limits = c(0,30), expand = c(0, 0)) +
geom_abline(linewidth = .75, intercept = 0,slope = 1) +
ylab("LLF Predictions (Y), 95% Conformal Intervals") +
xlab("Y (Real, Test Set)") +
ggtitle("LL Forests w/ Conformal Bands") +
::theme_few()+
ggthemesguides(color=guide_legend(title="Cover Truth? (Blue = Yes)")) +
scale_color_manual(values = c("firebrick1","dodgerblue")) +
annotate("text", x = 20, y = 5, label = "Coverage == .95",
parse = TRUE)
#
::ggarrange(ll.plot,ll.plot.c,nrow = 1,common.legend = TRUE,legend = "bottom") ggpubr
As the figure makes clear, the conformal bands achieve a much better coverage rate relative to the Gaussian confidence bands. Indeed, in this particular example the conformal approach reaches exactly the desired 95% coverage, while the Gaussian confidence bands achieve only 28%.
References
Citation
@online{t._rametta2023,
author = {T. Rametta, Jack},
title = {Conformalized {Local} {Linear} {Forests}},
date = {2023-08-12},
url = {https://cetialphafive.github.io/jrametta/posts/2023-08-12-conformal/},
langid = {en}
}