Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The size of tensor a (125) must match the size of tensor b (128) at non-singleton dimension 2 #80

Closed
Hr-Song opened this issue Aug 21, 2020 · 3 comments · Fixed by #82

Comments

@Hr-Song
Copy link

Hr-Song commented Aug 21, 2020

A RuntimeError happened when I tried using newest fast_scnn to infer on my own dataset.
The error has never happened when I was using other models in this repository on the same images.

Here is the Traceback:
Traceback (most recent call last):
File "image_inference_box.py", line 116, in
main()
File "image_inference_box.py", line 42, in main
result = inference_segmentor(model, img)
File "/home/lzhpc/mmsegmentation-master/mmseg/apis/inference.py", line 95, in inference_segmentor
result = model(return_loss=False, rescale=True, **data)
File "/home/lzhpc/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/lzhpc/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 84, in new_func
return old_func(*args, **kwargs)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/base.py", line 124, in forward
return self.forward_test(img, img_metas, **kwargs)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/base.py", line 106, in forward_test
return self.simple_test(imgs[0], img_metas[0], **kwargs)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/encoder_decoder.py", line 261, in simple_test
seg_logit = self.inference(img, img_meta, rescale)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/encoder_decoder.py", line 246, in inference
seg_logit = self.whole_inference(img, img_meta, rescale)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/encoder_decoder.py", line 213, in whole_inference
seg_logit = self.encode_decode(img, img_meta)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/encoder_decoder.py", line 87, in encode_decode
x = self.extract_feat(img)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/segmentors/encoder_decoder.py", line 79, in extract_feat
x = self.backbone(img)
File "/home/lzhpc/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/backbones/fast_scnn.py", line 381, in forward
lower_res_features)
File "/home/lzhpc/anaconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, **kwargs)
File "/home/lzhpc/mmsegmentation-master/mmseg/models/backbones/fast_scnn.py", line 249, in forward
out = higher_res_feature + lower_res_feature
RuntimeError: The size of tensor a (125) must match the size of tensor b (128) at non-singleton dimension 2

@johnzja
Copy link
Contributor

johnzja commented Aug 21, 2020

Could you please show us your config file && the size of input images?
Tensor size @ dimension 2 is the picture height. If it mismatches, it is possible that the (H,W) of the input picture is not a multiple of 32. In our tests, we use size=(1024, 2048), and it seems to work.

@Hr-Song
Copy link
Author

Hr-Song commented Aug 21, 2020

The size of input images was (2000, 2048). But in my config file, the test image scale was set to be resized to (1024,1024) .
Here is my config file.

norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='FastSCNN',
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_block_strides=(2, 2, 1),
global_out_channels=128,
higher_in_channels=64,
lower_in_channels=128,
fusion_out_channels=128,
out_indices=(0, 1, 2),
norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01),
align_corners=False),
decode_head=dict(
type='DepthwiseSeparableFCNHead',
in_channels=128,
channels=128,
concat_input=False,
num_classes=2,
in_index=-1,
norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01),
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=32,
num_convs=1,
num_classes=2,
in_index=-2,
norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01),
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='FCNHead',
in_channels=64,
channels=32,
num_convs=1,
num_classes=2,
in_index=-3,
norm_cfg=dict(type='SyncBN', requires_grad=True, momentum=0.01),
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
])
train_cfg = dict()
test_cfg = dict(mode='whole')
dataset_type = 'PascalVOCDataset'
data_root = 'data/VOCdevkit/VOC2012'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=False)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 2048), ratio_range=(0.5, 0.8)),
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False),
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1024, 1024),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
data = dict(
samples_per_gpu=8,
workers_per_gpu=4,
train=dict(
type='PascalVOCDataset',
data_root='data/VOCdevkit/VOC2012',
img_dir='JPEGImages',
ann_dir='SegmentationClassPNG',
split='ImageSets/Segmentation/trainval.txt',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='Resize', img_scale=(2048, 2048), ratio_range=(0.5, 0.8)),
dict(type='RandomCrop', crop_size=(512, 512), cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False),
dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]),
val=dict(
type='PascalVOCDataset',
data_root='data/VOCdevkit/VOC2012',
img_dir='JPEGImages',
ann_dir='SegmentationClassPNG',
split='ImageSets/Segmentation/val.txt',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1024, 1024),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='PascalVOCDataset',
data_root='data/VOCdevkit/VOC2012',
img_dir='JPEGImages',
ann_dir='SegmentationClassPNG',
split='ImageSets/Segmentation/test.txt',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1024, 1024),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=False),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]))
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook')
])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(type='SGD', lr=0.12, momentum=0.9, weight_decay=4e-05)
optimizer_config = dict()
lr_config = dict(policy='poly', power=0.9, min_lr=0.0001, by_epoch=False)
total_iters = 20000
checkpoint_config = dict(by_epoch=False, interval=2000)
evaluation = dict(interval=2000, metric='mIoU')
work_dir = './work_dirs/fast_scnn_4x8_80k_lr0.12_yiwu'
gpu_ids = range(0, 1)

@xvjiarui
Copy link
Collaborator

Fixed by #82

aravind-h-v pushed a commit to aravind-h-v/mmsegmentation that referenced this issue Mar 27, 2023
* Expose schedulers

* Update __init__.py

Co-authored-by: Anton Lozhkov <[email protected]>
wjkim81 pushed a commit to wjkim81/mmsegmentation that referenced this issue Dec 3, 2023
sibozhang pushed a commit to sibozhang/mmsegmentation that referenced this issue Mar 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants