Skip to content

Commit

Permalink
fix: remove audio primitives and fix imports (deepset-ai#22)
Browse files Browse the repository at this point in the history
* remove audio primitives and fix imports

* pylint

* hooks

* pass

* wrong import

* tests

* check only filename

* path to string

* keep all other fields
  • Loading branch information
ZanSara committed Mar 30, 2023
1 parent 45b0d37 commit fb35ad2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 38 deletions.
36 changes: 18 additions & 18 deletions nodes/text2speech/tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

from transformers import WhisperProcessor, WhisperForConditionalGeneration

from haystack.schema import Span, Answer, SpeechAnswer, Document, SpeechDocument
from haystack.nodes.audio import AnswerToSpeech, DocumentToSpeech
from haystack.nodes.audio._text_to_speech import TextToSpeech

from haystack.schema import Span, Answer, Document
from text2speech import AnswerToSpeech, DocumentToSpeech
from text2speech.utils import TextToSpeech


SAMPLES_PATH = Path(__file__).parent / "samples"


class WhisperHelper:
def __init__(self, model):
self._processor = WhisperProcessor.from_pretrained(model)
Expand All @@ -50,6 +50,7 @@ def transcribe(self, media_file: str):
def whisper_helper():
return WhisperHelper("openai/whisper-medium")


@pytest.mark.integration
class TestTextToSpeech:
def test_text_to_speech_audio_data(self, tmp_path, whisper_helper: WhisperHelper):
Expand Down Expand Up @@ -144,20 +145,19 @@ def test_answer_to_speech(self, tmp_path, whisper_helper: WhisperHelper):
)
results, _ = answer2speech.run(answers=[text_answer])

audio_answer: SpeechAnswer = results["answers"][0]
assert isinstance(audio_answer, SpeechAnswer)
assert audio_answer.type == "generative"
assert audio_answer.answer_audio.name == expected_audio_answer.name
assert audio_answer.context_audio.name == expected_audio_context.name
assert audio_answer.answer == "answer"
assert audio_answer.context == "the context for this answer is here"
audio_answer: Answer = results["answers"][0]
assert isinstance(audio_answer, Answer)
assert audio_answer.answer.split(os.path.sep)[-1] == str(expected_audio_answer).split(os.path.sep)[-1]
assert audio_answer.context.split(os.path.sep)[-1] == str(expected_audio_context).split(os.path.sep)[-1]
assert audio_answer.offsets_in_document == [Span(31, 37)]
assert audio_answer.offsets_in_context == [Span(21, 27)]
assert audio_answer.meta["answer_text"] == "answer"
assert audio_answer.meta["context_text"] == "the context for this answer is here"
assert audio_answer.meta["some_meta"] == "some_value"
assert audio_answer.meta["audio_format"] == "wav"

expected_doc = whisper_helper.transcribe(str(expected_audio_answer))
generated_doc = whisper_helper.transcribe(str(audio_answer.answer_audio))
generated_doc = whisper_helper.transcribe(str(audio_answer.answer))

assert expected_doc[0] in generated_doc[0]

Expand All @@ -178,15 +178,15 @@ def test_document_to_speech(self, tmp_path, whisper_helper: WhisperHelper):

results, _ = doc2speech.run(documents=[text_doc])

audio_doc: SpeechDocument = results["documents"][0]
assert isinstance(audio_doc, SpeechDocument)
audio_doc: Document = results["documents"][0]
assert isinstance(audio_doc, Document)
assert audio_doc.content_type == "audio"
assert audio_doc.content_audio.name == expected_audio_content.name
assert audio_doc.content == "this is the content of the document"
assert audio_doc.content.split(os.path.sep)[-1] == str(expected_audio_content).split(os.path.sep)[-1]
assert audio_doc.meta["content_text"] == "this is the content of the document"
assert audio_doc.meta["name"] == "test_document.txt"
assert audio_doc.meta["audio_format"] == "wav"

