Skip to content

Commit

Permalink
log images and compute inception every 10 epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
Frederikravnborg committed Jul 1, 2024
1 parent 648413b commit 03f25bc
Show file tree
Hide file tree
Showing 10 changed files with 298 additions and 22 deletions.
27 changes: 11 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,20 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
loss = self.get_loss(batch, batch_idx)
self.log("val_loss", loss)

# Generate and log images
with torch.no_grad():
generated_images = self.denoise_sample(batch.to(self.device), torch.tensor([self.t_range - 1], device=self.device))
self.generated_images.append(generated_images)

return loss

def on_validation_epoch_end(self):
# Generate noise
noise = torch.randn((32, 3, 32, 32), device=self.device) # Adjust dimensions according to your dataset
generated_images = self.denoise_sample(noise, self.t_range)

# Log generated images
self.log_images(generated_images, self.current_epoch)

# Calculate and log Inception Score
inception_score, inception_std = self.calculate_inception_score(generated_images)
wandb.log({'inception_score': inception_score, 'inception_score_std': inception_std})
if self.current_epoch % 10 == 0:
# Generate noise
noise = torch.randn((32, 3, 32, 32), device=self.device)
generated_images = self.denoise_sample(noise, self.t_range)

# Log generated images
self.log_images(generated_images, self.current_epoch)

# Calculate and log Inception Score
inception_score, inception_std = self.calculate_inception_score(generated_images)
wandb.log({'inception_score': inception_score, 'inception_score_std': inception_std})


def calculate_inception_score(self, images, splits=10):
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
wandb_name = f'{diffusion_steps}_steps'

if mode == 0: # Local
diffusion_steps = 10
diffusion_steps = 4
dataset_choice = "CIFAR"
max_epoch = 10
max_epoch = 100
batch_size = 128
train_fraction = 2
val_fraction = 2
train_fraction = 1
val_fraction = 1
continue_training = False
ckpt_path = '/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion/checkpoints/06.30-22.28.05/10_steps-epoch=00-loss=0.00.ckpt'
wandb_name = f'local_{diffusion_steps}_steps'
Expand Down
2 changes: 1 addition & 1 deletion wandb/latest-run
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"_timestamp": 1719834016.8627992, "_runtime": 71.721036195755, "_step": 73, "inception_score": 1.0858374835132234, "inception_score_std": 0.013026577973352192, "train_loss": 0.9565205574035645, "epoch": 9, "trainer/global_step": 19, "val_loss": 0.9546266198158264}
{"_timestamp": 1719834016.8627992, "_runtime": 71.721036195755, "_step": 73, "inception_score": 1.0858374835132234, "inception_score_std": 0.013026577973352192, "train_loss": 0.9565205574035645, "epoch": 9, "trainer/global_step": 19, "val_loss": 0.9546266198158264, "_wandb": {"runtime": 75}}
Binary file modified wandb/run-20240701_133905-o7olzhlr/run-o7olzhlr.wandb
Binary file not shown.
143 changes: 143 additions & 0 deletions wandb/run-20240701_140615-bllvndcc/files/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
wandb_version: 1

