Skip to content

Commit

Permalink
Fix demo/video_gpuaccel_demo.py scripts (#10568)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjkkkjjj committed Jun 30, 2023
1 parent b8e4573 commit c5c8aa0
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions demo/video_gpuaccel_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
data = test_pipeline(data)
_, data_sample = model.data_preprocessor([data], False)
data['inputs'] = [data['inputs']]
data['data_samples'] = [data['data_samples']]
data_sample = model.data_preprocessor(data, False)['data_samples']
batch_input_shape = data_sample[0].batch_input_shape
return batch_input_shape

Expand All @@ -69,8 +71,8 @@ def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
'scale_factor': (batch_input_shape[0] / ori_shape[0],
batch_input_shape[1] / ori_shape[1])
})
frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
data = {'inputs': frame_resize, 'data_sample': data_sample}
frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1)).cuda()
data = {'inputs': [frame_resize], 'data_samples': [data_sample]}
return data


Expand Down Expand Up @@ -112,7 +114,7 @@ def main():
for i, (frame_resize, frame_origin) in enumerate(
zip(track_iter_progress(video_resize), video_origin)):
data = pack_data(frame_resize, batch_input_shape, ori_shape)
result = model.test_step([data])[0]
result = model.test_step(data)[0]

visualizer.add_datasample(
name='video',
Expand Down

0 comments on commit c5c8aa0

Please sign in to comment.