Skip to content

Commit

Permalink
v0.4.0 fix z1&z2 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kimmo1019 committed May 15, 2024
1 parent ead29ff commit bdddc2c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 50 deletions.
6 changes: 3 additions & 3 deletions src/CausalEGM.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Metadata-Version: 2.1
Name: CausalEGM
Version: 0.3.4
Summary: CausalEGM: a general causal inference framework by encoding generative modeling
Version: 0.4.0
Summary: CausalEGM: an encoding generative modeling approach to dimension reduction and covariate adjustment in causal inference with observational studies
Home-page: https://github.com/SUwonglab/CausalEGM
Author: Qiao Liu
Author-email: [email protected]
Expand All @@ -12,4 +12,4 @@ Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE

Understanding and characterizing causal effect has become essential in observational studies while it is still challenging if the confounders are high-dimensional. In this article, we develop a general framework CausalEGM, for estimating causal effect by encoding generative modeling, which can be applied in both binary and continuous treatment settings. In the potential outcome framework with unconfoundedness, we build a bidirectional transformation between the high-dimensional confounders space and a low-dimensional latent space where the density is known (e.g., Gaussian). Through this, CausalEGM enables simultaneously decoupling the dependencies of confounders on both treatment and outcome, and mapping the confounders to the low-dimensional latent space. By conditioning on the low-dimensional latent features, CausalEGM is able to estimate the causal effect for each individual or estimate the average causal effect within a population. Our theoretical analysis shows that the excess risk for CausalEGM can be bounded through empirical process theory. Under an assumption on encoder-decoder networks, the consistency of the estimate can also be guaranteed. In a series of experiments, CausalEGM demonstrates superior performance against existing methods in both binary and continuous settings. Specifically, we find CausalEGM to be substantially more powerful than competing methods in the presence of large sample size and high dimensional confounders. CausalEGM is freely available at https://github.com/SUwonglab/CausalEGM.
In this article, we develop CausalEGM, a deep learning framework for nonlinear dimension reduction and generative modeling of the dependency among covariate features affecting treatment and response. CausalEGM can be used for estimating causal effects in both binary and continuous treatment settings. By learning a bidirectional transformation between the high-dimensional covariate space and a low-dimensional latent space and then modeling the dependencies of different subsets of the latent variables on the treatment and response, CausalEGM can extract the latent covariate features that affect both treatment and response. By conditioning on these features, one can mitigate the confounding effect of the high dimensional covariate on the estimation of the causal relation between treatment and response. In a series of experiments, the proposed method is shown to achieve superior performance over existing methods in both binary and continuous treatment settings. The improvement is substantial when the sample size is large and the covariate is of high dimension. Finally, we established excess risk bounds and consistency results for our method, and discuss how our approach is related to and improves upon other dimension reduction approaches in causal inference. CausalEGM is freely available at https://github.com/SUwonglab/CausalEGM.
2 changes: 1 addition & 1 deletion src/build/lib/CausalEGM/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = '0.3.4'
__version__ = '0.4.0'
from .causalEGM import CausalEGM, VariationalCausalEGM
from .util import Base_sampler, Semi_acic_sampler, Sim_Hirano_Imbens_sampler, Sim_Sun_sampler, Sim_Colangelo_sampler, Semi_Twins_sampler
85 changes: 41 additions & 44 deletions src/build/lib/CausalEGM/causalEGM.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(self, params, timestamp=None, random_seed=None):
self.dv_net = Discriminator(input_dim=params['v_dim'],model_name='dv_net',
nb_units=params['dv_units'])

self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][2],
self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][1],
output_dim = 1, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1],
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 1, model_name='h_net', nb_units=params['h_units'])

self.g_e_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
Expand Down Expand Up @@ -98,8 +98,8 @@ def initialize_nets(self, print_summary = False):
self.e_net(np.zeros((1, self.params['v_dim'])))
self.dz_net(np.zeros((1, sum(self.params['z_dims']))))
self.dv_net(np.zeros((1, self.params['v_dim'])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][2])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][1])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][1])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][2])))
if print_summary:
print(self.g_net.summary())
print(self.h_net.summary())
Expand Down Expand Up @@ -161,8 +161,8 @@ def train_gen_step(self, data_z, data_v, data_x, data_y):
g_loss_adv = -tf.reduce_mean(data_dv_)
e_loss_adv = -tf.reduce_mean(data_dz_)

data_y_ = self.f_net(tf.concat([data_z0, data_z2, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z1], axis=-1))
data_y_ = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z2], axis=-1))
if self.params['binary_treatment']:
data_x_ = tf.sigmoid(data_x_)
l2_loss_x = tf.reduce_mean((data_x_ - data_x)**2)
Expand Down Expand Up @@ -204,7 +204,16 @@ def train_disc_step(self, data_z, data_v):
with tf.GradientTape(persistent=True) as disc_tape:
data_v_ = self.g_net(data_z)
data_z_ = self.e_net(data_v)
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_v_hat = data_v*epsilon_v + data_v_*(1-epsilon_v)

