Use first media player announcement format for TTS (#125237)

* Use ANNOUNCEMENT format from first media player for tts

* Fix formatting

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-09-06 10:57:09 -05:00 committed by GitHub
parent 20639b0f02
commit ee59303d3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 123 additions and 4 deletions

View File

@ -72,6 +72,7 @@ class AssistSatelliteEntity(entity.Entity):
_run_has_tts: bool = False _run_has_tts: bool = False
_is_announcing = False _is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None _wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD __assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
@ -91,6 +92,11 @@ class AssistSatelliteEntity(entity.Entity):
"""Entity ID of the VAD sensitivity to use for the next conversation.""" """Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id return self._attr_vad_sensitivity_entity_id
@property
def tts_options(self) -> dict[str, Any] | None:
"""Options passed for text-to-speech."""
return self._attr_tts_options
async def async_intercept_wake_word(self) -> str | None: async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite. """Intercept the next wake word from the satellite.
@ -137,6 +143,9 @@ class AssistSatelliteEntity(entity.Entity):
if pipeline.tts_voice is not None: if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
if self.tts_options is not None:
tts_options.update(self.tts_options)
media_id = tts_generate_media_source_id( media_id = tts_generate_media_source_id(
self.hass, self.hass,
message, message,
@ -253,7 +262,7 @@ class AssistSatelliteEntity(entity.Entity):
pipeline_id=self._resolve_pipeline(), pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id, conversation_id=self._conversation_id,
device_id=device_id, device_id=device_id,
tts_audio_output="wav", tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase, wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings( audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity() silence_seconds=self._resolve_vad_sensitivity()

View File

@ -6,12 +6,14 @@ import asyncio
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from functools import partial from functools import partial
import io import io
from itertools import chain
import logging import logging
import socket import socket
from typing import Any, cast from typing import Any, cast
import wave import wave
from aioesphomeapi import ( from aioesphomeapi import (
MediaPlayerFormatPurpose,
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag, VoiceAssistantCommandFlag,
VoiceAssistantEventType, VoiceAssistantEventType,
@ -288,6 +290,18 @@ class EsphomeAssistSatellite(
end_stage = PipelineStage.TTS end_stage = PipelineStage.TTS
if feature_flags & VoiceAssistantFeature.SPEAKER:
# Stream WAV audio
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
else:
# ANNOUNCEMENT format from media player
self._update_tts_format()
# Run the pipeline # Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage) _LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self.entry_data.async_set_assist_pipeline_state(True) self.entry_data.async_set_assist_pipeline_state(True)
@ -340,6 +354,19 @@ class EsphomeAssistSatellite(
timer_info.is_active, timer_info.is_active,
) )
def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):
# Find first announcement format
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
break
async def _stream_tts_audio( async def _stream_tts_audio(
self, self,
media_id: str, media_id: str,

View File

@ -31,6 +31,7 @@ from aioesphomeapi import (
LightInfo, LightInfo,
LockInfo, LockInfo,
MediaPlayerInfo, MediaPlayerInfo,
MediaPlayerSupportedFormat,
NumberInfo, NumberInfo,
SelectInfo, SelectInfo,
SensorInfo, SensorInfo,
@ -148,6 +149,9 @@ class RuntimeEntryData:
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]] tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
] = field(default_factory=dict) ] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict) original_options: dict[str, Any] = field(default_factory=dict)
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
default_factory=lambda: defaultdict(list)
)
@property @property
def name(self) -> str: def name(self) -> str:

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from functools import partial from functools import partial
from typing import Any from typing import Any, cast
from aioesphomeapi import ( from aioesphomeapi import (
EntityInfo, EntityInfo,
@ -66,6 +66,9 @@ class EsphomeMediaPlayer(
if self._static_info.supports_pause: if self._static_info.supports_pause:
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
self._attr_supported_features = flags self._attr_supported_features = flags
self._entry_data.media_player_formats[self.entity_id] = cast(
MediaPlayerInfo, static_info
).supported_formats
@property @property
@esphome_state_property @esphome_state_property
@ -103,6 +106,11 @@ class EsphomeMediaPlayer(
self._key, media_url=media_id, announcement=announcement self._key, media_url=media_id, announcement=announcement
) )
async def async_will_remove_from_hass(self) -> None:
"""Handle entity being removed."""
await super().async_will_remove_from_hass()
self._entry_data.media_player_formats.pop(self.entity_id, None)
async def async_browse_media( async def async_browse_media(
self, self,
media_content_type: MediaType | str | None = None, media_content_type: MediaType | str | None = None,

View File

@ -61,7 +61,7 @@ async def test_entity_state(
assert kwargs["stt_stream"] is audio_stream assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav" assert kwargs["tts_audio_output"] is None
assert kwargs["wake_word_phrase"] is None assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings( assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT) silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)

View File

@ -11,6 +11,9 @@ from aioesphomeapi import (
APIClient, APIClient,
EntityInfo, EntityInfo,
EntityState, EntityState,
MediaPlayerFormatPurpose,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
UserService, UserService,
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag, VoiceAssistantCommandFlag,
@ -20,7 +23,7 @@ from aioesphomeapi import (
) )
import pytest import pytest
from homeassistant.components import assist_satellite from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite.entity import ( from homeassistant.components.assist_satellite.entity import (
AssistSatelliteEntity, AssistSatelliteEntity,
@ -820,3 +823,71 @@ async def test_streaming_tts_errors(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{}, {},
) )
async def test_tts_format_from_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that the text-to-speech format is pulled from the first media player."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.DEFAULT,
),
# This is the format that should be used for tts
MediaPlayerSupportedFormat(
format="mp3",
sample_rate=22050,
num_channels=1,
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream:
await satellite.handle_pipeline_start(
conversation_id="",
flags=0,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase=None,
)
mock_pipeline_from_audio_stream.assert_called_once()
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
# Should be ANNOUNCEMENT format from media player
assert kwargs.get("tts_audio_output") == {
tts.ATTR_PREFERRED_FORMAT: "mp3",
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}