From f96515b90a5afe74b5a848f97715d06568909ea4 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 17 Apr 2023 22:23:43 -0500 Subject: [PATCH] Use language util in stt/tts (#91521) * Use language util in stt/tts * Test language util in stt/tts * Fix common in TTS * Update snapshot --- homeassistant/components/stt/__init__.py | 13 ++++- homeassistant/components/stt/legacy.py | 14 ++++- homeassistant/components/tts/__init__.py | 17 ++++-- tests/components/stt/test_init.py | 7 ++- tests/components/tts/common.py | 11 +++- tests/components/tts/test_init.py | 71 ++++++++++++++---------- 6 files changed, 89 insertions(+), 44 deletions(-) diff --git a/homeassistant/components/stt/__init__.py b/homeassistant/components/stt/__init__.py index e0ab446153c..4873403d59d 100644 --- a/homeassistant/components/stt/__init__.py +++ b/homeassistant/components/stt/__init__.py @@ -193,9 +193,18 @@ class SpeechToTextEntity(RestoreEntity): @callback def check_metadata(self, metadata: SpeechMetadata) -> bool: """Check if given metadata supported by this provider.""" + if metadata.language not in self.supported_languages: + language_matches = language_util.matches( + metadata.language, + self.supported_languages, + ) + if language_matches: + metadata.language = language_matches[0] + else: + return False + if ( - metadata.language not in self.supported_languages - or metadata.format not in self.supported_formats + metadata.format not in self.supported_formats or metadata.codec not in self.supported_codecs or metadata.bit_rate not in self.supported_bit_rates or metadata.sample_rate not in self.supported_sample_rates diff --git a/homeassistant/components/stt/legacy.py b/homeassistant/components/stt/legacy.py index ffa21a257f1..be8429b9a4f 100644 --- a/homeassistant/components/stt/legacy.py +++ b/homeassistant/components/stt/legacy.py @@ -11,6 +11,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_prepare_setup_platform +from homeassistant.util import language as language_util from .const import ( DATA_PROVIDERS, @@ -158,9 +159,18 @@ class Provider(ABC): @callback def check_metadata(self, metadata: SpeechMetadata) -> bool: """Check if given metadata supported by this provider.""" + if metadata.language not in self.supported_languages: + language_matches = language_util.matches( + metadata.language, + self.supported_languages, + ) + if language_matches: + metadata.language = language_matches[0] + else: + return False + if ( - metadata.language not in self.supported_languages - or metadata.format not in self.supported_formats + metadata.format not in self.supported_formats or metadata.codec not in self.supported_codecs or metadata.bit_rate not in self.supported_bit_rates or metadata.sample_rate not in self.supported_sample_rates diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index c3349a7bf3e..27b9a754000 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -276,13 +276,20 @@ class SpeechManager: # Languages language = language or provider.default_language - if ( - language is None - or provider.supported_languages is None - or language not in provider.supported_languages - ): + + if language is None or provider.supported_languages is None: raise HomeAssistantError(f"Not supported language {language}") + if language not in provider.supported_languages: + language_matches = language_util.matches( + language, provider.supported_languages + ) + if language_matches: + # Choose best match + language = language_matches[0] + else: + raise HomeAssistantError(f"Not supported language {language}") + # Options if (default_options := provider.default_options) and options: merged_options = dict(default_options) diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index fd3dbba0e1f..5f3f491904a 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -52,7 +52,7 @@ class BaseProvider: @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" - return ["en"] + return ["de-DE", "en-US"] @property def supported_formats(self) -> list[AudioFormats]: @@ -213,7 +213,7 @@ async def test_get_provider_info( response = await client.get(f"/api/stt/{TEST_DOMAIN}") assert response.status == HTTPStatus.OK assert await response.json() == { - "languages": ["en"], + "languages": ["de-DE", "en-US"], "formats": ["wav", "ogg"], "codecs": ["pcm", "opus"], "sample_rates": [16000], @@ -236,6 +236,7 @@ async def test_non_existing_provider( response = await client.get("/api/stt/not_exist") assert response.status == HTTPStatus.NOT_FOUND + # Language en is matched with en-US response = await client.post( "/api/stt/not_exist", headers={ @@ -258,6 +259,8 @@ async def test_stream_audio( ) -> None: """Test streaming audio and getting response.""" client = await hass_client() + + # Language en is matched with en-US response = await client.post( f"/api/stt/{TEST_DOMAIN}", headers={ diff --git a/tests/components/tts/common.py b/tests/components/tts/common.py index fbdfad54e18..9e175089fc8 100644 --- a/tests/components/tts/common.py +++ b/tests/components/tts/common.py @@ -16,9 +16,10 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from tests.common import MockPlatform -SUPPORT_LANGUAGES = ["de", "en", "en_US"] +SUPPORT_LANGUAGES = ["de_DE", "en_GB", "en_US"] +TEST_LANGUAGES = ["de", "en"] -DEFAULT_LANG = "en" +DEFAULT_LANG = "en_US" class MockProvider(Provider): @@ -55,7 +56,11 @@ class MockTTS(MockPlatform): """A mock TTS platform.""" PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( - {vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)} + { + vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In( + SUPPORT_LANGUAGES + TEST_LANGUAGES + ) + } ) def __init__( diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 6e0a012bc37..bc176f08fda 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -49,7 +49,7 @@ async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> st @pytest.fixture def mock_provider() -> MockProvider: """Test TTS provider.""" - return MockProvider("en") + return MockProvider("en_US") @pytest.fixture @@ -104,11 +104,11 @@ async def test_setup_component_and_test_service( assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ) await hass.async_block_till_done() assert ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ).is_file() @@ -118,6 +118,7 @@ async def test_setup_component_and_test_service_with_config_language( """Set up a TTS platform and call service.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + # Language de is matched with de_DE config = {tts.DOMAIN: {"platform": "test", "language": "de"}} with assert_setup_component(1, tts.DOMAIN): @@ -136,11 +137,11 @@ async def test_setup_component_and_test_service_with_config_language( assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3" + == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3" ) await hass.async_block_till_done() assert ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3" ).is_file() @@ -197,6 +198,7 @@ async def test_setup_component_and_test_service_with_service_language( with assert_setup_component(1, tts.DOMAIN): assert await async_setup_component(hass, tts.DOMAIN, config) + # Language de is matched to de_DE await hass.services.async_call( tts.DOMAIN, "test_say", @@ -211,11 +213,11 @@ async def test_setup_component_and_test_service_with_service_language( assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3" + == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3" ) await hass.async_block_till_done() assert ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3" ).is_file() @@ -258,6 +260,7 @@ async def test_setup_component_and_test_service_with_service_options( with assert_setup_component(1, tts.DOMAIN): assert await async_setup_component(hass, tts.DOMAIN, config) + # Language de is matched with de_DE await hass.services.async_call( tts.DOMAIN, "test_say", @@ -275,12 +278,12 @@ async def test_setup_component_and_test_service_with_service_options( assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + == f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ) await hass.async_block_till_done() assert ( empty_cache_dir - / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ).is_file() @@ -305,6 +308,7 @@ async def test_setup_component_and_test_with_service_options_def( with assert_setup_component(1, tts.DOMAIN): assert await async_setup_component(hass, tts.DOMAIN, config) + # Language de is matched with de_DE await hass.services.async_call( tts.DOMAIN, "test_say", @@ -321,12 +325,12 @@ async def test_setup_component_and_test_with_service_options_def( assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + == f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ) await hass.async_block_till_done() assert ( empty_cache_dir - / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ).is_file() @@ -354,6 +358,7 @@ async def test_setup_component_and_test_with_service_options_def_2( with assert_setup_component(1, tts.DOMAIN): assert await async_setup_component(hass, tts.DOMAIN, config) + # Language de is matched with de_DE await hass.services.async_call( tts.DOMAIN, "test_say", @@ -371,12 +376,12 @@ async def test_setup_component_and_test_with_service_options_def_2( assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + == f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ) await hass.async_block_till_done() assert ( empty_cache_dir - / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ).is_file() @@ -391,6 +396,7 @@ async def test_setup_component_and_test_service_with_service_options_wrong( with assert_setup_component(1, tts.DOMAIN): assert await async_setup_component(hass, tts.DOMAIN, config) + # Language de is matched with de_DE with pytest.raises(HomeAssistantError): await hass.services.async_call( tts.DOMAIN, @@ -409,7 +415,7 @@ async def test_setup_component_and_test_service_with_service_options_wrong( await hass.async_block_till_done() assert not ( empty_cache_dir - / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3" + / f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3" ).is_file() @@ -439,7 +445,7 @@ async def test_setup_component_and_test_service_with_base_url_set( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == "http://fnord" "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" - "_en_-_test.mp3" + "_en-us_-_test.mp3" ) @@ -468,7 +474,7 @@ async def test_setup_component_and_test_service_clear_cache( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) await hass.async_block_till_done() assert ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ).is_file() await hass.services.async_call( @@ -476,7 +482,7 @@ async def test_setup_component_and_test_service_clear_cache( ) assert not ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ).is_file() @@ -510,10 +516,11 @@ async def test_setup_component_and_test_service_with_receive_voice( url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) client = await hass_client() req = await client.get(url) + # Language en is matched with en_US _, tts_data = mock_provider.get_tts_audio("bla", "en") assert tts_data is not None tts_data = tts.SpeechManager.write_tags( - "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3", + "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3", tts_data, mock_provider, message, @@ -539,6 +546,7 @@ async def test_setup_component_and_test_service_with_receive_voice_german( """Set up a TTS platform and call service and receive voice.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + # Language de is matched with de_DE config = {tts.DOMAIN: {"platform": "test", "language": "de"}} with assert_setup_component(1, tts.DOMAIN): @@ -560,7 +568,7 @@ async def test_setup_component_and_test_service_with_receive_voice_german( _, tts_data = mock_provider.get_tts_audio("bla", "de") assert tts_data is not None tts_data = tts.SpeechManager.write_tags( - "42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3", + "42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3", tts_data, mock_provider, "There is someone at the door.", @@ -582,7 +590,7 @@ async def test_setup_component_and_web_view_wrong_file( client = await hass_client() - url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" req = await client.get(url) assert req.status == HTTPStatus.NOT_FOUND @@ -599,7 +607,7 @@ async def test_setup_component_and_web_view_wrong_filename( client = await hass_client() - url = "/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd_en_-_test.mp3" + url = "/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd_en-us_-_test.mp3" req = await client.get(url) assert req.status == HTTPStatus.NOT_FOUND @@ -628,7 +636,7 @@ async def test_setup_component_test_without_cache( assert len(calls) == 1 await hass.async_block_till_done() assert not ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ).is_file() @@ -656,7 +664,7 @@ async def test_setup_component_test_with_cache_call_service_without_cache( assert len(calls) == 1 await hass.async_block_till_done() assert not ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ).is_file() @@ -666,10 +674,11 @@ async def test_setup_component_test_with_cache_dir( """Set up a TTS platform with cache and call service without cache.""" calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA) + # Language en is matched with en_US _, tts_data = mock_provider.get_tts_audio("bla", "en") assert tts_data is not None cache_file = ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ) with open(cache_file, "wb") as voice_file: @@ -705,7 +714,7 @@ async def test_setup_component_test_with_cache_dir( assert len(calls) == 1 assert ( await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) - == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + == "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ) await hass.async_block_till_done() @@ -753,10 +762,11 @@ async def test_setup_component_load_cache_retrieve_without_mem_cache( mock_tts, ) -> None: """Set up component and load cache and get without mem cache.""" + # Language en is matched with en_US _, tts_data = mock_provider.get_tts_audio("bla", "en") assert tts_data is not None cache_file = ( - empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" ) with open(cache_file, "wb") as voice_file: @@ -769,7 +779,7 @@ async def test_setup_component_load_cache_retrieve_without_mem_cache( client = await hass_client() - url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" + url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" req = await client.get(url) assert req.status == HTTPStatus.OK @@ -793,8 +803,8 @@ async def test_setup_component_and_web_get_url( assert req.status == HTTPStatus.OK response = await req.json() assert response == { - "url": "http://example.local:8123/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3", - "path": "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3", + "url": "http://example.local:8123/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3", + "path": "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3", } @@ -825,7 +835,7 @@ async def test_tags_with_wave(hass: HomeAssistant, mock_provider: MockProvider) ) tagged_data = ORIG_WRITE_TAGS( - "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav", + "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.wav", tts_data, mock_provider, "AI person is in front of your door.", @@ -937,6 +947,7 @@ def test_resolve_engine(hass: HomeAssistant, setup_tts) -> None: async def test_support_options(hass: HomeAssistant, setup_tts) -> None: """Test supporting options.""" + # Language en is matched with en_US assert await tts.async_support_options(hass, "test", "en") is True assert await tts.async_support_options(hass, "test", "nl") is False assert (