Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix settings and low distance warning #18

Merged
merged 3 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix a false low distance warning due to plotting
  • Loading branch information
tautomer committed Nov 8, 2022
commit 5262b8f5dea31f43ea1cbcb65d360f753cd1b636
26 changes: 13 additions & 13 deletions hippynn/layers/hiplayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def warn_if_under(distance, threshold):
d_count = distance < threshold
d_frac = d_count.to(distance.dtype).mean()
warnings.warn(
f"Provided distances are underneath sensitivity range!\n"
"Provided distances are underneath sensitivity range!\n"
f"Minimum distance in current batch: {dmin}\n"
f"Threshold distance for warning: {threshold}.\n"
f"Fraction of pairs under the threshold: {d_frac}"
Expand Down Expand Up @@ -53,8 +53,8 @@ def __init__(self, n_dist, min_dist_soft, max_dist_soft, hard_max_dist, cutoff_t
init_sigma = min_dist_soft * 2 * n_dist # pulled from theano code
self.sigma.data.fill_(init_sigma)

def forward(self, distflat):
if settings.WARN_LOW_DISTANCES:
def forward(self, distflat, allow_warning=True):
if allow_warning and settings.WARN_LOW_DISTANCES:
with torch.no_grad():
mu, argmin = self.mu.min(dim=1)
sig = self.sigma[:, argmin]
Expand All @@ -65,7 +65,7 @@ def forward(self, distflat):
mu_ds = self.mu
sig_ds = self.sigma

nondim = (distflat_ds ** -1 - mu_ds ** -1) ** 2 / (sig_ds ** -2)
nondim = (distflat_ds**-1 - mu_ds**-1) ** 2 / (sig_ds**-2)
base_sense = torch.exp(-0.5 * nondim)

total_sense = base_sense * self.cutoff(distflat).unsqueeze(1)
Expand All @@ -82,18 +82,18 @@ def __init__(self, n_dist, min_dist_soft, max_dist_soft, hard_max_dist, cutoff_t
init_sigma = min_dist_soft * 2 * n_dist
self.sigma.data.fill_(init_sigma)

def forward(self, distflat):
if settings.WARN_LOW_DISTANCES:
def forward(self, distflat, allow_warning=True):
if allow_warning and settings.WARN_LOW_DISTANCES:
with torch.no_grad():
# Warn if distance is less than the -inside- edge of the shortest sensitivity function
mu, argmin = self.mu.min(dim=1)
sig = self.sigma[:, argmin]
thresh = (mu ** -1 - sig ** -1) ** -1
thresh = (mu**-1 - sig**-1) ** -1

warn_if_under(distflat, thresh)
distflat_ds = distflat.unsqueeze(1)

nondim = (distflat_ds ** -1 - self.mu ** -1) ** 2 / (self.sigma ** -2)
nondim = (distflat_ds**-1 - self.mu**-1) ** 2 / (self.sigma**-2)
base_sense = torch.exp(-0.5 * nondim)

total_sense = base_sense * self.cutoff(distflat).unsqueeze(1)
Expand Down Expand Up @@ -225,7 +225,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)
env_features_vec = custom_kernels.envsum(sense_vec, in_features, pair_first, pair_second)
env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in)
features_out_vec = torch.mm(env_features_vec, weights_rs)
features_out_vec = features_out_vec.reshape(n_atoms_real,3,self.nf_out)
features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out)
features_out_vec = torch.square(features_out_vec).sum(dim=1) + 1e-30
features_out_vec = torch.sqrt(features_out_vec)
features_out_vec = features_out_vec * self.vecscales.unsqueeze(0)
Expand All @@ -246,8 +246,8 @@ def __init__(self, nf_in, nf_out, n_dist, mind_soft, maxd_soft, hard_cutoff, sen
torch.nn.init.normal_(self.quadscales.data)
# upper indices of flattened 3x3 array minus the (3,3) component
# which is not needed for a traceless tensor
upper_ind = torch.as_tensor([0, 1, 2, 4, 5],dtype=torch.int64)
self.register_buffer('upper_ind',upper_ind,persistent=False) # Static, not part of module state
upper_ind = torch.as_tensor([0, 1, 2, 4, 5], dtype=torch.int64)
self.register_buffer("upper_ind", upper_ind, persistent=False) # Static, not part of module state

def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs):

Expand All @@ -270,7 +270,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)
env_features_vec = env_features_vec.reshape(n_atoms_real * 3, self.n_dist * self.nf_in)
features_out_vec = torch.mm(env_features_vec, weights_rs)
# Norm and scale
features_out_vec = features_out_vec.reshape(n_atoms_real,3,self.nf_out)
features_out_vec = features_out_vec.reshape(n_atoms_real, 3, self.nf_out)
features_out_vec = torch.square(features_out_vec).sum(dim=1) + 1e-30
features_out_vec = torch.sqrt(features_out_vec)
features_out_vec = features_out_vec * self.vecscales.unsqueeze(0)
Expand All @@ -282,7 +282,7 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs, coord_pairs)
tr = torch.diagonal(rhatsquad, dim1=1, dim2=2).sum(dim=1) / 3.0 # Add divide by 3 early to save flops
tr = tr.unsqueeze(1).unsqueeze(2) * torch.eye(3, dtype=tr.dtype, device=tr.device).unsqueeze(0)
rhatsquad = rhatsquad - tr
rhatsqflat = rhatsquad.reshape(-1, 9)[:,self.upper_ind] # Upper-diagonal part
rhatsqflat = rhatsquad.reshape(-1, 9)[:, self.upper_ind] # Upper-diagonal part
sense_quad = sense_vals.unsqueeze(1) * rhatsqflat.unsqueeze(2)
sense_quad = sense_quad.reshape(-1, self.n_dist * 5)
# Weights
Expand Down
3 changes: 2 additions & 1 deletion hippynn/plotting/plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def plt_fn(self):
with torch.autograd.no_grad():
mu = list(self.sensitivity.parameters())[0]
r_range = torch.linspace(*self.r_params, dtype=mu.dtype, device=mu.device)
sense_out = self.sensitivity(r_range).cpu().data.numpy()
# allow_warning=False to disable the false low distance warning
sense_out = self.sensitivity(r_range, allow_warning=False).cpu().data.numpy()
r_range = r_range.cpu().data.numpy()
for sense_func in sense_out.transpose():
plt.plot(r_range, sense_func, c="r")
Expand Down