Skip to content

Commit

Permalink
SpeechRecognizer dictation
Browse files Browse the repository at this point in the history
  • Loading branch information
crc-32 committed Oct 29, 2024
1 parent e493116 commit 050188c
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 23 deletions.
5 changes: 5 additions & 0 deletions android/shared/src/androidMain/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android">

<uses-permission android:name="android.permission.RUN_USER_INITIATED_JOBS" />
<queries>
<intent>
<action android:name="android.speech.RecognitionService" />
</intent>
</queries>

</manifest>
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.rebble.cobble.shared.di

import android.os.Build
import android.os.Build.VERSION_CODES
import android.service.notification.StatusBarNotification
import com.benasher44.uuid.Uuid
import io.rebble.cobble.shared.AndroidPlatformContext
Expand All @@ -13,15 +15,20 @@ import io.rebble.cobble.shared.domain.notifications.AndroidNotificationActionExe
import io.rebble.cobble.shared.domain.notifications.CallNotificationProcessor
import io.rebble.cobble.shared.domain.notifications.NotificationProcessor
import io.rebble.cobble.shared.domain.notifications.PlatformNotificationActionExecutor
import io.rebble.cobble.shared.domain.voice.DictationService
import io.rebble.cobble.shared.domain.voice.NullDictationService
import io.rebble.cobble.shared.domain.voice.SpeechRecognizerDictationService
import io.rebble.cobble.shared.handlers.*
import io.rebble.cobble.shared.handlers.music.MusicHandler
import io.rebble.cobble.shared.jobs.AndroidJobScheduler
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import org.koin.android.ext.koin.androidContext
import org.koin.core.module.dsl.factoryOf
import org.koin.core.module.dsl.singleOf
import org.koin.core.qualifier.named
import org.koin.dsl.bind
import org.koin.dsl.binds
import org.koin.dsl.module

