Skip to content

Commit

Permalink
Fixed bug in NCC loss
Browse files Browse the repository at this point in the history
  • Loading branch information
cwmok committed Nov 2, 2023
1 parent 00bde08 commit 6a0a637
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
17 changes: 9 additions & 8 deletions Code/C2FViT_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def __init__(self, win=7, eps=1e-5):

def forward(self, I, J):
ndims = 3
win_size = self.w_temp
win_size_1d = self.w_temp

# set window size
if self.win is None:
Expand All @@ -703,13 +703,14 @@ def forward(self, I, J):

# compute filters
# compute local sums via convolution
I_sum = conv_fn(I, weight, padding=int(win_size/2))
J_sum = conv_fn(J, weight, padding=int(win_size/2))
I2_sum = conv_fn(I2, weight, padding=int(win_size/2))
J2_sum = conv_fn(J2, weight, padding=int(win_size/2))
IJ_sum = conv_fn(IJ, weight, padding=int(win_size/2))

# compute cross correltorch. Sin win_size = np.prod(self.win)
I_sum = conv_fn(I, weight, padding=int(win_size_1d/2))
J_sum = conv_fn(J, weight, padding=int(win_size_1d/2))
I2_sum = conv_fn(I2, weight, padding=int(win_size_1d/2))
J2_sum = conv_fn(J2, weight, padding=int(win_size_1d/2))
IJ_sum = conv_fn(IJ, weight, padding=int(win_size_1d/2))

# compute cross correltorch.
win_size = win_size_1d**ndims
u_I = I_sum/win_size
u_J = J_sum/win_size

Expand Down
2 changes: 1 addition & 1 deletion Code/Train_C2FViT_template_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def train():
parser = ArgumentParser()
parser.add_argument("--modelname", type=str,
dest="modelname",
default='C2FViT_affine_COM_template_matching_semi_',
default='C2FViT_affine_COM_template_matching_',
help="Model name")
parser.add_argument("--lr", type=float,
dest="lr", default=1e-4, help="learning rate")
Expand Down

0 comments on commit 6a0a637

Please sign in to comment.