# Bring Your Own PyTorch Lightning Model The InnerEye toolbox is capable of training any PyTorch Lighting (PL) model inside of AzureML, making use of all the usual InnerEye toolbox features: - Working with different model in the same codebase, and selecting one by name - Distributed training in AzureML - Logging via AzureML's native capabilities - Training on a local GPU machine or inside of AzureML without code changes - Supply commandline overrides for model configuration elements, to quickly queue many jobs This can be used by - Defining a special container class, that encapsulates the PyTorch Lighting model to train, and the data that should be used for training and testing. - Adding essential trainer parameters like number of epochs to that container. - Invoking the InnerEye runner and providing the name of the container class, like this: `python InnerEye/ML/runner.py --model=MyContainer`. To train in AzureML, just add a `--azureml` flag. There is a fully working example [HelloContainer](https://github.com/microsoft/InnerEye-DeepLearning/tree/main/InnerEye/ML/configs/other/HelloContainer.py), that implements a simple 1-dimensional regression model from data stored in a CSV file. You can run that from the command line by `python InnerEye/ML/runner.py --model=HelloContainer`. ## Setup In order to use these capabilities, you need to implement a class deriving from `LightningContainer`. This class encapsulates everything that is needed for training with PyTorch Lightning: - The `create_model` method needs to return a subclass of `LightningModule`, that has all the usual PyTorch Lightning methods required for training, like the `training_step` and `forward` methods. This object needs to adhere to additional constraints, see below. - The `get_data_module` method of the container needs to return a `LightningDataModule` that has the data loaders for training and validation data. - The optional `get_inference_data_module` returns a `LightningDataModule` that is used to read the data for inference (that is, evaluating the trained model). By default, this returns the same data as `get_training_data_module`, but you can override this for special models like segmentation models that are trained on equal sized image patches, but evaluated on full images of varying size. Your class needs to be defined in a Python file in the `InnerEye/ML/configs` folder, otherwise it won't be picked up correctly. If you'd like to have your model defined in a different folder, please specify the Python namespace via the `--model_configs_namespace` argument. For example, use `--model_configs_namespace=My.Own.configs` if your model configuration classes reside in folder `My/Own/configs` from the repository root. ### Cross Validation If you are doing cross validation you need to ensure that the `LightningDataModule` returned by your container's `get_data_module` method: - Needs to take into account the number of cross validation splits, and the cross validation split index when preparing the data. - Needs to log val/Loss in its `validation_step` method. You can find a working example of handling cross validation in the [HelloContainer](https://github.com/microsoft/InnerEye-DeepLearning/tree/main/InnerEye/ML/configs/other/HelloContainer.py) class. *Example*: ```python from pathlib import Path from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, LightningDataModule from InnerEye.ML.lightning_container import LightningContainer class MyLightningModel(LightningModule): def __init__(self): self.layer = ... def training_step(self, *args, **kwargs): ... def forward(self, *args, **kwargs): ... def configure_optimizers(self): ... def test_step(self, *args, **kwargs): ... class MyDataModule(LightningDataModule): def __init__(self, root_path: Path): # All data should be read from the folder given in self.root_path self.root_path = root_path def train_dataloader(self, *args, **kwargs) -> DataLoader: # The data should be read off self.root_path train_dataset = ... return DataLoader(train_dataset, batch_size=5, num_workers=5) def val_dataloader(self, *args, **kwargs) -> DataLoader: # The data should be read off self.root_path val_dataset = ... return DataLoader(val_dataset, batch_size=5, num_workers=5) def test_dataloader(self, *args, **kwargs) -> DataLoader: # The data should be read off self.root_path test_dataset = ... return DataLoader(test_dataset, batch_size=5, num_workers=5) class MyContainer(LightningContainer): def __init__(self): super().__init__() self.azure_dataset_id = "folder_name_in_azure_blob_storage" self.local_dataset = "/some/local/path" self.num_epochs = 42 def create_model(self) -> LightningModule: return MyLightningModel() def get_data_module(self) -> LightningDataModule: return MyDataModule(root_path=self.local_dataset) ``` Where does the data for training come from? - When training a model on a local box or VM, the data is read from the `local_dataset` folder that you define in the container. - When training a model in AzureML, the code searches for a folder called `folder_name_in_azure_blob_storage` in Azure blob storage. That is then downloaded or mounted. The local download path is then copied over the `local_dataset` field in the container, and hence you can always read data from `self.local_dataset` - Alternatively, you can use the `prepare_data` method of a `LightningDataModule` to download data from the web, for example. In this case, you don't need to define any of the `local_dataset` or `azure_dataset_id` fields. In the above example, training is done for 42 epochs. After the model is trained, it will be evaluated on the test set, via PyTorch Lightning's [built-in test functionality](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html?highlight=trainer.test#test). See below for an alternative way of running the evaluation on the test set. ### Data loaders The example above creates `DataLoader` objects from a dataset. When creating those, you need to specify a batch size (how many samples from your dataset will go into one minibatch), and a number of worker processes. Note that, by default, data loading will happen in the main process, meaning that your GPU will sit idle while the CPU reads data from disk. When specifying a number of workers, it will spawn processes that pre-fetch data from disk, and put them into a queue, ready for the GPU to pick it up when it is done processing the current minibatch. For more details, please see the documentation for [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). There is also a [tutorial describing the foundations of datasets and data loaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) ### Outputting files during training The Lightning model returned by `create_model` needs to write its output files to the current working directory. When running the InnerEye toolbox outside of AzureML, the toolbox will change the current working directory to a newly created output folder, with a name that contains the time stamp and and the model name. When running the InnerEye toolbox in AzureML, the folder structure will be set up such that all files written to the current working directory are later uploaded to Azure blob storage at the end of the AzureML job. The files will also be later available via the AzureML UI. ### Trainer arguments All arguments that control the PyTorch Lightning `Trainer` object are defined in the class `TrainerParams`. A `LightningContainer` object inherits from this class. The most essential one is the `num_epochs` field, which controls the `max_epochs` argument of the `Trainer`. Usage example: ```python from pytorch_lightning import LightningModule, LightningDataModule from InnerEye.ML.lightning_container import LightningContainer class MyContainer(LightningContainer): def __init__(self): super().__init__() self.num_epochs = 42 def create_model(self) -> LightningModule: return MyLightningModel() def get_data_module(self) -> LightningDataModule: return MyDataModule(root_path=self.local_dataset) ``` For further details how the `TrainerParams` are used, refer to the `create_lightning_trainer` method in [InnerEye/ML/model_training.py](https://github.com/microsoft/InnerEye-DeepLearning/tree/main/InnerEye/ML/model_training.py) ### Optimizer and LR scheduler arguments There are two possible ways of choosing the optimizer and LR scheduler: - The Lightning model returned by `create_model` can define its own `configure_optimizers` method, with the same signature as `LightningModule.configure_optimizers`. This is the typical way of configuring it for Lightning models. - Alternatively, the model can inherit from `LightningModuleWithOptimizer`. This class implements a `configure_optimizers` method that uses settings defined in the `OptimizerParams` class. These settings are all available from the command line, and you can, for example, start a new run with a different learning rate by supplying the additional commandline flag `--l_rate=1e-2`. ### Evaluating the trained model The InnerEye toolbox provides two possible routes of implementing that: You can either use PyTorch Lightning's built-in capabilities, via the `test_step` method. If the model that is returned by `create_model` implements the `test_step` method, the InnerEye toolbox will use the `trainer.test` method (see [docs](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html?highlight=trainer.test#test)). In this case, the best checkpoint during training will be used. The test data is read via the data loader created by the `test_dataloader` of the `LightningDataModule` that is used for training/validation. Alternatively, the model can implement the methods defined in `InnerEyeInference`. In this case, the methods will be call in this order: ```python model.on_inference_start() for dataset_split in [Train, Val, Test] model.on_inference_epoch_start(dataset_split, is_ensemble_model=False) for batch_idx, item in enumerate(dataloader[dataset_split])): model_outputs = model.forward(item) model.inference_step(item, batch_idx, model_outputs) model.on_inference_epoch_end() model.on_inference_end() ``` ## Overriding properties on the commandline You can define hyperparameters that affect data and/or model, as in the following code snippet: ```python import param from pytorch_lightning import LightningModule from InnerEye.ML.lightning_container import LightningContainer class DummyContainerWithParameters(LightningContainer): num_layers = param.Integer(default=4) def create_model(self) -> LightningModule: return MyLightningModel(self.num_layers) ... ``` All parameters added in this form will be automatically accessible from the commandline, there is no need to define a separate argument parser: When starting training, you can add a flag like `--num_layers=7`. ## Examples ### Setting only the required fields ```python from pytorch_lightning import LightningModule, LightningDataModule from InnerEye.ML.lightning_container import LightningContainer class Container1(LightningContainer): def __init__(self): super().__init__() self.azure_dataset_id = "some_folder_in_azure" self.num_epochs = 20 def create_model(self) -> LightningModule: return MyLightningModel() def get_data_module(self) -> LightningDataModule: # This should read data from self.local_dataset. Before training, the data folder "some_folder_in_azure" # (given by self.azure_dataset_id) will be downloaded or mounted, and its local path set in # self.local_dataset return MyDataModule(root_folder=self.local_dataset) ``` ### Adding additional arguments for the PyTorch Lightning trainer ```python from typing import Dict, Any from pytorch_lightning import LightningModule, LightningDataModule from InnerEye.ML.lightning_container import LightningContainer class Container2(LightningContainer): def __init__(self): super().__init__() self.azure_dataset_id = "some_folder_in_azure" self.num_epochs = 20 def create_model(self) -> LightningModule: return MyLightningModel() def get_data_module(self) -> LightningDataModule: # This should read data from self.local_dataset. Before training, the data folder "some_folder_in_azure" # (given by self.azure_dataset_id) will be downloaded or mounted, and its local path set in # self.local_dataset return MyDataModule(root_folder=self.local_dataset) def get_trainer_arguments(self) -> Dict[str, Any]: # These arguments will be passed through to the Lightning trainer. return {"gradient_clip_val": 1, "limit_train_batches": 10} ```