Use first media player announcement format for TTS (#125237)

* Use ANNOUNCEMENT format from first media player for tts

* Fix formatting

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-09-06 10:57:09 -05:00 committed by GitHub
parent 20639b0f02
commit ee59303d3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 123 additions and 4 deletions

View File

@ -72,6 +72,7 @@ class AssistSatelliteEntity(entity.Entity):
_run_has_tts: bool = False
_is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
@ -91,6 +92,11 @@ class AssistSatelliteEntity(entity.Entity):
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id
@property
def tts_options(self) -> dict[str, Any] | None:
"""Options passed for text-to-speech."""
return self._attr_tts_options
async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
@ -137,6 +143,9 @@ class AssistSatelliteEntity(entity.Entity):
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
if self.tts_options is not None:
tts_options.update(self.tts_options)
media_id = tts_generate_media_source_id(
self.hass,
message,
@ -253,7 +262,7 @@ class AssistSatelliteEntity(entity.Entity):
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()

View File

@ -6,12 +6,14 @@ import asyncio
from collections.abc import AsyncIterable
from functools import partial
import io
from itertools import chain
import logging
import socket
from typing import Any, cast
import wave
from aioesphomeapi import (
MediaPlayerFormatPurpose,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
@ -288,6 +290,18 @@ class EsphomeAssistSatellite(
end_stage = PipelineStage.TTS
if feature_flags & VoiceAssistantFeature.SPEAKER:
# Stream WAV audio
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
else:
# ANNOUNCEMENT format from media player
self._update_tts_format()
# Run the pipeline
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
self.entry_data.async_set_assist_pipeline_state(True)
@ -340,6 +354,19 @@ class EsphomeAssistSatellite(
timer_info.is_active,
)
def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):
# Find first announcement format
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
self._attr_tts_options = {
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
break
async def _stream_tts_audio(
self,
media_id: str,

View File

@ -31,6 +31,7 @@ from aioesphomeapi import (
LightInfo,
LockInfo,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
NumberInfo,
SelectInfo,
SensorInfo,
@ -148,6 +149,9 @@ class RuntimeEntryData:
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
] = field(default_factory=dict)
original_options: dict[str, Any] = field(default_factory=dict)
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
default_factory=lambda: defaultdict(list)
)
@property
def name(self) -> str:

View File

@ -3,7 +3,7 @@
from __future__ import annotations
from functools import partial
from typing import Any
from typing import Any, cast
from aioesphomeapi import (
EntityInfo,
@ -66,6 +66,9 @@ class EsphomeMediaPlayer(
if self._static_info.supports_pause:
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
self._attr_supported_features = flags
self._entry_data.media_player_formats[self.entity_id] = cast(
MediaPlayerInfo, static_info
).supported_formats
@property
@esphome_state_property
@ -103,6 +106,11 @@ class EsphomeMediaPlayer(
self._key, media_url=media_id, announcement=announcement
)
async def async_will_remove_from_hass(self) -> None:
"""Handle entity being removed."""
await super().async_will_remove_from_hass()
self._entry_data.media_player_formats.pop(self.entity_id, None)
async def async_browse_media(
self,
media_content_type: MediaType | str | None = None,

View File

@ -61,7 +61,7 @@ async def test_entity_state(
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav"
assert kwargs["tts_audio_output"] is None
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)

View File

@ -11,6 +11,9 @@ from aioesphomeapi import (
APIClient,
EntityInfo,
EntityState,
MediaPlayerFormatPurpose,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
UserService,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
@ -20,7 +23,7 @@ from aioesphomeapi import (
)
import pytest
from homeassistant.components import assist_satellite
from homeassistant.components import assist_satellite, tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite.entity import (
AssistSatelliteEntity,
@ -820,3 +823,71 @@ async def test_streaming_tts_errors(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)
async def test_tts_format_from_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that the text-to-speech format is pulled from the first media player."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.DEFAULT,
),
# This is the format that should be used for tts
MediaPlayerSupportedFormat(
format="mp3",
sample_rate=22050,
num_channels=1,
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream:
await satellite.handle_pipeline_start(
conversation_id="",
flags=0,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase=None,
)
mock_pipeline_from_audio_stream.assert_called_once()
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
# Should be ANNOUNCEMENT format from media player
assert kwargs.get("tts_audio_output") == {
tts.ATTR_PREFERRED_FORMAT: "mp3",
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}