Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization for online feature extraction of long recordings #3038

Merged
merged 7 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/feat/feature-window.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct FrameExtractionOptions {
bool snip_edges;
bool allow_downsample;
bool allow_upsample;
int max_feature_vectors;
FrameExtractionOptions():
samp_freq(16000),
frame_shift_ms(10.0),
Expand All @@ -61,6 +62,7 @@ struct FrameExtractionOptions {
blackman_coeff(0.42),
snip_edges(true),
allow_downsample(false),
max_feature_vectors(-1),
allow_upsample(false) { }

void Register(OptionsItf *opts) {
Expand Down Expand Up @@ -92,6 +94,10 @@ struct FrameExtractionOptions {
opts->Register("allow-downsample", &allow_downsample,
"If true, allow the input waveform to have a higher frequency than "
"the specified --sample-frequency (and we'll downsample).");
opts->Register("max-feature-vectors", &max_feature_vectors,
"Memory optimization. If larger than 0, periodically remove feature "
"vectors so that only this number of the latest feature vectors is "
"retained.");
opts->Register("allow-upsample", &allow_upsample,
"If true, allow the input waveform to have a lower frequency than "
"the specified --sample-frequency (and we'll upsample).");
Expand Down
40 changes: 40 additions & 0 deletions src/feat/online-feature-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,45 @@ void TestOnlineAppendFeature() {
}
}

void TestRecyclingVector() {
RecyclingVector full_vec;
RecyclingVector shrinking_vec(10);
for (int i = 0; i != 100; ++i) {
Vector <BaseFloat> data(1);
data.Set(i);
full_vec.PushBack(new Vector<BaseFloat>(data));
shrinking_vec.PushBack(new Vector<BaseFloat>(data));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add KALDI_ASSERTs for Size() at this point. The size is computed non-trivially in the end.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

KALDI_ASSERT(full_vec.Size() == 100);
KALDI_ASSERT(shrinking_vec.Size() == 100);

// full_vec should contain everything
for (int i = 0; i != 100; ++i) {
Vector <BaseFloat> *data = full_vec.At(i);
KALDI_ASSERT(data != nullptr);
KALDI_ASSERT((*data)(0) == static_cast<BaseFloat>(i));
}

// shrinking_vec may throw an exception for the first 90 elements
int caught_exceptions = 0;
for (int i = 0; i != 90; ++i) {
try {
shrinking_vec.At(i);
} catch (const std::runtime_error &) {
++caught_exceptions;
}
}
// it may actually store a bit more elements for performance efficiency considerations
KALDI_ASSERT(caught_exceptions >= 80);

// shrinking_vec should contain the last 10 elements
for (int i = 90; i != 100; ++i) {
Vector <BaseFloat> *data = shrinking_vec.At(i);
KALDI_ASSERT(data != nullptr);
KALDI_ASSERT((*data)(0) == static_cast<BaseFloat>(i));
}
}

} // end namespace kaldi

int main() {
Expand All @@ -387,6 +426,7 @@ int main() {
TestOnlinePlp();
TestOnlineTransform();
TestOnlineAppendFeature();
TestRecyclingVector();
}
std::cout << "Test OK.\n";
}
47 changes: 41 additions & 6 deletions src/feat/online-feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,54 @@

namespace kaldi {

RecyclingVector::RecyclingVector(int items_to_hold) :
items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold),
first_available_index_(0) {
}

RecyclingVector::~RecyclingVector() {
for (auto *item : items_) {
delete item;
}
}

Vector<BaseFloat> *RecyclingVector::At(int index) const {
if (index < first_available_index_) {
KALDI_ERR << "Attempted to retrieve feature vector that was "
"already removed by the RecyclingVector (index = " << index << "; "
<< "first_available_index = " << first_available_index_ << "; "
<< "size = " << Size() << ")";
}
// 'at' does size checking.
return items_.at(index - first_available_index_);
}

void RecyclingVector::PushBack(Vector<BaseFloat> *item) {
if (items_.size() == items_to_hold_) {
delete items_.front();
items_.pop_front();
++first_available_index_;
}
items_.push_back(item);
}

int RecyclingVector::Size() const {
return first_available_index_ + items_.size();
}


template<class C>
void OnlineGenericBaseFeature<C>::GetFrame(int32 frame,
VectorBase<BaseFloat> *feat) {
// 'at' does size checking.
feat->CopyFromVec(*(features_.at(frame)));
feat->CopyFromVec(*(features_.At(frame)));
};

template<class C>
OnlineGenericBaseFeature<C>::OnlineGenericBaseFeature(
const typename C::Options &opts):
computer_(opts), window_function_(computer_.GetFrameOptions()),
input_finished_(false), waveform_offset_(0) { }
input_finished_(false), waveform_offset_(0),
features_(opts.frame_opts.max_feature_vectors) { }

template<class C>
void OnlineGenericBaseFeature<C>::AcceptWaveform(BaseFloat sampling_rate,
Expand Down Expand Up @@ -63,11 +99,10 @@ template<class C>
void OnlineGenericBaseFeature<C>::ComputeFeatures() {
const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions();
int64 num_samples_total = waveform_offset_ + waveform_remainder_.Dim();
int32 num_frames_old = features_.size(),
int32 num_frames_old = features_.Size(),
num_frames_new = NumFrames(num_samples_total, frame_opts,
input_finished_);
KALDI_ASSERT(num_frames_new >= num_frames_old);
features_.resize(num_frames_new, NULL);

Vector<BaseFloat> window;
bool need_raw_log_energy = computer_.NeedRawLogEnergy();
Expand All @@ -81,7 +116,7 @@ void OnlineGenericBaseFeature<C>::ComputeFeatures() {
// note: this online feature-extraction code does not support VTLN.
BaseFloat vtln_warp = 1.0;
computer_.Compute(raw_log_energy, vtln_warp, &window, this_feature);
features_[frame] = this_feature;
features_.PushBack(this_feature);
}
// OK, we will now discard any portion of the signal that will not be
// necessary to compute frames in the future.
Expand Down
38 changes: 32 additions & 6 deletions src/feat/online-feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ namespace kaldi {
/// @{


/// This class serves as a storage for feature vectors with an option to limit
/// the memory usage by removing old elements. The deleted frames indices are
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
/// provides the indices as if no deletion was being performed.
/// This is useful when processing very long recordings which would otherwise
/// cause the memory to eventually blow up when the features are not being removed.
class RecyclingVector {
public:
/// By default it does not remove any elements.
RecyclingVector(int items_to_hold = -1);

/// The ownership is being retained by this collection - do not delete the item.
Vector<BaseFloat> *At(int index) const;

/// The ownership of the item is passed to this collection - do not delete the item.
void PushBack(Vector<BaseFloat> *item);

/// This method returns the size as if no "recycling" had happened,
/// i.e. equivalent to the number of times the PushBack method has been called.
int Size() const;

~RecyclingVector();

private:
std::deque<Vector<BaseFloat>*> items_;
int items_to_hold_;
int first_available_index_;
};


/// This is a templated class for online feature extraction;
/// it's templated on a class like MfccComputer or PlpComputer
/// that does the basic feature extraction.
Expand All @@ -61,7 +91,7 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {
return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
}

virtual int32 NumFramesReady() const { return features_.size(); }
virtual int32 NumFramesReady() const { return features_.Size(); }

virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);

Expand All @@ -88,10 +118,6 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {
ComputeFeatures();
}

~OnlineGenericBaseFeature() {
DeletePointers(&features_);
}

private:
// This function computes any additional feature frames that it is possible to
// compute from 'waveform_remainder_', which at this point may contain more
Expand All @@ -107,7 +133,7 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {

// features_ is the Mfcc or Plp or Fbank features that we have already computed.

std::vector<Vector<BaseFloat>*> features_;
RecyclingVector features_;

// True if the user has called "InputFinished()"
bool input_finished_;
Expand Down