Skip to content

Commit

Permalink
[src] In 'OnlineNnet2FeaturePipelineInfo', pass in global CMVN stats …
Browse files Browse the repository at this point in the history
…as value not filename. (kaldi-asr#4264)

- so that the global_cmvn_stats are kaldi::Matrix and not a std::string with filename
- this allows to assign the matrix directly from C++ without relying on OS file system
- this is handy for kaldi integrators, but it is a potentially non backward compatible change

Co-authored-by: Karel Vesely <[email protected]>
  • Loading branch information
KarelVesely84 and Karel Vesely committed Sep 15, 2020
1 parent 2b627b3 commit e41ba8e
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 43 deletions.
10 changes: 4 additions & 6 deletions src/cudafeat/online-batched-feature-pipeline-cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,10 @@ OnlineBatchedFeaturePipelineCuda::OnlineBatchedFeaturePipelineCuda(
}

if (info_.use_cmvn) {
KALDI_ASSERT(info_.global_cmvn_stats_rxfilename != "");

Matrix<double> global_cmvn_stats;
ReadKaldiObject(info_.global_cmvn_stats_rxfilename, &global_cmvn_stats);

OnlineCmvnState cmvn_state(global_cmvn_stats);
if (info_.global_cmvn_stats.NumCols() == 0) {
KALDI_ERR << "global_cmvn_stats for OnlineCmvn must be non-empty.";
}
OnlineCmvnState cmvn_state(info_.global_cmvn_stats);
CudaOnlineCmvnState cu_cmvn_state(cmvn_state);

// TODO do we want to parameterize stats coarsening factor?
Expand Down
11 changes: 6 additions & 5 deletions src/cudafeat/online-cuda-feature-pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ OnlineCudaFeaturePipeline::OnlineCudaFeaturePipeline(
}

if (info_.use_cmvn) {
KALDI_ASSERT(info_.global_cmvn_stats_rxfilename != "");
ReadKaldiObject(info_.global_cmvn_stats_rxfilename, &global_cmvn_stats);
OnlineCmvnState cmvn_state(global_cmvn_stats);
if (info_.global_cmvn_stats.NumCols() == 0) {
KALDI_ERR << "global_cmvn_stats for OnlineCmvn must be non-empty.";
}
OnlineCmvnState cmvn_state(info_.global_cmvn_stats);
CudaOnlineCmvnState cu_cmvn_state(cmvn_state);
cmvn = new CudaOnlineCmvn(info_.cmvn_opts, cu_cmvn_state);
}
}

if (info_.use_ivectors) {
OnlineIvectorExtractionConfig ivector_extraction_opts;
Expand All @@ -53,7 +54,7 @@ OnlineCudaFeaturePipeline::OnlineCudaFeaturePipeline(
ivector_extraction_opts.greedy_ivector_extractor = true;

ivector = new IvectorExtractorFastCuda(ivector_extraction_opts);
}
}
}

