Skip to content

Commit

Permalink
Merge pull request #2362 from MengzhangLI/scipy_1.x
Browse files Browse the repository at this point in the history
[Enhance] Make scipy as a default dependency in runtime in dev-1.x
  • Loading branch information
MeowZheng committed Nov 30, 2022
2 parents aefcab3 + 9251100 commit 383826f
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
6 changes: 1 addition & 5 deletions mmseg/models/backbones/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@
from mmengine.model.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmengine.runner.checkpoint import _load_checkpoint
from scipy import interpolate
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple

from mmseg.registry import MODELS
from ..utils import PatchEmbed
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer

try:
from scipy import interpolate
except ImportError:
interpolate = None


class BEiTAttention(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ mmcls>=1.0.0rc0
numpy
packaging
prettytable
scipy
7 changes: 5 additions & 2 deletions tests/test_models/test_backbones/test_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,11 @@ def test_beit_init():
}
}
model = BEiT(img_size=(512, 512))
with pytest.raises(AttributeError):
model.resize_rel_pos_embed(ckpt)
# If scipy is installed, this AttributeError would not be raised.
from mmengine.utils import is_installed
if not is_installed('scipy'):
with pytest.raises(AttributeError):
model.resize_rel_pos_embed(ckpt)

# pretrained=None
# init_cfg=123, whose type is unsupported
Expand Down
7 changes: 5 additions & 2 deletions tests/test_models/test_backbones/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,11 @@ def test_mae_init():
}
}
model = MAE(img_size=(512, 512))
with pytest.raises(AttributeError):
model.resize_rel_pos_embed(ckpt)
# If scipy is installed, this AttributeError would not be raised.
from mmengine.utils import is_installed
if not is_installed('scipy'):
with pytest.raises(AttributeError):
model.resize_rel_pos_embed(ckpt)

# test resize abs pos embed
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
Expand Down

0 comments on commit 383826f

Please sign in to comment.