Skip to content

Commit

Permalink
format,test=doc
Browse files Browse the repository at this point in the history
  • Loading branch information
zh794390558 committed Feb 28, 2022
1 parent 54341c8 commit 7509869
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 27 deletions.
2 changes: 1 addition & 1 deletion paddlespeech/s2t/io/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
Expand Down
4 changes: 1 addition & 3 deletions paddlespeech/s2t/models/u2_st/u2_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
Expand Down Expand Up @@ -291,7 +289,7 @@ def translate(
device = speech.place

# Let's assume B = batch_size and N = beam_size
# 1. Encoder and init hypothesis
# 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/t2s/modules/transformer/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)])
return MultiSequential(* [fn(n) for n in range(N)])
36 changes: 14 additions & 22 deletions tests/unit/asr/deepspeech2_online_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import unittest

import numpy as np
import paddle
import pickle
import os
from paddle import inference

from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline


class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -185,15 +186,12 @@ def test_ds2_8(self):
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)




class TestDeepSpeech2StaticModelOnline(unittest.TestCase):

def setUp(self):
export_prefix = "exp/deepspeech2_online/checkpoints/test_export"
if not os.path.exists(os.path.dirname(export_prefix)):
os.makedirs(os.path.dirname(export_prefix), mode=0o755)
infer_model = DeepSpeech2InferModelOnline(
infer_model = DeepSpeech2InferModelOnline(
feat_size=161,
dict_size=4233,
num_conv_layers=2,
Expand All @@ -207,27 +205,25 @@ def setUp(self):

with open("test_data/static_ds2online_inputs.pickle", "rb") as f:
self.data_dict = pickle.load(f)

self.setup_model(export_prefix)


def setup_model(self, export_prefix):
deepspeech_config = inference.Config(
export_prefix + ".pdmodel",
export_prefix + ".pdiparams")
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config = inference.Config(export_prefix + ".pdmodel",
export_prefix + ".pdiparams")
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and
os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config.enable_use_gpu(100, 0)
deepspeech_config.enable_memory_optim()
deepspeech_predictor = inference.create_predictor(deepspeech_config)
self.predictor = deepspeech_predictor

def test_unit(self):
input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])


x_chunk = self.data_dict["audio_chunk"]
x_chunk_lens = self.data_dict["audio_chunk_lens"]
Expand All @@ -246,13 +242,9 @@ def test_unit(self):
c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box)



output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(
output_names[0])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_handle = self.predictor.get_output_handle(output_names[0])
output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_state_h_handle = self.predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.predictor.get_output_handle(
Expand All @@ -264,7 +256,7 @@ def test_unit(self):
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
return True


if __name__ == '__main__':
unittest.main()

0 comments on commit 7509869

Please sign in to comment.