Skip to content

Commit

Permalink
Removing some distill code and adding tiny_imagenet hack
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 016bb8a commit 2d2b32d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 91 deletions.
5 changes: 4 additions & 1 deletion modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,10 @@ def get_matching_datasets(
seed=seed)

train_dataset = Subset(train_data, np.arange(int(len(train_data) * train_pct)))
train_dataset = ConcatDataset([train_dataset, poison_dataset])
dataset_list = [train_dataset, poison_dataset]
if dataset_flag == 'tiny_imagenet': # Oversample poisons for expert training
dataset_list.extend([poison_dataset] * 9)
train_dataset = ConcatDataset(dataset_list)

if train_pct < 1.0:
mtt_distill_dataset = Subset(distill_dataset, np.arange(int(len(distill_dataset) * train_pct)))
Expand Down
90 changes: 0 additions & 90 deletions modules/base_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,6 @@ def clf_correct(y_pred: torch.Tensor, y: torch.Tensor):
return correct


def distill_correct(y_pred: torch.Tensor, y: torch.Tensor):
y_hat = y_pred.argmax(1)
y_true = y.argmax(1)
correct = (y_hat == y_true).long().cpu().sum()
return correct


def clf_eval(model: torch.nn.Module, data: Union[DataLoader, Dataset]):
device = get_module_device(model)
dataloader, _ = either_dataloader_dataset_to_both(data, eval=True)
Expand Down Expand Up @@ -270,89 +263,6 @@ def mini_train(
return model


def mini_distill_train(
*,
student_model: torch.nn.Module,
teacher_model: torch.nn.Module,
distill_data: Union[DataLoader, Dataset],
test_data: Union[Union[DataLoader, Dataset],
Iterable[Union[DataLoader, Dataset]]] = None,
batch_size=32,
opt: optim.Optimizer,
scheduler,
epochs: int,
alpha: float = 0.0,
temperature: float = 1.0,
i_pct: float = None,
record: bool = False
):
device = get_module_device(student_model)
dataloader, _ = either_dataloader_dataset_to_both(distill_data,
batch_size=batch_size)
n = len(dataloader.dataset)
total_examples = epochs * n

if test_data:
num_sets = 1
if isinstance(test_data, Iterable):
num_sets = len(test_data)
else:
test_data = [test_data]
acc_loss = [[] for _ in range(num_sets)]

with make_pbar(total=total_examples) as pbar:
for _ in range(1, epochs + 1):
train_epoch_loss, train_epoch_correct = 0, 0
student_model.train()
teacher_model.eval()
for data in dataloader:
if i_pct is None:
x, y = data
else:
x, y_prime, y = data
y_prime = y_prime.to(device)
x, y = x.to(device), y.to(device)
minibatch_size = len(x)
student_model.zero_grad()
student_y_pred = student_model(x)
teacher_y_pred = torch.nn.functional.softmax(teacher_model(x), dim=1)
if i_pct is not None:
teacher_y_pred = (i_pct * teacher_y_pred) + ((1 - i_pct) * y_prime)
loss = clf_loss(student_y_pred, teacher_y_pred.argmax(axis=1))
correct = distill_correct(student_y_pred, teacher_y_pred)
loss.backward()
opt.step()
train_epoch_correct += int(correct.item())
train_epoch_loss += float(loss.item())
pbar.update(minibatch_size)

lr = get_mean_lr(opt)
if scheduler:
scheduler.step()

pbar_postfix = {
"acc": "%.2f" % (train_epoch_correct / n * 100),
"loss": "%.4g" % (train_epoch_loss / n),
"lr": "%.3g" % lr,
}
if test_data:
for i, dataset in enumerate(test_data):
acc, loss = clf_eval(student_model, dataset)
pbar_postfix.update(
{
"acc" + str(i): "%.2f" % (acc * 100),
"loss" + str(i): "%.4g" % loss,
}
)
if record:
acc_loss[i].append((acc, loss))

pbar.set_postfix(**pbar_postfix)
if record:
return student_model, *acc_loss
return student_model


def get_train_info(
params,
train_flag,
Expand Down

0 comments on commit 2d2b32d

Please sign in to comment.