diff --git a/r-package/grf/pkgdown/_pkgdown.yml b/r-package/grf/pkgdown/_pkgdown.yml index 0397d18fb..63562c011 100644 --- a/r-package/grf/pkgdown/_pkgdown.yml +++ b/r-package/grf/pkgdown/_pkgdown.yml @@ -16,6 +16,8 @@ navbar: href: articles/categorical_inputs.html - text: "Causal forest with time-to-event data" href: articles/survival.html + - text: "Cross-fold validation of heterogeneity" + href: articles/rate_cv.html - text: "Estimating ATEs on a new target population" href: articles/ate_transport.html - text: "Estimating conditional means" diff --git a/r-package/grf/vignettes/rate.Rmd b/r-package/grf/vignettes/rate.Rmd index 9a03416cc..165a6aa27 100644 --- a/r-package/grf/vignettes/rate.Rmd +++ b/r-package/grf/vignettes/rate.Rmd @@ -26,7 +26,7 @@ This vignette gives a brief introduction to how the *Rank-Weighted Average Treat ## Treatment prioritization rules We are in the familiar experimental setting (or unconfounded observational study) and are interested in the problem of determining which individuals to assign a binary treatment $W=\{0, 1\}$ and the associated value of this treatment allocation strategy. Given some subject characteristics $X_i$ we have access to a subject-specific treatment *prioritization rule* $S(X_i)$ which assigns scores to subjects. This prioritization rule should give a high score to units which we believe to have a large benefit of treatment and a low score to units with a low benefit of treatment. By benefit of treatment, we mean the difference in outcomes $Y$ from receiving the treatment given some subject characteristics $X_i$, as given by the conditional average treatment effect (CATE) -$$\tau(X_i) = E[Y_i(1) - Y_i(0) \,|\, X_i = x],$$ +$$\tau(x) = E[Y_i(1) - Y_i(0) \,|\, X_i = x],$$ where $Y(1)$ and $Y(0)$ are potential outcomes corresponding to the two treatment states. You might ask: why the general focus on an arbitrary rule $S(X_i)$ when you define benefit as measured by $\tau(X_i)$? Isn't it obvious that the estimated CATEs would serve the best purpose for treatment targeting? The answer is that this general problem formulation is quite convenient, as in some settings we may lack sufficient data to power an accurate CATE estimator and have to rely on other approaches to target treatment. Examples of other approaches are heuristics derived by domain experts or simpler models predicting risk scores (where risk is defined as $P[Y = 1 | X_i = x]$ and $Y \in \{0, 1\}$ with $Y=1$ being an adverse outcome), which is quite common in clinical applications. Consequently, in finite samples we may sometimes do better by relying on simpler rules which are correlated with the CATEs, than on noisy and complicated CATE estimates (remember: CATE estimation is a hard statistical task, and by focusing on a general rule we may circumvent some of the problems of obtaining accurate non-parametric point estimates by instead asking for estimates that *rank* units according to treatment benefit). Also, even if you have an accurate CATE estimator, there may be many to choose from (neural nets/random forests/various metalearners/etc). The question is: given a set of treatment prioritization rules $S(X_i)$, which one (if any) should we use? @@ -196,6 +196,8 @@ plot(rate.accord, xlab = "Treated fraction", main = "TOC evaluated on ACCORD\n t In this semi-synthetic example both AUTOCs are insignificant at conventional levels, suggesting there is no evidence of significant HTEs in the two trials. Note: this can also be attributed to a) low power, as perhaps the sample size is not large enough to detect HTEs, b) that the HTE estimator does not detect them, or c) the heterogeneity in the treatment effects along observable predictor variables are negligible. For a broader analysis comparing different prioritization strategies on the SPRINT and ACCORD datasets, see Yadlowsky et al. (2021). +For a discussion of alternatives to estimating RATEs that do not rely on a single train/test split, we refer to [this vignette](https://grf-labs.github.io/grf/articles/rate_cv.html). + ## Funding Development of the RATE functionality in GRF was supported in part by the award 5R01HL144555 from the National Institutes of Health. diff --git a/r-package/grf/vignettes/rate_cv.Rmd b/r-package/grf/vignettes/rate_cv.Rmd new file mode 100644 index 000000000..b9efb4088 --- /dev/null +++ b/r-package/grf/vignettes/rate_cv.Rmd @@ -0,0 +1,441 @@ +--- +title: "Cross-fold validation of heterogeneity" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{rate_cv} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +set.seed(123) +``` + +```{r setup} +library(grf) +``` +This vignette gives a brief overview of strategies to evaluate CATE estimators via *Rank-Weighted Average Treatment Effects* (*RATEs*) that avoids a single training and evaluation split. We're demonstrating the use of causal forest to detect heterogeneity via the RATE when using observational or experimental data with a binary treatment assignment and describe two proposals: a *sequential fold estimation* procedure and a heuristic for one-sided tests based on *out-of-bag estimates*. + +## Cross-fold estimation of RATE metrics +As covered in [this vignette](https://grf-labs.github.io/grf/articles/rate.html), RATE metrics can be used as a test for the presence of treatment heterogeneity. A common use case is to use a training set to obtain estimates of the conditional average treatment effect (CATE) and then evaluate these on a held-out test set using the RATE. Concretely, on a *training* set we can obtain some estimated CATE function $\hat \tau(\cdot)$, then on a *test* we can obtain an estimate of the RATE, then form a test for whether we detect heterogeneity via: + +$H_0$: RATE $=0$ (No heterogeneity detected by $\hat \tau(\cdot)$) + +$H_A$: RATE $\neq 0$. + +RATE metrics satisfy a central-limit theorem and motivate a simple $t$-test of this hypothesis using + +$$ +t = \frac{\widehat{RATE}}{\sqrt{Var(\widehat{RATE})}}. +$$ + +Using a confidence level $\alpha$ of 5%, we reject $H_0$ if the $p$-value $2\Phi(-|t|)$ is less than 5%. + +A drawback of this approach is that we are only utilizing a subset of our data to train a CATE estimator. In regimes with small sample sizes and hard-to-detect heterogeneity, this can be inconvenient. One appealing alternative could be to use cross-fold estimation, where we divide the data into $K$ folds and then alternate which subset we use for training a CATE estimator and which subset we use to estimate the RATE. +The figure below illustrates the case when using $K=5$ folds: fit a CATE estimator on the blue sample, and evaluate the RATE on the red sample. + +```{r, echo=FALSE} +plot(NULL, xlim = c(1, 6), ylim = c(1, 7), xaxt = 'n', yaxt = 'n', + xlab = "Fold ID", + ylab = "Estimation step", + main = "5-fold estimation" + ) + +lines(c(1, 6), c(5,5), lwd = 4, col = "blue") +lines(c(1, 2), c(5,5), lwd = 4, col = "red") +text(1.5, 5.25, offset = 0, "1", col = "red") +text(2.5, 5.25, offset = 0, "2", col = "blue") +text(3.5, 5.25, offset = 0, "3", col = "blue") +text(4.5, 5.25, offset = 0, "4", col = "blue") +text(5.5, 5.25, offset = 0, "5", col = "blue") + +lines(c(1, 6), c(4,4), lwd = 4, col = "blue") +lines(c(2, 3), c(4,4), lwd = 4, col = "red") +text(1.5, 4.25, offset = 0, "1", col = "blue") +text(2.5, 4.25, offset = 0, "2", col = "red") +text(3.5, 4.25, offset = 0, "3", col = "blue") +text(4.5, 4.25, offset = 0, "4", col = "blue") +text(5.5, 4.25, offset = 0, "5", col = "blue") + +lines(c(1, 6), c(3,3), lwd = 4, col = "blue") +lines(c(3, 4), c(3,3), lwd = 4, col = "red") +text(1.5, 3.25, offset = 0, "1", col = "blue") +text(2.5, 3.25, offset = 0, "2", col = "blue") +text(3.5, 3.25, offset = 0, "3", col = "red") +text(4.5, 3.25, offset = 0, "4", col = "blue") +text(5.5, 3.25, offset = 0, "5", col = "blue") + +lines(c(1, 6), c(2,2), lwd = 4, col = "blue") +lines(c(4, 5), c(2,2), lwd = 4, col = "red") +text(1.5, 2.25, offset = 0, "1", col = "blue") +text(2.5, 2.25, offset = 0, "2", col = "blue") +text(3.5, 2.25, offset = 0, "3", col = "blue") +text(4.5, 2.25, offset = 0, "4", col = "red") +text(5.5, 2.25, offset = 0, "5", col = "blue") + +lines(c(1, 6), c(1,1), lwd = 4, col = "blue") +lines(c(5, 6), c(1,1), lwd = 4, col = "red") +text(1.5, 1.25, offset = 0, "1", col = "blue") +text(2.5, 1.25, offset = 0, "2", col = "blue") +text(3.5, 1.25, offset = 0, "3", col = "blue") +text(4.5, 1.25, offset = 0, "4", col = "blue") +text(5.5, 1.25, offset = 0, "5", col = "red") + +text(3, 6.5, "Train", col = "blue") +text(4, 6.5, "Test", col = "red") +``` + +This strategy would give us 5 RATE estimates along with 5 $t$-values, and the hope would be that in aggregating these we are increasing our ability (power) to detect heterogeneity. On their own, each of these $t$-values serves as a valid test. The **problem** with aggregating them is that they are **not independent**, as the overlapping fold construction induces correlation. + +We can see this play out in a simple simulation exercise. Below we define a simple function that splits the data into 5 folds, forms a RATE estimate and $t$-value on each fold, and then compute a $p$-value by aggregating the test statistics. + +```{r} +# Estimate RATE on K random folds to get K t-statistics, then return a p-value +# for H0 using the sum of these t-values. +rate_cv_naive = function(X, Y, W, num.folds = 5) { + fold.id = sample(rep(1:num.folds, length = nrow(X))) + samples.by.fold = split(seq_along(fold.id), fold.id) + + t.statistics = c() + + # Form AIPW scores for estimating RATE + nuisance.forest = causal_forest(X, Y, W) + DR.scores = get_scores(nuisance.forest) + + for (k in 1:num.folds) { + train = unlist(samples.by.fold[-k]) + test = samples.by.fold[[k]] + + cate.forest = causal_forest(X[train, ], Y[train], W[train]) + + cate.hat.test = predict(cate.forest, X[test, ])$predictions + + rate.fold = rank_average_treatment_effect.fit(DR.scores[test], cate.hat.test) + t.statistics = c(t.statistics, rate.fold$estimate / rate.fold$std.err) + } + + p.value.naive = 2 * pnorm(-abs(sum(t.statistics) / sqrt(num.folds))) + + p.value.naive +} +``` +Next, we simulate synthetic data with no treatment heterogeneity, and form a Monte Carlo estimate of our naive procedure's rejection probability, using a 5% confidence level: + +```{r, eval=FALSE} +res = t(replicate(500, { + n = 1500 + p = 5 + X = matrix(rnorm(n * p), n, p) + W = rbinom(n, 1, 0.5) + Y = pmin(X[, 3], 0) + rnorm(n) + + pval.naive = rate_cv_naive(X, Y, W) + + c( + reject = pval.naive <= 0.05 + ) +})) +paste("Rejection rate:", mean(res) * 100, "%") +#> [1] "Rejection rate: 16.8 %" +``` +At a 5% level, any test with a rejection rate above 5% is an invalid test. + +## Sequential cross-fold estimation +It turns out, however, that there exists a simple algorithmic fix that can address the shortcomings of naive cross-fold validation. This approach, which we call **sequential** cross-fold estimation, breaks up the correlation between RATE metrics across folds by only using training folds in sequential order. As illustrated in the figure below, this strategy will give us $K - 1$ separate $t$-values. + +```{r, echo=FALSE} +plot(NULL, xlim = c(1, 6), ylim = c(1, 6), xaxt = 'n', yaxt = 'n', + xlab = "Fold ID", + ylab = "Estimation step", + main = "Sequential 5-fold estimation" +) + +lines(c(1, 2), c(4,4), lwd = 4, col = "blue") +lines(c(2, 3), c(4,4), lwd = 4, col = "red") +text(1.5, 4.25, offset = 0, "1", col = "blue") +text(2.5, 4.25, offset = 0, "2", col = "red") + +lines(c(1, 3), c(3,3), lwd = 4, col = "blue") +lines(c(3, 4), c(3,3), lwd = 4, col = "red") +text(1.5, 3.25, offset = 0, "1", col = "blue") +text(2.5, 3.25, offset = 0, "2", col = "blue") +text(3.5, 3.25, offset = 0, "3", col = "red") + +lines(c(1, 4), c(2,2), lwd = 4, col = "blue") +lines(c(4, 5), c(2,2), lwd = 4, col = "red") +text(1.5, 2.25, offset = 0, "1", col = "blue") +text(2.5, 2.25, offset = 0, "2", col = "blue") +text(3.5, 2.25, offset = 0, "3", col = "blue") +text(4.5, 2.25, offset = 0, "4", col = "red") + +lines(c(1, 5), c(1,1), lwd = 4, col = "blue") +lines(c(5, 6), c(1,1), lwd = 4, col = "red") +text(1.5, 1.25, offset = 0, "1", col = "blue") +text(2.5, 1.25, offset = 0, "2", col = "blue") +text(3.5, 1.25, offset = 0, "3", col = "blue") +text(4.5, 1.25, offset = 0, "4", col = "blue") +text(5.5, 1.25, offset = 0, "5", col = "red") + +text(3, 5.5, "Train", col = "blue") +text(4, 5.5, "Test", col = "red") +``` + +This construction is going to make our $t$-values satisfy a martingale property and allow for valid aggregation (Luedtke & van der Laan (2016), and Wager (2024)). The steps are: + +1. Randomly split the data into $k=1,\ldots,K$ evenly sized folds. + +2. For each fold $k=2,\ldots,K$: + + Obtain an estimated CATE function using $\color{blue}{\text{training}}$ data from folds $1,\ldots,k-1$. + + Form a RATE estimate using data from the $k$-th $\color{red}{\text{test}}$ fold and compute a $t$-statistic. + +3. Aggregate the $K-1$ different $t$-statistics and use this as a final test statistic. + +This procedure is coded in the simple function below. +```{r} +# Estimate RATE on K-1 random sequential folds to get K-1 t-statistics, then return a p-value +# for H0 using the sum of these t-values +rate_sequential = function(X, Y, W, num.folds = 5) { + fold.id = sample(rep(1:num.folds, length = nrow(X))) + samples.by.fold = split(seq_along(fold.id), fold.id) + + t.statistics = c() + + # Form AIPW scores for estimating RATE + nuisance.forest = causal_forest(X, Y, W) + DR.scores = get_scores(nuisance.forest) + + for (k in 2:num.folds) { + train = unlist(samples.by.fold[1:(k - 1)]) + test = samples.by.fold[[k]] + + cate.forest = causal_forest(X[train, ], Y[train], W[train]) + + cate.hat.test = predict(cate.forest, X[test, ])$predictions + + rate.fold = rank_average_treatment_effect.fit(DR.scores[test], cate.hat.test) + t.statistics = c(t.statistics, rate.fold$estimate / rate.fold$std.err) + } + + p.value = 2 * pnorm(-abs(sum(t.statistics) / sqrt(num.folds - 1))) + + p.value +} +``` + +If we repeat the check on data with no heterogeneity we can see this construction gives us the correct rejection rate (within the error tolerance of our Monte Carlo repetitions): +```{r, eval=FALSE} +res = t(replicate(500, { + n = 1500 + p = 5 + X = matrix(rnorm(n * p), n, p) + W = rbinom(n, 1, 0.5) + Y = pmin(X[, 3], 0) + rnorm(n) + + pval.sequential = rate_sequential(X, Y, W) + + c( + reject.sequential = pval.sequential <= 0.05 + ) +})) +paste("Rejection rate:", mean(res) * 100, "%") +#> [1] "Rejection rate: 5.4 %" +``` + +Now that we have a valid test for treatment heterogeneity that uses several folds, let's try it out on a simulated example and examine how it performs in detecting treatment effect heterogeneity. The baseline we are comparing with is sample splitting, which is proven to be valid in the RATE paper, but which has potentially lower power since we are only using a fraction of our data for validation. + +Below we generate data from synthetic observational data (Setup A in Nie & Wager (2021)) where treatment effect heterogeneity is hard to detect. Remember, in this case, we want our test to reject $H_0$ of no heterogeneity. + +```{r, eval=FALSE} +res = t(replicate(500, { + # Simulate synthetic data with "weak" treatment heterogeneity + n = 2500 + p = 5 + data = generate_causal_data(n, p, sigma.tau = 0.15, sigma.noise = 1, dgp = "nw1") + X = data$X + Y = data$Y + W = data$W + + pval.sequential = rate_sequential(X, Y, W) + + # Compare with a single train/test split + train = sample(n, n / 2) + test = -train + cf.train = causal_forest(X[train,], Y[train], W[train]) + cf.test = causal_forest(X[test,], Y[test], W[test]) + rate = rank_average_treatment_effect(cf.test, predict(cf.train, X[test, ])$predictions) + tstat.split = rate$estimate / rate$std.err + pval.split = 2 * pnorm(-abs(tstat.split)) + + c( + reject.sequential = pval.sequential <= 0.05, + reject.split = pval.split <= 0.05 + ) +})) +colMeans(res) * 100 +#> reject.sequential reject.split +#> 56.6 47.0 +``` + +Reassuringly, in this example, cross-fold estimation is able to improve power to detect treatment heterogeneity over our baseline. + +## Heuristic: out-of-bag estimation with one-sided tests +The fundamental challenge with estimating CATEs and evaluating the presence of heterogeneity using the same data is overfitting: we're simply constrained by the fact that we can't simply estimate and evaluate something using the same data set. We saw that naive cross-fold estimation did not overcome this challenge due to between-fold correlations (that biases our estimate downwards). A *heuristic* to salvage this naive aggregation is to use **one-sided** $p$-values. If we are willing to commit to a one-sided test, for example, RATE $>0$, this can serve as a useful strategy. + +Random forests, however, have a more computationally efficient alternative to simple cross-fold estimation. The forest-native version of leave-one-out estimation is *out-of-bag* (OOB) estimation. For a given unit $i$, these are estimates that are formed using all the trees in the random forest grown on subsamples that does not include the $i$-th unit. OOB-estimated RATEs paired with one-sided tests can serve as a useful heuristic when we want to use all our data to fit and evaluate heterogeneity. + +In the example below we compare a two-sided test for RATE $\neq 0$ with a one-sided test for RATE $>0$ at a 5% level. + +```{r, eval=FALSE} +res = t(replicate(500, { + n = 1500 + p = 5 + X = matrix(rnorm(n * p), n, p) + W = rbinom(n, 1, 0.5) + Y = pmin(X[, 3], 0) + rnorm(n) + + c.forest = causal_forest(X, Y, W) + tau.hat.oob = predict(c.forest)$predictions + rate.oob = rank_average_treatment_effect(c.forest, tau.hat.oob) + + t.stat.oob = rate.oob$estimate / rate.oob$std.err + # Compute a two-sided p-value Pr(>|t|) + p.val = 2 * pnorm(-abs(t.stat.oob)) + # Compute a one-sided p-value Pr(>t) + p.val.onesided = pnorm(t.stat.oob, lower.tail = FALSE) + + c( + reject.oob = p.val <= 0.05, + reject.oob.onesided = p.val.onesided <= 0.05 + ) +})) +colMeans(res) * 100 +#> reject.oob reject.oob.onesided +#> 30.4 4.0 +``` + +We see that the one-sided test gives us a correct rejection rate of a null effect. + +If we repeat the simple power exercise from above, we see that we achieve power essentially the same as sequential aggregation. +```{r, eval=FALSE} +res = t(replicate(500, { + # Simulate synthetic data with "weak" treatment heterogeneity + n = 2500 + p = 5 + data = generate_causal_data(n, p, sigma.tau = 0.15, sigma.noise = 1, dgp = "nw1") + X = data$X + Y = data$Y + W = data$W + + c.forest = causal_forest(X, Y, W) + tau.hat.oob = predict(c.forest)$predictions + rate.oob = rank_average_treatment_effect(c.forest, tau.hat.oob) + t.stat.oob = rate.oob$estimate / rate.oob$std.err + # Compute a one-sided p-value Pr(>t) + p.val.onesided = pnorm(t.stat.oob, lower.tail = FALSE) + + c( + reject.oob.onesided = p.val.onesided <= 0.05 + ) +})) +mean(res) * 100 +#> [1] 54.0 +``` + +## References +Chernozhukov, Victor, Mert Demirer, Esther Duflo, and Ivan Fernandez-Val. Generic machine learning inference on heterogeneous treatment effects in randomized experiments, with an application to immunization in India. _Econometrica, forthcoming_. + +Luedtke, Alexander R., and Mark J. van der Laan. Statistical inference for the mean outcome under a possibly non-unique optimal treatment strategy. _Annals of Statistics 44(2), 2016_. + +Nie, Xinkun, and Stefan Wager. Quasi-oracle estimation of heterogeneous treatment effects. _Biometrika 108(2), 2021_. + +Wager, Stefan. Sequential Validation of Treatment Heterogeneity. _arXiv preprint arXiv:2405.05534_ ([arxiv](https://arxiv.org/abs/2405.05534)) + +## Appendix: Median aggregation of p-values +One possible solution to salvage `rate_cv_naive` is to keep the traditional $K$-fold construction, but change the way we aggregate $t$-statistics for each fold. This is the strategy Chernozhukov et al. (2024) pursue. If instead of computing a $p$-value based on aggregated $t$-statistics, we take a median of $p$-values based on each $t$-statistic, we get a valid test. + +```{r} +# This function does the same thing as `rate_naive`, except it returns +# the median p-value. +rate_cv_median = function(X, Y, W, num.folds = 5) { + fold.id = sample(rep(1:num.folds, length = nrow(X))) + samples.by.fold = split(seq_along(fold.id), fold.id) + + t.statistics = c() + + nuisance.forest = causal_forest(X, Y, W) + DR.scores = get_scores(nuisance.forest) + + for (k in 1:num.folds) { + train = unlist(samples.by.fold[-k]) + test = samples.by.fold[[k]] + + cate.forest = causal_forest(X[train, ], Y[train], W[train]) + + cate.hat.test = predict(cate.forest, X[test, ])$predictions + + rate.fold = rank_average_treatment_effect.fit(DR.scores[test], cate.hat.test) + t.statistics = c(t.statistics, rate.fold$estimate / rate.fold$std.err) + } + + p.value.median = median(2 * pnorm(-abs(t.statistics))) + + p.value.median +} +``` + +This construction gives us a valid rejection rate using a 5% confidence value, but is quite conservative: +```{r, eval=FALSE} +res = t(replicate(500, { + n = 1500 + p = 5 + X = matrix(rnorm(n * p), n, p) + W = rbinom(n, 1, 0.5) + Y = pmin(X[, 3], 0) + rnorm(n) + + rate.cv.median = rate_cv_median(X, Y, W) + + c( + reject.median = rate.cv.median <= 0.05 + ) +})) +paste("Rejection rate:", mean(res) * 100, "%") +#> [1] "Rejection rate: 0.2 %" +``` + +If we repeat the power simulation from above, this algorithm does unfortunately not yield higher power than sample splitting. A drawback of this approach is that if all the $t$-statistics are positive, but not significant on their own, then `rate_sequential` can take this into account, while `rate_cv_median` cannot. +```{r, eval=FALSE} +res = t(replicate(500, { + # Simulate synthetic data with "weak" treatment heterogeneity + n = 2500 + p = 5 + data = generate_causal_data(n, p, sigma.tau = 0.15, sigma.noise = 1, dgp = "nw1") + X = data$X + Y = data$Y + W = data$W + + pval.median = rate_cv_median(X, Y, W) + + # Compare with sample splitting + train = sample(n, n / 2) + test = -train + cf.train = causal_forest(X[train,], Y[train], W[train]) + cf.test = causal_forest(X[test,], Y[test], W[test]) + rate = rank_average_treatment_effect(cf.test, predict(cf.train, X[test, ])$predictions) + tstat.split = rate$estimate / rate$std.err + pval.split = 2 * pnorm(-abs(tstat.split)) + + c( + reject.median = pval.median <= 0.05, + reject.split = pval.split <= 0.05 + ) +})) +colMeans(res) * 100 +#> reject.median reject.split +#> 14.6 47.8 +```