-
-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify video_domain_adapter #292
base: main
Are you sure you want to change the base?
Changes from 1 commit
7ccd345
d955f73
1cecdf2
f9d0577
046ef98
77f1b0f
8a8581b
23b0e8e
f993f8d
60951d4
76f3e72
feaf72a
f5bc2b7
63c5be9
f89d8fc
b845a88
ef74b72
b43802c
ba6f5c5
bdf9cbb
3ea4678
1540051
de0e6cd
cf1638b
a2b3ce8
4470413
37aeaac
a95a185
ab23896
40861fc
dc4b990
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from PIL import Image | ||
|
||
|
@@ -16,14 +17,13 @@ class VideoRecord(object): | |
represents a video sample's metadata. | ||
|
||
Args: | ||
root_datapath: the system path to the root folder | ||
of the videos. | ||
row: A list with four or more elements where 1) The first | ||
element is the path to the video sample's frames excluding | ||
the root_datapath prefix 2) The second element is the starting frame id of the video | ||
3) The third element is the inclusive ending frame id of the video | ||
4) The fourth element is the label index. | ||
5) any following elements are labels in the case of multi-label classification | ||
root_datapath (Path, optional): the system path to the root folder of the videos. | ||
row (tuple, optional): A list with four or more elements where | ||
1) The first element is the path to the video sample's frames excluding the root_datapath prefix. | ||
2) The second element is the starting frame id of the video. | ||
3) The third element is the inclusive ending frame id of the video. | ||
4) The fourth element is the label index. | ||
5) Any following elements are labels in the case of multi-label classification. | ||
""" | ||
|
||
def __init__(self, row, root_datapath): | ||
|
@@ -56,6 +56,35 @@ def label(self): | |
return [int(label_id) for label_id in self._data[3:]] | ||
|
||
|
||
class VideoFeatureRecord(object): | ||
""" | ||
Helper class for class VideoFeatureDataset. This class represents a video feature vector. | ||
|
||
Args: | ||
index (int): the index of the video feature vector. | ||
row (pandas.Series, optional): A series with information of feature vector. | ||
num_segments (int): the number of segments to split the video into. | ||
""" | ||
|
||
def __init__(self, index, row, num_segments): | ||
self._data = row | ||
self._index = index | ||
self._n_seg = num_segments | ||
|
||
@property | ||
def num_frames(self): | ||
return int(self._n_seg) | ||
|
||
@property | ||
def label(self): | ||
if ("verb_class" in self._data) and ("noun_class" in self._data): | ||
return int(self._data.verb_class), int(self._data.noun_class) | ||
elif ("verb_class" in self._data) and ("noun_class" not in self._data): | ||
return [int(self._data.verb_class)] | ||
else: | ||
return 0, 0 | ||
|
||
|
||
class VideoFrameDataset(torch.utils.data.Dataset): | ||
r""" | ||
A highly efficient and adaptable dataset class for videos. | ||
|
@@ -97,44 +126,47 @@ class VideoFrameDataset(torch.utils.data.Dataset): | |
might be ``jumping\0052\`` or ``sample1\`` or ``00053\``. | ||
|
||
Args: | ||
root_path: The root path in which video folders lie. | ||
root_path (str, Path): root path in which video folders lie. | ||
this is ROOT_DATA from the description above. | ||
annotationfile_path: The .txt annotation file containing | ||
annotationfile_path (str, Path): .txt annotation file containing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
one row per video sample as described above. | ||
image_modality: Image modality (RGB or Optical Flow). | ||
num_segments: The number of segments the video should | ||
be divided into to sample frames from. | ||
frames_per_segment: The number of frames that should | ||
image_modality (str): image modality (RGB or Optical Flow). | ||
num_segments (int): number of segments the video should be divided into to sample frames from. | ||
Default is 1 in image mode and 5 in feature vector mode. | ||
frames_per_segment (int): number of frames that should | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The lines seem do not exceed 120 characters. Please check the same problems for the other. |
||
be loaded per segment. For each segment's | ||
frame-range, a random start index or the | ||
center is chosen, from which frames_per_segment | ||
consecutive frames are loaded. | ||
imagefile_template: The image filename template that video frame files | ||
imagefile_template (str): image filename template that video frame files | ||
have inside of their video folders as described above. | ||
transform: Transform pipeline that receives a list of PIL images/frames. | ||
random_shift: Whether the frames from each segment should be taken | ||
transform (Compose, optional): transform pipeline that receives a list of PIL images/frames. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torchvision.transforms.Compose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
random_shift (bool): whether the frames from each segment should be taken | ||
consecutively starting from the center of the segment, or | ||
consecutively starting from a random location inside the | ||
segment range. | ||
test_mode: Whether this is a test dataset. If so, chooses | ||
test_mode (bool): whether this is a test dataset. If so, chooses | ||
frames from segments with random_shift=False. | ||
input_type (str): type of input. (options: 'image' or 'feature') | ||
num_data_load (int): number of the data to load. (only used in feature vector mode) | ||
total_segments (int): total number of segments a video is divided into. (only used in feature vector mode) | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
root_path: str, | ||
annotationfile_path: str, | ||
image_modality: str = "rgb", | ||
num_segments: int = 3, | ||
frames_per_segment: int = 1, | ||
imagefile_template: str = "img_{:05d}.jpg", | ||
root_path, | ||
annotationfile_path, | ||
image_modality="rgb", | ||
num_segments=1, | ||
frames_per_segment=1, | ||
imagefile_template="img_{:05d}.jpg", | ||
transform=None, | ||
random_shift: bool = True, | ||
test_mode: bool = False, | ||
input_type: str = "image", | ||
num_data_load: int = None, | ||
total_segments: int = 25, | ||
random_shift=True, | ||
test_mode=False, | ||
input_type="image", | ||
num_data_load=None, | ||
total_segments=25, | ||
): | ||
super(VideoFrameDataset, self).__init__() | ||
|
||
|
@@ -196,7 +228,19 @@ def _read_feature_vector(self): | |
self._data = dict(zip(data_narrations, data_features)) | ||
|
||
def _parse_list(self): | ||
self.video_list = [VideoRecord(x.strip().split(" "), self.root_path) for x in open(self.annotationfile_path)] | ||
if self.input_type == "image": | ||
self.video_list = [ | ||
VideoRecord(x.strip().split(" "), self.root_path) for x in open(self.annotationfile_path) | ||
] | ||
elif self.input_type == "feature": | ||
label_file = pd.read_pickle(self.annotationfile_path).reset_index() | ||
self.video_list = [ | ||
VideoFeatureRecord(i, row[1], self.total_segments) for i, row in enumerate(label_file.iterrows()) | ||
] | ||
# repeat the list if the length is less than num_data_load (especially for target data) | ||
n_repeat = self.num_data_load // len(self.video_list) | ||
n_left = self.num_data_load % len(self.video_list) | ||
self.video_list = self.video_list * n_repeat + self.video_list[:n_left] | ||
|
||
def _get_random_indices(self, record): | ||
""" | ||
|
@@ -311,8 +355,11 @@ def _get(self, record, indices): | |
seg_img = self._load_feature_vector(frame_index, record.segment_id) | ||
images.extend(seg_img) | ||
image_indices.append(frame_index) | ||
if frame_index < record.end_frame: | ||
frame_index += 1 | ||
|
||
if self.input_type == "image": | ||
frame_index = frame_index + 1 if frame_index < record.end_frame else frame_index | ||
else: # feature vector does not have record.end_frame. | ||
frame_index = frame_index + 1 if frame_index < record.num_frames else frame_index | ||
|
||
if self.input_type == "image" and self.transform is not None: | ||
images = self.transform(images) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
Path
a Python variable type?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to pathlib.Path.