with tf.GradientTape() as gp_tape_z:
gp_tape_z.watch(data_z_hat)
data_dz_hat = self.dz_net(data_z_hat)
with tf.GradientTape() as gp_tape_v:
gp_tape_v.watch(data_v_hat)
data_dv_hat = self.dv_net(data_v_hat)

data_dv_ = self.dv_net(data_v_)
data_dz_ = self.dz_net(data_z_)

Expand All @@ -215,16 +224,12 @@ def train_disc_step(self, data_z, data_v):
dv_loss = -tf.reduce_mean(data_dv) + tf.reduce_mean(data_dv_)

#gradient penalty for z
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_dz_hat = self.dz_net(data_z_hat)
grad_z = tf.gradients(data_dz_hat, data_z_hat)[0] #(bs,z_dim)
grad_z = gp_tape_z.gradient(data_dz_hat, data_z_hat) #(bs,z_dim)
grad_norm_z = tf.sqrt(tf.reduce_sum(tf.square(grad_z), axis=1))#(bs,)
gpz_loss = tf.reduce_mean(tf.square(grad_norm_z - 1.0))

#gradient penalty for v
data_v_hat = data_v*epsilon_v + data_v_*(1-epsilon_v)
data_dv_hat = self.dv_net(data_v_hat)
grad_v = tf.gradients(data_dv_hat, data_v_hat)[0] #(bs,v_dim)
grad_v = gp_tape_v.gradient(data_dv_hat, data_v_hat) #(bs,v_dim)
grad_norm_v = tf.sqrt(tf.reduce_sum(tf.square(grad_v), axis=1))#(bs,)
gpv_loss = tf.reduce_mean(tf.square(grad_norm_v - 1.0))

Expand Down Expand Up @@ -316,13 +321,13 @@ def train(self, data=None, data_file=None, sep='\t', header=0, normalize=False,
ckpt_save_path = self.ckpt_manager.save(batch_idx)
#print('Saving checkpoint for iteration {} at {}'.format(batch_idx, ckpt_save_path))
if self.params['save_res'] and batch_idx > 0 and batch_idx % batches_per_save == 0:
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), self.best_causal_pre)
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), causal_pre)
if self.params['save_res']:
self.save('{}/causal_pre_final.{}'.format(self.save_dir,save_format), self.best_causal_pre)

if self.params['binary_treatment']:
self.ATE = np.mean(self.best_causal_pre)
print('The average treatment effect (ATE) is ', self.ATE)
print('The average treatment effect (ATE) is', self.ATE)

def evaluate(self, data, nb_intervals=200):
"""Internal evaluation in the training process of CausalEGM.
Expand All @@ -345,29 +350,28 @@ def evaluate(self, data, nb_intervals=200):
Float denoting outcome reconstruction loss.
"""
data_x, data_y, data_v = data
data_z = self.z_sampler.get_batch(len(data_x))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z1], axis=-1),verbose=0)
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z2], axis=-1),verbose=0)
if self.params['binary_treatment']:
data_x_pred = tf.sigmoid(data_x_pred)
mse_x = np.mean((data_x-data_x_pred)**2)
mse_y = np.mean((data_y-data_y_pred)**2)
if self.params['binary_treatment']:
#individual treatment effect (ITE) && average treatment effect (ATE)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z2, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z2, np.zeros((len(data_x),1))], axis=-1),verbose=0)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z1, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z1, np.zeros((len(data_x),1))], axis=-1),verbose=0)
ite_pre = y_pred_pos-y_pred_neg
return ite_pre, mse_x, mse_y
else:
#average dose response function (ADRF)
dose_response = []
for x in np.linspace(self.params['x_min'], self.params['x_max'], nb_intervals):
data_x = np.tile(x, (len(data_x), 1))
y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
dose_response.append(np.mean(y_pred))
return np.array(dose_response), mse_x, mse_y

Expand All @@ -389,12 +393,10 @@ def predict(self, data_x, data_v):
assert len(data_x) == len(data_v)
if len(data_x.shape)==1:
data_x = data_x.reshape(-1,1)
data_z = self.z_sampler.get_batch(len(data_x))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
return np.squeeze(data_y_pred)

