From 159e600d9fc0cb42e75eb10cfa4dce21eac6b84e Mon Sep 17 00:00:00 2001 From: Jacob Kelly Date: Thu, 9 Jul 2020 23:15:51 -0400 Subject: [PATCH] fix physionet --- latent_ode.py | 49 ++++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/latent_ode.py b/latent_ode.py index 71fe100..8f94d66 100644 --- a/latent_ode.py +++ b/latent_ode.py @@ -16,7 +16,7 @@ from jax.experimental.jet import jet from jax.flatten_util import ravel_pytree -import lib +from lib.optimizers import exponential_decay from lib.ode import odeint from physionet_data import init_physionet_data @@ -57,7 +57,6 @@ rng = jax.random.PRNGKey(seed) dirname = parse_args.dirname count_nfe = not parse_args.no_count_nfe -num_blocks = parse_args.num_blocks ode_kwargs = { "atol": parse_args.atol, "rtol": parse_args.rtol @@ -276,31 +275,31 @@ def init_model(gen_ode_kwargs, "b_init": jnp.zeros } - gru_rnn = hk.transform(wrap_module(LatentGRU, - latent_dim=rec_dim, - n_units=gru_units, - **init_kwargs)) + gru_rnn = hk.without_apply_rng(hk.transform(wrap_module(LatentGRU, + latent_dim=rec_dim, + n_units=gru_units, + **init_kwargs))) gru_rnn_params = gru_rnn.init(rng, *initialization_data_["gru_rnn"]) # note: the ODE-RNN version uses double - rec_to_gen = hk.transform(wrap_module(lambda: hk.Sequential([ + rec_to_gen = hk.without_apply_rng(hk.transform(wrap_module(lambda: hk.Sequential([ lambda x, y: jnp.concatenate((x, y), axis=-1), hk.Linear(50, **init_kwargs), jnp.tanh, hk.Linear(2 * gen_dim, **init_kwargs) - ]))) + ])))) rec_to_gen_params = rec_to_gen.init(rng, *initialization_data_["rec_to_gen"]) - gen_dynamics = hk.transform(wrap_module(GenDynamics, - latent_dim=gen_dim, - units=dynamics_units, - layers=gen_layers)) + gen_dynamics = hk.without_apply_rng(hk.transform(wrap_module(GenDynamics, + latent_dim=gen_dim, + units=dynamics_units, + layers=gen_layers))) gen_dynamics_params = gen_dynamics.init(rng, *initialization_data_["gen_dynamics"]) gen_dynamics_wrap = lambda x, t, params: gen_dynamics.apply(params, x, t) - gen_to_data = hk.transform(wrap_module(hk.Linear, - output_size=data_dim, - **init_kwargs)) + gen_to_data = hk.without_apply_rng(hk.transform(wrap_module(hk.Linear, + output_size=data_dim, + **init_kwargs))) gen_to_data_params = gen_to_data.init(rng, initialization_data_["gen_to_data"]) init_params = { @@ -310,7 +309,7 @@ def init_model(gen_ode_kwargs, "gen_to_data": gen_to_data_params } - def forward(count_nfe_, reg, _method, params, data, data_timesteps, timesteps, mask, num_samples=3): + def forward(count_nfe_, params, data, data_timesteps, timesteps, mask, num_samples=3): """ Forward pass of the model. y are the latent variables of the recognition model @@ -343,7 +342,7 @@ def integrate_sample(z0_): dynamics = gen_dynamics_wrap init_fn = lambda x: x else: - dynamics = augment_dynamics(gen_dynamics_wrap, reg) + dynamics = augment_dynamics(gen_dynamics_wrap) init_fn = aug_init return jax.vmap(lambda z_, t_: odeint(dynamics, init_fn(z_), t_, params["gen_dynamics"], **gen_ode_kwargs), @@ -391,7 +390,7 @@ def scan_fun(prev_state, xi): model = { "forward": partial(forward, False), "params": init_params, - "nfe": lambda *args: partial(forward, count_nfe, reg)(*args)[-1] + "nfe": lambda *args: partial(forward, count_nfe)(*args)[-1] } return model @@ -474,10 +473,10 @@ def run(): forward = lambda *args: model["forward"](*args)[1:] grad_fn = jax.grad(lambda *args: loss_fn(forward, *args)) - lr_schedule = lib.optimizers.exponential_decay(step_size=parse_args.lr, - decay_steps=1, - decay_rate=0.999, - lowest=parse_args.lr / 10) + lr_schedule = exponential_decay(step_size=parse_args.lr, + decay_steps=1, + decay_rate=0.999, + lowest=parse_args.lr / 10) opt_init, opt_update, get_params = optimizers.adamax(step_size=lr_schedule) opt_state = opt_init(model["params"]) @@ -574,7 +573,7 @@ def evaluate_loss(opt_state, ds_test, kl_coef): print(print_str) - outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_info.txt" % (dirname, reg, lam, num_blocks), "a") + outfile = open("%s/reg_%s_lam_%.12e_info.txt" % (dirname, reg, lam), "a") outfile.write(print_str + "\n") outfile.close() @@ -594,14 +593,14 @@ def evaluate_loss(opt_state, ds_test, kl_coef): pickle.dump(fargs, outfile) outfile.close() - outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_iter.txt" % (dirname, reg, lam, num_blocks), "a") + outfile = open("%s/reg_%s_lam_%.12e_iter.txt" % (dirname, reg, lam), "a") outfile.write("Iter: {:04d}\n".format(itr)) outfile.close() meta = { "info": info, "args": parse_args } - outfile = open("%s/reg_%s_lam_%.12e_num_blocks_%d_meta.pickle" % (dirname, reg, lam, num_blocks), "wb") + outfile = open("%s/reg_%s_lam_%.12e_meta.pickle" % (dirname, reg, lam), "wb") pickle.dump(meta, outfile) outfile.close()