Skip to content

Commit

Permalink
ef35a38ad29c@2024-04-07_09-03-46: fix ape-e with fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
shenyunhang committed Apr 7, 2024
1 parent 358e86e commit cd8ce26
Show file tree
Hide file tree
Showing 14 changed files with 899 additions and 78 deletions.
2 changes: 2 additions & 0 deletions ape/engine/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def create_fsdp_model(model, *, fp16_compression=False, **kwargs):
# _module_classes_to_ignore=(MultiScaleDeformableAttention,),
)

model = model.to(param_dtype)

fsdp = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
Expand Down
8 changes: 4 additions & 4 deletions ape/evaluation/oideval.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compute_average_precision(precision, recall):

if not isinstance(precision, np.ndarray) or not isinstance(recall, np.ndarray):
raise ValueError("precision and recall must be numpy array")
if precision.dtype != np.float or recall.dtype != np.float:
if precision.dtype != float or recall.dtype != float:
raise ValueError("input must be float numpy array.")
if len(precision) != len(recall):
raise ValueError("precision and recall must be of the same size.")
Expand Down Expand Up @@ -448,8 +448,8 @@ def accumulate(self):
tps = np.logical_and(dt_m, np.logical_not(dt_ig))
fps = np.logical_and(np.logical_not(dt_m), np.logical_not(dt_ig))

tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
tp_sum = np.cumsum(tps, axis=1).astype(dtype=float)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=float)

