Skip to content

Commit

Permalink
modified: __init__.py
Browse files Browse the repository at this point in the history
	modified:   cmhe_torch.py
	modified:   cmhe_utilities.py
  • Loading branch information
chiragnagpal committed Jun 28, 2022
1 parent 304d615 commit cc0f00e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
1 change: 0 additions & 1 deletion auton_survival/models/cmhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def fit(self, x, t, e, a, vsize=0.15, val_data=None,
optimizer: str
The choice of the gradient based optimization method. One of
'Adam', 'RMSProp' or 'SGD'.
"""

processed_data = self._preprocess_training_data(x, t, e, a,
Expand Down
41 changes: 38 additions & 3 deletions auton_survival/models/cmhe/cmhe_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ class DeepCMHETorch(torch.nn.Module):

def _init_dcmhe_layers(self, lastdim):

self.expert = torch.nn.Linear(lastdim, self.k, bias=False)
self.z_gate = torch.nn.Linear(lastdim, self.k, bias=False)
self.phi_gate = torch.nn.Linear(lastdim, self.g, bias=False)

self.expert = IdentifiableLinear(lastdim, self.k, bias=False)
self.z_gate = IdentifiableLinear(lastdim, self.k, bias=False)
self.phi_gate = IdentifiableLinear(lastdim, self.g, bias=False)
# self.expert = torch.nn.Linear(lastdim, self.k, bias=False)
# self.z_gate = torch.nn.Linear(lastdim, self.k, bias=False)
# self.phi_gate = torch.nn.Linear(lastdim, self.g, bias=False)
self.omega = torch.nn.Parameter(torch.rand(self.g)-0.5)

def __init__(self, k, g, inputdim, layers=None, gamma=100,
Expand Down Expand Up @@ -96,3 +100,34 @@ def forward(self, x, a):
logp_joint_hrs[:, i, j] = log_hrs[:, i] + (j!=2)*a*self.omega[j]

return logp_jointlatent_gate, logp_joint_hrs

class IdentifiableLinear(torch.nn.Module):

"""
Softmax and LogSoftmax with K classes in pytorch are over specfied and lead to
issues of mis-identifiability. This class is a wrapper for linear layers that
are correctly specified with K-1 columns. The output of this layer for the Kth
class is all zeros. This allows direct application of pytorch.nn.LogSoftmax
and pytorch.nn.Softmax.
"""

def __init__(self, in_features, out_features, bias=True):

super(IdentifiableLinear, self).__init__()

assert out_features>0; "Output features must be greater than 0"

self.out_features = out_features
self.in_features = in_features
self.linear = torch.nn.Linear(in_features, max(out_features-1, 1), bias=bias)

@property
def weight(self):
return self.linear.weight

def forward(self, x):
if self.out_features == 1:
return self.linear(x).reshape(-1, 1)
else:
zeros = torch.zeros(len(x), 1, device=x.device)
return torch.cat([self.linear(x), zeros], dim=1)
1 change: 0 additions & 1 deletion auton_survival/models/cmhe/cmhe_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,6 @@ def predict_latent_z(model, x):
def predict_latent_phi(model, x):

model, _ = model

x = model.embedding(x)

p_phi_gate = torch.nn.Softmax(dim=1)(model.phi_gate(x)).detach().numpy()
Expand Down

0 comments on commit cc0f00e

Please sign in to comment.