Skip to content

Commit

Permalink
remove not used code to support higher pytorch version, and fix the c…
Browse files Browse the repository at this point in the history
…ode of using cpu device
  • Loading branch information
wangyingming committed Nov 2, 2021
1 parent 6274fca commit 582a937
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
4 changes: 2 additions & 2 deletions datasets/data_prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, loader, device, prefetch=True):
self.loader = iter(loader)
self.prefetch = prefetch
self.device = device
if prefetch:
if prefetch and self.device=='cuda':
self.stream = torch.cuda.Stream()
self.preload()

Expand Down Expand Up @@ -50,7 +50,7 @@ def preload(self):
# else:

def next(self):
if self.prefetch:
if self.prefetch and self.device=='cuda':
torch.cuda.current_stream().wait_stream(self.stream)
samples = self.next_samples
targets = self.next_targets
Expand Down
17 changes: 0 additions & 17 deletions models/row_column_decoupled_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from torch.nn import grad # noqa: F401

from torch._jit_internal import boolean_dispatch, List, Optional, _overload
from torch._overrides import has_torch_function, handle_torch_function


Tensor = torch.Tensor

Expand Down Expand Up @@ -107,21 +105,6 @@ def multi_head_rcda_forward(query_row, # type: Tensor
- attn_output_weights: :math:`(N, L, HW)` where N is the batch size,
L is the target sequence length, HW is the source sequence length.
"""
if not torch.jit.is_scripting():
tens_ops = (query_row,query_col, key_row, key_col, value, in_proj_weight, in_proj_bias, bias_k_row,bias_k_col, bias_v,
out_proj_weight, out_proj_bias)
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
return handle_torch_function(
multi_head_rcda_forward, tens_ops, query_row,query_col, key_row, key_col, value,
embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
bias_k_row,bias_k_col, bias_v, add_zero_attn, dropout_p, out_proj_weight,
out_proj_bias, training=training, key_padding_mask=key_padding_mask,
need_weights=need_weights, attn_mask=attn_mask,
use_separate_proj_weight=use_separate_proj_weight,
q_row_proj_weight=q_row_proj_weight, q_col_proj_weight=q_col_proj_weight,
k_row_proj_weight=k_row_proj_weight, k_col_proj_weight=k_col_proj_weight,
v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)


bsz, tgt_len, embed_dim = query_row.size()
src_len_row = key_row.size()[2]
Expand Down

0 comments on commit 582a937

Please sign in to comment.