mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 17:27:52 +00:00
Assist satellite to use TTS tokens for announcements (#140336)
* Migrate Assist Satellite to use token * Fix tests * Fix tests
This commit is contained in:
parent
e710d3699c
commit
f32bb1a318
@ -23,9 +23,6 @@ from homeassistant.components.assist_pipeline import (
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.components.tts import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, entity
|
||||
@ -98,6 +95,9 @@ class AssistSatelliteAnnouncement:
|
||||
original_media_id: str
|
||||
"""The raw media ID before processing."""
|
||||
|
||||
tts_token: str | None
|
||||
"""The TTS token of the media."""
|
||||
|
||||
media_id_source: Literal["url", "media_id", "tts"]
|
||||
"""Source of the media ID."""
|
||||
|
||||
@ -474,6 +474,7 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
) -> AssistSatelliteAnnouncement:
|
||||
"""Resolve the media ID."""
|
||||
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||
tts_token: str | None = None
|
||||
|
||||
if media_id:
|
||||
original_media_id = media_id
|
||||
@ -484,6 +485,10 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
pipeline_id = self._resolve_pipeline()
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
engine = tts.async_resolve_engine(self.hass, pipeline.tts_engine)
|
||||
if engine is None:
|
||||
raise HomeAssistantError(f"TTS engine {pipeline.tts_engine} not found")
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
@ -491,14 +496,23 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
if self.tts_options is not None:
|
||||
tts_options.update(self.tts_options)
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
stream = tts.async_create_stream(
|
||||
self.hass,
|
||||
message,
|
||||
engine=pipeline.tts_engine,
|
||||
engine=engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
stream.async_set_message(message)
|
||||
|
||||
tts_token = stream.token
|
||||
media_id = stream.url
|
||||
original_media_id = tts.generate_media_source_id(
|
||||
self.hass,
|
||||
message,
|
||||
engine=engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
original_media_id = media_id
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
if not media_id_source:
|
||||
@ -517,5 +531,9 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
|
||||
return AssistSatelliteAnnouncement(
|
||||
message, media_id, original_media_id, media_id_source
|
||||
message=message,
|
||||
media_id=media_id,
|
||||
original_media_id=original_media_id,
|
||||
tts_token=tts_token,
|
||||
media_id_source=media_id_source,
|
||||
)
|
||||
|
@ -31,6 +31,8 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
from tests.components.tts.common import MockResultStream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_session_conversation_id() -> Generator[Mock]:
|
||||
@ -186,8 +188,9 @@ async def test_new_pipeline_cancels_pipeline(
|
||||
{"message": "Hello"},
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
||||
original_media_id="media-source://bla",
|
||||
tts_token="test-token",
|
||||
media_id_source="tts",
|
||||
),
|
||||
),
|
||||
@ -200,6 +203,7 @@ async def test_new_pipeline_cancels_pipeline(
|
||||
message="Hello",
|
||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||
original_media_id="media-source://given",
|
||||
tts_token=None,
|
||||
media_id_source="media_id",
|
||||
),
|
||||
),
|
||||
@ -209,6 +213,7 @@ async def test_new_pipeline_cancels_pipeline(
|
||||
message="",
|
||||
media_id="http://example.com/bla.mp3",
|
||||
original_media_id="http://example.com/bla.mp3",
|
||||
tts_token=None,
|
||||
media_id_source="url",
|
||||
),
|
||||
),
|
||||
@ -243,9 +248,17 @@ async def test_announce(
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
new=tts_generate_media_source_id,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
@ -500,7 +513,8 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
"Better system prompt",
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
||||
tts_token="test-token",
|
||||
original_media_id="media-source://generated",
|
||||
media_id_source="tts",
|
||||
),
|
||||
@ -517,6 +531,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
AssistSatelliteAnnouncement(
|
||||
message="Hello",
|
||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||
tts_token=None,
|
||||
original_media_id="media-source://given",
|
||||
media_id_source="media_id",
|
||||
),
|
||||
@ -530,6 +545,7 @@ async def test_vad_sensitivity_entity_not_found(
|
||||
AssistSatelliteAnnouncement(
|
||||
message="",
|
||||
media_id="http://example.com/given.mp3",
|
||||
tts_token=None,
|
||||
original_media_id="http://example.com/given.mp3",
|
||||
media_id_source="url",
|
||||
),
|
||||
@ -554,9 +570,17 @@ async def test_start_conversation(
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
return_value="media-source://generated",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
|
@ -4,28 +4,28 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from .conftest import TEST_DOMAIN, MockAssistSatellite
|
||||
|
||||
from tests.components.tts.common import MockResultStream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts():
|
||||
async def mock_tts(hass: HomeAssistant):
|
||||
"""Mock TTS service."""
|
||||
assert await async_setup_component(hass, "tts", {})
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
):
|
||||
yield
|
||||
@ -41,9 +41,13 @@ async def test_broadcast_intent(
|
||||
) -> None:
|
||||
"""Test we can invoke a broadcast intent."""
|
||||
|
||||
result = await intent.async_handle(
|
||||
hass, "test", intent.INTENT_BROADCAST, {"message": {"value": "Hello"}}
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud",
|
||||
):
|
||||
result = await intent.async_handle(
|
||||
hass, "test", intent.INTENT_BROADCAST, {"message": {"value": "Hello"}}
|
||||
)
|
||||
|
||||
assert result.as_dict() == {
|
||||
"card": {},
|
||||
@ -71,13 +75,17 @@ async def test_broadcast_intent(
|
||||
assert len(entity2.announcements) == 1
|
||||
assert len(entity_no_features.announcements) == 0
|
||||
|
||||
result = await intent.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent.INTENT_BROADCAST,
|
||||
{"message": {"value": "Hello"}},
|
||||
device_id=entity.device_entry.id,
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud",
|
||||
):
|
||||
result = await intent.async_handle(
|
||||
hass,
|
||||
"test",
|
||||
intent.INTENT_BROADCAST,
|
||||
{"message": {"value": "Hello"}},
|
||||
device_id=entity.device_entry.id,
|
||||
)
|
||||
# Broadcast doesn't targets device that triggered it.
|
||||
assert result.as_dict() == {
|
||||
"card": {},
|
||||
|
@ -41,7 +41,6 @@ from homeassistant.components.esphome.assist_satellite import (
|
||||
EsphomeAssistSatellite,
|
||||
VoiceAssistantUDPServer,
|
||||
)
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.components.select import (
|
||||
DOMAIN as SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
@ -57,6 +56,8 @@ from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .conftest import MockESPHomeDevice
|
||||
|
||||
from tests.components.tts.common import MockResultStream
|
||||
|
||||
|
||||
def get_satellite_entity(
|
||||
hass: HomeAssistant, mac_address: str
|
||||
@ -1209,22 +1210,23 @@ async def test_announce_message(
|
||||
media_id: str, timeout: float, text: str
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
|
||||
assert text == "test-text"
|
||||
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud_tts",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
patch.object(
|
||||
mock_client,
|
||||
|
@ -270,6 +270,8 @@ async def mock_config_entry_setup(
|
||||
class MockResultStream(ResultStream):
|
||||
"""Mock result stream."""
|
||||
|
||||
test_set_message: str | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant, extension: str, data: bytes) -> None:
|
||||
"""Initialize the result stream."""
|
||||
super().__init__(
|
||||
@ -285,6 +287,11 @@ class MockResultStream(ResultStream):
|
||||
hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self
|
||||
self._mock_data = data
|
||||
|
||||
@callback
|
||||
def async_set_message(self, message: str) -> None:
|
||||
"""Set message to be generated."""
|
||||
self.test_set_message = message
|
||||
|
||||
async def async_stream_result(self):
|
||||
"""Stream the result."""
|
||||
yield self._mock_data
|
||||
|
@ -27,6 +27,8 @@ from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.components.tts.common import MockResultStream
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
_MEDIA_ID = "12345"
|
||||
|
||||
@ -879,6 +881,7 @@ async def test_announce(
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
@ -926,6 +929,7 @@ async def test_voip_id_is_ip_address(
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
@ -978,6 +982,7 @@ async def test_announce_timeout(
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
@ -1018,6 +1023,7 @@ async def test_start_conversation(
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
tts_token="test-token",
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
@ -1162,8 +1168,16 @@ async def test_start_conversation_user_doesnt_pick_up(
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="test media id",
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="test tts",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
):
|
||||
satellite.transport = Mock()
|
||||
|
Loading…
x
Reference in New Issue
Block a user