TTS to use ffmpeg in streaming fashion (#140536)

This commit is contained in:
Paulus Schoutsen 2025-04-19 06:41:52 -04:00 committed by GitHub
parent 42c4ed85a1
commit 6f99b1d69b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 73 deletions

View File

@ -14,8 +14,6 @@ import mimetypes
import os
import re
import secrets
import subprocess
import tempfile
from time import monotonic
from typing import Any, Final
@ -309,80 +307,73 @@ async def _async_convert_audio(
) -> AsyncGenerator[bytes]:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
audio_bytes = b"".join([chunk async for chunk in audio_bytes_gen])
data = await hass.async_add_executor_job(
lambda: _convert_audio(
ffmpeg_manager.binary,
from_extension,
audio_bytes,
to_extension,
to_sample_rate=to_sample_rate,
to_sample_channels=to_sample_channels,
to_sample_bytes=to_sample_bytes,
)
command = [
ffmpeg_manager.binary,
"-hide_banner",
"-loglevel",
"error",
"-f",
from_extension,
"-i",
"pipe:",
"-f",
to_extension,
]
if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])
if to_sample_channels is not None:
command.extend(["-ac", str(to_sample_channels)])
if to_extension == "mp3":
# Max quality for MP3.
command.extend(["-q:a", "0"])
if to_sample_bytes == 2:
# 16-bit samples.
command.extend(["-sample_fmt", "s16"])
command.append("pipe:1") # Send output to stdout.
process = await asyncio.create_subprocess_exec(
*command,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
yield data
async def write_input() -> None:
assert process.stdin
try:
async for chunk in audio_bytes_gen:
process.stdin.write(chunk)
await process.stdin.drain()
finally:
if process.stdin:
process.stdin.close()
def _convert_audio(
ffmpeg_binary: str,
from_extension: str,
audio_bytes: bytes,
to_extension: str,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> bytes:
"""Convert audio to a preferred format using ffmpeg."""
writer_task = hass.async_create_background_task(
write_input(), "tts_ffmpeg_conversion"
)
# We have to use a temporary file here because some formats like WAV store
# the length of the file in the header, and therefore cannot be written in a
# streaming fashion.
with tempfile.NamedTemporaryFile(
mode="wb+", suffix=f".{to_extension}"
) as output_file:
# input
command = [
ffmpeg_binary,
"-y", # overwrite temp file
"-f",
from_extension,
"-i",
"pipe:", # input from stdin
]
# output
command.extend(["-f", to_extension])
if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])
if to_sample_channels is not None:
command.extend(["-ac", str(to_sample_channels)])
if to_extension == "mp3":
# Max quality for MP3
command.extend(["-q:a", "0"])
if to_sample_bytes == 2:
# 16-bit samples
command.extend(["-sample_fmt", "s16"])
command.append(output_file.name)
with subprocess.Popen(
command, stdin=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
_stdout, stderr = proc.communicate(input=audio_bytes)
if proc.returncode != 0:
_LOGGER.error(stderr.decode())
raise RuntimeError(
f"Unexpected error while running ffmpeg with arguments: {command}."
"See log for details."
)
output_file.seek(0)
return output_file.read()
assert process.stdout
chunk_size = 4096
try:
while True:
chunk = await process.stdout.read(chunk_size)
if not chunk:
break
yield chunk
finally:
# Ensure we wait for the input writer to complete.
await writer_task
# Wait for process termination and check for errors.
retcode = await process.wait()
if retcode != 0:
assert process.stderr
stderr_data = await process.stderr.read()
_LOGGER.error(stderr_data.decode())
raise RuntimeError(
f"Unexpected error while running ffmpeg with arguments: {command}. "
"See log for details."
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

View File

@ -117,7 +117,6 @@ async def test_get_tts_audio_different_formats(
assert wav_file.getframerate() == 48000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 2
assert wav_file.getnframes() == wav_file.getframerate() # one second
assert mock_client.written == snapshot