-
Notifications
You must be signed in to change notification settings - Fork 20
/
torchmoji.h
32 lines (23 loc) · 964 Bytes
/
torchmoji.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#ifndef TORCHMOJI_H
#define TORCHMOJI_H
#include "VoxCommon.hpp"
// TorchMoji: Emotion contextualizer model (Cookie design: skipping last layer and using hidden states to feed TTS model)
// Allows for manipulation of emotion at inference time
class TorchMoji
{
private:
// Word, ID
std::map<std::string,int32_t> Dictionary;
torch::jit::script::Module Model;
void LoadDict(const std::string& Path);
std::vector<int32_t> WordsToIDs(const std::vector<std::string> &Words);
public:
TorchMoji();
TorchMoji(const std::string& InitPath,const std::string& DPath);
void Initialize(const std::string& Path,const std::string& DictPath);
// Return hidden states of emotion state.
// -> Seq: Vector of words
// <- Returns float vec of size VoxCommon::TorchMojiEmbSize containing hidden states, ready to feed into TTS model.
std::vector<float> Infer(const std::vector<std::string>& Seq);
};
#endif // TORCHMOJI_H