This function calculates counterfactual predictions for each level of a specified treatment variable in a generalized linear model (GLM). It is designed to aid in the assessment of treatment effects by predicting outcomes under different treatments under causal inference framework.


predict_counterfactuals(object, trt)



a fitted glm object for which counterfactual predictions are desired.


a string specifying the name of the treatment variable in the model formula. It must be one of the linear predictor variables used in fitting the object.


an updated glm object appended with an additional component counterfactual.predictions.

This component contains a tibble with two columns: cf_pred_0 and cf_pred_1, representing counterfactual predictions for each level of the treatment variable. A descriptive label attribute explains the counterfactual scenario associated with each column.


The function works by creating two new datasets from the original data used to fit the GLM model. In these datasets, the treatment variable is set to each of its levels across all records (e.g., patients).

Predictions are then made for each dataset based on the fitted GLM model, simulating the response variable under each treatment condition.

The results are stored in a tidy format and appended to the original model object for further analysis or inspection.

For averaging counterfactual outcomes, apply average_predictions().

# Preparing data and fitting a GLM model
trial01$trtp <- factor(trial01$trtp)
fit1 <- glm(aval ~ trtp + bl_cov, family = "binomial", data = trial01)

# Generating counterfactual predictions
fit2 <- predict_counterfactuals(fit1, "trtp")
#> Warning: There is 1 record omitted from the original data due to missing values, please check if they should be imputed prior to model fitting.

# Accessing the counterfactual predictions
#> # A tibble: 267 × 2
#>    cf_pred_0 cf_pred_1
#>        <dbl>     <dbl>
#>  1     0.533     0.463
#>  2     0.537     0.468
#>  3     0.481     0.413
#>  4     0.510     0.441
#>  5     0.428     0.362
#>  6     0.474     0.406
#>  7     0.490     0.421
#>  8     0.496     0.427
#>  9     0.514     0.445
#> 10     0.490     0.422
#> # ℹ 257 more rows
