mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Use first media player announcement format for TTS (#125237)
* Use ANNOUNCEMENT format from first media player for tts * Fix formatting --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
20639b0f02
commit
ee59303d3c
@ -72,6 +72,7 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
_run_has_tts: bool = False
|
||||
_is_announcing = False
|
||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||
_attr_tts_options: dict[str, Any] | None = None
|
||||
|
||||
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
@ -91,6 +92,11 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
return self._attr_vad_sensitivity_entity_id
|
||||
|
||||
@property
|
||||
def tts_options(self) -> dict[str, Any] | None:
|
||||
"""Options passed for text-to-speech."""
|
||||
return self._attr_tts_options
|
||||
|
||||
async def async_intercept_wake_word(self) -> str | None:
|
||||
"""Intercept the next wake word from the satellite.
|
||||
|
||||
@ -137,6 +143,9 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
if self.tts_options is not None:
|
||||
tts_options.update(self.tts_options)
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
message,
|
||||
@ -253,7 +262,7 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output="wav",
|
||||
tts_audio_output=self.tts_options,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
|
@ -6,12 +6,14 @@ import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from functools import partial
|
||||
import io
|
||||
from itertools import chain
|
||||
import logging
|
||||
import socket
|
||||
from typing import Any, cast
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
MediaPlayerFormatPurpose,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
VoiceAssistantEventType,
|
||||
@ -288,6 +290,18 @@ class EsphomeAssistSatellite(
|
||||
|
||||
end_stage = PipelineStage.TTS
|
||||
|
||||
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||
# Stream WAV audio
|
||||
self._attr_tts_options = {
|
||||
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
}
|
||||
else:
|
||||
# ANNOUNCEMENT format from media player
|
||||
self._update_tts_format()
|
||||
|
||||
# Run the pipeline
|
||||
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||
self.entry_data.async_set_assist_pipeline_state(True)
|
||||
@ -340,6 +354,19 @@ class EsphomeAssistSatellite(
|
||||
timer_info.is_active,
|
||||
)
|
||||
|
||||
def _update_tts_format(self) -> None:
|
||||
"""Update the TTS format from the first media player."""
|
||||
for supported_format in chain(*self.entry_data.media_player_formats.values()):
|
||||
# Find first announcement format
|
||||
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
|
||||
self._attr_tts_options = {
|
||||
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
}
|
||||
break
|
||||
|
||||
async def _stream_tts_audio(
|
||||
self,
|
||||
media_id: str,
|
||||
|
@ -31,6 +31,7 @@ from aioesphomeapi import (
|
||||
LightInfo,
|
||||
LockInfo,
|
||||
MediaPlayerInfo,
|
||||
MediaPlayerSupportedFormat,
|
||||
NumberInfo,
|
||||
SelectInfo,
|
||||
SensorInfo,
|
||||
@ -148,6 +149,9 @@ class RuntimeEntryData:
|
||||
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
|
||||
] = field(default_factory=dict)
|
||||
original_options: dict[str, Any] = field(default_factory=dict)
|
||||
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from aioesphomeapi import (
|
||||
EntityInfo,
|
||||
@ -66,6 +66,9 @@ class EsphomeMediaPlayer(
|
||||
if self._static_info.supports_pause:
|
||||
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
|
||||
self._attr_supported_features = flags
|
||||
self._entry_data.media_player_formats[self.entity_id] = cast(
|
||||
MediaPlayerInfo, static_info
|
||||
).supported_formats
|
||||
|
||||
@property
|
||||
@esphome_state_property
|
||||
@ -103,6 +106,11 @@ class EsphomeMediaPlayer(
|
||||
self._key, media_url=media_id, announcement=announcement
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Handle entity being removed."""
|
||||
await super().async_will_remove_from_hass()
|
||||
self._entry_data.media_player_formats.pop(self.entity_id, None)
|
||||
|
||||
async def async_browse_media(
|
||||
self,
|
||||
media_content_type: MediaType | str | None = None,
|
||||
|
@ -61,7 +61,7 @@ async def test_entity_state(
|
||||
assert kwargs["stt_stream"] is audio_stream
|
||||
assert kwargs["pipeline_id"] is None
|
||||
assert kwargs["device_id"] is None
|
||||
assert kwargs["tts_audio_output"] == "wav"
|
||||
assert kwargs["tts_audio_output"] is None
|
||||
assert kwargs["wake_word_phrase"] is None
|
||||
assert kwargs["audio_settings"] == AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||
|
@ -11,6 +11,9 @@ from aioesphomeapi import (
|
||||
APIClient,
|
||||
EntityInfo,
|
||||
EntityState,
|
||||
MediaPlayerFormatPurpose,
|
||||
MediaPlayerInfo,
|
||||
MediaPlayerSupportedFormat,
|
||||
UserService,
|
||||
VoiceAssistantAudioSettings,
|
||||
VoiceAssistantCommandFlag,
|
||||
@ -20,7 +23,7 @@ from aioesphomeapi import (
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import assist_satellite
|
||||
from homeassistant.components import assist_satellite, tts
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||
from homeassistant.components.assist_satellite.entity import (
|
||||
AssistSatelliteEntity,
|
||||
@ -820,3 +823,71 @@ async def test_streaming_tts_errors(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
async def test_tts_format_from_media_player(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test that the text-to-speech format is pulled from the first media player."""
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[
|
||||
MediaPlayerInfo(
|
||||
object_id="mymedia_player",
|
||||
key=1,
|
||||
name="my media_player",
|
||||
unique_id="my_media_player",
|
||||
supports_pause=True,
|
||||
supported_formats=[
|
||||
MediaPlayerSupportedFormat(
|
||||
format="flac",
|
||||
sample_rate=48000,
|
||||
num_channels=2,
|
||||
purpose=MediaPlayerFormatPurpose.DEFAULT,
|
||||
),
|
||||
# This is the format that should be used for tts
|
||||
MediaPlayerSupportedFormat(
|
||||
format="mp3",
|
||||
sample_rate=22050,
|
||||
num_channels=1,
|
||||
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
) as mock_pipeline_from_audio_stream:
|
||||
await satellite.handle_pipeline_start(
|
||||
conversation_id="",
|
||||
flags=0,
|
||||
audio_settings=VoiceAssistantAudioSettings(),
|
||||
wake_word_phrase=None,
|
||||
)
|
||||
|
||||
mock_pipeline_from_audio_stream.assert_called_once()
|
||||
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
|
||||
|
||||
# Should be ANNOUNCEMENT format from media player
|
||||
assert kwargs.get("tts_audio_output") == {
|
||||
tts.ATTR_PREFERRED_FORMAT: "mp3",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user