Skip to content

Commit

Permalink
pep8 cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ieee8023 committed Jan 22, 2023
1 parent ab38acb commit 11c53a0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions torchxrayvision/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"class": "ResNetAE101"
}


class Bottleneck(nn.Module):
expansion = 4

Expand Down Expand Up @@ -171,13 +172,12 @@ def _make_up_block(self, block, init_channels, num_layer, stride=1):
return nn.Sequential(*layers)

def encode(self, x, check_resolution=True):

if check_resolution and hasattr(self, 'weights_metadata'):
resolution = self.weights_metadata['resolution']
if (x.shape[2] != resolution) | (x.shape[3] != resolution):
raise ValueError("Input size ({}x{}) is not the native resolution ({}x{}) for this model. Set check_resolution=False on the encode function to override this error.".format(x.shape[2], x.shape[3], resolution, resolution))



x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
Expand Down
12 changes: 6 additions & 6 deletions torchxrayvision/baseline_models/chestx_det/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class PSPNet(nn.Module):
def __init__(self):

super(PSPNet, self).__init__()

self.transform = torchvision.transforms.Compose([
torchvision.transforms.Normalize(
[0.485, 0.456, 0.406],
Expand All @@ -50,9 +50,9 @@ def __init__(self):
])

self._targets = ['Left Clavicle', 'Right Clavicle', 'Left Scapula', 'Right Scapula',
'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis',
'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum', 'Weasand', 'Spine']
'Left Lung', 'Right Lung', 'Left Hilus Pulmonis', 'Right Hilus Pulmonis',
'Heart', 'Aorta', 'Facies Diaphragmatica', 'Mediastinum', 'Weasand', 'Spine']

model = pspnet(len(self.targets))

url = "https://github.com/mlmed/torchxrayvision/releases/download/v1/pspnet_chestxray_best_model_4.pth"
Expand All @@ -74,7 +74,7 @@ def __init__(self):
except Exception as e:
print("Loading failure. Check weights file:", self.weights_filename_local)
raise (e)

model.eval()
self.model = model
self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False)
Expand All @@ -90,7 +90,7 @@ def forward(self, x):

# expecting values between [-1024,1024]
x = (x + 1024) / 2048

# now between [0,1] for this model preprocessing
x = self.transform(x)
y = self.model(x)
Expand Down

0 comments on commit 11c53a0

Please sign in to comment.