Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test] Add more test cases #45

Merged
merged 16 commits into from
Jan 29, 2022
Prev Previous commit
Next Next commit
update: test_pc_grad_optimizers
  • Loading branch information
kozistr committed Jan 29, 2022
commit 7d707b8bd0b635ae482e8b5c3bc14f433019cf00
30 changes: 26 additions & 4 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class MultiHeadLogisticRegression(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.head1 = nn.Linear(2, 1)
self.head2 = nn.Linear(2, 1)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.fc1(x)
x = F.relu(x)
return self.head1(x), self.head2(x)


def make_dataset(num_samples: int = 100, dims: int = 2, seed: int = 42) -> Tuple[torch.Tensor, torch.Tensor]:
rng = np.random.RandomState(seed)

Expand Down Expand Up @@ -181,6 +194,9 @@ def test_sam_optimizers(optimizer_config):
loss_fn(y_data, model(x_data)).backward()
optimizer.second_step(zero_grad=True)

if init_loss == np.inf:
init_loss = loss

assert init_loss > 2.0 * loss


Expand All @@ -190,9 +206,9 @@ def test_pc_grad_optimizers(optimizer_config):

x_data, y_data = make_dataset()

model: nn.Module = LogisticRegression()
model: nn.Module = MultiHeadLogisticRegression()
loss_fn_1: nn.Module = nn.BCEWithLogitsLoss()
loss_fn_2: nn.Module = nn.BCEWithLogitsLoss()
loss_fn_2: nn.Module = nn.L1Loss()

optimizer_class, config, iterations = optimizer_config
optimizer = PCGrad(optimizer_class(model.parameters(), **config))
Expand All @@ -201,8 +217,14 @@ def test_pc_grad_optimizers(optimizer_config):
init_loss: float = np.inf
for _ in range(iterations):
optimizer.zero_grad()
y_pred = model(x_data)
loss1, loss2 = loss_fn_1(y_pred, y_data), loss_fn_2(y_pred, y_data)
y_pred_1, y_pred_2 = model(x_data)
loss1, loss2 = loss_fn_1(y_pred_1, y_data), loss_fn_2(y_pred_2, y_data)

loss = (loss1 + loss2) / 2.0
if init_loss == np.inf:
init_loss = loss

optimizer.pc_backward([loss1, loss2])
optimizer.step()

assert init_loss > 2.0 * loss