dt_pointers[cat_idx][area_idx] = {
"tps": tps,
Expand Down Expand Up @@ -479,7 +479,7 @@ def accumulate(self):
pr[i - 1] = pr[i]

mAP = compute_average_precision(
np.array(pr, np.float).reshape(-1), np.array(rc, np.float).reshape(-1)
np.array(pr, float).reshape(-1), np.array(rc, float).reshape(-1)
)
precision[iou_thr_idx, :, cat_idx, area_idx] = mAP

Expand Down
6 changes: 4 additions & 2 deletions ape/layers/fuse_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def __init__(
self.gamma_l = nn.Parameter(init_values * torch.ones((l_dim)), requires_grad=True)

def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
v = self.layer_norm_v(v.float())
l = self.layer_norm_l(l.float())
# v = self.layer_norm_v(v.float())
# l = self.layer_norm_l(l.float())
v = self.layer_norm_v(v)
l = self.layer_norm_l(l)
delta_v, delta_l = self.attn(
v, l, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l
)
Expand Down
2 changes: 1 addition & 1 deletion ape/layers/vision_language_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(

def forward(self, v, l, attention_mask_v=None, attention_mask_l=None):
if self.use_checkpoint and self.training:
return checkpoint.checkpoint(self.b_attn, v, l, attention_mask_v, attention_mask_l)
return checkpoint.checkpoint(self.b_attn, v, l, attention_mask_v, attention_mask_l, use_reentrant=False)
else:
return self.b_attn(v, l, attention_mask_v, attention_mask_l)

Expand Down
2 changes: 2 additions & 0 deletions ape/modeling/ape_deta/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
text_feature_batch_repeat: bool = True,
text_feature_bank: bool = False,
text_feature_bank_reset: bool = False,
text_feature_bank_random_size: bool = False,
text_feature_reduce_type: str = "last",
text_feature_reduce_before_fusion: bool = True,
expression_cumulative_gt_class: bool = True,
Expand Down Expand Up @@ -276,6 +277,7 @@ def __init__(

self.text_feature_bank = text_feature_bank
self.text_feature_bank_reset = text_feature_bank_reset
self.text_feature_bank_random_size = text_feature_bank_random_size
if self.text_feature_bank:
features_phrase_bank = torch.zeros(
(
Expand Down
5 changes: 5 additions & 0 deletions ape/modeling/ape_deta/deformable_detr_segm.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ def forward(self, batched_inputs, do_postprocess=True):
: max(len(text_list), self.criterion[dataset_id].num_classes)
]

if self.text_feature_bank and self.text_feature_bank_random_size:
features_l = features_l[
: random.randint(len(text_list), len(features_l))
]

if self.text_feature_batch_repeat:
features_l = features_l.unsqueeze(0).repeat(len(batched_inputs), 1, 1)
else:
Expand Down
7 changes: 7 additions & 0 deletions ape/modeling/ape_deta/deformable_detr_segm_vl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import random
import math
import os
import time
Expand Down Expand Up @@ -324,6 +325,12 @@ def forward(self, batched_inputs, do_postprocess=True):
: max(len(text_list), self.criterion[dataset_id].num_classes)
]

if self.text_feature_bank and self.text_feature_bank_random_size:
text_feature_bank_size = random.randint(len(text_list), len(features_l))
features_l = features_l[
: random.randint(len(text_list), len(features_l))
]

if self.text_feature_batch_repeat:
features_l = features_l.unsqueeze(0).repeat(len(batched_inputs), 1, 1)
else:
Expand Down
89 changes: 55 additions & 34 deletions ape/modeling/ape_deta/deformable_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

from ape.layers import MultiScaleDeformableAttention
from detrex.layers import (
Expand Down Expand Up @@ -59,12 +60,7 @@ def __init__(
else:
self.post_norm_layer = None

if use_act_checkpoint:
from fairscale.nn.checkpoint import checkpoint_wrapper

for i, layer in enumerate(self.layers):
layer = checkpoint_wrapper(layer)
self.layers[i] = layer
self.use_checkpoint = use_act_checkpoint

def forward(
self,
Expand All @@ -80,16 +76,30 @@ def forward(
):

for layer in self.layers:
query = layer(
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs,
)
if self.use_checkpoint and self.training:
query = checkpoint.checkpoint(
layer,
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
use_reentrant=False,
**kwargs,
)
else:
query = layer(
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs,
)

if self.post_norm_layer is not None:
query = self.post_norm_layer(query)
Expand Down Expand Up @@ -144,12 +154,7 @@ def __init__(
self.bbox_embed = None
self.class_embed = None

if use_act_checkpoint:
from fairscale.nn.checkpoint import checkpoint_wrapper

for i, layer in enumerate(self.layers):
layer = checkpoint_wrapper(layer)
self.layers[i] = layer
self.use_checkpoint = use_act_checkpoint

def forward(
self,
Expand Down Expand Up @@ -179,18 +184,34 @@ def forward(
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]

output = layer(
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
**kwargs,
)
if self.use_checkpoint and self.training:
output = checkpoint.checkpoint(
layer,
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
use_reentrant=False,
**kwargs,
)
else:
output = layer(
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
**kwargs,
)

if self.bbox_embed is not None:
tmp = self.bbox_embed[layer_idx](output)
Expand Down
91 changes: 56 additions & 35 deletions ape/modeling/ape_deta/deformable_transformer_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

from ape.layers import MultiScaleDeformableAttention
from detrex.layers import (
Expand All @@ -28,7 +29,7 @@ def __init__(
post_norm: bool = False,
num_feature_levels: int = 4,
vl_layer=None,
use_act_checkpoint=False,
use_act_checkpoint: bool = False,
pytorch_attn=False,
):
super(DeformableDetrTransformerEncoderVL, self).__init__(
Expand Down Expand Up @@ -63,12 +64,7 @@ def __init__(

self.vl_layers = nn.ModuleList([copy.deepcopy(vl_layer) for _ in range(num_layers)])

if use_act_checkpoint:
from fairscale.nn.checkpoint import checkpoint_wrapper

for i, layer in enumerate(self.layers):
layer = checkpoint_wrapper(layer)
self.layers[i] = layer
self.use_checkpoint = use_act_checkpoint

def forward(
self,
Expand All @@ -93,16 +89,30 @@ def forward(
attention_mask_v=query_key_padding_mask,
attention_mask_l=attention_mask_l,
)
query = layer(
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs,
)
if self.use_checkpoint and self.training:
query = checkpoint.checkpoint(
layer,
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
use_reentrant=False,
**kwargs,
)
else:
query = layer(
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs,
)

if self.post_norm_layer is not None:
query = self.post_norm_layer(query)
Expand Down Expand Up @@ -160,12 +170,7 @@ def __init__(
self.bbox_embed = None
self.class_embed = None

if use_act_checkpoint:
from fairscale.nn.checkpoint import checkpoint_wrapper

for i, layer in enumerate(self.layers):
layer = checkpoint_wrapper(layer)
self.layers[i] = layer
self.use_checkpoint = use_act_checkpoint

self.look_forward_twice = look_forward_twice

Expand Down Expand Up @@ -197,18 +202,34 @@ def forward(
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]

output = layer(
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
**kwargs,
)
if self.use_checkpoint and self.training:
output = checkpoint.checkpoint(
layer,
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
use_reentrant=False,
**kwargs,
)
else:
output = layer(
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
**kwargs,
)

if self.bbox_embed is not None:
tmp = self.bbox_embed[layer_idx](output)
Expand Down
6 changes: 4 additions & 2 deletions ape/modeling/backbone/vit_eva_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def forward(self, x: torch.Tensor):
from apex.normalization import FusedLayerNorm
except:
FusedLayerNorm = LayerNorm
print("apex.normalization.FusedLayerNorm not found, will use pytorch implementations")
# print("apex.normalization.FusedLayerNorm not found, will use pytorch implementations")

has_sdp_kernel = hasattr(torch.backends.cuda, "sdp_kernel")


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -256,7 +258,7 @@ def forward(self, x, rel_pos_bias=None, attn_mask=None):
q = self.rope(q).type_as(v)
k = self.rope(k).type_as(v)

if True:
if has_sdp_kernel:
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.xattn_drop, scale=self.scale)
x = x.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
Expand Down
Loading

0 comments on commit cd8ce26

Please sign in to comment.