Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add option for mixed precision dtype to enable running on cpus #31

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions roma/models/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
}

def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864):
def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
if isinstance(coarse_res, int):
coarse_res = (coarse_res, coarse_res)
if isinstance(upsample_res, int):
Expand All @@ -26,12 +26,12 @@ def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[in
dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
map_location=device)
model = roma_model(resolution=coarse_res, upsample_preds=True,
weights=weights,dinov2_weights = dinov2_weights,device=device)
weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
model.upsample_res = upsample_res
print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
return model

def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864):
def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
if isinstance(coarse_res, int):
coarse_res = (coarse_res, coarse_res)
if isinstance(upsample_res, int):
Expand All @@ -47,7 +47,7 @@ def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int
dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
map_location=device)
model = roma_model(resolution=coarse_res, upsample_preds=True,
weights=weights,dinov2_weights = dinov2_weights,device=device)
weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
model.upsample_res = upsample_res
print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
return model
6 changes: 4 additions & 2 deletions roma/models/model_zoo/roma_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
import torch.nn as nn
import torch
from roma.models.matcher import *
from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
from roma.models.encoders import *

def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs):
def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
# roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
#torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
#torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
Expand Down Expand Up @@ -146,7 +147,8 @@ def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_w
amp = True),
amp = True,
use_vgg = True,
dinov2_weights = dinov2_weights
dinov2_weights = dinov2_weights,
amp_dtype=amp_dtype,
)
h,w = resolution
symmetric = True
Expand Down