-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLivePredictionLoop.py
More file actions
128 lines (105 loc) · 4.09 KB
/
LivePredictionLoop.py
File metadata and controls
128 lines (105 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Live microphone speech recognition using a trained SpeechRecognitionModel.
This script:
- Loads a trained CTC-based speech recognition model
- Records audio from the microphone in fixed-length chunks
- Runs speech recognition on each chunk
- Prints transcriptions continuously until interrupted
"""
import torch
import torchaudio
import numpy as np
import sounddevice as sd
import scipy.io.wavfile as wavfile
import tempfile
import os
from VoiceRecognition import SpeechRecognitionModel
from Predict import predict
SAMPLE_RATE = 16000
DURATION = 10
def load_model(checkpoint_path, device):
"""
Load a trained speech recognition model from a checkpoint.
Args:
checkpoint_path (str): Path to the model checkpoint (.pth).
device (torch.device): Device to load the model onto.
Returns:
SpeechRecognitionModel: Loaded model in evaluation mode.
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
vocab_size = 30
model = SpeechRecognitionModel(output_size=vocab_size)# ← your model class
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
return model
def record_to_wav(duration, sample_rate):
"""
Record audio from the default microphone and save it as a temporary WAV file.
Args:
duration (int): Recording duration in seconds.
sample_rate (int): Sample rate in Hz.
Returns:
str: Path to the temporary WAV file.
"""
print("Recording to wav...")
audio = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1, dtype='float32')
sd.wait()
audio = np.squeeze(audio)
tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_wav_path = tmp_wav.name
tmp_wav.close()
wavfile.write(tmp_wav_path, sample_rate, audio)
return tmp_wav_path
def live_loop(model,mel_transform, amplitude_to_db, device):
"""
Continuously record audio, run speech recognition, and print transcriptions.
The loop records fixed-length audio segments and processes them until
interrupted by the user (Ctrl+C).
Args:
model (SpeechRecognitionModel): Trained speech recognition model.
mel_transform (MelSpectrogram): Mel spectrogram transform.
amplitude_to_db (AmplitudeToDB): Log-amplitude transform.
device (torch.device): Device to run inference on.
"""
print("Listening! press ctrl+c to stop...")
try:
while True:
wav_path = record_to_wav(DURATION, SAMPLE_RATE)
try:
transcript = predict(wav_path, model, mel_transform, amplitude_to_db, device=device)
print(transcript.replace("<pad>", "").replace("<unk>", ""))
finally:
os.remove(wav_path)
except KeyboardInterrupt:
print("Exiting...")
if __name__ == "__main__":
"""
Entry point for live microphone speech recognition.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---- CHECKPOINT SELECTION ----
checkpoints = sorted(f for f in os.listdir(".") if f.endswith(".pth"))
if not checkpoints:
raise FileNotFoundError("No .pth checkpoints found in the current directory.")
print("\nAvailable checkpoints:")
for i, ckpt in enumerate(checkpoints):
print(f"[{i}] {ckpt}")
print("[Enter] Use default checkpoint (first in list)")
choice = input("Select a checkpoint to load: ").strip()
if choice.isdigit() and int(choice) < len(checkpoints):
checkpoint_path = checkpoints[int(choice)]
else:
checkpoint_path = checkpoints[0]
print(f"\nLoading checkpoint: {checkpoint_path}")
model = load_model(checkpoint_path, device)
# ---- AUDIO TRANSFORMS ----
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=1024,
hop_length=512,
n_mels=64,
)
amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
# ---- LIVE LOOP ----
live_loop(model, mel_transform, amplitude_to_db, device)