Skip to content

Commit

Permalink
fix z1&z2 exchange bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kimmo1019 authored and kimmo1019 committed May 10, 2024
1 parent 564e0e4 commit 791624e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 38 deletions.
66 changes: 29 additions & 37 deletions src/CausalEGM/causalEGM.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(self, params, timestamp=None, random_seed=None):
self.dv_net = Discriminator(input_dim=params['v_dim'],model_name='dv_net',
nb_units=params['dv_units'])

self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][2],
self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][1],
output_dim = 1, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1],
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 1, model_name='h_net', nb_units=params['h_units'])

self.g_e_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
Expand Down Expand Up @@ -98,8 +98,8 @@ def initialize_nets(self, print_summary = False):
self.e_net(np.zeros((1, self.params['v_dim'])))
self.dz_net(np.zeros((1, sum(self.params['z_dims']))))
self.dv_net(np.zeros((1, self.params['v_dim'])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][2])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][1])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][1])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][2])))
if print_summary:
print(self.g_net.summary())
print(self.h_net.summary())
Expand Down Expand Up @@ -161,8 +161,8 @@ def train_gen_step(self, data_z, data_v, data_x, data_y):
g_loss_adv = -tf.reduce_mean(data_dv_)
e_loss_adv = -tf.reduce_mean(data_dz_)

data_y_ = self.f_net(tf.concat([data_z0, data_z2, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z1], axis=-1))
data_y_ = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z2], axis=-1))
if self.params['binary_treatment']:
data_x_ = tf.sigmoid(data_x_)
l2_loss_x = tf.reduce_mean((data_x_ - data_x)**2)
Expand Down Expand Up @@ -316,7 +316,7 @@ def train(self, data=None, data_file=None, sep='\t', header=0, normalize=False,
ckpt_save_path = self.ckpt_manager.save(batch_idx)
#print('Saving checkpoint for iteration {} at {}'.format(batch_idx, ckpt_save_path))
if self.params['save_res'] and batch_idx > 0 and batch_idx % batches_per_save == 0:
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), self.best_causal_pre)
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), causal_pre)
if self.params['save_res']:
self.save('{}/causal_pre_final.{}'.format(self.save_dir,save_format), self.best_causal_pre)

