mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 11:47:06 +00:00
Add VoIP announce (#136781)
* Implement async_announce for VoIP * Add tests * Add network to voip dependencies
This commit is contained in:
parent
7256575c09
commit
64cda8cdb8
@ -8,23 +8,29 @@ from functools import partial
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
import wave
|
||||
|
||||
from voip_utils import RtpDatagramProtocol
|
||||
from voip_utils import SIP_PORT, RtpDatagramProtocol
|
||||
from voip_utils.sip import SipEndpoint, get_sip_endpoint
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteAnnouncement,
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
AssistSatelliteEntityFeature,
|
||||
)
|
||||
from homeassistant.components.network import async_get_source_ip
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .const import CHANNELS, CONF_SIP_PORT, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
||||
from .devices import VoIPDevice
|
||||
from .entity import VoIPEntity
|
||||
|
||||
@ -34,6 +40,9 @@ if TYPE_CHECKING:
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
|
||||
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
|
||||
_ANNOUNCEMENT_HANGUP_SEC: Final = 0.5
|
||||
|
||||
|
||||
class Tones(IntFlag):
|
||||
@ -80,6 +89,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||
_attr_translation_key = "assist_satellite"
|
||||
_attr_name = None
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -105,6 +115,12 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||
self._tones = tones
|
||||
self._processing_tone_done = asyncio.Event()
|
||||
|
||||
self._announcement: AssistSatelliteAnnouncement | None = None
|
||||
self._announcement_done = asyncio.Event()
|
||||
self._check_announcement_ended_task: asyncio.Task | None = None
|
||||
self._last_chunk_time: float | None = None
|
||||
self._rtp_port: int | None = None
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
@ -149,25 +165,108 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||
"""Set the current satellite configuration."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Plays announcement in a loop, blocking until the caller hangs up.
|
||||
"""
|
||||
self._announcement_done.clear()
|
||||
|
||||
if self._rtp_port is None:
|
||||
# Choose random port for RTP
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind(("", 0))
|
||||
_rtp_ip, self._rtp_port = sock.getsockname()
|
||||
sock.close()
|
||||
|
||||
# HA SIP server
|
||||
source_ip = await async_get_source_ip(self.hass)
|
||||
sip_port = self.config_entry.options.get(CONF_SIP_PORT, SIP_PORT)
|
||||
source_endpoint = get_sip_endpoint(host=source_ip, port=sip_port)
|
||||
|
||||
try:
|
||||
# VoIP ID is SIP header
|
||||
destination_endpoint = SipEndpoint(self.voip_device.voip_id)
|
||||
except ValueError:
|
||||
# VoIP ID is IP address
|
||||
destination_endpoint = get_sip_endpoint(
|
||||
host=self.voip_device.voip_id, port=SIP_PORT
|
||||
)
|
||||
|
||||
self._announcement = announcement
|
||||
|
||||
# Make the call
|
||||
self.hass.data[DOMAIN].protocol.outgoing_call(
|
||||
source=source_endpoint,
|
||||
destination=destination_endpoint,
|
||||
rtp_port=self._rtp_port,
|
||||
)
|
||||
|
||||
await self._announcement_done.wait()
|
||||
|
||||
async def _check_announcement_ended(self) -> None:
|
||||
"""Continuously checks if an audio chunk was received within a time limit.
|
||||
|
||||
If not, the caller is presumed to have hung up and the announcement is ended.
|
||||
"""
|
||||
while self._announcement is not None:
|
||||
if (self._last_chunk_time is not None) and (
|
||||
(time.monotonic() - self._last_chunk_time) > _ANNOUNCEMENT_HANGUP_SEC
|
||||
):
|
||||
# Caller hung up
|
||||
self._announcement = None
|
||||
self._announcement_done.set()
|
||||
self._check_announcement_ended_task = None
|
||||
_LOGGER.debug("Announcement ended")
|
||||
break
|
||||
|
||||
await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# VoIP
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._run_pipeline_task is None:
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._clear_audio_queue()
|
||||
self._tts_done.clear()
|
||||
self._last_chunk_time = time.monotonic()
|
||||
|
||||
if self._announcement is None:
|
||||
# Pipeline with STT
|
||||
if self._run_pipeline_task is None:
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._clear_audio_queue()
|
||||
self._tts_done.clear()
|
||||
self._run_pipeline_task = (
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
elif self._run_pipeline_task is None:
|
||||
# Announcement only
|
||||
if self._check_announcement_ended_task is None:
|
||||
# Check if caller hung up
|
||||
self._check_announcement_ended_task = (
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._check_announcement_ended(),
|
||||
"voip_announcement_ended",
|
||||
)
|
||||
)
|
||||
|
||||
# Play announcement (will repeat)
|
||||
self._run_pipeline_task = self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
self._play_announcement(self._announcement),
|
||||
"voip_play_announcement",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(self) -> None:
|
||||
"""Run a pipeline with STT input and TTS output."""
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
|
||||
self.async_set_context(Context(user_id=self.config_entry.data["user"]))
|
||||
@ -209,6 +308,23 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||
self._run_pipeline_task = None
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
async def _play_announcement(
|
||||
self, announcement: AssistSatelliteAnnouncement
|
||||
) -> None:
|
||||
"""Play an announcement once."""
|
||||
_LOGGER.debug("Playing announcement")
|
||||
|
||||
try:
|
||||
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
||||
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
|
||||
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected error while playing announcement")
|
||||
raise
|
||||
finally:
|
||||
self._run_pipeline_task = None
|
||||
_LOGGER.debug("Announcement finished")
|
||||
|
||||
def _clear_audio_queue(self) -> None:
|
||||
"""Ensure audio queue is empty."""
|
||||
while not self._audio_queue.empty():
|
||||
@ -239,7 +355,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||
self._pipeline_had_error = True
|
||||
_LOGGER.warning(event)
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
try:
|
||||
if self.transport is None:
|
||||
@ -253,7 +369,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
||||
|
||||
if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
|
||||
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
|
||||
# Don't overlap TTS and processing beep
|
||||
_LOGGER.debug("Waiting for processing tone")
|
||||
await self._processing_tone_done.wait()
|
||||
|
@ -3,7 +3,7 @@
|
||||
"name": "Voice over IP",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline", "assist_satellite"],
|
||||
"dependencies": ["assist_pipeline", "assist_satellite", "network"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
|
@ -16,7 +16,7 @@ from homeassistant.components.assist_satellite import AssistSatelliteEntity
|
||||
|
||||
# pylint: disable-next=hass-component-root-import
|
||||
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
||||
from homeassistant.components.voip import HassVoipDatagramProtocol
|
||||
from homeassistant.components.voip import DOMAIN, HassVoipDatagramProtocol
|
||||
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
||||
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
||||
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
||||
@ -844,3 +844,100 @@ async def test_pipeline_error(
|
||||
|
||||
assert sum(played_audio_bytes) > 0
|
||||
assert played_audio_bytes == snapshot()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_announce(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test announcement."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
assert (
|
||||
satellite.supported_features
|
||||
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
|
||||
# Protocol has already been mocked, but "outgoing_call" is not async
|
||||
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
|
||||
mock_protocol.outgoing_call = Mock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||
) as mock_send_tts,
|
||||
):
|
||||
satellite.transport = Mock()
|
||||
announce_task = hass.async_create_background_task(
|
||||
satellite.async_announce(announcement), "voip_announce"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_protocol.outgoing_call.assert_called_once()
|
||||
|
||||
# Trigger announcement
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
await announce_task
|
||||
|
||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("socket_enabled")
|
||||
async def test_voip_id_is_ip_address(
|
||||
hass: HomeAssistant,
|
||||
voip_devices: VoIPDevices,
|
||||
voip_device: VoIPDevice,
|
||||
) -> None:
|
||||
"""Test announcement when VoIP is an IP address instead of a SIP header."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||
assert isinstance(satellite, VoipAssistSatellite)
|
||||
assert (
|
||||
satellite.supported_features
|
||||
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
)
|
||||
|
||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test announcement",
|
||||
media_id=_MEDIA_ID,
|
||||
original_media_id=_MEDIA_ID,
|
||||
media_id_source="tts",
|
||||
)
|
||||
|
||||
# Protocol has already been mocked, but "outgoing_call" is not async
|
||||
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
|
||||
mock_protocol.outgoing_call = Mock()
|
||||
|
||||
with (
|
||||
patch.object(voip_device, "voip_id", "192.168.68.10"),
|
||||
patch(
|
||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||
) as mock_send_tts,
|
||||
):
|
||||
satellite.transport = Mock()
|
||||
announce_task = hass.async_create_background_task(
|
||||
satellite.async_announce(announcement), "voip_announce"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
mock_protocol.outgoing_call.assert_called_once()
|
||||
assert (
|
||||
mock_protocol.outgoing_call.call_args.kwargs["destination"].host
|
||||
== "192.168.68.10"
|
||||
)
|
||||
|
||||
# Trigger announcement
|
||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||
await announce_task
|
||||
|
||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user