def getADRF(self, x_list, data_v=None):
Expand All @@ -415,16 +417,14 @@ def getADRF(self, x_list, data_v=None):
"""
if data_v is None:
data_v = self.data_sampler.load_all()[-1]
data_z = self.z_sampler.get_batch(len(data_v))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
if not self.params['binary_treatment']:
dose_response = []
for x in x_list:
data_x = np.tile(x, (len(data_v), 1))
y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
dose_response.append(np.mean(y_pred))
return np.array(dose_response)
else:
Expand All @@ -446,14 +446,13 @@ def getCATE(self,data_v):
Numpy.ndarray (1-D) denoting the predicted CATE values with shape [nb_sample, ].
"""
assert data_v.shape[1] == self.params['v_dim']
data_z = self.z_sampler.get_batch(len(data_v))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
if self.params['binary_treatment']:
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z2, np.ones((len(data_v),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z2, np.zeros((len(data_v),1))], axis=-1),verbose=0)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z1, np.ones((len(data_v),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z1, np.zeros((len(data_v),1))], axis=-1),verbose=0)
cate_pre = y_pred_pos-y_pred_neg
return np.squeeze(cate_pre)
else:
Expand Down Expand Up @@ -508,9 +507,9 @@ def __init__(self, params, timestamp=None, random_seed=None):
self.e_net = BaseFullyConnectedNet(input_dim=params['v_dim'],output_dim = 2*sum(params['z_dims']),
model_name='e_net', nb_units=params['e_units'])

self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][2],
self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][1],
output_dim = 1, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1],
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 1, model_name='h_net', nb_units=params['h_units'])

self.g_e_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
Expand Down Expand Up @@ -558,8 +557,8 @@ def initialize_nets(self, print_summary = False):

self.g_net(np.zeros((1, sum(self.params['z_dims']))))
self.e_net(np.zeros((1, self.params['v_dim'])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][2])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][1])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][1])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][2])))
if print_summary:
print(self.g_net.summary())
print(self.h_net.summary())
Expand Down Expand Up @@ -612,10 +611,9 @@ def train_step(self, data_z, data_v, data_x, data_y):
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_z3 = data_z_[:-self.params['z_dims'][3]:]

data_y_ = self.f_net(tf.concat([data_z0, data_z2, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z1], axis=-1))
data_y_ = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z2], axis=-1))
if self.params['binary_treatment']:
data_x_ = tf.sigmoid(data_x_)
l2_loss_x = tf.reduce_mean((data_x_ - data_x)**2)
Expand All @@ -635,7 +633,7 @@ def train_step(self, data_z, data_v, data_x, data_y):
def sample(self, eps=None):
"""Generate data by decoder."""
if eps is None:
eps = tf.random.normal(shape=(100, sum(params['z_dims'])))
eps = tf.random.normal(shape=(100, sum(self.params['z_dims'])))
return self.g_net(eps)

def encode(self, v):
Expand Down Expand Up @@ -724,7 +722,7 @@ def train(self, data=None, data_file=None, sep='\t', header=0, normalize=False,
ckpt_save_path = self.ckpt_manager.save(batch_idx)
#print('Saving checkpoint for iteration {} at {}'.format(batch_idx, ckpt_save_path))
if self.params['save_res'] and batch_idx > 0 and batch_idx % batches_per_save == 0:
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), self.best_causal_pre)
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), causal_pre)
if self.params['save_res']:
self.save('{}/causal_pre_final.{}'.format(self.save_dir,save_format), self.best_causal_pre)
if self.params['binary_treatment']:
Expand Down Expand Up @@ -752,30 +750,29 @@ def evaluate(self, data, nb_intervals=200):
Float denoting outcome reconstruction loss.
"""
data_x, data_y, data_v = data
data_z = self.z_sampler.get_batch(len(data_x))
mean, logvar = self.encode(data_v)
data_z_ = self.reparameterize(mean, logvar)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z1], axis=-1),verbose=0)
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z2], axis=-1),verbose=0)
if self.params['binary_treatment']:
data_x_pred = tf.sigmoid(data_x_pred)
mse_x = np.mean((data_x-data_x_pred)**2)
mse_y = np.mean((data_y-data_y_pred)**2)
if self.params['binary_treatment']:
#individual treatment effect (ITE) && average treatment effect (ATE)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z2, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z2, np.zeros((len(data_x),1))], axis=-1),verbose=0)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z1, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z1, np.zeros((len(data_x),1))], axis=-1),verbose=0)
ite_pre = y_pred_pos-y_pred_neg
return ite_pre, mse_x, mse_y
else:
#average dose response function (ADRF)
dose_response = []
for x in np.linspace(self.params['x_min'], self.params['x_max'], nb_intervals):
data_x = np.tile(x, (len(data_x), 1))
y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
dose_response.append(np.mean(y_pred))
return np.array(dose_response), mse_x, mse_y

Expand Down
Loading

0 comments on commit bdddc2c

Please sign in to comment.