Skip to content

Commit

Permalink
fix VideoDataset bug when worker > 0
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyue-zephyrus committed Jun 17, 2019
1 parent 9294377 commit 1b31844
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions mmaction/datasets/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def label(self):
return int(self._data[1])



class VideoDataset(Dataset):
def __init__(self,
ann_file,
Expand Down Expand Up @@ -173,9 +172,9 @@ def _set_group_flag(self):
# if img_info['width'] / img_info['height'] > 1:
self.flag[i] = 1

def _load_image(self, directory, modality, idx):
def _load_image(self, video_reader, directory, modality, idx):
if modality in ['RGB', 'RGBDiff']:
return [self.video_frames[idx - 1]]
return [video_reader[idx - 1]]
elif modality == 'Flow':
raise NotImplementedError
else:
Expand Down Expand Up @@ -242,7 +241,8 @@ def _get_test_indices(self, record):
self.old_length // self.new_step, dtype=int)
return offsets + 1, skip_offsets

def _get_frames(self, record, image_tmpl, modality, indices, skip_offsets):
def _get_frames(self, record, video_reader, image_tmpl,
modality, indices, skip_offsets):
if self.use_decord:
if modality not in ['RGB', 'RGBDiff']:
raise NotImplementedError
Expand All @@ -253,22 +253,22 @@ def _get_frames(self, record, image_tmpl, modality, indices, skip_offsets):
# TODO: a more elegant way need!
while (attempts < 5):
try:
self.video_reader.seek(p)
video_reader.seek(p)
break
except:
except EOFError:
attempts += 1
p -= 1
for i, ind in enumerate(
range(0, self.old_length, self.new_step)):
if (skip_offsets[i] > 0 and
p + skip_offsets[i] <= record.num_frames):
self.video_reader.skip_frames(skip_offsets[i])
seg_imgs = [self.video_reader.next().asnumpy()]
video_reader.skip_frames(skip_offsets[i])
seg_imgs = [video_reader.next().asnumpy()]
else:
seg_imgs = [self.video_reader.next().asnumpy()]
seg_imgs = [video_reader.next().asnumpy()]
images.extend(seg_imgs)
if p + self.new_step < record.num_frames:
self.video_reader.skip_frames(self.new_step)
video_reader.skip_frames(self.new_step)
return images
else:
images = list()
Expand All @@ -278,10 +278,12 @@ def _get_frames(self, record, image_tmpl, modality, indices, skip_offsets):
range(0, self.old_length, self.new_step)):
if p + skip_offsets[i] <= record.num_frames:
seg_imgs = self._load_image(
video_reader,
osp.join(self.img_prefix, record.path),
modality, p + skip_offsets[i])
else:
seg_imgs = self._load_image(
video_reader,
osp.join(self.img_prefix, record.path),
modality, p)
images.extend(seg_imgs)
Expand All @@ -292,13 +294,13 @@ def _get_frames(self, record, image_tmpl, modality, indices, skip_offsets):
def __getitem__(self, idx):
record = self.video_infos[idx]
if self.use_decord:
self.video_reader = decord.VideoReader('{}.{}'.format(
video_reader = decord.VideoReader('{}.{}'.format(
osp.join(self.img_prefix, record.path), self.video_ext))
record.num_frames = len(self.video_reader)
record.num_frames = len(video_reader)
else:
self.video_frames = mmcv.VideoReader('{}.{}'.format(
video_reader = mmcv.VideoReader('{}.{}'.format(
osp.join(self.img_prefix, record.path), self.video_ext))
record.num_frames = len(self.video_frames)
record.num_frames = len(video_reader)
if self.test_mode:
segment_indices, skip_offsets = self._get_test_indices(record)
else:
Expand All @@ -312,8 +314,8 @@ def __getitem__(self, idx):
# handle the first modality
modality = self.modalities[0]
image_tmpl = self.image_tmpls[0]
img_group = self._get_frames(
record, image_tmpl, modality, segment_indices, skip_offsets)
img_group = self._get_frames(record, video_reader, image_tmpl,
modality, segment_indices, skip_offsets)

flip = True if np.random.rand() < self.flip_ratio else False
if (self.img_scale_dict is not None and
Expand Down

0 comments on commit 1b31844

Please sign in to comment.