val androidModule = module {
Expand Down Expand Up @@ -67,4 +74,9 @@ val androidModule = module {
singleOf(::NotificationProcessor)
singleOf(::CallNotificationProcessor)
singleOf(::AndroidPlatformAppMessageIPC) bind PlatformAppMessageIPC::class
if (Build.VERSION.SDK_INT >= VERSION_CODES.TIRAMISU) {
factoryOf(::SpeechRecognizerDictationService) bind DictationService::class
} else {
factoryOf(::NullDictationService) bind DictationService::class
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ fun initKoin(context: Context) {
dataStoreModule,
androidModule,
libpebbleModule,
voiceModule,
dependenciesModule
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package io.rebble.cobble.shared.domain.voice

import android.content.Context
import android.content.Intent
import android.media.AudioFormat
import android.os.Build
import android.os.Build.VERSION_CODES
import android.os.Bundle
import android.os.ParcelFileDescriptor
import android.speech.*
import androidx.annotation.RequiresApi
import androidx.compose.ui.text.intl.Locale
import com.example.speex_codec.SpeexCodec
import com.example.speex_codec.SpeexDecodeResult
import io.rebble.cobble.shared.Logging
import io.rebble.libpebblecommon.packets.Result
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.*
import org.koin.core.component.KoinComponent
import org.koin.core.component.inject
import java.io.InputStream
import java.io.PipedInputStream
import java.io.PipedOutputStream
import java.nio.ByteBuffer
import kotlin.math.roundToInt

@RequiresApi(VERSION_CODES.TIRAMISU)
class SpeechRecognizerDictationService: DictationService, KoinComponent {
private val context: Context by inject()
private val scope = CoroutineScope(Dispatchers.IO)

sealed class SpeechRecognizerStatus {
object Ready: SpeechRecognizerStatus()
class Error(val error: Int): SpeechRecognizerStatus()
class Results(val results: List<Pair<Float, String>>): SpeechRecognizerStatus()
}

private fun beginSpeechRecognition(speechRecognizer: SpeechRecognizer, intent: Intent) = callbackFlow<SpeechRecognizerStatus> {
speechRecognizer.setRecognitionListener(object : RecognitionListener {
private var lastPartials = emptyList<Pair<Float, String>>()
override fun onReadyForSpeech(params: Bundle?) {
trySend(SpeechRecognizerStatus.Ready)
}

override fun onBeginningOfSpeech() {
Logging.i("Speech start detected")
}

override fun onRmsChanged(rmsdB: Float) {

}

override fun onBufferReceived(buffer: ByteArray?) {

}

override fun onEndOfSpeech() {
Logging.i("Speech end detected")
}

override fun onError(error: Int) {
trySend(SpeechRecognizerStatus.Error(error))
}

override fun onResults(results: Bundle?) {
//XXX: appears that with offline on a pixel we only get partials? with scores when they're final
trySend(SpeechRecognizerStatus.Results(lastPartials))
}

override fun onPartialResults(results: Bundle?) {
val result = results?.getStringArrayList(SpeechRecognizer.RESULTS_RECOGNITION)?.toList()
val confidence = results?.getFloatArray(SpeechRecognizer.CONFIDENCE_SCORES)?.toList()
if (confidence != null && result != null) {
lastPartials = confidence.zip(result)
}
}

override fun onEvent(eventType: Int, params: Bundle?) {

}

})
speechRecognizer.startListening(intent)
awaitClose {
Logging.d("Closing speech recognition listener")
speechRecognizer.cancel()
}
}.flowOn(Dispatchers.Main)

companion object {
fun buildRecognizerIntent(audioSource: ParcelFileDescriptor? = null, encoding: Int = AudioFormat.ENCODING_PCM_16BIT, sampleRate: Int = 16000) = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra(RecognizerIntent.EXTRA_LANGUAGE_MODEL, RecognizerIntent.LANGUAGE_MODEL_FREE_FORM)
audioSource?.let {
putExtra(RecognizerIntent.EXTRA_AUDIO_SOURCE, audioSource)
putExtra(RecognizerIntent.EXTRA_AUDIO_SOURCE_ENCODING, encoding)
putExtra(RecognizerIntent.EXTRA_AUDIO_SOURCE_CHANNEL_COUNT, 1)
putExtra(RecognizerIntent.EXTRA_AUDIO_SOURCE_SAMPLING_RATE, sampleRate)
}
putExtra(RecognizerIntent.EXTRA_LANGUAGE, Locale.current.toLanguageTag())
}
}

override fun handleSpeechStream(speexEncoderInfo: SpeexEncoderInfo, audioStreamFrames: Flow<AudioStreamFrame>) = flow {
if (!SpeechRecognizer.isRecognitionAvailable(context)) {
emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
return@flow
}
val decoder = SpeexCodec(speexEncoderInfo.sampleRate, speexEncoderInfo.bitRate)
val decodedBuf = ByteArray(speexEncoderInfo.frameSize * Short.SIZE_BYTES)
val recognizerPipes = ParcelFileDescriptor.createSocketPair()
val recognizerReadPipe = recognizerPipes[0]
val recognizerWritePipe = ParcelFileDescriptor.AutoCloseOutputStream(recognizerPipes[1]).buffered(320 * Short.SIZE_BYTES)
val recognizerIntent = buildRecognizerIntent(recognizerReadPipe, AudioFormat.ENCODING_PCM_16BIT, speexEncoderInfo.sampleRate.toInt())
//val recognizerIntent = buildRecognizerIntent()
val speechRecognizer = withContext(Dispatchers.Main) {
if (Build.VERSION.SDK_INT > VERSION_CODES.R && SpeechRecognizer.isOnDeviceRecognitionAvailable(context)) {
SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
} else {
SpeechRecognizer.createSpeechRecognizer(context)
}
}
val supported = withContext(Dispatchers.Main) {
speechRecognizer.checkRecognitionSupport(recognizerIntent)
}

//TODO: handle downloads, etc
Logging.d("Recognition support: $supported")
if (supported == RecognitionSupportResult.Unsupported) {
Logging.e("Speech recognition language/type not supported")
emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
return@flow
}
val audioJob = scope.launch {
audioStreamFrames
.onEach { frame ->
if (frame is AudioStreamFrame.Stop) {
//Logging.v("Stop")
withContext(Dispatchers.IO) {
recognizerWritePipe.flush()
}
withContext(Dispatchers.Main) {
//XXX: Shouldn't use main here for I/O call but recognizer has weird thread behaviour
recognizerWritePipe.close()
recognizerReadPipe.close()
speechRecognizer.stopListening()
}
} else if (frame is AudioStreamFrame.AudioData) {
val result = decoder.decodeFrame(frame.data, decodedBuf, hasHeaderByte = true)
if (result != SpeexDecodeResult.Success) {
Logging.e("Speex decode error: ${result.name}")
}
withContext(Dispatchers.IO) {
recognizerWritePipe.write(decodedBuf)
}
}
}
.catch {
Logging.e("Error in audio stream: $it")
}
.collect()
}
try {
beginSpeechRecognition(speechRecognizer, recognizerIntent).collect { status ->
when (status) {
is SpeechRecognizerStatus.Ready -> emit(DictationServiceResponse.Ready)
is SpeechRecognizerStatus.Error -> {
Logging.e("Speech recognition error: ${status.error}")
when (status.error) {
SpeechRecognizer.ERROR_NETWORK -> emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
SpeechRecognizer.ERROR_SPEECH_TIMEOUT -> emit(DictationServiceResponse.Error(Result.FailTimeout))
SpeechRecognizer.ERROR_NO_MATCH -> emit(DictationServiceResponse.Error(Result.FailRecognizerError))
else -> emit(DictationServiceResponse.Error(Result.FailServiceUnavailable))
}
emit(DictationServiceResponse.Complete)
}
is SpeechRecognizerStatus.Results -> {
Logging.d("Speech recognition results: ${status.results}")
emit(DictationServiceResponse.Transcription(
listOf(
buildList {
status.results.firstOrNull()?.second?.split(" ")?.forEach {
add(Word(it, 100u))
}
}
)
))
emit(DictationServiceResponse.Complete)
}
}
}
} finally {
audioJob.cancel()
speechRecognizer.destroy()
}

}
}

enum class RecognitionSupportResult {
SupportedOnDevice,
SupportedOnline,
NeedsDownload,
Unsupported
}

@RequiresApi(VERSION_CODES.TIRAMISU)
suspend fun SpeechRecognizer.checkRecognitionSupport(intent: Intent): RecognitionSupportResult {
val result = CompletableDeferred<RecognitionSupport>()
val language = Locale.current.toLanguageTag()
val executor = Dispatchers.IO.asExecutor()
checkRecognitionSupport(intent, executor, object : RecognitionSupportCallback {
override fun onSupportResult(recognitionSupport: RecognitionSupport) {
//TODO: override locale depending on user choice
result.complete(recognitionSupport)
}

override fun onError(error: Int) {
result.completeExceptionally(Exception("Error checking recognition support: $error"))
}
})
val support = result.await()
return when {
support.supportedOnDeviceLanguages.contains(language) -> RecognitionSupportResult.SupportedOnDevice
support.installedOnDeviceLanguages.contains(language) -> RecognitionSupportResult.SupportedOnDevice
support.onlineLanguages.contains(language) -> RecognitionSupportResult.SupportedOnline
support.pendingOnDeviceLanguages.contains(language) -> RecognitionSupportResult.NeedsDownload
else -> RecognitionSupportResult.Unsupported
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ class VoiceSessionHandler(
when (it) {
is DictationServiceResponse.Ready -> {
pebbleDevice.activeVoiceSession.value = voiceSession
pebbleDevice.voiceService.send(SessionSetupResult(
val result = SessionSetupResult(
sessionType = SessionType.Dictation,
result = Result.Success
))
)
if (appInitiated) {
result.flags.set(1u)
}
pebbleDevice.voiceService.send(result)
sentReady = true
}
is DictationServiceResponse.Error -> {
Expand All @@ -102,14 +106,22 @@ class VoiceSessionHandler(
}
}
is DictationServiceResponse.Transcription -> {
val a = DictationResult(
val resp = DictationResult(
voiceSession.sessionId.toUShort(),
Result.Success,
listOf(
makeTranscription(it.sentences)
)
buildList {
add(makeTranscription(it.sentences))
if (appInitiated && voiceSession.appUuid != null) {
add(VoiceAttribute.AppUuid().apply {
uuid.set(voiceSession.appUuid)
})
}
}
)
pebbleDevice.voiceService.send(a)
if (appInitiated) {
resp.flags.set(1u)
}
pebbleDevice.voiceService.send(resp)
}
}
}
Expand Down
11 changes: 10 additions & 1 deletion android/speex_codec/src/main/cpp/speex_codec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
extern "C"
JNIEXPORT jint JNICALL
Java_com_example_speex_1codec_SpeexCodec_decode(JNIEnv *env, jobject thiz,
jbyteArray encoded_frame, jbyteArray out_frame) {
jbyteArray encoded_frame, jbyteArray out_frame, jboolean has_header_byte) {
jbyte *encoded_frame_data = env->GetByteArrayElements(encoded_frame, nullptr);
jsize encoded_frame_length = env->GetArrayLength(encoded_frame);
if (has_header_byte) {
// Skip the first byte
encoded_frame_data++;
encoded_frame_length--;
}
auto *bits = reinterpret_cast<SpeexBits *>(env->GetLongField(thiz, env->GetFieldID(env->GetObjectClass(thiz), "speexDecBits", "J")));
auto *dec_state = reinterpret_cast<void *>(env->GetLongField(thiz, env->GetFieldID(env->GetObjectClass(thiz), "speexDecState", "J")));
jshort pcm_frame[320];
Expand All @@ -16,6 +21,10 @@ Java_com_example_speex_1codec_SpeexCodec_decode(JNIEnv *env, jobject thiz,
if (result == 0) {
env->SetByteArrayRegion(out_frame, 0, sizeof(pcm_frame), reinterpret_cast<jbyte *>(pcm_frame));
}
if (has_header_byte) {
// Restore the first byte, so that the encoded_frame_data pointer points to the original address
encoded_frame_data--;
}
env->ReleaseByteArrayElements(encoded_frame, encoded_frame_data, 0);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ class SpeexCodec(private val sampleRate: Long, private val bitRate: Int): AutoCl
* @param decodedFrame The buffer to store the decoded frame in.
*
*/
fun decodeFrame(encodedFrame: ByteArray, decodedFrame: ByteArray): SpeexDecodeResult {
return SpeexDecodeResult.fromInt(decode(encodedFrame, decodedFrame))
fun decodeFrame(encodedFrame: ByteArray, decodedFrame: ByteArray, hasHeaderByte: Boolean = true): SpeexDecodeResult {
return SpeexDecodeResult.fromInt(decode(encodedFrame, decodedFrame, hasHeaderByte))
}

override fun close() {
destroySpeexBits(speexDecBits)
destroyDecState(speexDecState)
}

private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteArray): Int
private external fun decode(encodedFrame: ByteArray, decodedFrame: ByteArray, hasHeaderByte: Boolean): Int
private external fun initSpeexBits(): Long
private external fun initDecState(sampleRate: Long, bitRate: Int): Long
private external fun destroySpeexBits(speexBits: Long)
Expand Down

0 comments on commit 050188c

Please sign in to comment.