mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 02:37:08 +00:00
Assist satellite to use TTS tokens for announcements (#140336)
* Migrate Assist Satellite to use token * Fix tests * Fix tests
This commit is contained in:
parent
e710d3699c
commit
f32bb1a318
@ -23,9 +23,6 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
vad,
|
vad,
|
||||||
)
|
)
|
||||||
from homeassistant.components.media_player import async_process_play_media_url
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
from homeassistant.components.tts import (
|
|
||||||
generate_media_source_id as tts_generate_media_source_id,
|
|
||||||
)
|
|
||||||
from homeassistant.core import Context, callback
|
from homeassistant.core import Context, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session, entity
|
from homeassistant.helpers import chat_session, entity
|
||||||
@ -98,6 +95,9 @@ class AssistSatelliteAnnouncement:
|
|||||||
original_media_id: str
|
original_media_id: str
|
||||||
"""The raw media ID before processing."""
|
"""The raw media ID before processing."""
|
||||||
|
|
||||||
|
tts_token: str | None
|
||||||
|
"""The TTS token of the media."""
|
||||||
|
|
||||||
media_id_source: Literal["url", "media_id", "tts"]
|
media_id_source: Literal["url", "media_id", "tts"]
|
||||||
"""Source of the media ID."""
|
"""Source of the media ID."""
|
||||||
|
|
||||||
@ -474,6 +474,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
) -> AssistSatelliteAnnouncement:
|
) -> AssistSatelliteAnnouncement:
|
||||||
"""Resolve the media ID."""
|
"""Resolve the media ID."""
|
||||||
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||||
|
tts_token: str | None = None
|
||||||
|
|
||||||
if media_id:
|
if media_id:
|
||||||
original_media_id = media_id
|
original_media_id = media_id
|
||||||
@ -484,6 +485,10 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
pipeline_id = self._resolve_pipeline()
|
pipeline_id = self._resolve_pipeline()
|
||||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||||
|
|
||||||
|
engine = tts.async_resolve_engine(self.hass, pipeline.tts_engine)
|
||||||
|
if engine is None:
|
||||||
|
raise HomeAssistantError(f"TTS engine {pipeline.tts_engine} not found")
|
||||||
|
|
||||||
tts_options: dict[str, Any] = {}
|
tts_options: dict[str, Any] = {}
|
||||||
if pipeline.tts_voice is not None:
|
if pipeline.tts_voice is not None:
|
||||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||||
@ -491,14 +496,23 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
if self.tts_options is not None:
|
if self.tts_options is not None:
|
||||||
tts_options.update(self.tts_options)
|
tts_options.update(self.tts_options)
|
||||||
|
|
||||||
media_id = tts_generate_media_source_id(
|
stream = tts.async_create_stream(
|
||||||
self.hass,
|
self.hass,
|
||||||
message,
|
engine=engine,
|
||||||
engine=pipeline.tts_engine,
|
language=pipeline.tts_language,
|
||||||
|
options=tts_options,
|
||||||
|
)
|
||||||
|
stream.async_set_message(message)
|
||||||
|
|
||||||
|
tts_token = stream.token
|
||||||
|
media_id = stream.url
|
||||||
|
original_media_id = tts.generate_media_source_id(
|
||||||
|
self.hass,
|
||||||
|
message,
|
||||||
|
engine=engine,
|
||||||
language=pipeline.tts_language,
|
language=pipeline.tts_language,
|
||||||
options=tts_options,
|
options=tts_options,
|
||||||
)
|
)
|
||||||
original_media_id = media_id
|
|
||||||
|
|
||||||
if media_source.is_media_source_id(media_id):
|
if media_source.is_media_source_id(media_id):
|
||||||
if not media_id_source:
|
if not media_id_source:
|
||||||
@ -517,5 +531,9 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
media_id = async_process_play_media_url(self.hass, media_id)
|
media_id = async_process_play_media_url(self.hass, media_id)
|
||||||
|
|
||||||
return AssistSatelliteAnnouncement(
|
return AssistSatelliteAnnouncement(
|
||||||
message, media_id, original_media_id, media_id_source
|
message=message,
|
||||||
|
media_id=media_id,
|
||||||
|
original_media_id=original_media_id,
|
||||||
|
tts_token=tts_token,
|
||||||
|
media_id_source=media_id_source,
|
||||||
)
|
)
|
||||||
|
@ -31,6 +31,8 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from . import ENTITY_ID
|
from . import ENTITY_ID
|
||||||
from .conftest import MockAssistSatellite
|
from .conftest import MockAssistSatellite
|
||||||
|
|
||||||
|
from tests.components.tts.common import MockResultStream
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_chat_session_conversation_id() -> Generator[Mock]:
|
def mock_chat_session_conversation_id() -> Generator[Mock]:
|
||||||
@ -186,8 +188,9 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
{"message": "Hello"},
|
{"message": "Hello"},
|
||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
||||||
original_media_id="media-source://bla",
|
original_media_id="media-source://bla",
|
||||||
|
tts_token="test-token",
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -200,6 +203,7 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||||
original_media_id="media-source://given",
|
original_media_id="media-source://given",
|
||||||
|
tts_token=None,
|
||||||
media_id_source="media_id",
|
media_id_source="media_id",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -209,6 +213,7 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
message="",
|
message="",
|
||||||
media_id="http://example.com/bla.mp3",
|
media_id="http://example.com/bla.mp3",
|
||||||
original_media_id="http://example.com/bla.mp3",
|
original_media_id="http://example.com/bla.mp3",
|
||||||
|
tts_token=None,
|
||||||
media_id_source="url",
|
media_id_source="url",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -243,9 +248,17 @@ async def test_announce(
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
"homeassistant.components.tts.generate_media_source_id",
|
||||||
new=tts_generate_media_source_id,
|
new=tts_generate_media_source_id,
|
||||||
),
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
|
return_value="tts.cloud",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_create_stream",
|
||||||
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
return_value=PlayMedia(
|
return_value=PlayMedia(
|
||||||
@ -500,7 +513,8 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
"Better system prompt",
|
"Better system prompt",
|
||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
|
||||||
|
tts_token="test-token",
|
||||||
original_media_id="media-source://generated",
|
original_media_id="media-source://generated",
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
),
|
),
|
||||||
@ -517,6 +531,7 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="Hello",
|
message="Hello",
|
||||||
media_id="https://www.home-assistant.io/resolved.mp3",
|
media_id="https://www.home-assistant.io/resolved.mp3",
|
||||||
|
tts_token=None,
|
||||||
original_media_id="media-source://given",
|
original_media_id="media-source://given",
|
||||||
media_id_source="media_id",
|
media_id_source="media_id",
|
||||||
),
|
),
|
||||||
@ -530,6 +545,7 @@ async def test_vad_sensitivity_entity_not_found(
|
|||||||
AssistSatelliteAnnouncement(
|
AssistSatelliteAnnouncement(
|
||||||
message="",
|
message="",
|
||||||
media_id="http://example.com/given.mp3",
|
media_id="http://example.com/given.mp3",
|
||||||
|
tts_token=None,
|
||||||
original_media_id="http://example.com/given.mp3",
|
original_media_id="http://example.com/given.mp3",
|
||||||
media_id_source="url",
|
media_id_source="url",
|
||||||
),
|
),
|
||||||
@ -554,9 +570,17 @@ async def test_start_conversation(
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
"homeassistant.components.tts.generate_media_source_id",
|
||||||
return_value="media-source://generated",
|
return_value="media-source://generated",
|
||||||
),
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
|
return_value="tts.cloud",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_create_stream",
|
||||||
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.media_source.async_resolve_media",
|
||||||
return_value=PlayMedia(
|
return_value=PlayMedia(
|
||||||
|
@ -4,28 +4,28 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.media_source import PlayMedia
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .conftest import TEST_DOMAIN, MockAssistSatellite
|
from .conftest import TEST_DOMAIN, MockAssistSatellite
|
||||||
|
|
||||||
|
from tests.components.tts.common import MockResultStream
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_tts():
|
async def mock_tts(hass: HomeAssistant):
|
||||||
"""Mock TTS service."""
|
"""Mock TTS service."""
|
||||||
|
assert await async_setup_component(hass, "tts", {})
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
"homeassistant.components.tts.generate_media_source_id",
|
||||||
return_value="media-source://bla",
|
return_value="media-source://bla",
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.tts.async_create_stream",
|
||||||
return_value=PlayMedia(
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
url="https://www.home-assistant.io/resolved.mp3",
|
|
||||||
mime_type="audio/mp3",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
@ -41,9 +41,13 @@ async def test_broadcast_intent(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can invoke a broadcast intent."""
|
"""Test we can invoke a broadcast intent."""
|
||||||
|
|
||||||
result = await intent.async_handle(
|
with patch(
|
||||||
hass, "test", intent.INTENT_BROADCAST, {"message": {"value": "Hello"}}
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
)
|
return_value="tts.cloud",
|
||||||
|
):
|
||||||
|
result = await intent.async_handle(
|
||||||
|
hass, "test", intent.INTENT_BROADCAST, {"message": {"value": "Hello"}}
|
||||||
|
)
|
||||||
|
|
||||||
assert result.as_dict() == {
|
assert result.as_dict() == {
|
||||||
"card": {},
|
"card": {},
|
||||||
@ -71,13 +75,17 @@ async def test_broadcast_intent(
|
|||||||
assert len(entity2.announcements) == 1
|
assert len(entity2.announcements) == 1
|
||||||
assert len(entity_no_features.announcements) == 0
|
assert len(entity_no_features.announcements) == 0
|
||||||
|
|
||||||
result = await intent.async_handle(
|
with patch(
|
||||||
hass,
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
"test",
|
return_value="tts.cloud",
|
||||||
intent.INTENT_BROADCAST,
|
):
|
||||||
{"message": {"value": "Hello"}},
|
result = await intent.async_handle(
|
||||||
device_id=entity.device_entry.id,
|
hass,
|
||||||
)
|
"test",
|
||||||
|
intent.INTENT_BROADCAST,
|
||||||
|
{"message": {"value": "Hello"}},
|
||||||
|
device_id=entity.device_entry.id,
|
||||||
|
)
|
||||||
# Broadcast doesn't targets device that triggered it.
|
# Broadcast doesn't targets device that triggered it.
|
||||||
assert result.as_dict() == {
|
assert result.as_dict() == {
|
||||||
"card": {},
|
"card": {},
|
||||||
|
@ -41,7 +41,6 @@ from homeassistant.components.esphome.assist_satellite import (
|
|||||||
EsphomeAssistSatellite,
|
EsphomeAssistSatellite,
|
||||||
VoiceAssistantUDPServer,
|
VoiceAssistantUDPServer,
|
||||||
)
|
)
|
||||||
from homeassistant.components.media_source import PlayMedia
|
|
||||||
from homeassistant.components.select import (
|
from homeassistant.components.select import (
|
||||||
DOMAIN as SELECT_DOMAIN,
|
DOMAIN as SELECT_DOMAIN,
|
||||||
SERVICE_SELECT_OPTION,
|
SERVICE_SELECT_OPTION,
|
||||||
@ -57,6 +56,8 @@ from homeassistant.helpers.entity_component import EntityComponent
|
|||||||
|
|
||||||
from .conftest import MockESPHomeDevice
|
from .conftest import MockESPHomeDevice
|
||||||
|
|
||||||
|
from tests.components.tts.common import MockResultStream
|
||||||
|
|
||||||
|
|
||||||
def get_satellite_entity(
|
def get_satellite_entity(
|
||||||
hass: HomeAssistant, mac_address: str
|
hass: HomeAssistant, mac_address: str
|
||||||
@ -1209,22 +1210,23 @@ async def test_announce_message(
|
|||||||
media_id: str, timeout: float, text: str
|
media_id: str, timeout: float, text: str
|
||||||
):
|
):
|
||||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
|
||||||
assert text == "test-text"
|
assert text == "test-text"
|
||||||
|
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
"homeassistant.components.tts.generate_media_source_id",
|
||||||
return_value="media-source://bla",
|
return_value="media-source://bla",
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.media_source.async_resolve_media",
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
return_value=PlayMedia(
|
return_value="tts.cloud_tts",
|
||||||
url="https://www.home-assistant.io/resolved.mp3",
|
),
|
||||||
mime_type="audio/mp3",
|
patch(
|
||||||
),
|
"homeassistant.components.tts.async_create_stream",
|
||||||
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
),
|
),
|
||||||
patch.object(
|
patch.object(
|
||||||
mock_client,
|
mock_client,
|
||||||
|
@ -270,6 +270,8 @@ async def mock_config_entry_setup(
|
|||||||
class MockResultStream(ResultStream):
|
class MockResultStream(ResultStream):
|
||||||
"""Mock result stream."""
|
"""Mock result stream."""
|
||||||
|
|
||||||
|
test_set_message: str | None = None
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant, extension: str, data: bytes) -> None:
|
def __init__(self, hass: HomeAssistant, extension: str, data: bytes) -> None:
|
||||||
"""Initialize the result stream."""
|
"""Initialize the result stream."""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -285,6 +287,11 @@ class MockResultStream(ResultStream):
|
|||||||
hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self
|
hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self
|
||||||
self._mock_data = data
|
self._mock_data = data
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_set_message(self, message: str) -> None:
|
||||||
|
"""Set message to be generated."""
|
||||||
|
self.test_set_message = message
|
||||||
|
|
||||||
async def async_stream_result(self):
|
async def async_stream_result(self):
|
||||||
"""Stream the result."""
|
"""Stream the result."""
|
||||||
yield self._mock_data
|
yield self._mock_data
|
||||||
|
@ -27,6 +27,8 @@ from homeassistant.helpers import entity_registry as er
|
|||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.components.tts.common import MockResultStream
|
||||||
|
|
||||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||||
_MEDIA_ID = "12345"
|
_MEDIA_ID = "12345"
|
||||||
|
|
||||||
@ -879,6 +881,7 @@ async def test_announce(
|
|||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
|
tts_token="test-token",
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -926,6 +929,7 @@ async def test_voip_id_is_ip_address(
|
|||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
|
tts_token="test-token",
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -978,6 +982,7 @@ async def test_announce_timeout(
|
|||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
|
tts_token="test-token",
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -1018,6 +1023,7 @@ async def test_start_conversation(
|
|||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
|
tts_token="test-token",
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -1162,8 +1168,16 @@ async def test_start_conversation_user_doesnt_pick_up(
|
|||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
"homeassistant.components.tts.generate_media_source_id",
|
||||||
return_value="test media id",
|
return_value="media-source://bla",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_resolve_engine",
|
||||||
|
return_value="test tts",
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_create_stream",
|
||||||
|
return_value=MockResultStream(hass, "wav", b""),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.transport = Mock()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user