Skip to content

Commit

Permalink
[src] Refactor online decoder; get grammar decoding work in online case.
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Sep 1, 2018
1 parent a39b15c commit d0c68a6
Show file tree
Hide file tree
Showing 18 changed files with 897 additions and 1,609 deletions.
10 changes: 6 additions & 4 deletions src/decoder/grammar-fst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,10 @@ GrammarFst::ExpandedState *GrammarFst::ExpandStateUserDefined(
}


void GrammarFst::Write(std::ostream &os) const {
void GrammarFst::Write(std::ostream &os, bool binary) const {
using namespace kaldi;
bool binary = true;
if (!binary)
KALDI_ERR << "GrammarFst::Write only supports binary mode.";
int32 format = 1,
num_ifsts = ifsts_.size();
WriteToken(os, binary, "<GrammarFst>");
Expand Down Expand Up @@ -414,11 +415,12 @@ static ConstFst<StdArc> *ReadConstFstFromStream(std::istream &is) {



void GrammarFst::Read(std::istream &is) {
void GrammarFst::Read(std::istream &is, bool binary) {
using namespace kaldi;
if (!binary)
KALDI_ERR << "GrammarFst::Read only supports binary mode.";
if (top_fst_ != NULL)
Destroy();
bool binary = true;
int32 format = 1, num_ifsts;
ExpectToken(is, binary, "<GrammarFst>");
ReadBasicType(is, binary, &format);
Expand Down
10 changes: 6 additions & 4 deletions src/decoder/grammar-fst.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,13 @@ class GrammarFst {
GrammarFst(): top_fst_(NULL) { }

// This Write function allows you to dump a GrammarFst to disk as a single
// object. It only supports binary mode.
void Write(std::ostream &os) const;
// object. It only supports binary mode, but the option is allowed for
// compatibility with other Kaldi read/write functions (it will crash if
// binary == true).
void Write(std::ostream &os, bool binary) const;

// Reads the format that Write() outputs.
void Read(std::istream &os);
// Reads the format that Write() outputs. Will crash if binary is false.
void Read(std::istream &os, bool binary);

StateId Start() const {
// the top 32 bits of the 64-bit state-id will be zero.
Expand Down
179 changes: 101 additions & 78 deletions src/decoder/lattice-faster-decoder.cc

Large diffs are not rendered by default.

190 changes: 138 additions & 52 deletions src/decoder/lattice-faster-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,124 @@ struct LatticeFasterDecoderConfig {
}
};

namespace decoder {
// We will template the decoder on the token type as well as the FST type; this
// is a mechanism so that we can use the same underlying decoder code for
// versions of the decoder that support quickly getting the best path
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
// those that do not (LatticeFasterDecoder).


// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
template <typename Token>
struct ForwardLink {
using Label = fst::StdArc::Label;

Token *next_tok; // the next token [or NULL if represents final-state]
Label ilabel; // ilabel on arc
Label olabel; // olabel on arc
BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
ForwardLink *next; // next in singly-linked list of forward arcs (arcs
// in the state-level lattice) from a token.
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
BaseFloat graph_cost, BaseFloat acoustic_cost,
ForwardLink *next):
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
next(next) { }
};


struct StdToken {
using ForwardLinkT = ForwardLink<StdToken>;
using Token = StdToken;

// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.

// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;

// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat extra_cost;

// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;

//'next' is the next in the singly-linked list of tokens for this frame.
Token *next;

// This function does nothing and should be optimized out; it's needed
// so we can share the regular LatticeFasterDecoderTpl code and the code
// for LatticeFasterOnlineDecoder that supports fast traceback.
inline void SetBackpointer (Token *backpointer) { }

// This constructor just ignores the 'backpointer' argument. That argument is
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
// fast way to obtain the best path).
inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
};

struct BackpointerToken {
using ForwardLinkT = ForwardLink<BackpointerToken>;
using Token = BackpointerToken;

// BackpointerToken is like Token but also
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.

// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;

// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat extra_cost;

// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;

//'next' is the next in the singly-linked list of tokens for this frame.
BackpointerToken *next;

// Best preceding BackpointerToken (could be a on this frame, connected to
// this via an epsilon transition, or on a previous frame). This is only
// required for an efficient GetBestPath function in
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
// (the "links" list is what stores the forward links, for that).
Token *backpointer;

inline void SetBackpointer (Token *backpointer) {
this->backpointer = backpointer;
}

inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next),
backpointer(backpointer) { }
};

} // namespace decoder


/** This is the "normal" lattice-generating decoder.
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
Expand All @@ -106,14 +224,14 @@ struct LatticeFasterDecoderConfig {
without having to know at compile time.
*/

template <typename FST = fst::StdFst>
template <typename FST, typename Token = decoder::StdToken>
class LatticeFasterDecoderTpl {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using ForwardLinkT = decoder::ForwardLink<Token>;

// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
Expand Down Expand Up @@ -164,8 +282,13 @@ class LatticeFasterDecoderTpl {
/// of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
/// The raw lattice will be topologically sorted.
bool GetRawLattice(Lattice *ofst,
bool use_final_probs = true) const;
///
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
/// which also supports a pruning beam, in case for some reason
/// you want it pruned tighter than the regular lattice beam.
/// We could put that here in future needed.
bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;



/// [Deprecated, users should now use GetRawLattice and determinize it
Expand Down Expand Up @@ -220,53 +343,13 @@ class LatticeFasterDecoderTpl {
// whenever we call ProcessEmitting().
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }

private:
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
struct Token;
struct ForwardLink {
Token *next_tok; // the next token [or NULL if represents final-state]
Label ilabel; // ilabel on link.
Label olabel; // olabel on link.
BaseFloat graph_cost; // graph cost of traversing link (contains LM, etc.)
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing link
ForwardLink *next; // next in singly-linked list of forward links from a
// token.
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
BaseFloat graph_cost, BaseFloat acoustic_cost,
ForwardLink *next):
next_tok(next_tok), ilabel(ilabel), olabel(olabel),
graph_cost(graph_cost), acoustic_cost(acoustic_cost),
next(next) { }
};
protected:
// we make things protected instead of private, as code in
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
// internals.

// Token is what's resident in a particular state at a particular time.
// In this decoder a Token actually contains *forward* links.
// When first created, a Token just has the (total) cost. We add forward
// links from it when we process the next frame.
struct Token {
BaseFloat tot_cost; // would equal weight.Value()... cost up to this point.
BaseFloat extra_cost; // >= 0. This is used in pruning away tokens.
// there is a comment in lattice-faster-decoder.cc explaining this;
// search for "a note on the definition of extra_cost".

ForwardLink *links; // Head of singly linked list of ForwardLinks

Token *next; // Next in list of tokens for this frame.

inline Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links,
Token *next):
tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next) { }
inline void DeleteForwardLinks() {
ForwardLink *l = links, *m;
while (l != NULL) {
m = l->next;
delete l;
l = m;
}
links = NULL;
}
};
// Deletes the elements of the singly linked list tok->links.
inline static void DeleteForwardLinks(Token *tok);

// head of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
Expand Down Expand Up @@ -296,8 +379,11 @@ class LatticeFasterDecoderTpl {
// index plus one, which is used to index into the active_toks_ array.
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
// token was newly created or the cost changed.
// The 'backpointer' argument has no purpose (and will hopefully be optimized
// out) if Token == StdToken.
inline Token *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, bool *changed);
BaseFloat tot_cost, Token *backpointer,
bool *changed);

// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
Expand Down Expand Up @@ -438,7 +524,7 @@ class LatticeFasterDecoderTpl {
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl);
};

typedef LatticeFasterDecoderTpl<fst::StdFst> LatticeFasterDecoder;
typedef LatticeFasterDecoderTpl<fst::StdFst, decoder::StdToken> LatticeFasterDecoder;



Expand Down
Loading

0 comments on commit d0c68a6

Please sign in to comment.