Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tflite inference only predicts one label despite multiclass label training #8

Open
raziurtotha opened this issue Dec 6, 2022 · 1 comment

Comments

@raziurtotha
Copy link

raziurtotha commented Dec 6, 2022

Hi Shawn,
I have trained a multiclass classifier for speech recognition using your tutorials with tensorflow. The model can predict but it always outputs a single class. I suppose the problem is with the inference code because .h5 model can predict multiclass without any issue. I have been searching online for several days for some insight but I can't quite figure it out. Here is my code. Any suggestions would be really appreciated.

import sounddevice as sd
import numpy as np
import scipy.signal
import timeit
import python_speech_features

import tflite_runtime.interpreter as tflite

import importlib

#Parameters
debug_time = 0
debug_acc = 0
word_threshold = 0.95
rec_duration = 0.5   # 0.5
sample_length = 0.5
window_stride = 0.5  # 0.5
sample_rate = 8000   # The mic requires at least 44100 Hz to work
resample_rate = 8000
num_channels = 1
num_mfcc = 16

model_path = 'model.tflite'

mfccs_old = np.zeros((32, 25))

#Load model (interpreter)
interpreter = tflite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)

#Filter and downsample
def decimate(signal, old_fs, new_fs):

    #Check to make sure we're downsampling
    if new_fs > old_fs:
        print("Error: target sample rate higher than original")
        return signal, old_fs

    #Downsampling is possible only by an integer factor
    dec_factor = old_fs / new_fs
    if not dec_factor.is_integer():
        print("Error: can only downsample by integer factor")

    #Do decimation
    resampled_signal = scipy.signal.decimate(signal, int(dec_factor))

    return resampled_signal, new_fs

#Callback that gets called every 0.5 seconds
def sd_callback(rec, frames, time, status):

    #Start timing for debug purposes
    start = timeit.default_timer()

    #Notify errors
    if status:
        print('Error:', status)

    global mfccs_old

    #Compute MFCCs
    mfccs = python_speech_features.base.mfcc(rec,
                                            samplerate=resample_rate,
                                            winlen=0.02,
                                            winstep=0.02,
                                            numcep=num_mfcc,
                                            nfilt=26,
                                            nfft=512, # 2048
                                            preemph=0.0,
                                            ceplifter=0,
                                            appendEnergy=True,
                                            winfunc=np.hanning)

    delta = python_speech_features.base.delta(mfccs, 2)

    mfccs_delta = np.append(mfccs, delta, axis=1)

    mfccs_new = mfccs_delta.transpose()
    mfccs = np.append(mfccs_old, mfccs_new, axis=1)
#mfccs = np.insert(mfccs, [0], 0, axis=1)
    mfccs_old = mfccs_new

    #Run inference and make predictions
    in_tensor = np.float32(mfccs.reshape(1, mfccs.shape[0], mfccs.shape[1], 1))
    interpreter.set_tensor(input_details[0]['index'], in_tensor)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    val = np.amax(output_data)                      # DEFINED FOR BINARY CLASSIFICATION, CHANGE TO MULTICLASS
    ind = np.where(output_data == val)
    prediction = ind[1].astype(int)
    if val > word_threshold:
        print('index:', ind[1])
        print('accuracy', val, '/n')
        print(int(prediction))

    if debug_acc:
#print('accuracy:', val)
#print('index:', ind[1])
        print('out tensor:', output_data)
    if debug_time:
        print(timeit.default_timer() - start)

#Start recording from microphone
with sd.InputStream(channels=num_channels,
        samplerate=sample_rate,
        blocksize=int(sample_rate * rec_duration),
        callback=sd_callback):
    while True:
        pass

@ShawnHymel
Copy link
Owner

If you print output_data, what do you see? You should see a set (list) of probabilities for each of your classes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants