Fix audio format for VoIP (#125785)

Fix audio format
This commit is contained in:
Michael Hansen 2024-09-11 19:57:47 -05:00 committed by GitHub
parent 2475e8c0c4
commit 9651072103
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 3 deletions

View File

@ -8,7 +8,7 @@ from functools import partial
import io import io
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Final from typing import TYPE_CHECKING, Any, Final
import wave import wave
from voip_utils import RtpDatagramProtocol from voip_utils import RtpDatagramProtocol
@ -120,6 +120,16 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
"""Return the entity ID of the VAD sensitivity to use for the next conversation.""" """Return the entity ID of the VAD sensitivity to use for the next conversation."""
return self.voip_device.get_vad_sensitivity_entity_id(self.hass) return self.voip_device.get_vad_sensitivity_entity_id(self.hass)
@property
def tts_options(self) -> dict[str, Any] | None:
"""Options passed for text-to-speech."""
return {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass.""" """Run when entity about to be added to hass."""
await super().async_added_to_hass() await super().async_added_to_hass()

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import io import io
from pathlib import Path from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import wave import wave
@ -10,7 +11,7 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from voip_utils import CallInfo from voip_utils import CallInfo
from homeassistant.components import assist_pipeline, assist_satellite, voip from homeassistant.components import assist_pipeline, assist_satellite, tts, voip
from homeassistant.components.assist_satellite.entity import ( from homeassistant.components.assist_satellite.entity import (
AssistSatelliteEntity, AssistSatelliteEntity,
AssistSatelliteState, AssistSatelliteState,
@ -205,11 +206,24 @@ async def test_pipeline(
bad_chunk = bytes([1, 2, 3, 4]) bad_chunk = bytes([1, 2, 3, 4])
async def async_pipeline_from_audio_stream( async def async_pipeline_from_audio_stream(
hass: HomeAssistant, context: Context, *args, device_id: str | None, **kwargs hass: HomeAssistant,
context: Context,
*args,
device_id: str | None,
tts_audio_output: str | dict[str, Any] | None,
**kwargs,
): ):
assert context.user_id == voip_user_id assert context.user_id == voip_user_id
assert device_id == voip_device.device_id assert device_id == voip_device.device_id
# voip can only stream WAV
assert tts_audio_output == {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
stt_stream = kwargs["stt_stream"] stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"] event_callback = kwargs["event_callback"]
in_command = False in_command = False