Expand Down Expand Up @@ -345,29 +345,28 @@ def evaluate(self, data, nb_intervals=200):
Float denoting outcome reconstruction loss.
"""
data_x, data_y, data_v = data
data_z = self.z_sampler.get_batch(len(data_x))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z1], axis=-1),verbose=0)
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z2], axis=-1),verbose=0)
if self.params['binary_treatment']:
data_x_pred = tf.sigmoid(data_x_pred)
mse_x = np.mean((data_x-data_x_pred)**2)
mse_y = np.mean((data_y-data_y_pred)**2)
if self.params['binary_treatment']:
#individual treatment effect (ITE) && average treatment effect (ATE)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z2, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z2, np.zeros((len(data_x),1))], axis=-1),verbose=0)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z1, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z1, np.zeros((len(data_x),1))], axis=-1),verbose=0)
ite_pre = y_pred_pos-y_pred_neg
return ite_pre, mse_x, mse_y
else:
#average dose response function (ADRF)
dose_response = []
for x in np.linspace(self.params['x_min'], self.params['x_max'], nb_intervals):
data_x = np.tile(x, (len(data_x), 1))
y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
dose_response.append(np.mean(y_pred))
return np.array(dose_response), mse_x, mse_y

Expand All @@ -389,12 +388,10 @@ def predict(self, data_x, data_v):
assert len(data_x) == len(data_v)
if len(data_x.shape)==1:
data_x = data_x.reshape(-1,1)
data_z = self.z_sampler.get_batch(len(data_x))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
return np.squeeze(data_y_pred)

def getADRF(self, x_list, data_v=None):
Expand All @@ -415,16 +412,14 @@ def getADRF(self, x_list, data_v=None):
"""
if data_v is None:
data_v = self.data_sampler.load_all()[-1]
data_z = self.z_sampler.get_batch(len(data_v))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
if not self.params['binary_treatment']:
dose_response = []
for x in x_list:
data_x = np.tile(x, (len(data_v), 1))
y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
dose_response.append(np.mean(y_pred))
return np.array(dose_response)
else:
Expand All @@ -446,14 +441,13 @@ def getCATE(self,data_v):
Numpy.ndarray (1-D) denoting the predicted CATE values with shape [nb_sample, ].
"""
assert data_v.shape[1] == self.params['v_dim']
data_z = self.z_sampler.get_batch(len(data_v))
data_z_ = self.e_net.predict(data_v,verbose=0)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
if self.params['binary_treatment']:
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z2, np.ones((len(data_v),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z2, np.zeros((len(data_v),1))], axis=-1),verbose=0)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z1, np.ones((len(data_v),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z1, np.zeros((len(data_v),1))], axis=-1),verbose=0)
cate_pre = y_pred_pos-y_pred_neg
return np.squeeze(cate_pre)
else:
Expand Down Expand Up @@ -508,9 +502,9 @@ def __init__(self, params, timestamp=None, random_seed=None):
self.e_net = BaseFullyConnectedNet(input_dim=params['v_dim'],output_dim = 2*sum(params['z_dims']),
model_name='e_net', nb_units=params['e_units'])

self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][2],
self.f_net = BaseFullyConnectedNet(input_dim=1+params['z_dims'][0]+params['z_dims'][1],
output_dim = 1, model_name='f_net', nb_units=params['f_units'])
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][1],
self.h_net = BaseFullyConnectedNet(input_dim=params['z_dims'][0]+params['z_dims'][2],
output_dim = 1, model_name='h_net', nb_units=params['h_units'])

self.g_e_optimizer = tf.keras.optimizers.Adam(params['lr'], beta_1=0.5, beta_2=0.9)
Expand Down Expand Up @@ -558,8 +552,8 @@ def initialize_nets(self, print_summary = False):

self.g_net(np.zeros((1, sum(self.params['z_dims']))))
self.e_net(np.zeros((1, self.params['v_dim'])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][2])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][1])))
self.f_net(np.zeros((1, 1+self.params['z_dims'][0]+self.params['z_dims'][1])))
self.h_net(np.zeros((1, self.params['z_dims'][0]+self.params['z_dims'][2])))
if print_summary:
print(self.g_net.summary())
print(self.h_net.summary())
Expand Down Expand Up @@ -612,10 +606,9 @@ def train_step(self, data_z, data_v, data_x, data_y):
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_z3 = data_z_[:-self.params['z_dims'][3]:]

data_y_ = self.f_net(tf.concat([data_z0, data_z2, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z1], axis=-1))
data_y_ = self.f_net(tf.concat([data_z0, data_z1, data_x], axis=-1))
data_x_ = self.h_net(tf.concat([data_z0, data_z2], axis=-1))
if self.params['binary_treatment']:
data_x_ = tf.sigmoid(data_x_)
l2_loss_x = tf.reduce_mean((data_x_ - data_x)**2)
Expand All @@ -635,7 +628,7 @@ def train_step(self, data_z, data_v, data_x, data_y):
def sample(self, eps=None):
"""Generate data by decoder."""
if eps is None:
eps = tf.random.normal(shape=(100, sum(params['z_dims'])))
eps = tf.random.normal(shape=(100, sum(self.params['z_dims'])))
return self.g_net(eps)

def encode(self, v):
Expand Down Expand Up @@ -724,7 +717,7 @@ def train(self, data=None, data_file=None, sep='\t', header=0, normalize=False,
ckpt_save_path = self.ckpt_manager.save(batch_idx)
#print('Saving checkpoint for iteration {} at {}'.format(batch_idx, ckpt_save_path))
if self.params['save_res'] and batch_idx > 0 and batch_idx % batches_per_save == 0:
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), self.best_causal_pre)
self.save('{}/causal_pre_at_{}.{}'.format(self.save_dir, batch_idx, save_format), causal_pre)
if self.params['save_res']:
self.save('{}/causal_pre_final.{}'.format(self.save_dir,save_format), self.best_causal_pre)
if self.params['binary_treatment']:
Expand Down Expand Up @@ -752,30 +745,29 @@ def evaluate(self, data, nb_intervals=200):
Float denoting outcome reconstruction loss.
"""
data_x, data_y, data_v = data
data_z = self.z_sampler.get_batch(len(data_x))
mean, logvar = self.encode(data_v)
data_z_ = self.reparameterize(mean, logvar)
data_z0 = data_z_[:,:self.params['z_dims'][0]]
data_z1 = data_z_[:,self.params['z_dims'][0]:sum(self.params['z_dims'][:2])]
data_z2 = data_z_[:,sum(self.params['z_dims'][:2]):sum(self.params['z_dims'][:3])]
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z1], axis=-1),verbose=0)
data_y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
data_x_pred = self.h_net.predict(tf.concat([data_z0, data_z2], axis=-1),verbose=0)
if self.params['binary_treatment']:
data_x_pred = tf.sigmoid(data_x_pred)
mse_x = np.mean((data_x-data_x_pred)**2)
mse_y = np.mean((data_y-data_y_pred)**2)
if self.params['binary_treatment']:
#individual treatment effect (ITE) && average treatment effect (ATE)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z2, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z2, np.zeros((len(data_x),1))], axis=-1),verbose=0)
y_pred_pos = self.f_net.predict(tf.concat([data_z0, data_z1, np.ones((len(data_x),1))], axis=-1),verbose=0)
y_pred_neg = self.f_net.predict(tf.concat([data_z0, data_z1, np.zeros((len(data_x),1))], axis=-1),verbose=0)
ite_pre = y_pred_pos-y_pred_neg
return ite_pre, mse_x, mse_y
else:
#average dose response function (ADRF)
dose_response = []
for x in np.linspace(self.params['x_min'], self.params['x_max'], nb_intervals):
data_x = np.tile(x, (len(data_x), 1))
y_pred = self.f_net.predict(tf.concat([data_z0, data_z2, data_x], axis=-1),verbose=0)
y_pred = self.f_net.predict(tf.concat([data_z0, data_z1, data_x], axis=-1),verbose=0)
dose_response.append(np.mean(y_pred))
return np.array(dose_response), mse_x, mse_y

Expand Down
2 changes: 1 addition & 1 deletion src/configs/Semi_acic.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
dataset: Semi_acic
output_dir: '.'
v_dim: 177
z_dims: [3,3,6,6]
z_dims: [3,6,3,6]
lr: 0.0002
alpha: 1
beta: 1
Expand Down

0 comments on commit 791624e

Please sign in to comment.