Send language to Wyoming STT (#97344)

This commit is contained in:
Michael Hansen 2023-08-01 03:05:01 -05:00 committed by GitHub
parent 5aa3e36754
commit 8ad37d7640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 7 deletions

View File

@ -2,7 +2,7 @@
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
import logging import logging
from wyoming.asr import Transcript from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
@ -89,6 +89,10 @@ class WyomingSttProvider(stt.SpeechToTextEntity):
"""Process an audio stream to STT service.""" """Process an audio stream to STT service."""
try: try:
async with AsyncTcpClient(self.service.host, self.service.port) as client: async with AsyncTcpClient(self.service.host, self.service.port) as client:
# Set transcription language
await client.write_event(Transcribe(language=metadata.language).event())
# Begin audio stream
await client.write_event( await client.write_event(
AudioStart( AudioStart(
rate=SAMPLE_RATE, rate=SAMPLE_RATE,
@ -106,6 +110,7 @@ class WyomingSttProvider(stt.SpeechToTextEntity):
) )
await client.write_event(chunk.event()) await client.write_event(chunk.event())
# End audio stream
await client.write_event(AudioStop().event()) await client.write_event(AudioStop().event())
while True: while True:

View File

@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -69,3 +70,16 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
return_value=TTS_INFO, return_value=TTS_INFO,
): ):
await hass.config_entries.async_setup(tts_config_entry.entry_id) await hass.config_entries.async_setup(tts_config_entry.entry_id)
@pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata."""
return stt.SpeechMetadata(
language=hass.config.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)

View File

@ -1,6 +1,13 @@
# serializer version: 1 # serializer version: 1
# name: test_streaming_audio # name: test_streaming_audio
list([ list([
dict({
'data': dict({
'language': 'en',
}),
'payload': None,
'type': 'transcibe',
}),
dict({ dict({
'data': dict({ 'data': dict({
'channels': 1, 'channels': 1,

View File

@ -27,7 +27,9 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO] assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO]
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None: async def test_streaming_audio(
hass: HomeAssistant, init_wyoming_stt, metadata, snapshot
) -> None:
"""Test streaming audio.""" """Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None assert entity is not None
@ -40,7 +42,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
"homeassistant.components.wyoming.stt.AsyncTcpClient", "homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([Transcript(text="Hello world").event()]), MockAsyncTcpClient([Transcript(text="Hello world").event()]),
) as mock_client: ) as mock_client:
result = await entity.async_process_audio_stream(None, audio_stream()) result = await entity.async_process_audio_stream(metadata, audio_stream())
assert result.result == stt.SpeechResultState.SUCCESS assert result.result == stt.SpeechResultState.SUCCESS
assert result.text == "Hello world" assert result.text == "Hello world"
@ -48,7 +50,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
async def test_streaming_audio_connection_lost( async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_stt hass: HomeAssistant, init_wyoming_stt, metadata
) -> None: ) -> None:
"""Test streaming audio and losing connection.""" """Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
@ -61,13 +63,15 @@ async def test_streaming_audio_connection_lost(
"homeassistant.components.wyoming.stt.AsyncTcpClient", "homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([None]), MockAsyncTcpClient([None]),
): ):
result = await entity.async_process_audio_stream(None, audio_stream()) result = await entity.async_process_audio_stream(metadata, audio_stream())
assert result.result == stt.SpeechResultState.ERROR assert result.result == stt.SpeechResultState.ERROR
assert result.text is None assert result.text is None
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None: async def test_streaming_audio_oserror(
hass: HomeAssistant, init_wyoming_stt, metadata
) -> None:
"""Test streaming audio and error raising.""" """Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None assert entity is not None
@ -81,7 +85,7 @@ async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) ->
"homeassistant.components.wyoming.stt.AsyncTcpClient", "homeassistant.components.wyoming.stt.AsyncTcpClient",
mock_client, mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(None, audio_stream()) result = await entity.async_process_audio_stream(metadata, audio_stream())
assert result.result == stt.SpeechResultState.ERROR assert result.result == stt.SpeechResultState.ERROR
assert result.text is None assert result.text is None