Skip to content

Code for the paper "Causal Transformer for Estimating Counterfactual Outcomes"

License

Notifications You must be signed in to change notification settings

Valentyn1997/CausalTransformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CausalTransformer

Conference arXiv Python application

Causal Transformer for estimating counterfactual outcomes over time.

Screenshot 2022-06-03 at 16 41 44

The project is built with following Python libraries:

  1. Pytorch-Lightning - deep learning models
  2. Hydra - simplified command line arguments management
  3. MlFlow - experiments tracking

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

MlFlow Setup / Connection

To start an experiments server, run:

mlflow server --port=5000

To access MlFLow web UI with all the experiments, connect via ssh:

ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>

Then, one can go to local browser http:https://localhost:5000.

Experiments

Main training script is universal for different models and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml and other files in configs/ folder.

Generic script with logging and fixed random seed is following (with training-type enc_dec, gnet, rmsn and multi):

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_<training-type>.py +dataset=<dataset> +backbone=<backbone> exp.seed=10 exp.logging=True

Backbones (baselines)

One needs to choose a backbone and then fill the specific hyperparameters (they are left blank in the configs):

Models already have best hyperparameters saved (for each model and dataset), one can access them via: +backbone/<backbone>_hparams/cancer_sim_<balancing_objective>=<coeff_value> or +backbone/<backbone>_hparams/mimic3_real=diastolic_blood_pressure.

For CT, EDCT, and CT, several adversarial balancing objectives are available:

  • counterfactual domain confusion loss (this paper): exp.balancing=domain_confusion
  • gradient reversal (originally in CRN, but can be used for all the methods): exp.balancing=grad_reverse

To train a decoder (for CRN and RMSNs), use the flag model.train_decoder=True.

To perform a manual hyperparameter tuning use the flags model.<sub_model>.tune_hparams=True, and then see model.<sub_model>.hparams_grid. Use model.<sub_model>.tune_range to specify the number of trials for random search.

Datasets

One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set gamma for cancer_sim with dataset.coeff=1.0):

  • Synthetic Tumor Growth Simulator: +dataset=cancer_sim
  • MIMIC III Semi-synthetic Simulator (multiple treatments and outcomes): +dataset=mimic3_synthetic
  • MIMIC III Real-world dataset: +dataset=mimic3_real

Before running MIMIC III experiments, place MIMIC-III-extract dataset (all_hourly_data.h5) to data/processed/

Example of running Causal Transformer on Synthetic Tumor Growth Generator with gamma = [1.0, 2.0, 3.0] and different random seeds (total of 30 subruns), using hyperparameters:

PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices> 
python3 runnables/train_multi.py -m +dataset=cancer_sim +backbone=ct +backbone/ct_hparams/cancer_sim_domain_conf='0','1','2' exp.seed=10,101,1010,10101,101010

Updated results

Self- and cross-attention bug

