Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Update PL to 1.3.8 #531

Merged
merged 27 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
few fixes
  • Loading branch information
melanibe committed Jul 9, 2021
commit 26dea77e759e5f0ee21e550c54a89045495f698b
2 changes: 1 addition & 1 deletion InnerEye/ML/SSL/lightning_modules/ssl_classifier_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def training_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) ->

def validation_step(self, batch: Any, batch_id: int, *args: Any, **kwargs: Any) -> None: # type: ignore
loss = self.shared_step(batch, is_training=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=False)
self.log('val/loss', loss, on_step=False, on_epoch=True, sync_dist=True)
for metric in self.val_metrics:
self.log(f"val/{metric.name}", metric, on_epoch=True, on_step=False)

Expand Down
2 changes: 1 addition & 1 deletion InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def log_on_epoch(self,
if isinstance(value, numbers.Number):
value = torch.tensor(value, dtype=torch.float, device=self.device)
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
sync_dist = True if sync_dist_override is None else sync_dist_override
sync_dist = self.use_sync_dist if sync_dist_override is None else sync_dist_override
self.log(prefix + metric_name, value,
sync_dist=sync_dist,
on_step=False, on_epoch=True,
Expand Down
4 changes: 2 additions & 2 deletions InnerEye/ML/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def compute_metrics(self, cropped_sample: CroppedSample, segmentation: torch.Ten
self.log_on_epoch(name=MetricType.SUBJECT_COUNT,
value=num_subjects,
is_training=is_training,
reduce_fx=sum,
sync_dist_op=None)
reduce_fx=torch.sum,
sync_dist_op="sum")

def training_or_validation_epoch_end(self, is_training: bool) -> None:
"""
Expand Down