Skip to content

Latest commit

 

History

History
294 lines (247 loc) · 9.43 KB

pytorch-lightning入门到精通(1).md

File metadata and controls

294 lines (247 loc) · 9.43 KB

MNIST手写数字分类

  • 先导入一堆库

    import torch
    import torch.nn.functional as F
    from torch import nn
    from torch.nn import *
    from torch.optim import *
    from torch.optim.lr_scheduler import *
    from torch.utils.data import Dataset, DataLoader
    from torchvision import datasets, transforms
    
    import pytorch_lightning as pl
    from pytorch_lightning.callbacks import Callback
    
    import os, gc, time
    import math, random
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from tqdm import tqdm
    from sklearn.metrics import *
  • 定义超参数

    epochs = 10
    batch_size = 128
  • 创建数据集和生成器

    train_transform = transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.1307,), (0.3081,))
    ])
    val_transform = transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = datasets.MNIST("data", train=True, download=True, transform=train_transform)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    
    val_dataset = datasets.MNIST("data", train=False, transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
  • 定义损失函数和指标

    class CrossEntropyLoss(Module): # 和Pytorch官方实现一样
        def __init__(self):
            super().__init__()
    
        def forward(self, input, target, reduction="mean"):
            N, C = input.size()
    
            if target.dim() > 1:
                one_hot = target
            else:
                one_hot = torch.zeros((N, C), dtype=input.dtype, device=input.device)
                one_hot.scatter_(1, target.reshape(N, 1), 1)
    
            loss = -(one_hot * F.log_softmax(input, 1)).sum(1)
            if reduction == "mean":
                return loss.mean(0)
            elif reduction == "sum":
                return loss.sum(0)
            else:
                return loss
    
    
    class ClassificationMetric(object): # 记录结果并计算指标
        def __init__(self, accuracy=True, recall=True, precision=True, f1=True, average="macro"):
            self.accuracy = accuracy
            self.recall = recall
            self.precision = precision
            self.f1 = f1
            self.average = average
    
            self.preds = []
            self.target = []
    
        def reset(self): # 重置结果
            self.preds.clear()
            self.target.clear()
            gc.collect()
    
        def update(self, preds, target): # 更新结果
            preds = list(preds.cpu().detach().argmax(1).numpy())
            target = list(target.cpu().detach().argmax(1).numpy()) if target.dim() > 1 else list(target.cpu().detach().numpy())
            self.preds += preds
            self.target += target
    
        def compute(self): # 计算结果
            metrics = []
            if self.accuracy:
                metrics.append(accuracy_score(self.target, self.preds))
            if self.recall:
                metrics.append(recall_score(self.target, self.preds, labels=list(set(self.preds)), average=self.average))
            if self.precision:
                metrics.append(precision_score(self.target, self.preds, labels=list(set(self.preds)), average=self.average))
            if self.f1:
                metrics.append(f1_score(self.target, self.preds, labels=list(set(self.preds)), average=self.average))
            self.reset()
            return metrics
  • 定义模型和训练逻辑

    class CustomModel(pl.LightningModule):
        def __init__(self):
            super().__init__()
            # 定义网络层
            self.conv1 = Conv2d(1, 10, 5)
            self.conv2 = Conv2d(10, 20, 3)
            self.fc1 = Linear(20*10*10, 500)
            self.fc2 = Linear(500, 10)
            # 定义损失函数
            self.train_criterion = CrossEntropyLoss()
            self.val_criterion = CrossEntropyLoss()
            # 定义指标
            self.train_metric = ClassificationMetric(recall=False, precision=False)
            self.val_metric = ClassificationMetric(recall=False, precision=False)
            # 定义log
            self.history = {
                "loss": [], "acc": [], "f1": [],
                "val_loss": [], "val_acc": [], "val_f1": [],
            }
            
        def forward(self,x):
            in_size = x.size(0)
            out = self.conv1(x)
            out = F.relu(out)
            out = F.max_pool2d(out, 2, 2)
            out = self.conv2(out)
            out = F.relu(out)
            out = out.view(in_size, -1)
            out = self.fc1(out)
            out = F.relu(out)
            out = self.fc2(out)
            out = F.log_softmax(out, dim=1)
            return out
        
        def training_step(self, batch, idx):
            x, y = batch
            _y = self(x)
            # 计算loss
            loss = self.train_criterion(_y, y)
            # 更新结果
            self.train_metric.update(_y, y)
            return loss
    
        def training_epoch_end(self, outs):
            # 计算平均loss
            loss = 0.
            for out in outs:
                loss += out["loss"].cpu().detach().item()
            loss /= len(outs)
            # 计算指标
            acc, f1 = self.train_metric.compute()
            # 记录log
            self.history["loss"].append(loss)
            self.history["acc"].append(acc)
            self.history["f1"].append(f1)
    
        def validation_step(self, batch, idx):
            x, y = batch
            _y = self(x)
            val_loss = self.val_criterion(_y, y)
            self.val_metric.update(_y, y)
            return val_loss
    
        def validation_epoch_end(self, outs):
            val_loss = sum(outs).item() / len(outs)
            val_acc, val_f1 = self.val_metric.compute()
    
            self.history["val_loss"].append(val_loss)
            self.history["val_acc"].append(val_acc)
            self.history["val_f1"].append(val_f1)
    
        def configure_optimizers(self):
            # 设置优化器
            return Adam(self.parameters())
    model = CustomModel()
  • 定义进度条和学习曲线(我不喜欢官方的进度条,所以自己写了一个)

    class FlexibleTqdm(Callback):
        def __init__(self, steps, column_width=10):
            super(FlexibleTqdm, self).__init__()
            self.steps = steps
            self.column_width = column_width
            self.info = "\rEpoch_%d %s%% [%s]"
    
        def on_train_start(self, trainer, module):
            history = module.history
            self.row = "-" * (self.column_width + 1) * (len(history) + 2) + "-"
            title = "|"
            title += "epoch".center(self.column_width) + "|"
            title += "time".center(self.column_width) + "|"
            for i in history.keys():
                title += i.center(self.column_width) + "|"
            print(self.row)
            print(title)
            print(self.row)
    
        def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx):
            current_index = int((batch_idx + 1) * 100 / self.steps)
            tqdm = ["."] * 100
            for i in range(current_index - 1):
                tqdm[i] = "="
            if current_index:
                tqdm[current_index - 1] = ">"
            print(self.info % (module.current_epoch, str(current_index).rjust(3), "".join(tqdm)), end="")
    
        def on_epoch_start(self, trainer, module):
            print(self.info % (module.current_epoch, "  0", "." * 100), end="")
            self.begin = time.perf_counter()
    
        def on_epoch_end(self, trainer, module):
            self.end = time.perf_counter()
            history = module.history
            detail = "\r|"
            detail += str(module.current_epoch).center(self.column_width) + "|"
            detail += ("%d" % (self.end - self.begin)).center(self.column_width) + "|"
            for j in history.values():
                detail += ("%.06f" % j[-1]).center(self.column_width) + "|"
            print("\r" + " " * 120, end="")
            print(detail)
            print(self.row)
            
            
    class LearningCurve(Callback):
        def __init__(self, figsize=(12, 4), names=("loss", "acc", "f1")):
            super(LearningCurve, self).__init__()
            self.figsize = figsize
            self.names = names
    
        def on_fit_end(self, trainer, module):
            history = module.history
            plt.figure(figsize=self.figsize)
            for i, j in enumerate(self.names):
                plt.subplot(1, len(self.names), i + 1)
                plt.title(j + "/val_" + j)
                plt.plot(history[j], "--o", color='r', label=j)
                plt.plot(history["val_" + j], "-*", color='g', label="val_" + j)
                plt.legend()
            plt.show()
  • 设置训练参数(我不喜欢用logger,所以全关闭了)

    trainer_params = {
        "gpus": 1,
        "max_epochs": epochs,  # 1000
        "checkpoint_callback": False,  # True
        "logger": False,  # TensorBoardLogger
        "progress_bar_refresh_rate": 0,  # 1
        "num_sanity_val_steps": 0,  # 2
        "callbacks": [
            FlexibleTqdm(len(train_dataset) // batch_size, column_width=12), # 注意设置progress_bar_refresh_rate=0,取消自带的进度条
            LearningCurve(figsize=(12, 4), names=("loss", "acc", "f1")),
        ],  # None
    }
    trainer = pl.Trainer(**trainer_params)
  • 开启训练

    trainer.fit(model, train_dataloader, val_dataloader)

训练结果