New results for semi-synthetic and real-world experiments after fixing a bug with self- and cross-attentions (#7). Therein, the bug affected only Tables 1 and 2, and Figure 5 (https://arxiv.org/pdf/2204.07258.pdf). Nevertheless, the performance of the CT with the bug fixed did not change drastically.

Table 1 (updated). Results for semi-synthetic data for $\tau$-step-ahead prediction based on real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.

$\tau = 1$ $\tau = 2$ $\tau = 3$ $\tau = 4$ $\tau = 5$ $\tau = 6$ $\tau = 7$ $\tau = 8$ $\tau = 9$ $\tau = 10$
MSMs 0.37 ± 0.01 0.57 ± 0.03 0.74 ± 0.06 0.88 ± 0.03 1.14 ± 0.10 1.95 ± 1.48 3.44 ± 4.57 > 10.0 > 10.0 > 10.0
RMSNs 0.24 ± 0.01 0.47 ± 0.01 0.60 ± 0.01 0.70 ± 0.02 0.78 ± 0.04 0.84 ± 0.05 0.89 ± 0.06 0.94 ± 0.08 0.97 ± 0.09 1.00 ± 0.11
CRN 0.30 ± 0.01 0.48 ± 0.02 0.59 ± 0.02 0.65 ± 0.02 0.68 ± 0.02 0.71 ± 0.01 0.72 ± 0.01 0.74 ± 0.01 0.76 ± 0.01 0.78 ± 0.02
G-Net 0.34 ± 0.01 0.67 ± 0.03 0.83 ± 0.04 0.94 ± 0.04 1.03 ± 0.05 1.10 ± 0.05 1.16 ± 0.05 1.21 ± 0.06 1.25 ± 0.06 1.29 ± 0.06
EDCT (GR; $\lambda = 1$) 0.29 ± 0.01 0.46 ± 0.01 0.56 ± 0.01 0.62 ± 0.01 0.67 ± 0.01 0.70 ± 0.01 0.72 ± 0.01 0.74 ± 0.01 0.76 ± 0.01 0.78 ± 0.01
CT ($\alpha = 0$) (ours, fixed) 0.20 ± 0.01 0.38 ± 0.01 0.46 ± 0.01 0.50 ± 0.01 0.52 ± 0.01 0.54 ± 0.01 0.56 ± 0.01 0.57 ± 0.01 0.59 ± 0.01 0.60 ± 0.01
CT (ours, fixed) 0.21 ± 0.01 0.38 ± 0.01 0.46 ± 0.01 0.50 ± 0.01 0.53 ± 0.01 0.54 ± 0.01 0.55 ± 0.01 0.57 ± 0.01 0.58 ± 0.01 0.59 ± 0.01

Table 2 (updated). Results for experiments with real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.

$\tau = 1$ $\tau = 2$ $\tau = 3$ $\tau = 4$ $\tau = 5$
MSMs 6.37 ± 0.26 9.06 ± 0.41 11.89 ± 1.28 13.12 ± 1.25 14.44 ± 1.12
RMSNs 5.20 ± 0.15 9.79 ± 0.31 10.52 ± 0.39 11.09 ± 0.49 11.64 ± 0.62
CRN 4.84 ± 0.08 9.15 ± 0.16 9.81 ± 0.17 10.15 ± 0.19 10.40 ± 0.21
G-Net 5.13 ± 0.05 11.88 ± 0.20 12.91 ± 0.26 13.57 ± 0.30 14.08 ± 0.31
CT (ours, fixed) 4.60 ± 0.08 9.01 ± 0.21 9.58 ± 0.19 9.89 ± 0.21 10.12 ± 0.22

Figure 6 (updated). Subnetworks importance scores based on semi-synthetic benchmark (higher values correspond to higher importance of subnetwork connectivity via cross-attentions). Shown: RMSE differences between model with isolated subnetwork and full CT, means ± standard errors.

subnet-isolation

Last active entry zeroing bug

New results after fixing a bug with the synthetic tumor-growth simulator: outcome corresponding to the last entry for every time series was zeroed.

Table 9 (updated). Normalized RMSE for one-step-ahead prediction. Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the the amount of time-varying confounding: higher values mean larger treatment assignment bias.

$\gamma = 0$ $\gamma = 1$ $\gamma = 2$ $\gamma = 3$ $\gamma = 4$
MSMs 1.091 ± 0.115 1.202 ± 0.108 1.383 ± 0.090 1.647 ± 0.121 1.981 ± 0.232
RMSNs 0.834 ± 0.072 0.860 ± 0.025 1.000 ± 0.134 1.131 ± 0.057 1.434 ± 0.148
CRN 0.755 ± 0.059 0.788 ± 0.057 0.881 ± 0.066 1.062 ± 0.088 1.358 ± 0.167
G-Net 0.795 ± 0.066 0.841 ± 0.038 0.946 ± 0.083 1.057 ± 0.146 1.319 ± 0.248
CT ($\alpha = 0$) (ours) 0.772 ± 0.051 0.783 ± 0.071 0.862 ± 0.052 1.062 ± 0.119 1.331 ± 0.217
CT (ours) 0.770 ± 0.049 0.783 ± 0.071 0.864 ± 0.059 1.098 ± 0.097 1.413 ± 0.259

Table 10 (updated). Normalized RMSE for $\tau$-step-ahead prediction (here: random trajectories setting). Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the amount of time-varying confounding: higher values mean larger treatment assignment bias.

$\gamma = 0$ $\gamma = 1$ $\gamma = 2$ $\gamma = 3$ $\gamma = 4$
('2', 'MSMs') 0.975 ± 0.063 1.183 ± 0.146 1.428 ± 0.274 1.673 ± 0.431 1.884 ± 0.637
('2', 'RMSNs') 0.825 ± 0.057 0.851 ± 0.043 0.861 ± 0.078 0.993 ± 0.126 1.269 ± 0.294
('2', 'CRN') 0.761 ± 0.058 0.760 ± 0.037 0.805 ± 0.050 2.045 ± 1.491 1.209 ± 0.192
('2', 'G-Net') 1.006 ± 0.082 0.994 ± 0.086 1.185 ± 0.077 1.083 ± 0.145 1.243 ± 0.202
('2', 'CT ($\alpha = 0$) (ours)') 0.766 ± 0.029 0.781 ± 0.066 0.814 ± 0.078 0.944 ± 0.144 1.191 ± 0.316
('2', 'CT (ours)') 0.762 ± 0.028 0.781 ± 0.058 0.818 ± 0.091 1.001 ± 0.150 1.163 ± 0.233
('3', 'MSMs') 0.937 ± 0.060 1.133 ± 0.158 1.344 ± 0.262 1.525 ± 0.400 1.564 ± 0.545
('3', 'RMSNs') 0.824 ± 0.043 0.871 ± 0.036 0.857 ± 0.109 1.020 ± 0.140 1.267 ± 0.298
('3', 'CRN') 0.769 ± 0.057 0.777 ± 0.037 0.826 ± 0.077 1.789 ± 1.108 1.356 ± 0.330
('3', 'G-Net') 1.103 ± 0.092 1.097 ± 0.095 1.355 ± 0.107 1.225 ± 0.184 1.382 ± 0.242
('3', 'CT ($\alpha = 0$) (ours)') 0.766 ± 0.037 0.806 ± 0.060 0.828 ± 0.106 0.996 ± 0.185 1.335 ± 0.465
('3', 'CT (ours)') 0.762 ± 0.036 0.807 ± 0.056 0.838 ± 0.120 1.072 ± 0.196 1.283 ± 0.312
('4', 'MSMs') 0.845 ± 0.060 1.022 ± 0.149 1.196 ± 0.233 1.325 ± 0.363 1.308 ± 0.482
('4', 'RMSNs') 0.780 ± 0.046 0.834 ± 0.040 0.814 ± 0.123 0.988 ± 0.146 1.169 ± 0.269
('4', 'CRN') 0.734 ± 0.061 0.743 ± 0.037 0.805 ± 0.096 1.567 ± 0.825 1.327 ± 0.293
('4', 'G-Net') 1.092 ± 0.090 1.074 ± 0.098 1.385 ± 0.117 1.212 ± 0.202 1.358 ± 0.253
('4', 'CT ($\alpha = 0$) (ours)') 0.730 ± 0.042 0.776 ± 0.056 0.802 ± 0.119 0.983 ± 0.208 1.394 ± 0.563
('4', 'CT (ours)') 0.726 ± 0.041 0.777 ± 0.054 0.810 ± 0.128 1.075 ± 0.220 1.302 ± 0.356
('5', 'MSMs') 0.747 ± 0.056 0.896 ± 0.136 1.038 ± 0.210 1.128 ± 0.320 1.155 ± 0.448
('5', 'RMSNs') 0.717 ± 0.053 0.775 ± 0.041 0.747 ± 0.124 0.922 ± 0.141 1.057 ± 0.246
('5', 'CRN') 0.678 ± 0.062 0.692 ± 0.037 0.761 ± 0.104 1.410 ± 0.604 1.242 ± 0.239
('5', 'G-Net') 1.033 ± 0.086 1.014 ± 0.097 1.358 ± 0.118 1.160 ± 0.199 1.285 ± 0.242
('5', 'CT ($\alpha = 0$) (ours)') 0.673 ± 0.044 0.722 ± 0.052 0.748 ± 0.124 0.931 ± 0.213 1.405 ± 0.648
('5', 'CT (ours)') 0.669 ± 0.043 0.723 ± 0.053 0.751 ± 0.125 1.036 ± 0.238 1.264 ± 0.389
('6', 'MSMs') 0.647 ± 0.055 0.778 ± 0.123 0.894 ± 0.188 0.952 ± 0.284 1.060 ± 0.432
('6', 'RMSNs') 0.646 ± 0.058 0.702 ± 0.043 0.675 ± 0.121 0.847 ± 0.132 0.947 ± 0.225
('6', 'CRN') 0.614 ± 0.057 0.631 ± 0.035 0.706 ± 0.104 1.308 ± 0.438 1.132 ± 0.194
('6', 'G-Net') 0.963 ± 0.083 0.942 ± 0.090 1.321 ± 0.118 1.092 ± 0.183 1.195 ± 0.223
('6', 'CT ($\alpha = 0$) (ours)') 0.609 ± 0.042 0.657 ± 0.046 0.684 ± 0.122 0.864 ± 0.201 1.383 ± 0.699
('6', 'CT (ours)') 0.605 ± 0.040 0.657 ± 0.047 0.685 ± 0.119 0.979 ± 0.249 1.201 ± 0.419

Table 11 (updated). Normalized RMSE for $\tau$-step-ahead prediction (here: single sliding treatment setting). Shown: mean and standard deviation over five runs (lower is better). Parameter $\gamma$ is the amount of time-varying confounding: higher values mean larger treatment assignment bias.

$\gamma = 0$ $\gamma = 1$ $\gamma = 2$ $\gamma = 3$ $\gamma = 4$
('2', 'MSMs') 1.362 ± 0.109 1.612 ± 0.172 1.939 ± 0.365 2.290 ± 0.545 2.468 ± 1.058
('2', 'RMSNs') 0.742 ± 0.043 0.760 ± 0.047 0.827 ± 0.056 0.957 ± 0.106 1.276 ± 0.240
('2', 'CRN') 0.671 ± 0.066 0.666 ± 0.052 0.741 ± 0.042 1.668 ± 1.184 1.151 ± 0.166
('2', 'G-Net') 1.021 ± 0.067 1.009 ± 0.092 1.271 ± 0.075 1.113 ± 0.149 1.257 ± 0.227
('2', 'CT ($\alpha = 0$) (ours)') 0.685 ± 0.050 0.679 ± 0.044 0.714 ± 0.053 0.875 ± 0.105 1.072 ± 0.315
('2', 'CT (ours)') 0.681 ± 0.052 0.677 ± 0.044 0.713 ± 0.042 0.908 ± 0.122 1.274 ± 0.366
('3', 'MSMs') 1.679 ± 0.132 1.953 ± 0.208 2.302 ± 0.437 2.640 ± 0.639 2.622 ± 1.132
('3', 'RMSNs') 0.783 ± 0.053 0.792 ± 0.047 0.889 ± 0.050 1.086 ± 0.175 1.382 ± 0.286
('3', 'CRN') 0.700 ± 0.078 0.692 ± 0.046 0.818 ± 0.051 1.959 ± 1.032 1.360 ± 0.225
('3', 'G-Net') 1.253 ± 0.079 1.226 ± 0.104 1.611 ± 0.102 1.383 ± 0.200 1.574 ± 0.328
('3', 'CT ($\alpha = 0$) (ours)') 0.707 ± 0.053 0.711 ± 0.038 0.770 ± 0.043 0.969 ± 0.119 1.261 ± 0.462
('3', 'CT (ours)') 0.703 ± 0.055 0.712 ± 0.040 0.770 ± 0.032 1.010 ± 0.119 1.536 ± 0.450
('4', 'MSMs') 1.871 ± 0.145 2.145 ± 0.227 2.489 ± 0.471 2.791 ± 0.681 2.615 ± 1.142
('4', 'RMSNs') 0.821 ± 0.079 0.837 ± 0.058 0.963 ± 0.106 1.216 ± 0.240 1.416 ± 0.304
('4', 'CRN') 0.734 ± 0.087 0.722 ± 0.041 0.898 ± 0.068 2.201 ± 0.967 1.573 ± 0.255
('4', 'G-Net') 1.390 ± 0.087 1.347 ± 0.112 1.819 ± 0.133 1.544 ± 0.243 1.769 ± 0.413
('4', 'CT ($\alpha = 0$) (ours)') 0.729 ± 0.056 0.749 ± 0.033 0.826 ± 0.046 1.053 ± 0.147 1.426 ± 0.574
('4', 'CT (ours)') 0.726 ± 0.057 0.748 ± 0.036 0.822 ± 0.036 1.089 ± 0.122 1.762 ± 0.523
('5', 'MSMs') 1.963 ± 0.155 2.221 ± 0.231 2.547 ± 0.479 2.810 ± 0.684 2.542 ± 1.122
('5', 'RMSNs') 0.855 ± 0.099 0.889 ± 0.074 1.030 ± 0.165 1.349 ± 0.326 1.434 ± 0.299
('5', 'CRN') 0.769 ± 0.094 0.755 ± 0.039 0.976 ± 0.082 2.361 ± 1.000 1.730 ± 0.292
('5', 'G-Net') 1.477 ± 0.092 1.430 ± 0.119 1.963 ± 0.157 1.667 ± 0.275 1.907 ± 0.471
('5', 'CT ($\alpha = 0$) (ours)') 0.758 ± 0.055 0.788 ± 0.036 0.875 ± 0.056 1.118 ± 0.172 1.560 ± 0.663
('5', 'CT (ours)') 0.756 ± 0.057 0.786 ± 0.039 0.870 ± 0.048 1.154 ± 0.111 1.922 ± 0.569
('6', 'MSMs') 1.970 ± 0.155 2.205 ± 0.228 2.509 ± 0.469 2.732 ± 0.662 2.422 ± 1.084
('6', 'RMSNs') 0.889 ± 0.112 0.936 ± 0.091 1.081 ± 0.211 1.473 ± 0.433 1.436 ± 0.290
('6', 'CRN') 0.807 ± 0.097 0.790 ± 0.035 1.047 ± 0.092 2.480 ± 1.078 1.827 ± 0.326
('6', 'G-Net') 1.538 ± 0.091 1.493 ± 0.121 2.062 ± 0.172 1.758 ± 0.286 1.994 ± 0.500
('6', 'CT ($\alpha = 0$) (ours)') 0.790 ± 0.058 0.827 ± 0.036 0.915 ± 0.063 1.177 ± 0.193 1.654 ± 0.704
('6', 'CT (ours)') 0.789 ± 0.059 0.821 ± 0.034 0.909 ± 0.054 1.205 ± 0.100 2.052 ± 0.608

Project based on the cookiecutter data science project template. #cookiecutterdatascience

About

Code for the paper "Causal Transformer for Estimating Counterfactual Outcomes"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages