Source code for infer

import sounddevice as sd

import torch

from models import EncDecBaseModel
from transforms import MfccTransform,Scattering

[docs]def record(seconds=1,sample_rate=16000): """Function to record the user using the machine default microphone Args: seconds (int, optional): Defaults to 1. sample_rate (int, optional): Defaults to 16000. Yields: torch.tensor : A 1 dimensional tensor representing the audio """ while True: if input("Press Enter to start recording or q to exit : ")=='q': break print(f"\033[91m Recording started for {seconds} seconds. \033[00m") # Record the audio clip recording = sd.rec(int(seconds * sample_rate),samplerate=sample_rate,channels=1) sd.wait() print("\033[91m Recording ended. \033[00m") yield torch.from_numpy(recording)
[docs]def predict(model,audio,device): """Function that return the prediction of the model, i.e the argmax of the softmax of the logits. Args: model (nn.Module): pytorch neural network audio (torch.tensor): audio tensor device (torch.device): Preferably cpu Returns: str: predicted label """ audio=audio.squeeze().to(device) #pre_process=MfccTransform(sample_rate=16000) prediction=torch.nn.functional.softmax(model(audio.unsqueeze(0)),dim=-1).squeeze().argmax(dim=-1) labels_names = ["backward","bed","bird","cat","dog","down","eight","five","follow","forward","four","go","happy","house","learn","left","marvin","nine","no","off","on","one","right","seven","sheila","six","stop","three","tree","two","up","visual","wow","yes","zero"] return labels_names[prediction]
if __name__=='__main__': """The script enable live testing of the trained models. In order to do so choose your model then run the script add follow the instructions to record yourself. """ #Define device device=torch.device("cpu") #Load pytorch model PATH='../models/model.pt' model = torch.nn.Sequential( Scattering(),EncDecBaseModel(num_mels=125,num_classes=35,final_filter=128,input_length=1600))#EncDecBaseModel(num_mels=64,num_classes=35,final_filter=128,input_length=1601) model.load_state_dict(torch.load(PATH)) model.to(device) model.eval() for audio in record(): print(f"\033[92m Predicted: {predict(model,audio,device)}. \033[00m \n")