From 334359bb0a118a917740682c0104a29d06187ec0 Mon Sep 17 00:00:00 2001 From: tronikos Date: Tue, 3 Sep 2024 06:23:07 -0700 Subject: [PATCH] Add Google Cloud Speech-to-Text (STT) (#120854) * Google Cloud * . * fix * mypy * add tests * Update .coveragerc * Update const.py * upload file, reconfigure and import flow * fixes * default to latest_short * mypy * update * Allow clearing options in the UI * update * update * update --- .../components/google_cloud/__init__.py | 2 +- .../components/google_cloud/config_flow.py | 20 ++- .../components/google_cloud/const.py | 164 ++++++++++++++++++ .../components/google_cloud/manifest.json | 5 +- .../components/google_cloud/strings.json | 3 +- homeassistant/components/google_cloud/stt.py | 147 ++++++++++++++++ requirements_all.txt | 3 + requirements_test_all.txt | 3 + .../google_cloud/test_config_flow.py | 2 + 9 files changed, 345 insertions(+), 4 deletions(-) create mode 100644 homeassistant/components/google_cloud/stt.py diff --git a/homeassistant/components/google_cloud/__init__.py b/homeassistant/components/google_cloud/__init__.py index 84848543790..9d1923fd87d 100644 --- a/homeassistant/components/google_cloud/__init__.py +++ b/homeassistant/components/google_cloud/__init__.py @@ -6,7 +6,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant -PLATFORMS = [Platform.TTS] +PLATFORMS = [Platform.STT, Platform.TTS] async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: diff --git a/homeassistant/components/google_cloud/config_flow.py b/homeassistant/components/google_cloud/config_flow.py index bf97de67eb1..dec849de4e6 100644 --- a/homeassistant/components/google_cloud/config_flow.py +++ b/homeassistant/components/google_cloud/config_flow.py @@ -26,7 +26,16 @@ from homeassistant.helpers.selector import ( SelectSelectorMode, ) -from .const import CONF_KEY_FILE, CONF_SERVICE_ACCOUNT_INFO, DEFAULT_LANG, DOMAIN, TITLE +from .const import ( + CONF_KEY_FILE, + CONF_SERVICE_ACCOUNT_INFO, + CONF_STT_MODEL, + DEFAULT_LANG, + DEFAULT_STT_MODEL, + DOMAIN, + SUPPORTED_STT_MODELS, + TITLE, +) from .helpers import ( async_tts_voices, tts_options_schema, @@ -162,6 +171,15 @@ class GoogleCloudOptionsFlowHandler(OptionsFlowWithConfigEntry): **tts_options_schema( self.options, voices, from_config_flow=True ).schema, + vol.Optional( + CONF_STT_MODEL, + default=DEFAULT_STT_MODEL, + ): SelectSelector( + SelectSelectorConfig( + mode=SelectSelectorMode.DROPDOWN, + options=SUPPORTED_STT_MODELS, + ) + ), } ), self.options, diff --git a/homeassistant/components/google_cloud/const.py b/homeassistant/components/google_cloud/const.py index 6a718bf35d3..f416d36483a 100644 --- a/homeassistant/components/google_cloud/const.py +++ b/homeassistant/components/google_cloud/const.py @@ -10,6 +10,7 @@ CONF_KEY_FILE = "key_file" DEFAULT_LANG = "en-US" +# TTS constants CONF_GENDER = "gender" CONF_VOICE = "voice" CONF_ENCODING = "encoding" @@ -18,3 +19,166 @@ CONF_PITCH = "pitch" CONF_GAIN = "gain" CONF_PROFILES = "profiles" CONF_TEXT_TYPE = "text_type" + +# STT constants +CONF_STT_MODEL = "stt_model" + +DEFAULT_STT_MODEL = "latest_short" + +# https://cloud.google.com/speech-to-text/docs/transcription-model +SUPPORTED_STT_MODELS = [ + "latest_long", + "latest_short", + "telephony", + "telephony_short", + "medical_dictation", + "medical_conversation", + "command_and_search", + "default", + "phone_call", + "video", +] + +# https://cloud.google.com/speech-to-text/docs/speech-to-text-supported-languages +STT_LANGUAGES = [ + "af-ZA", + "am-ET", + "ar-AE", + "ar-BH", + "ar-DZ", + "ar-EG", + "ar-IL", + "ar-IQ", + "ar-JO", + "ar-KW", + "ar-LB", + "ar-MA", + "ar-MR", + "ar-OM", + "ar-PS", + "ar-QA", + "ar-SA", + "ar-SY", + "ar-TN", + "ar-YE", + "az-AZ", + "bg-BG", + "bn-BD", + "bn-IN", + "bs-BA", + "ca-ES", + "cmn-Hans-CN", + "cmn-Hans-HK", + "cmn-Hant-TW", + "cs-CZ", + "da-DK", + "de-AT", + "de-CH", + "de-DE", + "el-GR", + "en-AU", + "en-CA", + "en-GB", + "en-GH", + "en-HK", + "en-IE", + "en-IN", + "en-KE", + "en-NG", + "en-NZ", + "en-PH", + "en-PK", + "en-SG", + "en-TZ", + "en-US", + "en-ZA", + "es-AR", + "es-BO", + "es-CL", + "es-CO", + "es-CR", + "es-DO", + "es-EC", + "es-ES", + "es-GT", + "es-HN", + "es-MX", + "es-NI", + "es-PA", + "es-PE", + "es-PR", + "es-PY", + "es-SV", + "es-US", + "es-UY", + "es-VE", + "et-EE", + "eu-ES", + "fa-IR", + "fi-FI", + "fil-PH", + "fr-BE", + "fr-CA", + "fr-CH", + "fr-FR", + "gl-ES", + "gu-IN", + "hi-IN", + "hr-HR", + "hu-HU", + "hy-AM", + "id-ID", + "is-IS", + "it-CH", + "it-IT", + "iw-IL", + "ja-JP", + "jv-ID", + "ka-GE", + "kk-KZ", + "km-KH", + "kn-IN", + "ko-KR", + "lo-LA", + "lt-LT", + "lv-LV", + "mk-MK", + "ml-IN", + "mn-MN", + "mr-IN", + "ms-MY", + "my-MM", + "ne-NP", + "nl-BE", + "nl-NL", + "no-NO", + "pa-Guru-IN", + "pl-PL", + "pt-BR", + "pt-PT", + "ro-RO", + "ru-RU", + "si-LK", + "sk-SK", + "sl-SI", + "sq-AL", + "sr-RS", + "su-ID", + "sv-SE", + "sw-KE", + "sw-TZ", + "ta-IN", + "ta-LK", + "ta-MY", + "ta-SG", + "te-IN", + "th-TH", + "tr-TR", + "uk-UA", + "ur-IN", + "ur-PK", + "uz-UZ", + "vi-VN", + "yue-Hant-HK", + "zu-ZA", +] diff --git a/homeassistant/components/google_cloud/manifest.json b/homeassistant/components/google_cloud/manifest.json index d0dda80a870..3e08b6254db 100644 --- a/homeassistant/components/google_cloud/manifest.json +++ b/homeassistant/components/google_cloud/manifest.json @@ -7,5 +7,8 @@ "documentation": "https://www.home-assistant.io/integrations/google_cloud", "integration_type": "service", "iot_class": "cloud_push", - "requirements": ["google-cloud-texttospeech==2.17.2"] + "requirements": [ + "google-cloud-texttospeech==2.17.2", + "google-cloud-speech==2.27.0" + ] } diff --git a/homeassistant/components/google_cloud/strings.json b/homeassistant/components/google_cloud/strings.json index 0a0804005de..3bf9d8c8489 100644 --- a/homeassistant/components/google_cloud/strings.json +++ b/homeassistant/components/google_cloud/strings.json @@ -24,7 +24,8 @@ "pitch": "Default pitch of the voice", "gain": "Default volume gain (in dB) of the voice", "profiles": "Default audio profiles", - "text_type": "Default text type" + "text_type": "Default text type", + "stt_model": "STT model" } } } diff --git a/homeassistant/components/google_cloud/stt.py b/homeassistant/components/google_cloud/stt.py new file mode 100644 index 00000000000..13715ae29f8 --- /dev/null +++ b/homeassistant/components/google_cloud/stt.py @@ -0,0 +1,147 @@ +"""Support for the Google Cloud STT service.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, AsyncIterable +import logging + +from google.api_core.exceptions import GoogleAPIError, Unauthenticated +from google.cloud import speech_v1 + +from homeassistant.components.stt import ( + AudioBitRates, + AudioChannels, + AudioCodecs, + AudioFormats, + AudioSampleRates, + SpeechMetadata, + SpeechResult, + SpeechResultState, + SpeechToTextEntity, +) +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import ( + CONF_SERVICE_ACCOUNT_INFO, + CONF_STT_MODEL, + DEFAULT_STT_MODEL, + DOMAIN, + STT_LANGUAGES, +) + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Google Cloud speech platform via config entry.""" + service_account_info = config_entry.data[CONF_SERVICE_ACCOUNT_INFO] + client = speech_v1.SpeechAsyncClient.from_service_account_info(service_account_info) + async_add_entities([GoogleCloudSpeechToTextEntity(config_entry, client)]) + + +class GoogleCloudSpeechToTextEntity(SpeechToTextEntity): + """Google Cloud STT entity.""" + + def __init__( + self, + entry: ConfigEntry, + client: speech_v1.SpeechAsyncClient, + ) -> None: + """Init Google Cloud STT entity.""" + self._attr_unique_id = f"{entry.entry_id}-stt" + self._attr_name = entry.title + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, entry.entry_id)}, + manufacturer="Google", + model="Cloud", + entry_type=dr.DeviceEntryType.SERVICE, + ) + self._entry = entry + self._client = client + self._model = entry.options.get(CONF_STT_MODEL, DEFAULT_STT_MODEL) + + @property + def supported_languages(self) -> list[str]: + """Return a list of supported languages.""" + return STT_LANGUAGES + + @property + def supported_formats(self) -> list[AudioFormats]: + """Return a list of supported formats.""" + return [AudioFormats.WAV, AudioFormats.OGG] + + @property + def supported_codecs(self) -> list[AudioCodecs]: + """Return a list of supported codecs.""" + return [AudioCodecs.PCM, AudioCodecs.OPUS] + + @property + def supported_bit_rates(self) -> list[AudioBitRates]: + """Return a list of supported bitrates.""" + return [AudioBitRates.BITRATE_16] + + @property + def supported_sample_rates(self) -> list[AudioSampleRates]: + """Return a list of supported samplerates.""" + return [AudioSampleRates.SAMPLERATE_16000] + + @property + def supported_channels(self) -> list[AudioChannels]: + """Return a list of supported channels.""" + return [AudioChannels.CHANNEL_MONO] + + async def async_process_audio_stream( + self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] + ) -> SpeechResult: + """Process an audio stream to STT service.""" + streaming_config = speech_v1.StreamingRecognitionConfig( + config=speech_v1.RecognitionConfig( + encoding=( + speech_v1.RecognitionConfig.AudioEncoding.OGG_OPUS + if metadata.codec == AudioCodecs.OPUS + else speech_v1.RecognitionConfig.AudioEncoding.LINEAR16 + ), + sample_rate_hertz=metadata.sample_rate, + language_code=metadata.language, + model=self._model, + ) + ) + + async def request_generator() -> ( + AsyncGenerator[speech_v1.StreamingRecognizeRequest] + ): + # The first request must only contain a streaming_config + yield speech_v1.StreamingRecognizeRequest(streaming_config=streaming_config) + # All subsequent requests must only contain audio_content + async for audio_content in stream: + yield speech_v1.StreamingRecognizeRequest(audio_content=audio_content) + + try: + responses = await self._client.streaming_recognize( + requests=request_generator(), + timeout=10, + ) + + transcript = "" + async for response in responses: + _LOGGER.debug("response: %s", response) + if not response.results: + continue + result = response.results[0] + if not result.alternatives: + continue + transcript += response.results[0].alternatives[0].transcript + except GoogleAPIError as err: + _LOGGER.error("Error occurred during Google Cloud STT call: %s", err) + if isinstance(err, Unauthenticated): + self._entry.async_start_reauth(self.hass) + return SpeechResult(None, SpeechResultState.ERROR) + + return SpeechResult(transcript, SpeechResultState.SUCCESS) diff --git a/requirements_all.txt b/requirements_all.txt index e14ef8af35a..ebc621486c9 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -985,6 +985,9 @@ google-api-python-client==2.71.0 # homeassistant.components.google_pubsub google-cloud-pubsub==2.23.0 +# homeassistant.components.google_cloud +google-cloud-speech==2.27.0 + # homeassistant.components.google_cloud google-cloud-texttospeech==2.17.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index b28d39da203..d151c8e6bc2 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -835,6 +835,9 @@ google-api-python-client==2.71.0 # homeassistant.components.google_pubsub google-cloud-pubsub==2.23.0 +# homeassistant.components.google_cloud +google-cloud-speech==2.27.0 + # homeassistant.components.google_cloud google-cloud-texttospeech==2.17.2 diff --git a/tests/components/google_cloud/test_config_flow.py b/tests/components/google_cloud/test_config_flow.py index a5a51052e66..e4b4631f223 100644 --- a/tests/components/google_cloud/test_config_flow.py +++ b/tests/components/google_cloud/test_config_flow.py @@ -161,6 +161,7 @@ async def test_options_flow( "gain", "profiles", "text_type", + "stt_model", } assert mock_api_tts_from_service_account_info.list_voices.call_count == 2 @@ -179,5 +180,6 @@ async def test_options_flow( "gain": 0.0, "profiles": [], "text_type": "text", + "stt_model": "latest_short", } assert mock_api_tts_from_service_account_info.list_voices.call_count == 3