-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreal_time_local_chat_ai_v2.py
407 lines (302 loc) · 12 KB
/
real_time_local_chat_ai_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import time
import numpy as np
import pyaudio
import onnxruntime as rt
import pvporcupine
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import multiprocessing as mp
import time
import pyttsx3
from langchain.prompts import PromptTemplate
from langchain_community.llms import LlamaCpp
from langchain.chains import LLMChain, ConversationChain
from langchain.memory import ConversationBufferWindowMemory
import wave
import os
import requests
# source_dir = os.getcwd()
# VAD
class VADDetector2:
def __init__(self, model_path, chunk_size=512, format=pyaudio.paFloat32, channels=1, rate=16000):
self.model = rt.InferenceSession(model_path, providers=['CPUExecutionProvider'], sess_options=self._get_session_options())
self.chunk_size = chunk_size
self.format = format
self.channels = channels
self.rate = rate
self.h_state = np.zeros((2, 1, 64)).astype('float32')
self.c_state = np.zeros((2, 1, 64)).astype('float32')
def _get_session_options(self):
opts = rt.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
return opts
def process_audio_chunk(self, audio_chunk):
ort_inputs = {
'input': audio_chunk.reshape(1, -1),
'h': self.h_state,
'c': self.c_state,
'sr': np.array(self.rate, dtype='int64')
}
ort_outs = self.model.run(None, ort_inputs)
self.h_state, self.c_state = ort_outs[1], ort_outs[2]
return ort_outs[0]
def detect_voice_activity(self, data):
try:
audio_chunk = np.frombuffer(data, dtype=np.float32)
output = self.process_audio_chunk(audio_chunk)
if output.item() < 0.5:
vad_sentance_end_bool = True
else:
vad_sentance_end_bool = False
except KeyboardInterrupt:
print("Exiting program...")
return vad_sentance_end_bool
# WAKE WORD
def detect_wake_word(jarvis_handle):
global wake_word_bool
# Define a function to get the next audio frame
def get_next_audio_frame(stream, chunk_size):
data = stream.read(chunk_size)
return np.frombuffer(data, dtype=np.int16)
# Set up PyAudio stream
p = pyaudio.PyAudio()
chunk_size = 512
format = pyaudio.paInt16
channels = 1
rate = 16000
stream = p.open(
format=format,
channels=channels,
rate=rate,
input=True,
frames_per_buffer=chunk_size
)
try:
print("Listening for wake word...")
while True:
audio_frame = get_next_audio_frame(stream, chunk_size)
keyword_index = jarvis_handle.process(audio_frame)
if keyword_index >= 0:
wake_word_bool = True
break
else:
wake_word_bool = False
except KeyboardInterrupt:
print("Interrupted by user")
exit()
finally:
stream.stop_stream()
stream.close()
p.terminate()
def run_vad(vad_detector, audio_chunk):
return vad_detector.detect_voice_activity(audio_chunk)
# STT
def run_sst_model(processor_stt, model_stt, vad_detector):
# Function for spliting audio into chunks if longer than 30 seconds
def chunk_array(data, chunk_size):
chunks = []
num_chunks = len(data) // chunk_size
remainder = len(data) % chunk_size
for i in range(num_chunks):
chunks.append(data[i * chunk_size : (i + 1) * chunk_size])
if remainder > 0:
chunks.append(data[num_chunks * chunk_size :])
return chunks
# Function for processing audio data
def process_audio_stt(audio_data_list):
start_time = time.time()
audio_data = np.concatenate(audio_data_list, axis=0)
input_features = processor_stt(audio_data, sampling_rate=16000, return_tensors="pt").input_features
# Generate token ids
predicted_ids = model_stt.generate(input_features, max_length = 448)
# Decode token ids to text
transcription = processor_stt.batch_decode(predicted_ids, skip_special_tokens=True)
print(f'Speech to text translation time: {time.time() - start_time}')
return transcription
# Set up PyAudio stream
p = pyaudio.PyAudio()
chunk_size = 512
format = pyaudio.paFloat32
channels = 1
rate = 16000
stream = p.open(
format=format,
channels=channels,
rate=rate,
input=True,
frames_per_buffer=chunk_size)
audio_data_list = []
# Counts the number of consecutive audio chunks that are silent (a chunk of 512 is roughly 0.032 seconds)
vad_sentance_end_couter = 0
while True:
print('Listening...')
data = stream.read(chunk_size, exception_on_overflow = False)
# Append audio chunk to audio_data
data_edited = np.frombuffer(data, dtype=np.float32)
audio_data_list.append(data_edited)
vad_sentance_end_bool = run_vad(vad_detector, data)
if vad_sentance_end_bool:
vad_sentance_end_couter += 1
else:
vad_sentance_end_couter = 0
if vad_sentance_end_couter > 60:
print('Question end detected')
print('Translating speech to text...')
max_chunk_size = 850
if len(audio_data_list) < max_chunk_size:
transcription = process_audio_stt(audio_data_list)
else:
print('Chunking audio data...')
transcription_list = []
chunked_data = chunk_array(audio_data_list, max_chunk_size)
for chunk in chunked_data:
transcription_part = process_audio_stt(chunk)
transcription_list.append(transcription_part[0])
transcription = " ".join(transcription_list)
transcription = [transcription]
return transcription, vad_sentance_end_couter
# Run llm function
def run_llm(conversation, text_translation):
print(f'Question: {text_translation}')
start_time = time.time()
llm_out = conversation.predict(input=text_translation)
print(f'Time elapsed: {time.time() - start_time}')
print(f'LLM Answer: {llm_out}')
speak(engine, llm_out)
# Speak back to you
def speak(engine, text):
start_time = time.time()
engine.say(text)
engine.runAndWait ()
print(f'Time elapsed: {time.time() - start_time}')
# Load bell audio files
def load_audio(file_path):
"""Load audio data from a WAV file."""
with wave.open(file_path, 'rb') as wf:
audio_data = wf.readframes(wf.getnframes())
sample_width = wf.getsampwidth()
channels = wf.getnchannels()
sample_rate = wf.getframerate()
return audio_data, sample_width, channels, sample_rate
# Play bell audio files
def play_audio(audio_data, sample_width, channels, sample_rate):
"""Play audio data using PyAudio."""
p = pyaudio.PyAudio()
stream = p.open(format=p.get_format_from_width(sample_width),
channels=channels,
rate=sample_rate,
output=True)
stream.write(audio_data)
stream.stop_stream()
stream.close()
p.terminate()
if __name__ == "__main__":
#
# Settings
#
which_llm = "mistral_3gb" #"stablelm" "quiklang", mistral_3gb
acc_device = "cuda" # "cuda" "cpu"
wake_word = "jarvis" # "jarvis", 'hey siri', 'hey google', 'terminator', 'alexa', 'ok google', 'computer'
wake_word_sensitivity = 0.9 # 0.8
voice_gender = 0 # 0 = male, 1 = female
#
# load acitivty detection
#
vad_file_path = os.path.join("/app/vad/silero_vad.onnx")
print("Loading voice acitivty detection...")
vad_detector = VADDetector2(vad_file_path)
jarvis_handle = pvporcupine.create(keywords=[wake_word], sensitivities=[wake_word_sensitivity])
#
# Load speech to text
#
print("Loading voice understanding...")
stt_file_path = '/app/sst/'
processor_stt = WhisperProcessor.from_pretrained(stt_file_path, local_files_only = True)
model_stt = WhisperForConditionalGeneration.from_pretrained(stt_file_path, local_files_only = True)
#
# load bell audio files
#
audio_file_path1 = "app/audio/bell_short_wav2.wav"
audio_file_path2 = "app/audio/bell2_short_wav2.wav"
audio_data, sample_width, channels, sample_rate = load_audio(audio_file_path1)
audio_data2, sample_width2, channels2, sample_rate2 = load_audio(audio_file_path2)
#
# Load language model prompt and langchain memory
#
print("Loading language module...")
template = """
# Conversation history:
# {history} - Conversation history end.
# You are an AI named Jarvis who is fun, and very intelligent.
# Your goal is to help the human, answer questions, and bring humour.
# Only speak in the first person as Jarvis, do not speak in prose, do not summarise the conversation, do not use smiley faces or emojis. Respond concisely.
<|user|>
{input}<|endoftext|>
<|Jarvis|>
"""
# Initialse prompt
PROMPT = PromptTemplate(input_variables=["history", "input"], template=template)
# Store all conversation
#memory = ConversationBufferMemory(ai_prefix="Jarvis")
# Store last K conversations
memory = ConversationBufferWindowMemory(ai_prefix="Jarvis",k=10)
llm_mapping = {
"mistral_3gb": "mistral-7b-instruct-v0.1.Q2_K.gguf",
"stablelm": "stablelm-zephyr-3b.Q4_K_M.gguf",
"quiklang": "zephyr-quiklang-3b-4k.Q4_K_M.gguf"
}
source_llm = source_dir + 'llm_model/llm_model/' + llm_mapping.get(which_llm, "")
# LLM settings for creativity
# Test: temperature = 0.7, top_p=20, top_k=0.4
# Precise: temperature =0.7, top_k = 40, top_p = 0.1
# Creative: temperature = 0.72, top_k= 0, top_p = 0.73
# Sphinx: temperature = 1.99, top_k= 30, top_p = 0.18
#
# Load LLM
#
llm = LlamaCpp(
model_path=source_llm,
n_gpu_layers=-1,
n_batch=124,
n_ctx=2048,
max_tokens=1024,
verbose=True, # Verbose is required to pass to the callback manager
repetition_penalty = 1.6,
temperature =0.75, top_k = 20, top_p = 0.4
)
conversation = ConversationChain(
prompt=PROMPT,
llm=llm,
verbose=True,
memory=memory,
)
#
# Load voice
#
print("Loading voice...")
engine = pyttsx3.init('sapi5')
voices = engine.getProperty ('voices')
engine.setProperty ('voice', voices [voice_gender].id)
print("Models loaded.")
speak(engine, "Jarvis loaded, I am ready to assist you.")
while True:
try:
# Detect wake word
detect_wake_word(jarvis_handle)
if wake_word_bool:
print("Wake word detected!")
#Sound 1
play_audio(audio_data, sample_width, channels, sample_rate)
# Run STT
text_translation, vad_sentance_end_couter = run_sst_model(processor_stt, model_stt, vad_detector)
#Sound 2
play_audio(audio_data2, sample_width2, channels2, sample_rate2)
# Run LLM
if len(text_translation[0]) < 8:
print('Error: question too short')
continue
else:
run_llm(conversation, text_translation)
except KeyboardInterrupt:
print("Ctrl+C detected. Exiting loop.")
break