Skip to content

Commit

Permalink
fix physionet
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobjinkelly committed Jul 10, 2020
1 parent 04680eb commit 159e600
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions latent_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax.experimental.jet import jet
from jax.flatten_util import ravel_pytree

import lib
from lib.optimizers import exponential_decay
from lib.ode import odeint
from physionet_data import init_physionet_data

Expand Down Expand Up @@ -57,7 +57,6 @@
rng = jax.random.PRNGKey(seed)
dirname = parse_args.dirname
count_nfe = not parse_args.no_count_nfe
num_blocks = parse_args.num_blocks
ode_kwargs = {
"atol": parse_args.atol,
"rtol": parse_args.rtol
Expand Down Expand Up @@ -276,31 +275,31 @@ def init_model(gen_ode_kwargs,
"b_init": jnp.zeros
}

gru_rnn = hk.transform(wrap_module(LatentGRU,
latent_dim=rec_dim,
n_units=gru_units,
**init_kwargs))
gru_rnn = hk.without_apply_rng(hk.transform(wrap_module(LatentGRU,
latent_dim=rec_dim,
n_units=gru_units,
**init_kwargs)))
gru_rnn_params = gru_rnn.init(rng, *initialization_data_["gru_rnn"])

# note: the ODE-RNN version uses double
rec_to_gen = hk.transform(wrap_module(lambda: hk.Sequential([
rec_to_gen = hk.without_apply_rng(hk.transform(wrap_module(lambda: hk.Sequential([
lambda x, y: jnp.concatenate((x, y), axis=-1),
hk.Linear(50, **init_kwargs),
jnp.tanh,
hk.Linear(2 * gen_dim, **init_kwargs)
])))
]))))
rec_to_gen_params = rec_to_gen.init(rng, *initialization_data_["rec_to_gen"])

gen_dynamics = hk.transform(wrap_module(GenDynamics,
latent_dim=gen_dim,
units=dynamics_units,
layers=gen_layers))
gen_dynamics = hk.without_apply_rng(hk.transform(wrap_module(GenDynamics,
latent_dim=gen_dim,
units=dynamics_units,
layers=gen_layers)))
gen_dynamics_params = gen_dynamics.init(rng, *initialization_data_["gen_dynamics"])
gen_dynamics_wrap = lambda x, t, params: gen_dynamics.apply(params, x, t)

gen_to_data = hk.transform(wrap_module(hk.Linear,
output_size=data_dim,
**init_kwargs))
gen_to_data = hk.without_apply_rng(hk.transform(wrap_module(hk.Linear,
output_size=data_dim,
**init_kwargs)))
gen_to_data_params = gen_to_data.init(rng, initialization_data_["gen_to_data"])

init_params = {
Expand All @@ -310,7 +309,7 @@ def init_model(gen_ode_kwargs,
"gen_to_data": gen_to_data_params
}

def forward(count_nfe_, reg, _method, params, data, data_timesteps, timesteps, mask, num_samples=3):
def forward(count_nfe_, params, data, data_timesteps, timesteps, mask, num_samples=3):
"""
Forward pass of the model.
y are the latent variables of the recognition model
Expand Down Expand Up @@ -343,7 +342,7 @@ def integrate_sample(z0_):
dynamics = gen_dynamics_wrap
init_fn = lambda x: x
else:
dynamics = augment_dynamics(gen_dynamics_wrap, reg)
dynamics = augment_dynamics(gen_dynamics_wrap)
init_fn = aug_init
return jax.vmap(lambda z_, t_: odeint(dynamics, init_fn(z_), t_,
params["gen_dynamics"], **gen_ode_kwargs),
Expand Down Expand Up @@ -391,7 +390,7 @@ def scan_fun(prev_state, xi):
model = {
"forward": partial(forward, False),
"params": init_params,
"nfe": lambda *args: partial(forward, count_nfe, reg)(*args)[-1]
"nfe": lambda *args: partial(forward, count_nfe)(*args)[-1]
}

return model
Expand Down Expand Up @@ -474,10 +473,10 @@ def run():
forward = lambda *args: model["forward"](*args)[1:]
grad_fn = jax.grad(lambda *args: loss_fn(forward, *args))

lr_schedule = lib.optimizers.exponential_decay(step_size=parse_args.lr,
decay_steps=1,
decay_rate=0.999,
lowest=parse_args.lr / 10)
lr_schedule = exponential_decay(step_size=parse_args.lr,
decay_steps=1,
decay_rate=0.999,
lowest=parse_args.lr / 10)
opt_init, opt_update, get_params = optimizers.adamax(step_size=lr_schedule)
opt_state = opt_init(model["params"])

Expand Down Expand Up @@ -574,7 +573,7 @@ def evaluate_loss(opt_state, ds_test, kl_coef):

print(print_str)

outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_info.txt" % (dirname, reg, lam, num_blocks), "a")
outfile = open("%s/reg_%s_lam_%.12e_info.txt" % (dirname, reg, lam), "a")
outfile.write(print_str + "\n")
outfile.close()

Expand All @@ -594,14 +593,14 @@ def evaluate_loss(opt_state, ds_test, kl_coef):
pickle.dump(fargs, outfile)
outfile.close()

outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_iter.txt" % (dirname, reg, lam, num_blocks), "a")
outfile = open("%s/reg_%s_lam_%.12e_iter.txt" % (dirname, reg, lam), "a")
outfile.write("Iter: {:04d}\n".format(itr))
outfile.close()
meta = {
"info": info,
"args": parse_args
}
outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_meta.pickle" % (dirname, reg, lam, num_blocks), "wb")
outfile = open("%s/reg_%s_lam_%.12e_meta.pickle" % (dirname, reg, lam), "wb")
pickle.dump(meta, outfile)
outfile.close()

Expand Down

0 comments on commit 159e600

Please sign in to comment.