Use language util in stt/tts (#91521)

* Use language util in stt/tts

* Test language util in stt/tts

* Fix common in TTS

* Update snapshot
This commit is contained in:
Michael Hansen 2023-04-17 22:23:43 -05:00 committed by GitHub
parent 95d16c9829
commit f96515b90a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 89 additions and 44 deletions

View File

@ -193,9 +193,18 @@ class SpeechToTextEntity(RestoreEntity):
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if metadata.language not in self.supported_languages:
language_matches = language_util.matches(
metadata.language,
self.supported_languages,
)
if language_matches:
metadata.language = language_matches[0]
else:
return False
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates

View File

@ -11,6 +11,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util import language as language_util
from .const import (
DATA_PROVIDERS,
@ -158,9 +159,18 @@ class Provider(ABC):
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if metadata.language not in self.supported_languages:
language_matches = language_util.matches(
metadata.language,
self.supported_languages,
)
if language_matches:
metadata.language = language_matches[0]
else:
return False
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates

View File

@ -276,13 +276,20 @@ class SpeechManager:
# Languages
language = language or provider.default_language
if (
language is None
or provider.supported_languages is None
or language not in provider.supported_languages
):
if language is None or provider.supported_languages is None:
raise HomeAssistantError(f"Not supported language {language}")
if language not in provider.supported_languages:
language_matches = language_util.matches(
language, provider.supported_languages
)
if language_matches:
# Choose best match
language = language_matches[0]
else:
raise HomeAssistantError(f"Not supported language {language}")
# Options
if (default_options := provider.default_options) and options:
merged_options = dict(default_options)

View File

@ -52,7 +52,7 @@ class BaseProvider:
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en"]
return ["de-DE", "en-US"]
@property
def supported_formats(self) -> list[AudioFormats]:
@ -213,7 +213,7 @@ async def test_get_provider_info(
response = await client.get(f"/api/stt/{TEST_DOMAIN}")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["en"],
"languages": ["de-DE", "en-US"],
"formats": ["wav", "ogg"],
"codecs": ["pcm", "opus"],
"sample_rates": [16000],
@ -236,6 +236,7 @@ async def test_non_existing_provider(
response = await client.get("/api/stt/not_exist")
assert response.status == HTTPStatus.NOT_FOUND
# Language en is matched with en-US
response = await client.post(
"/api/stt/not_exist",
headers={
@ -258,6 +259,8 @@ async def test_stream_audio(
) -> None:
"""Test streaming audio and getting response."""
client = await hass_client()
# Language en is matched with en-US
response = await client.post(
f"/api/stt/{TEST_DOMAIN}",
headers={

View File

@ -16,9 +16,10 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform
SUPPORT_LANGUAGES = ["de", "en", "en_US"]
SUPPORT_LANGUAGES = ["de_DE", "en_GB", "en_US"]
TEST_LANGUAGES = ["de", "en"]
DEFAULT_LANG = "en"
DEFAULT_LANG = "en_US"
class MockProvider(Provider):
@ -55,7 +56,11 @@ class MockTTS(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
{
vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(
SUPPORT_LANGUAGES + TEST_LANGUAGES
)
}
)
def __init__(

View File

@ -49,7 +49,7 @@ async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> st
@pytest.fixture
def mock_provider() -> MockProvider:
"""Test TTS provider."""
return MockProvider("en")
return MockProvider("en_US")
@pytest.fixture
@ -104,11 +104,11 @@ async def test_setup_component_and_test_service(
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
)
await hass.async_block_till_done()
assert (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()
@ -118,6 +118,7 @@ async def test_setup_component_and_test_service_with_config_language(
"""Set up a TTS platform and call service."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language de is matched with de_DE
config = {tts.DOMAIN: {"platform": "test", "language": "de"}}
with assert_setup_component(1, tts.DOMAIN):
@ -136,11 +137,11 @@ async def test_setup_component_and_test_service_with_config_language(
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3"
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3"
)
await hass.async_block_till_done()
assert (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3"
).is_file()
@ -197,6 +198,7 @@ async def test_setup_component_and_test_service_with_service_language(
with assert_setup_component(1, tts.DOMAIN):
assert await async_setup_component(hass, tts.DOMAIN, config)
# Language de is matched to de_DE
await hass.services.async_call(
tts.DOMAIN,
"test_say",
@ -211,11 +213,11 @@ async def test_setup_component_and_test_service_with_service_language(
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3"
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3"
)
await hass.async_block_till_done()
assert (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3"
).is_file()
@ -258,6 +260,7 @@ async def test_setup_component_and_test_service_with_service_options(
with assert_setup_component(1, tts.DOMAIN):
assert await async_setup_component(hass, tts.DOMAIN, config)
# Language de is matched with de_DE
await hass.services.async_call(
tts.DOMAIN,
"test_say",
@ -275,12 +278,12 @@ async def test_setup_component_and_test_service_with_service_options(
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
== f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
)
await hass.async_block_till_done()
assert (
empty_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
).is_file()
@ -305,6 +308,7 @@ async def test_setup_component_and_test_with_service_options_def(
with assert_setup_component(1, tts.DOMAIN):
assert await async_setup_component(hass, tts.DOMAIN, config)
# Language de is matched with de_DE
await hass.services.async_call(
tts.DOMAIN,
"test_say",
@ -321,12 +325,12 @@ async def test_setup_component_and_test_with_service_options_def(
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
== f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
)
await hass.async_block_till_done()
assert (
empty_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
).is_file()
@ -354,6 +358,7 @@ async def test_setup_component_and_test_with_service_options_def_2(
with assert_setup_component(1, tts.DOMAIN):
assert await async_setup_component(hass, tts.DOMAIN, config)
# Language de is matched with de_DE
await hass.services.async_call(
tts.DOMAIN,
"test_say",
@ -371,12 +376,12 @@ async def test_setup_component_and_test_with_service_options_def_2(
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
== f"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
)
await hass.async_block_till_done()
assert (
empty_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
).is_file()
@ -391,6 +396,7 @@ async def test_setup_component_and_test_service_with_service_options_wrong(
with assert_setup_component(1, tts.DOMAIN):
assert await async_setup_component(hass, tts.DOMAIN, config)
# Language de is matched with de_DE
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
tts.DOMAIN,
@ -409,7 +415,7 @@ async def test_setup_component_and_test_service_with_service_options_wrong(
await hass.async_block_till_done()
assert not (
empty_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_{opt_hash}_test.mp3"
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_{opt_hash}_test.mp3"
).is_file()
@ -439,7 +445,7 @@ async def test_setup_component_and_test_service_with_base_url_set(
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "http://fnord"
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
"_en_-_test.mp3"
"_en-us_-_test.mp3"
)
@ -468,7 +474,7 @@ async def test_setup_component_and_test_service_clear_cache(
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
await hass.async_block_till_done()
assert (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()
await hass.services.async_call(
@ -476,7 +482,7 @@ async def test_setup_component_and_test_service_clear_cache(
)
assert not (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()
@ -510,10 +516,11 @@ async def test_setup_component_and_test_service_with_receive_voice(
url = await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
client = await hass_client()
req = await client.get(url)
# Language en is matched with en_US
_, tts_data = mock_provider.get_tts_audio("bla", "en")
assert tts_data is not None
tts_data = tts.SpeechManager.write_tags(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3",
tts_data,
mock_provider,
message,
@ -539,6 +546,7 @@ async def test_setup_component_and_test_service_with_receive_voice_german(
"""Set up a TTS platform and call service and receive voice."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language de is matched with de_DE
config = {tts.DOMAIN: {"platform": "test", "language": "de"}}
with assert_setup_component(1, tts.DOMAIN):
@ -560,7 +568,7 @@ async def test_setup_component_and_test_service_with_receive_voice_german(
_, tts_data = mock_provider.get_tts_audio("bla", "de")
assert tts_data is not None
tts_data = tts.SpeechManager.write_tags(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de_-_test.mp3",
"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_test.mp3",
tts_data,
mock_provider,
"There is someone at the door.",
@ -582,7 +590,7 @@ async def test_setup_component_and_web_view_wrong_file(
client = await hass_client()
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
req = await client.get(url)
assert req.status == HTTPStatus.NOT_FOUND
@ -599,7 +607,7 @@ async def test_setup_component_and_web_view_wrong_filename(
client = await hass_client()
url = "/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd_en_-_test.mp3"
url = "/api/tts_proxy/265944dsk32c1b2a621be5930510bb2cd_en-us_-_test.mp3"
req = await client.get(url)
assert req.status == HTTPStatus.NOT_FOUND
@ -628,7 +636,7 @@ async def test_setup_component_test_without_cache(
assert len(calls) == 1
await hass.async_block_till_done()
assert not (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()
@ -656,7 +664,7 @@ async def test_setup_component_test_with_cache_call_service_without_cache(
assert len(calls) == 1
await hass.async_block_till_done()
assert not (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
).is_file()
@ -666,10 +674,11 @@ async def test_setup_component_test_with_cache_dir(
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language en is matched with en_US
_, tts_data = mock_provider.get_tts_audio("bla", "en")
assert tts_data is not None
cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
)
with open(cache_file, "wb") as voice_file:
@ -705,7 +714,7 @@ async def test_setup_component_test_with_cache_dir(
assert len(calls) == 1
assert (
await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID])
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
== "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
)
await hass.async_block_till_done()
@ -753,10 +762,11 @@ async def test_setup_component_load_cache_retrieve_without_mem_cache(
mock_tts,
) -> None:
"""Set up component and load cache and get without mem cache."""
# Language en is matched with en_US
_, tts_data = mock_provider.get_tts_audio("bla", "en")
assert tts_data is not None
cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
)
with open(cache_file, "wb") as voice_file:
@ -769,7 +779,7 @@ async def test_setup_component_load_cache_retrieve_without_mem_cache(
client = await hass_client()
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
req = await client.get(url)
assert req.status == HTTPStatus.OK
@ -793,8 +803,8 @@ async def test_setup_component_and_web_get_url(
assert req.status == HTTPStatus.OK
response = await req.json()
assert response == {
"url": "http://example.local:8123/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
"path": "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3",
"url": "http://example.local:8123/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3",
"path": "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3",
}
@ -825,7 +835,7 @@ async def test_tags_with_wave(hass: HomeAssistant, mock_provider: MockProvider)
)
tagged_data = ORIG_WRITE_TAGS(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.wav",
tts_data,
mock_provider,
"AI person is in front of your door.",
@ -937,6 +947,7 @@ def test_resolve_engine(hass: HomeAssistant, setup_tts) -> None:
async def test_support_options(hass: HomeAssistant, setup_tts) -> None:
"""Test supporting options."""
# Language en is matched with en_US
assert await tts.async_support_options(hass, "test", "en") is True
assert await tts.async_support_options(hass, "test", "nl") is False
assert (