Remove fuzzy language matching from stt and tts (#92002)

* Remove fuzzy language matching from stt and tts

* Update tests
This commit is contained in:
Erik Montnemery 2023-04-25 17:54:42 +02:00 committed by GitHub
parent d1e6e4078c
commit 792ea92e55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 57 additions and 97 deletions

View File

@ -221,18 +221,9 @@ class SpeechToTextEntity(RestoreEntity):
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if metadata.language not in self.supported_languages:
language_matches = language_util.matches(
metadata.language,
self.supported_languages,
)
if language_matches:
metadata.language = language_matches[0]
else:
return False
if (
metadata.format not in self.supported_formats
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates

View File

@ -11,7 +11,6 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.util import language as language_util
from .const import (
DATA_PROVIDERS,
@ -163,18 +162,9 @@ class Provider(ABC):
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if metadata.language not in self.supported_languages:
language_matches = language_util.matches(
metadata.language,
self.supported_languages,
)
if language_matches:
metadata.language = language_matches[0]
else:
return False
if (
metadata.format not in self.supported_formats
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates

View File

@ -483,20 +483,13 @@ class SpeechManager:
"""Validate and process options."""
# Languages
language = language or engine_instance.default_language
if language is None or engine_instance.supported_languages is None:
if (
language is None
or engine_instance.supported_languages is None
or language not in engine_instance.supported_languages
):
raise HomeAssistantError(f"Not supported language {language}")
if language not in engine_instance.supported_languages:
language_matches = language_util.matches(
language, engine_instance.supported_languages
)
if language_matches:
# Choose best match
language = language_matches[0]
else:
raise HomeAssistantError(f"Not supported language {language}")
# Options
if (default_options := engine_instance.default_options) and options:
merged_options = dict(default_options)

View File

@ -151,7 +151,7 @@
dict({
'data': dict({
'engine': 'test',
'language': 'en-UA',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'Arnold Schwarzenegger',
}),
@ -160,7 +160,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-UA&voice=Arnold+Schwarzenegger",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),
@ -238,7 +238,7 @@
dict({
'data': dict({
'engine': 'test',
'language': 'en-AU',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': 'Arnold Schwarzenegger',
}),
@ -247,7 +247,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-AU&voice=Arnold+Schwarzenegger",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=Arnold+Schwarzenegger",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3',
}),

View File

@ -84,9 +84,9 @@ async def test_pipeline_from_audio_stream_legacy(
"language": "en",
"name": "test_name",
"stt_engine": "test",
"stt_language": "en-UK",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-AU",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
}
)
@ -150,9 +150,9 @@ async def test_pipeline_from_audio_stream_entity(
"language": "en",
"name": "test_name",
"stt_engine": mock_stt_provider_entity.entity_id,
"stt_language": "en-UK",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-UA",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
}
)

View File

@ -54,7 +54,7 @@ class BaseProvider:
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["de", "de-CH", "en-US"]
return ["de", "de-CH", "en"]
@property
def supported_formats(self) -> list[AudioFormats]:
@ -224,7 +224,7 @@ async def test_get_provider_info(
response = await client.get(f"/api/stt/{setup.url_path}")
assert response.status == HTTPStatus.OK
assert await response.json() == {
"languages": ["de", "de-CH", "en-US"],
"languages": ["de", "de-CH", "en"],
"formats": ["wav", "ogg"],
"codecs": ["pcm", "opus"],
"sample_rates": [16000],
@ -247,7 +247,6 @@ async def test_non_existing_provider(
response = await client.get("/api/stt/not_exist")
assert response.status == HTTPStatus.NOT_FOUND
# Language en is matched with en-US
response = await client.post(
"/api/stt/not_exist",
headers={
@ -270,8 +269,6 @@ async def test_stream_audio(
) -> None:
"""Test streaming audio and getting response."""
client = await hass_client()
# Language en is matched with en-US
response = await client.post(
f"/api/stt/{setup.url_path}",
headers={
@ -404,7 +401,7 @@ async def test_ws_list_engines(
assert msg["success"]
assert msg["result"] == {
"providers": [
{"engine_id": engine_id, "supported_languages": ["de", "de-CH", "en-US"]}
{"engine_id": engine_id, "supported_languages": ["de", "de-CH", "en"]}
]
}
@ -421,7 +418,7 @@ async def test_ws_list_engines(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": engine_id, "supported_languages": ["en-US"]}]
"providers": [{"engine_id": engine_id, "supported_languages": ["en"]}]
}
await client.send_json_auto_id({"type": "stt/engine/list", "language": "en-UK"})
@ -429,7 +426,7 @@ async def test_ws_list_engines(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"providers": [{"engine_id": engine_id, "supported_languages": ["en-US"]}]
"providers": [{"engine_id": engine_id, "supported_languages": ["en"]}]
}
await client.send_json_auto_id({"type": "stt/engine/list", "language": "de"})

View File

@ -32,7 +32,6 @@ from tests.common import (
DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test"
TEST_LANGUAGES = ["de", "en"]
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
@ -105,11 +104,7 @@ class MockTTS(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(
SUPPORT_LANGUAGES + TEST_LANGUAGES
)
}
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
)
def __init__(self, provider: MockProvider, **kwargs: Any) -> None:

View File

@ -217,9 +217,9 @@ async def test_service(
).is_file()
# Language de is matched with de_DE
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), [(MockProvider("de"), MockTTSEntity("de"))]
("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -346,7 +346,7 @@ async def test_service_default_special_language(
{
ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
},
"test",
),
@ -357,7 +357,7 @@ async def test_service_default_special_language(
ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
},
"tts.test",
),
@ -455,7 +455,7 @@ async def test_service_wrong_language(
{
ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"voice": "alex", "age": 5},
},
"test",
@ -467,7 +467,7 @@ async def test_service_wrong_language(
ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"voice": "alex", "age": 5},
},
"tts.test",
@ -541,7 +541,7 @@ class MockEntityWithDefaults(MockTTSEntity):
{
ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
},
"test",
),
@ -552,7 +552,7 @@ class MockEntityWithDefaults(MockTTSEntity):
ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
},
"tts.test",
),
@ -607,7 +607,7 @@ async def test_service_default_options(
{
ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"age": 5},
},
"test",
@ -619,7 +619,7 @@ async def test_service_default_options(
ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"age": 5},
},
"tts.test",
@ -674,7 +674,7 @@ async def test_merge_default_service_options(
{
ATTR_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"speed": 1},
},
"test",
@ -686,7 +686,7 @@ async def test_merge_default_service_options(
ATTR_ENTITY_ID: "tts.test",
tts.ATTR_MEDIA_PLAYER_ENTITY_ID: "media_player.something",
tts.ATTR_MESSAGE: "There is someone at the door.",
tts.ATTR_LANGUAGE: "de",
tts.ATTR_LANGUAGE: "de_DE",
tts.ATTR_OPTIONS: {"speed": 1},
},
"tts.test",
@ -855,7 +855,8 @@ async def test_service_receive_voice(
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"), [(MockProvider("de"), MockTTSEntity("de"))]
("mock_provider", "mock_tts_entity"),
[(MockProvider("de_DE"), MockTTSEntity("de_DE"))],
)
@pytest.mark.parametrize(
("setup", "tts_service", "service_data", "expected_url_suffix"),
@ -1047,7 +1048,6 @@ async def test_setup_legacy_cache_dir(
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language en is matched with en_US
tts_data = b""
cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
@ -1084,7 +1084,6 @@ async def test_setup_cache_dir(
"""Set up a TTS platform with cache and call service without cache."""
calls = async_mock_service(hass, DOMAIN_MP, SERVICE_PLAY_MEDIA)
# Language en is matched with en_US
tts_data = b""
cache_file = empty_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
@ -1187,10 +1186,9 @@ async def test_load_cache_legacy_retrieve_without_mem_cache(
hass_client: ClientSessionGenerator,
) -> None:
"""Set up component and load cache and get without mem cache."""
# Language en is matched with en_US
tts_data = b""
cache_file = (
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
empty_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
)
with open(cache_file, "wb") as voice_file:
@ -1200,7 +1198,7 @@ async def test_load_cache_legacy_retrieve_without_mem_cache(
client = await hass_client()
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3"
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
req = await client.get(url)
assert req.status == HTTPStatus.OK
@ -1214,7 +1212,6 @@ async def test_load_cache_retrieve_without_mem_cache(
hass_client: ClientSessionGenerator,
) -> None:
"""Set up component and load cache and get without mem cache."""
# Language en is matched with en_US
tts_data = b""
cache_file = empty_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
@ -1306,7 +1303,7 @@ async def test_tags_with_wave() -> None:
)
tagged_data = ORIG_WRITE_TAGS(
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.wav",
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.wav",
tts_data,
"Test",
"AI person is in front of your door.",
@ -1367,9 +1364,9 @@ def test_invalid_base_url(value) -> None:
("engine", "language", "options", "cache", "result_query"),
(
(None, None, None, None, ""),
(None, "de", None, None, "language=de"),
(None, "de", {"voice": "henk"}, None, "language=de&voice=henk"),
(None, "de", None, True, "cache=true&language=de"),
(None, "de_DE", None, None, "language=de_DE"),
(None, "de_DE", {"voice": "henk"}, None, "language=de_DE&voice=henk"),
(None, "de_DE", None, True, "cache=true&language=de_DE"),
),
)
async def test_generate_media_source_id(
@ -1456,11 +1453,12 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
)
async def test_support_options(hass: HomeAssistant, setup: str, engine_id: str) -> None:
"""Test supporting options."""
# Language en is matched with en_US
assert await tts.async_support_options(hass, engine_id, "en") is True
assert await tts.async_support_options(hass, engine_id, "en_US") is True
assert await tts.async_support_options(hass, engine_id, "nl") is False
assert (
await tts.async_support_options(hass, engine_id, "en", {"invalid_option": "yo"})
await tts.async_support_options(
hass, engine_id, "en_US", {"invalid_option": "yo"}
)
is False
)
@ -1496,7 +1494,7 @@ async def test_legacy_fetching_in_async(
# Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id(
hass, "test message", "test", "en", None, None
hass, "test message", "test", "en_US", None, None
)
task = hass.async_create_task(
@ -1526,7 +1524,7 @@ async def test_legacy_fetching_in_async(
# Test error is not cached
media_source_id = tts.generate_media_source_id(
hass, "test message 2", "test", "en", None, None
hass, "test message 2", "test", "en_US", None, None
)
tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error"))
@ -1569,7 +1567,7 @@ async def test_fetching_in_async(
# Test async_get_media_source_audio
media_source_id = tts.generate_media_source_id(
hass, "test message", "tts.test", "en", None, None
hass, "test message", "tts.test", "en_US", None, None
)
task = hass.async_create_task(
@ -1599,7 +1597,7 @@ async def test_fetching_in_async(
# Test error is not cached
media_source_id = tts.generate_media_source_id(
hass, "test message 2", "tts.test", "en", None, None
hass, "test message 2", "tts.test", "en_US", None, None
)
tts_audio = asyncio.Future()
tts_audio.set_exception(HomeAssistantError("test error"))

View File

@ -109,7 +109,7 @@ async def test_legacy_resolving(hass: HomeAssistant, mock_provider: MSProvider)
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/test?message=Bye%20World&language=de&voice=Paulus",
"media-source://tts/test?message=Bye%20World&language=de_DE&voice=Paulus",
None,
)
assert media.url.startswith("/api/tts_proxy/")
@ -144,7 +144,7 @@ async def test_resolving(hass: HomeAssistant, mock_tts_entity: MSEntity) -> None
mock_get_tts_audio.reset_mock()
media = await media_source.async_resolve_media(
hass,
"media-source://tts/tts.test?message=Bye%20World&language=de&voice=Paulus",
"media-source://tts/tts.test?message=Bye%20World&language=de_DE&voice=Paulus",
None,
)
assert media.url.startswith("/api/tts_proxy/")

View File

@ -104,7 +104,7 @@ async def test_setup_service(
"name": "tts_test",
"entity_id": "tts.test",
"media_player": "media_player.demo",
"language": "en",
"language": "en_US",
},
}

View File

@ -53,9 +53,7 @@ async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) ->
) as mock_client:
extension, data = await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)
assert extension == "wav"
@ -89,7 +87,7 @@ async def test_get_tts_audio_raw(
hass,
"Hello world",
"tts.test_tts",
hass.config.language,
"en-US",
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
),
)
@ -109,9 +107,7 @@ async def test_get_tts_audio_connection_lost(
), pytest.raises(HomeAssistantError):
await tts.async_get_media_source_audio(
hass,
tts.generate_media_source_id(
hass, "Hello world", "tts.test_tts", hass.config.language
),
tts.generate_media_source_id(hass, "Hello world", "tts.test_tts", "en-US"),
)