Skip to content

Commit

Permalink
Add CUDA support
Browse files Browse the repository at this point in the history
  • Loading branch information
tkipf committed Jan 21, 2019
1 parent d8b467f commit 48db9f0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
20 changes: 14 additions & 6 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class CompILE(nn.Module):
"""CompILE model."""
"""CompILE reference implementation (non-batched, single sample only)."""
def __init__(self, input_dim, hidden_dim, latent_dim, max_num_segments,
temp_b=1., temp_z=1., beta_z=.1, beta_b=.1, prior_rate=3.,
latent_dist='gaussian'):
Expand Down Expand Up @@ -69,11 +69,15 @@ def get_boundaries(self, encodings, segment_id, evaluate=False):
ones = torch.ones(encodings.size(0), 1)
logits_b = None
sample_b = torch.cat([zeros, ones], dim=1)
if encodings.is_cuda:
sample_b = sample_b.cuda()
else:
hidden = F.relu(self.head_b_1(encodings))
logits_b = self.head_b_2(hidden).squeeze(-1)
# Mask out first position with large neg. value.
neg_inf = torch.ones(encodings.size(0), 1) * utils.NEG_INF
if encodings.is_cuda:
neg_inf = neg_inf.cuda()
logits_b = torch.cat([neg_inf, logits_b[:, 1:]], dim=1)
if not evaluate:
sample_b = utils.gumbel_softmax_sample(
Expand Down Expand Up @@ -110,7 +114,7 @@ def get_latents(self, encodings, probs_b, evaluate=False):
return logits_z, sample_z

def decode(self, sample_z, length):
"""Decode single time step from latents and copy over seq. length."""
"""Decode single time step from latents and repeat over full seq."""
hidden = F.relu(self.decode_1(sample_z))
pred = self.decode_2(hidden)
return pred.unsqueeze(1).repeat(1, length, 1)
Expand All @@ -135,7 +139,7 @@ def get_segment_probs(self, all_b_samples, all_masks, segment_id):
return neg_cumsum

def get_losses(self, inputs):
"""Get losses (NLL, Kl divergences and ELBO)."""
"""Get losses (NLL, KL divergences and ELBO)."""
targets = inputs.view(-1)
all_encs, all_recs, all_masks, all_b, all_z = self.forward(inputs)

Expand All @@ -158,12 +162,14 @@ def get_losses(self, inputs):
kl_z += utils.kl_gaussian(mu, log_var).mean(0)
elif self.latent_dist == 'concrete':
kl_z += utils.kl_categorical_uniform(
F.softmax(all_z['logits'][seg_id])).mean(0)
F.softmax(all_z['logits'][seg_id], dim=-1)).mean(0)

# KL divergence ob b (first segment only, ignore first and last step).
probs_b = F.softmax(all_b['logits'][0])
# KL divergence on b (first segment only, ignore first and last step).
probs_b = F.softmax(all_b['logits'][0], dim=-1)
log_prior_b = utils.poisson_categorical_log_prior(
probs_b.size(1), self.prior_rate)
if inputs.is_cuda:
log_prior_b = log_prior_b.cuda()
kl_b = self.max_num_segments * utils.kl_categorical(
probs_b[:, 1:-1], log_prior_b[:, 1:-1]).mean(0)

Expand Down Expand Up @@ -195,6 +201,8 @@ def forward(self, inputs, evaluate=False):

# Create initial mask.
mask = torch.ones(inputs.size(0), inputs.size(1), 1)
if inputs.is_cuda:
mask = mask.cuda()

all_b = {'logits': [], 'samples': []}
all_z = {'logits': [], 'samples': []}
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
parser.add_argument('--num-segments', type=int, default=3,
help='Number of segments in data generation.')

parser.add_argument('--no-cuda', action='store_true', default=False,
parser.add_argument('--no-cuda', action='store_true', default=True,
help='Disable CUDA training.')
parser.add_argument('--log-interval', type=int, default=1,
help='Logging interval.')
Expand Down Expand Up @@ -60,7 +60,7 @@
model.train()
loss, nll, kl_z, kl_b = model.get_losses(data)

# Run eval.
# Run evaluation.
model.eval()
acc = model.get_reconstruction_accuracy(data)

Expand Down
4 changes: 3 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def gumbel_softmax_sample(logits, temp=1.):
def gaussian_sample(mu, log_var):
"""Sample from Gaussian distribution."""
gaussian_noise = torch.randn(mu.size())
if mu.is_cuda:
gaussian_noise = gaussian_noise.cuda()
return mu + torch.exp(log_var * 0.5) * gaussian_noise


Expand All @@ -43,7 +45,7 @@ def kl_gaussian(mu, log_var):

def kl_categorical_uniform(preds):
"""KL divergence between categorical distribution and uniform prior."""
kl_div = preds * torch.log(preds + EPS) # Up to constant, can be negative.
kl_div = preds * torch.log(preds + EPS) # Constant term omitted.
return kl_div.sum(1)


Expand Down

0 comments on commit 48db9f0

Please sign in to comment.