expected_doc = whisper_helper.transcribe(str(expected_audio_content))
generated_doc = whisper_helper.transcribe(str(audio_doc.content_audio))
generated_doc = whisper_helper.transcribe(str(audio_doc.content))

assert expected_doc[0] in generated_doc[0]
assert expected_doc[0] in generated_doc[0]
19 changes: 10 additions & 9 deletions nodes/text2speech/text2speech/answer_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from tqdm.auto import tqdm

from haystack.nodes import BaseComponent
from haystack.schema import Answer, SpeechAnswer
from haystack.nodes.audio._text_to_speech import TextToSpeech
from haystack.schema import Answer
from text2speech.utils import TextToSpeech


class AnswerToSpeech(BaseComponent):
Expand Down Expand Up @@ -79,16 +79,17 @@ def run(self, answers: List[Answer]) -> Tuple[Dict[str, List[Answer]], str]: #
text=answer.context, generated_audio_dir=self.generated_audio_dir, **self.params
)

audio_answer = SpeechAnswer.from_text_answer(
answer_object=answer,
audio_answer=answer_audio,
audio_context=context_audio,
additional_meta={
audio_answer = Answer.from_dict(answer.to_dict())
audio_answer.answer = str(answer_audio)
audio_answer.context = str(context_audio)
audio_answer.meta.update(
{
"answer_text": answer.answer,
"context_text": answer.context,
"audio_format": self.params.get("audio_format", answer_audio.suffix.replace(".", "")),
"sample_rate": self.converter.model.fs,
},
}
)
audio_answer.type = "generative"
audio_answers.append(audio_answer)

return {"answers": audio_answers}, "output_1"
Expand Down
17 changes: 9 additions & 8 deletions nodes/text2speech/text2speech/document_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from tqdm.auto import tqdm

from haystack.nodes import BaseComponent
from haystack.schema import Document, SpeechDocument
from haystack.nodes.audio._text_to_speech import TextToSpeech
from haystack.schema import Document
from text2speech.utils import TextToSpeech


class DocumentToSpeech(BaseComponent):
Expand Down Expand Up @@ -60,15 +60,16 @@ def run(self, documents: List[Document]) -> Tuple[Dict[str, List[Document]], str
content_audio = self.converter.text_to_audio_file(
text=doc.content, generated_audio_dir=self.generated_audio_dir, **self.params
)
audio_document = SpeechDocument.from_text_document(
document_object=doc,
audio_content=content_audio,
additional_meta={
audio_document = Document.from_dict(doc.to_dict())
audio_document.content = str(content_audio)
audio_document.content_type = "audio"
audio_document.meta.update(
{
"content_text": doc.content,
"audio_format": self.params.get("audio_format", content_audio.suffix.replace(".", "")),
"sample_rate": self.converter.model.fs,
},
}
)
audio_document.type = "generative"
audio_documents.append(audio_document)

return {"documents": audio_documents}, "output_1"
Expand Down
12 changes: 9 additions & 3 deletions nodes/text2speech/text2speech/utils/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from pydub import AudioSegment

from haystack.errors import AudioNodeError
from haystack.errors import NodeError
from haystack.modeling.utils import initialize_device_settings


Expand All @@ -27,6 +27,12 @@
)


class Text2SpeechError(NodeError):
"""
Error class for text2speech nodes
"""


class TextToSpeech:
"""
This class converts text into audio using text-to-speech models.
Expand Down Expand Up @@ -138,12 +144,12 @@ def text_to_audio_data(self, text: str, _models_output_key: str = "wav") -> np.a
"""
prediction = self.model(text)
if not prediction:
raise AudioNodeError(
raise Text2SpeechError(
"The model returned no predictions. Make sure you selected a valid text-to-speech model."
)
output = prediction.get(_models_output_key, None)
if output is None:
raise AudioNodeError(
raise Text2SpeechError(
f"The model returned no output under the {_models_output_key} key."
f"The available output keys are {prediction.keys()}. Make sure you selected the right key."
)
Expand Down

0 comments on commit fb35ad2

Please sign in to comment.