mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Add start conversation support to ESPHome (#141387)
This commit is contained in:
parent
7319637bd5
commit
ae18fa2e30
@ -253,6 +253,11 @@ class EsphomeAssistSatellite(
|
||||
# Will use media player for TTS/announcements
|
||||
self._update_tts_format()
|
||||
|
||||
if feature_flags & VoiceAssistantFeature.START_CONVERSATION:
|
||||
self._attr_supported_features |= (
|
||||
assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||
)
|
||||
|
||||
# Update wake word select when config is updated
|
||||
self.async_on_remove(
|
||||
self.entry_data.async_register_assist_satellite_set_wake_word_callback(
|
||||
@ -342,6 +347,23 @@ class EsphomeAssistSatellite(
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
await self._do_announce(announcement, run_pipeline_after=False)
|
||||
|
||||
async def async_start_conversation(
|
||||
self, start_announcement: assist_satellite.AssistSatelliteAnnouncement
|
||||
) -> None:
|
||||
"""Start a conversation from the satellite."""
|
||||
await self._do_announce(start_announcement, run_pipeline_after=True)
|
||||
|
||||
async def _do_announce(
|
||||
self,
|
||||
announcement: assist_satellite.AssistSatelliteAnnouncement,
|
||||
run_pipeline_after: bool,
|
||||
) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Optionally run a voice pipeline after the announcement has finished.
|
||||
"""
|
||||
_LOGGER.debug(
|
||||
"Waiting for announcement to finished (message=%s, media_id=%s)",
|
||||
announcement.message,
|
||||
@ -374,7 +396,10 @@ class EsphomeAssistSatellite(
|
||||
media_id = async_process_play_media_url(self.hass, proxy_url)
|
||||
|
||||
await self.cli.send_voice_assistant_announcement_await_response(
|
||||
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
|
||||
media_id,
|
||||
_ANNOUNCEMENT_TIMEOUT_SEC,
|
||||
announcement.message,
|
||||
start_conversation=run_pipeline_after,
|
||||
)
|
||||
|
||||
async def handle_pipeline_start(
|
||||
|
@ -25,7 +25,12 @@ from aioesphomeapi import (
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import assist_satellite, conversation, tts
|
||||
from homeassistant.components import (
|
||||
assist_pipeline,
|
||||
assist_satellite,
|
||||
conversation,
|
||||
tts,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteConfiguration,
|
||||
@ -1160,7 +1165,7 @@ async def test_announce_supported_features(
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that the announce supported feature is set by flags."""
|
||||
"""Test that the announce supported feature is not set by default."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
@ -1207,11 +1212,12 @@ async def test_announce_message(
|
||||
done = asyncio.Event()
|
||||
|
||||
async def send_voice_assistant_announcement_await_response(
|
||||
media_id: str, timeout: float, text: str
|
||||
media_id: str, timeout: float, text: str, start_conversation: bool
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
|
||||
assert text == "test-text"
|
||||
assert not start_conversation
|
||||
|
||||
done.set()
|
||||
|
||||
@ -1296,10 +1302,11 @@ async def test_announce_media_id(
|
||||
done = asyncio.Event()
|
||||
|
||||
async def send_voice_assistant_announcement_await_response(
|
||||
media_id: str, timeout: float, text: str
|
||||
media_id: str, timeout: float, text: str, start_conversation: bool
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "https://www.home-assistant.io/proxied.flac"
|
||||
assert not start_conversation
|
||||
|
||||
done.set()
|
||||
|
||||
@ -1338,6 +1345,234 @@ async def test_announce_media_id(
|
||||
)
|
||||
|
||||
|
||||
async def test_start_conversation_supported_features(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that the start conversation supported feature is not set by default."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
assert not (
|
||||
satellite.supported_features & AssistSatelliteEntityFeature.START_CONVERSATION
|
||||
)
|
||||
|
||||
|
||||
async def test_start_conversation_message(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test start conversation with message."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.SPEAKER
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
| VoiceAssistantFeature.START_CONVERSATION
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
pipeline = assist_pipeline.Pipeline(
|
||||
conversation_engine="test engine",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test pipeline",
|
||||
stt_engine="test stt",
|
||||
stt_language="en",
|
||||
tts_engine="test tts",
|
||||
tts_language="en",
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def send_voice_assistant_announcement_await_response(
|
||||
media_id: str, timeout: float, text: str, start_conversation: bool
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
|
||||
assert text == "test-text"
|
||||
assert start_conversation
|
||||
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud_tts",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
patch.object(
|
||||
mock_client,
|
||||
"send_voice_assistant_announcement_await_response",
|
||||
new=send_voice_assistant_announcement_await_response,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
|
||||
return_value=pipeline,
|
||||
),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
"start_conversation",
|
||||
{"entity_id": satellite.entity_id, "start_message": "test-text"},
|
||||
blocking=True,
|
||||
)
|
||||
await done.wait()
|
||||
assert satellite.state == AssistSatelliteState.IDLE
|
||||
|
||||
|
||||
async def test_start_conversation_media_id(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
device_registry: dr.DeviceRegistry,
|
||||
) -> None:
|
||||
"""Test start conversation with media id."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[
|
||||
MediaPlayerInfo(
|
||||
object_id="mymedia_player",
|
||||
key=1,
|
||||
name="my media_player",
|
||||
unique_id="my_media_player",
|
||||
supports_pause=True,
|
||||
supported_formats=[
|
||||
MediaPlayerSupportedFormat(
|
||||
format="flac",
|
||||
sample_rate=48000,
|
||||
num_channels=2,
|
||||
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
|
||||
sample_bytes=2,
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.SPEAKER
|
||||
| VoiceAssistantFeature.API_AUDIO
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
| VoiceAssistantFeature.START_CONVERSATION
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
dev = device_registry.async_get_device(
|
||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||
)
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
pipeline = assist_pipeline.Pipeline(
|
||||
conversation_engine="test engine",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test pipeline",
|
||||
stt_engine="test stt",
|
||||
stt_language="en",
|
||||
tts_engine="test tts",
|
||||
tts_language="en",
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def send_voice_assistant_announcement_await_response(
|
||||
media_id: str, timeout: float, text: str, start_conversation: bool
|
||||
):
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
assert media_id == "https://www.home-assistant.io/proxied.flac"
|
||||
assert start_conversation
|
||||
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
mock_client,
|
||||
"send_voice_assistant_announcement_await_response",
|
||||
new=send_voice_assistant_announcement_await_response,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.esphome.assist_satellite.async_create_proxy_url",
|
||||
return_value="https://www.home-assistant.io/proxied.flac",
|
||||
) as mock_async_create_proxy_url,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
|
||||
return_value=pipeline,
|
||||
),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
"start_conversation",
|
||||
{
|
||||
"entity_id": satellite.entity_id,
|
||||
"start_media_id": "https://www.home-assistant.io/resolved.mp3",
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await done.wait()
|
||||
assert satellite.state == AssistSatelliteState.IDLE
|
||||
|
||||
mock_async_create_proxy_url.assert_called_once_with(
|
||||
hass,
|
||||
dev.id,
|
||||
"https://www.home-assistant.io/resolved.mp3",
|
||||
media_format="flac",
|
||||
rate=48000,
|
||||
channels=2,
|
||||
width=2,
|
||||
)
|
||||
|
||||
|
||||
async def test_satellite_unloaded_on_disconnect(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
|
Loading…
x
Reference in New Issue
Block a user