_wandb:
desc: null
value:
python_version: 3.9.6
cli_version: 0.17.3
framework: lightning
is_jupyter_run: false
is_kaggle_kernel: false
start_time: 1719835575
t:
1:
- 1
- 9
- 41
- 55
- 103
2:
- 1
- 9
- 41
- 55
- 103
3:
- 7
- 13
- 23
- 66
4: 3.9.6
5: 0.17.3
8:
- 4
- 5
13: darwin-arm64
m:
- 1: trainer/global_step
6:
- 3
- 1: generated/_image_0._type
5: 1
6:
- 1
- 1: generated/_image_0.width
5: 1
6:
- 1
- 1: generated/_image_0.height
5: 1
6:
- 1
- 1: generated/_image_0.format
5: 1
6:
- 1
- 1: generated/_image_0.count
5: 1
6:
- 1
- 1: generated/_image_0.filenames
5: 1
6:
- 1
- 1: generated/_image_0.captions
5: 1
6:
- 1
- 1: generated/_image_16._type
5: 1
6:
- 1
- 1: generated/_image_16.width
5: 1
6:
- 1
- 1: generated/_image_16.height
5: 1
6:
- 1
- 1: generated/_image_16.format
5: 1
6:
- 1
- 1: generated/_image_16.count
5: 1
6:
- 1
- 1: generated/_image_16.filenames
5: 1
6:
- 1
- 1: generated/_image_16.captions
5: 1
6:
- 1
- 1: generated/_image_31._type
5: 1
6:
- 1
- 1: generated/_image_31.width
5: 1
6:
- 1
- 1: generated/_image_31.height
5: 1
6:
- 1
- 1: generated/_image_31.format
5: 1
6:
- 1
- 1: generated/_image_31.count
5: 1
6:
- 1
- 1: generated/_image_31.filenames
5: 1
6:
- 1
- 1: generated/_image_31.captions
5: 1
6:
- 1
- 1: inception_score
5: 1
6:
- 1
- 1: inception_score_std
5: 1
6:
- 1
- 1: train_loss
5: 1
6:
- 1
- 1: epoch
5: 1
6:
- 1
- 1: val_loss
5: 1
6:
- 1
88 changes: 88 additions & 0 deletions wandb/run-20240701_140615-bllvndcc/files/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
GitPython==3.1.43
Jinja2==3.1.4
Markdown==3.6
MarkupSafe==2.1.5
PyYAML==6.0.1
Pygments==2.18.0
Werkzeug==3.0.3
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
appnope==0.1.4
asttokens==2.4.1
async-timeout==4.0.3
attrs==23.2.0
certifi==2024.6.2
charset-normalizer==3.3.2
click==8.1.7
comm==0.2.2
contourpy==1.2.1
cycler==0.12.1
debugpy==1.8.2
decorator==5.1.1
docker-pycreds==0.4.0
exceptiongroup==1.2.1
executing==2.0.1
filelock==3.15.4
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.6.1
gitdb==4.0.11
grpcio==1.64.1
idna==3.7
imageio==2.34.2
importlib_metadata==8.0.0
importlib_resources==6.4.0
ipykernel==6.29.4
ipython==8.18.1
jedi==0.19.1
jupyter_client==8.6.2
jupyter_core==5.7.2
kiwisolver==1.4.5
lightning-utilities==0.11.3.post0
matplotlib-inline==0.1.7
matplotlib==3.9.0
mpmath==1.3.0
multidict==6.0.5
nest-asyncio==1.6.0
networkx==3.2.1
numpy==2.0.0
packaging==24.1
parso==0.8.4
pexpect==4.9.0
pillow==10.3.0
pip==24.1.1
platformdirs==4.2.2
prompt_toolkit==3.0.47
protobuf==4.25.3
psutil==6.0.0
ptyprocess==0.7.0
pure-eval==0.2.2
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytorch-fid==0.3.0
pytorch-lightning==2.3.1
pyzmq==26.0.3
requests==2.32.3
scipy==1.13.1
sentry-sdk==2.7.1
setproctitle==1.3.3
setuptools==58.0.4
six==1.16.0
smmap==5.0.1
stack-data==0.6.3
sympy==1.12.1
tensorboard-data-server==0.7.2
tensorboard==2.17.0
torch==2.3.1
torchmetrics==1.4.0.post0
torchvision==0.18.1
tornado==6.4.1
tqdm==4.66.4
traitlets==5.14.3
typing_extensions==4.12.2
urllib3==2.2.2
wandb==0.17.3
wcwidth==0.2.13
yarl==1.9.4
zipp==3.19.2
49 changes: 49 additions & 0 deletions wandb/run-20240701_140615-bllvndcc/files/wandb-metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"os": "macOS-15.0-arm64-arm-64bit",
"python": "3.9.6",
"heartbeatAt": "2024-07-01T12:06:16.217701",
"startedAt": "2024-07-01T12:06:15.097152",
"docker": null,
"cuda": null,
"args": [],
"state": "running",
"program": "/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion/train.py",
"codePathLocal": "train.py",
"codePath": "train.py",
"git": {
"remote": "https://github.com/Frederikravnborg/pytorch-diffusion.git",
"commit": "648413b40642a4a5a29b60e735787b17b1962c7a"
},
"email": "[email protected]",
"root": "/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion",
"host": "mac.students.clients.local",
"username": "fredmac",
"executable": "/Users/fredmac/Documents/DTU-FredMac/pytorch-diffusion/.venv/bin/python",
"cpu_count": 10,
"cpu_count_logical": 10,
"cpu_freq": {
"current": 3504,
"min": 702,
"max": 3504
},
"cpu_freq_per_core": [
{
"current": 3504,
"min": 702,
"max": 3504
}
],
"disk": {
"/": {
"total": 460.4317207336426,
"used": 10.496021270751953
}
},
"gpuapple": {
"type": "Apple M2 Pro",
"vendor": "sppci_vendor_Apple"
},
"memory": {
"total": 16.0
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"_timestamp": 1719835677.853007, "_runtime": 102.74766516685486, "_step": 143, "inception_score": 1.0765033198177902, "inception_score_std": 0.008850658704948819, "train_loss": 0.9674491882324219, "epoch": 57, "trainer/global_step": 57, "val_loss": 0.9580375552177429}
Binary file not shown.

0 comments on commit 03f25bc

Please sign in to comment.