mirror of
https://github.com/home-assistant/core.git
synced 2025-07-27 07:07:28 +00:00
Use first media player announcement format for TTS (#125237)
* Use ANNOUNCEMENT format from first media player for tts * Fix formatting --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
20639b0f02
commit
ee59303d3c
@ -72,6 +72,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
_run_has_tts: bool = False
|
_run_has_tts: bool = False
|
||||||
_is_announcing = False
|
_is_announcing = False
|
||||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||||
|
_attr_tts_options: dict[str, Any] | None = None
|
||||||
|
|
||||||
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
@ -91,6 +92,11 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
||||||
return self._attr_vad_sensitivity_entity_id
|
return self._attr_vad_sensitivity_entity_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tts_options(self) -> dict[str, Any] | None:
|
||||||
|
"""Options passed for text-to-speech."""
|
||||||
|
return self._attr_tts_options
|
||||||
|
|
||||||
async def async_intercept_wake_word(self) -> str | None:
|
async def async_intercept_wake_word(self) -> str | None:
|
||||||
"""Intercept the next wake word from the satellite.
|
"""Intercept the next wake word from the satellite.
|
||||||
|
|
||||||
@ -137,6 +143,9 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
if pipeline.tts_voice is not None:
|
if pipeline.tts_voice is not None:
|
||||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||||
|
|
||||||
|
if self.tts_options is not None:
|
||||||
|
tts_options.update(self.tts_options)
|
||||||
|
|
||||||
media_id = tts_generate_media_source_id(
|
media_id = tts_generate_media_source_id(
|
||||||
self.hass,
|
self.hass,
|
||||||
message,
|
message,
|
||||||
@ -253,7 +262,7 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
pipeline_id=self._resolve_pipeline(),
|
pipeline_id=self._resolve_pipeline(),
|
||||||
conversation_id=self._conversation_id,
|
conversation_id=self._conversation_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
tts_audio_output="wav",
|
tts_audio_output=self.tts_options,
|
||||||
wake_word_phrase=wake_word_phrase,
|
wake_word_phrase=wake_word_phrase,
|
||||||
audio_settings=AudioSettings(
|
audio_settings=AudioSettings(
|
||||||
silence_seconds=self._resolve_vad_sensitivity()
|
silence_seconds=self._resolve_vad_sensitivity()
|
||||||
|
@ -6,12 +6,14 @@ import asyncio
|
|||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import io
|
import io
|
||||||
|
from itertools import chain
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
from aioesphomeapi import (
|
from aioesphomeapi import (
|
||||||
|
MediaPlayerFormatPurpose,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantCommandFlag,
|
VoiceAssistantCommandFlag,
|
||||||
VoiceAssistantEventType,
|
VoiceAssistantEventType,
|
||||||
@ -288,6 +290,18 @@ class EsphomeAssistSatellite(
|
|||||||
|
|
||||||
end_stage = PipelineStage.TTS
|
end_stage = PipelineStage.TTS
|
||||||
|
|
||||||
|
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||||
|
# Stream WAV audio
|
||||||
|
self._attr_tts_options = {
|
||||||
|
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# ANNOUNCEMENT format from media player
|
||||||
|
self._update_tts_format()
|
||||||
|
|
||||||
# Run the pipeline
|
# Run the pipeline
|
||||||
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||||
self.entry_data.async_set_assist_pipeline_state(True)
|
self.entry_data.async_set_assist_pipeline_state(True)
|
||||||
@ -340,6 +354,19 @@ class EsphomeAssistSatellite(
|
|||||||
timer_info.is_active,
|
timer_info.is_active,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_tts_format(self) -> None:
|
||||||
|
"""Update the TTS format from the first media player."""
|
||||||
|
for supported_format in chain(*self.entry_data.media_player_formats.values()):
|
||||||
|
# Find first announcement format
|
||||||
|
if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT:
|
||||||
|
self._attr_tts_options = {
|
||||||
|
tts.ATTR_PREFERRED_FORMAT: supported_format.format,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_RATE: supported_format.sample_rate,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: supported_format.num_channels,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
async def _stream_tts_audio(
|
async def _stream_tts_audio(
|
||||||
self,
|
self,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
|
@ -31,6 +31,7 @@ from aioesphomeapi import (
|
|||||||
LightInfo,
|
LightInfo,
|
||||||
LockInfo,
|
LockInfo,
|
||||||
MediaPlayerInfo,
|
MediaPlayerInfo,
|
||||||
|
MediaPlayerSupportedFormat,
|
||||||
NumberInfo,
|
NumberInfo,
|
||||||
SelectInfo,
|
SelectInfo,
|
||||||
SensorInfo,
|
SensorInfo,
|
||||||
@ -148,6 +149,9 @@ class RuntimeEntryData:
|
|||||||
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
|
tuple[type[EntityInfo], int], list[Callable[[EntityInfo], None]]
|
||||||
] = field(default_factory=dict)
|
] = field(default_factory=dict)
|
||||||
original_options: dict[str, Any] = field(default_factory=dict)
|
original_options: dict[str, Any] = field(default_factory=dict)
|
||||||
|
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
|
||||||
|
default_factory=lambda: defaultdict(list)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from aioesphomeapi import (
|
from aioesphomeapi import (
|
||||||
EntityInfo,
|
EntityInfo,
|
||||||
@ -66,6 +66,9 @@ class EsphomeMediaPlayer(
|
|||||||
if self._static_info.supports_pause:
|
if self._static_info.supports_pause:
|
||||||
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
|
flags |= MediaPlayerEntityFeature.PAUSE | MediaPlayerEntityFeature.PLAY
|
||||||
self._attr_supported_features = flags
|
self._attr_supported_features = flags
|
||||||
|
self._entry_data.media_player_formats[self.entity_id] = cast(
|
||||||
|
MediaPlayerInfo, static_info
|
||||||
|
).supported_formats
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@esphome_state_property
|
@esphome_state_property
|
||||||
@ -103,6 +106,11 @@ class EsphomeMediaPlayer(
|
|||||||
self._key, media_url=media_id, announcement=announcement
|
self._key, media_url=media_id, announcement=announcement
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
|
"""Handle entity being removed."""
|
||||||
|
await super().async_will_remove_from_hass()
|
||||||
|
self._entry_data.media_player_formats.pop(self.entity_id, None)
|
||||||
|
|
||||||
async def async_browse_media(
|
async def async_browse_media(
|
||||||
self,
|
self,
|
||||||
media_content_type: MediaType | str | None = None,
|
media_content_type: MediaType | str | None = None,
|
||||||
|
@ -61,7 +61,7 @@ async def test_entity_state(
|
|||||||
assert kwargs["stt_stream"] is audio_stream
|
assert kwargs["stt_stream"] is audio_stream
|
||||||
assert kwargs["pipeline_id"] is None
|
assert kwargs["pipeline_id"] is None
|
||||||
assert kwargs["device_id"] is None
|
assert kwargs["device_id"] is None
|
||||||
assert kwargs["tts_audio_output"] == "wav"
|
assert kwargs["tts_audio_output"] is None
|
||||||
assert kwargs["wake_word_phrase"] is None
|
assert kwargs["wake_word_phrase"] is None
|
||||||
assert kwargs["audio_settings"] == AudioSettings(
|
assert kwargs["audio_settings"] == AudioSettings(
|
||||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||||
|
@ -11,6 +11,9 @@ from aioesphomeapi import (
|
|||||||
APIClient,
|
APIClient,
|
||||||
EntityInfo,
|
EntityInfo,
|
||||||
EntityState,
|
EntityState,
|
||||||
|
MediaPlayerFormatPurpose,
|
||||||
|
MediaPlayerInfo,
|
||||||
|
MediaPlayerSupportedFormat,
|
||||||
UserService,
|
UserService,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantCommandFlag,
|
VoiceAssistantCommandFlag,
|
||||||
@ -20,7 +23,7 @@ from aioesphomeapi import (
|
|||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import assist_satellite
|
from homeassistant.components import assist_satellite, tts
|
||||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||||
from homeassistant.components.assist_satellite.entity import (
|
from homeassistant.components.assist_satellite.entity import (
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
@ -820,3 +823,71 @@ async def test_streaming_tts_errors(
|
|||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_format_from_media_player(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test that the text-to-speech format is pulled from the first media player."""
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[
|
||||||
|
MediaPlayerInfo(
|
||||||
|
object_id="mymedia_player",
|
||||||
|
key=1,
|
||||||
|
name="my media_player",
|
||||||
|
unique_id="my_media_player",
|
||||||
|
supports_pause=True,
|
||||||
|
supported_formats=[
|
||||||
|
MediaPlayerSupportedFormat(
|
||||||
|
format="flac",
|
||||||
|
sample_rate=48000,
|
||||||
|
num_channels=2,
|
||||||
|
purpose=MediaPlayerFormatPurpose.DEFAULT,
|
||||||
|
),
|
||||||
|
# This is the format that should be used for tts
|
||||||
|
MediaPlayerSupportedFormat(
|
||||||
|
format="mp3",
|
||||||
|
sample_rate=22050,
|
||||||
|
num_channels=1,
|
||||||
|
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||||
|
assert satellite is not None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
) as mock_pipeline_from_audio_stream:
|
||||||
|
await satellite.handle_pipeline_start(
|
||||||
|
conversation_id="",
|
||||||
|
flags=0,
|
||||||
|
audio_settings=VoiceAssistantAudioSettings(),
|
||||||
|
wake_word_phrase=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_pipeline_from_audio_stream.assert_called_once()
|
||||||
|
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
|
||||||
|
|
||||||
|
# Should be ANNOUNCEMENT format from media player
|
||||||
|
assert kwargs.get("tts_audio_output") == {
|
||||||
|
tts.ATTR_PREFERRED_FORMAT: "mp3",
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user