OnlineCudaFeaturePipeline::~OnlineCudaFeaturePipeline() {
Expand Down
12 changes: 7 additions & 5 deletions src/online2/online-nnet2-feature-pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ OnlineNnet2FeaturePipelineInfo::OnlineNnet2FeaturePipelineInfo(
use_cmvn = (config.cmvn_config != "");
if (use_cmvn) {
ReadConfigFromFile(config.cmvn_config, &cmvn_opts);
global_cmvn_stats_rxfilename = config.global_cmvn_stats_rxfilename;
if (global_cmvn_stats_rxfilename == "")
if (config.global_cmvn_stats_rxfilename == "")
KALDI_ERR << "--global-cmvn-stats option is required "
<< " when --cmvn-config is specified.";
ReadKaldiObject(config.global_cmvn_stats_rxfilename, &global_cmvn_stats);
}

if (config.ivector_extraction_config != "") {
Expand Down Expand Up @@ -119,9 +119,11 @@ OnlineNnet2FeaturePipeline::OnlineNnet2FeaturePipeline(
}

if (info_.use_cmvn) {
KALDI_ASSERT(info.global_cmvn_stats_rxfilename != "");
ReadKaldiObject(info.global_cmvn_stats_rxfilename, &global_cmvn_stats_);
OnlineCmvnState initial_state(global_cmvn_stats_);
if (info_.global_cmvn_stats.NumCols() == 0) {
KALDI_ERR << "global_cmvn_stats for OnlineCmvn must be non-empty, "
<< "please assign it to OnlineNnet2FeaturePipelineInfo.";
}
OnlineCmvnState initial_state(info_.global_cmvn_stats);
cmvn_feature_ = new OnlineCmvn(info_.cmvn_opts, initial_state,
feature_plus_optional_pitch_);
feature_plus_optional_cmvn_ = cmvn_feature_;
Expand Down
4 changes: 1 addition & 3 deletions src/online2/online-nnet2-feature-pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ struct OnlineNnet2FeaturePipelineInfo {
/// and the OnlineCmvn is added to the feature preparation pipeline.
bool use_cmvn;
OnlineCmvnOptions cmvn_opts; /// Options for online cmvn, read from config file.
std::string global_cmvn_stats_rxfilename; /// Filename used for reading global
/// cmvn stats in OnlineCmvn.
Matrix<double> global_cmvn_stats; /// Matrix with global cmvn stats in OnlineCmvn.

/// If the user specified --ivector-extraction-config, we assume we're using
/// iVectors as an extra input to the neural net. Actually, we don't
Expand Down Expand Up @@ -300,7 +299,6 @@ class OnlineNnet2FeaturePipeline: public OnlineFeatureInterface {

OnlineCmvn *cmvn_feature_;
Matrix<BaseFloat> lda_mat_; /// LDA matrix, if supplied
Matrix<double> global_cmvn_stats_; /// Global CMVN stats.

/// feature_plus_optional_pitch_ is the base_feature_ appended (OnlineAppendFeature)
/// with pitch_feature_, if used; otherwise, points to the same address as
Expand Down
12 changes: 6 additions & 6 deletions src/online2bin/online2-wav-nnet2-am-compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ int main(int argc, char *argv[]) {
bool pad_input = true;
bool online = true;

// feature_config includes configuration for the iVector adaptation,
// feature_opts includes configuration for the iVector adaptation,
// as well as the basic features.
OnlineNnet2FeaturePipelineConfig feature_config;
OnlineNnet2FeaturePipelineConfig feature_opts;
ParseOptions po(usage);
po.Register("apply-log", &apply_log, "Apply a log to the result of the computation "
"before outputting.");
Expand All @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) {
"in the file given to --ivector-extraction-config, and "
"--chunk-length=-1.");

feature_config.Register(&po);
feature_opts.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 4) {
po.PrintUsage();
Expand All @@ -82,16 +82,16 @@ int main(int argc, char *argv[]) {
wav_rspecifier = po.GetArg(3),
features_or_loglikes_wspecifier = po.GetArg(4);

OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
if (!online) {
feature_info.ivector_extractor_info.use_most_recent_ivector = true;
feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
chunk_length_secs = -1.0;
}

Matrix<double> global_cmvn_stats;
if (feature_info.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
if (feature_opts.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_opts.global_cmvn_stats_rxfilename,
&global_cmvn_stats);

TransitionModel trans_model;
Expand Down
12 changes: 6 additions & 6 deletions src/online2bin/online2-wav-nnet2-latgen-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ int main(int argc, char *argv[]) {

OnlineEndpointConfig endpoint_config;

// feature_config includes configuration for the iVector adaptation,
// feature_opts includes configuration for the iVector adaptation,
// as well as the basic features.
OnlineNnet2FeaturePipelineConfig feature_config;
OnlineNnet2FeaturePipelineConfig feature_opts;
OnlineNnet2DecodingConfig nnet2_decoding_config;

BaseFloat chunk_length_secs = 0.05;
Expand All @@ -127,7 +127,7 @@ int main(int argc, char *argv[]) {
po.Register("num-threads-startup", &g_num_threads,
"Number of threads used when initializing iVector extractor.");

feature_config.Register(&po);
feature_opts.Register(&po);
nnet2_decoding_config.Register(&po);
endpoint_config.Register(&po);

Expand All @@ -144,16 +144,16 @@ int main(int argc, char *argv[]) {
wav_rspecifier = po.GetArg(4),
clat_wspecifier = po.GetArg(5);

OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);
if (!online) {
feature_info.ivector_extractor_info.use_most_recent_ivector = true;
feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
chunk_length_secs = -1.0;
}

Matrix<double> global_cmvn_stats;
if (feature_info.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
if (feature_opts.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_opts.global_cmvn_stats_rxfilename,
&global_cmvn_stats);

TransitionModel trans_model;
Expand Down
12 changes: 6 additions & 6 deletions src/online2bin/online2-wav-nnet2-latgen-threaded.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ int main(int argc, char *argv[]) {

OnlineEndpointConfig endpoint_config;

// feature_config includes configuration for the iVector adaptation,
// feature_opts includes configuration for the iVector adaptation,
// as well as the basic features.
OnlineNnet2FeaturePipelineConfig feature_config;
OnlineNnet2FeaturePipelineConfig feature_opts;
OnlineNnet2DecodingThreadedConfig nnet2_decoding_config;

BaseFloat chunk_length_secs = 0.05;
Expand Down Expand Up @@ -131,7 +131,7 @@ int main(int argc, char *argv[]) {
po.Register("num-threads-startup", &g_num_threads,
"Number of threads used when initializing iVector extractor. ");

feature_config.Register(&po);
feature_opts.Register(&po);
nnet2_decoding_config.Register(&po);
endpoint_config.Register(&po);

Expand All @@ -148,16 +148,16 @@ int main(int argc, char *argv[]) {
wav_rspecifier = po.GetArg(4),
clat_wspecifier = po.GetArg(5);

OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
OnlineNnet2FeaturePipelineInfo feature_info(feature_opts);

if (modify_ivector_config) {
feature_info.ivector_extractor_info.use_most_recent_ivector = true;
feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
}

Matrix<double> global_cmvn_stats;
if (feature_info.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
if (feature_opts.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_opts.global_cmvn_stats_rxfilename,
&global_cmvn_stats);

TransitionModel trans_model;
Expand Down
4 changes: 2 additions & 2 deletions src/online2bin/online2-wav-nnet3-latgen-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ int main(int argc, char *argv[]) {
}

Matrix<double> global_cmvn_stats;
if (feature_info.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
if (feature_opts.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_opts.global_cmvn_stats_rxfilename,
&global_cmvn_stats);

TransitionModel trans_model;
Expand Down
4 changes: 2 additions & 2 deletions src/online2bin/online2-wav-nnet3-latgen-grammar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ int main(int argc, char *argv[]) {
}

Matrix<double> global_cmvn_stats;
if (feature_info.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
if (feature_opts.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_opts.global_cmvn_stats_rxfilename,
&global_cmvn_stats);

TransitionModel trans_model;
Expand Down
4 changes: 2 additions & 2 deletions src/online2bin/online2-wav-nnet3-wake-word-decoder-faster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ int main(int argc, char *argv[]) {
}

Matrix<double> global_cmvn_stats;
if (feature_info.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_info.global_cmvn_stats_rxfilename,
if (feature_opts.global_cmvn_stats_rxfilename != "")
ReadKaldiObject(feature_opts.global_cmvn_stats_rxfilename,
&global_cmvn_stats);

TransitionModel trans_model;
Expand Down

0 comments on commit e41ba8e

Please sign in to comment.