Rewrite go2rtc binary handling to be async (#128078)

This commit is contained in:
Robert Resch 2024-10-14 15:32:00 +02:00 committed by GitHub
parent cdb1b1df15
commit f5b55d5eb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 115 additions and 81 deletions

View File

@ -50,9 +50,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up WebRTC from a config entry.""" """Set up WebRTC from a config entry."""
if binary := entry.data.get(CONF_BINARY): if binary := entry.data.get(CONF_BINARY):
# HA will manage the binary # HA will manage the binary
server = Server(binary) server = Server(hass, binary)
entry.async_on_unload(server.stop) entry.async_on_unload(server.stop)
server.start() await server.start()
client = Go2RtcClient(async_get_clientsession(hass), entry.data[CONF_HOST]) client = Go2RtcClient(async_get_clientsession(hass), entry.data[CONF_HOST])

View File

@ -1,56 +1,70 @@
"""Go2rtc server.""" """Go2rtc server."""
from __future__ import annotations import asyncio
import logging import logging
import subprocess
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from threading import Thread
from .const import DOMAIN from homeassistant.core import HomeAssistant
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_TERMINATE_TIMEOUT = 5
class Server(Thread): def _create_temp_file() -> str:
"""Server thread.""" """Create temporary config file."""
# Set delete=False to prevent the file from being deleted when the file is closed
# Linux is clearing tmp folder on reboot, so no need to delete it manually
with NamedTemporaryFile(prefix="go2rtc", suffix=".yaml", delete=False) as file:
return file.name
def __init__(self, binary: str) -> None:
async def _log_output(process: asyncio.subprocess.Process) -> None:
"""Log the output of the process."""
assert process.stdout is not None
async for line in process.stdout:
_LOGGER.debug(line[:-1].decode().strip())
class Server:
"""Go2rtc server."""
def __init__(self, hass: HomeAssistant, binary: str) -> None:
"""Initialize the server.""" """Initialize the server."""
super().__init__(name=DOMAIN, daemon=True) self._hass = hass
self._binary = binary self._binary = binary
self._stop_requested = False self._process: asyncio.subprocess.Process | None = None
def run(self) -> None: async def start(self) -> None:
"""Run the server.""" """Start the server."""
_LOGGER.debug("Starting go2rtc server") _LOGGER.debug("Starting go2rtc server")
self._stop_requested = False config_file = await self._hass.async_add_executor_job(_create_temp_file)
with (
NamedTemporaryFile(prefix="go2rtc", suffix=".yaml") as file,
subprocess.Popen(
[self._binary, "-c", "webrtc.ice_servers=[]", "-c", file.name],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as process,
):
while not self._stop_requested and process.poll() is None:
assert process.stdout
line = process.stdout.readline()
if line == b"":
break
_LOGGER.debug(line[:-1].decode())
_LOGGER.debug("Terminating go2rtc server") self._process = await asyncio.create_subprocess_exec(
self._binary,
"-c",
"webrtc.ice_servers=[]",
"-c",
config_file,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
self._hass.async_create_background_task(
_log_output(self._process), "Go2rtc log output"
)
async def stop(self) -> None:
"""Stop the server."""
if self._process:
_LOGGER.debug("Stopping go2rtc server")
process = self._process
self._process = None
process.terminate() process.terminate()
try: try:
process.wait(timeout=5) await asyncio.wait_for(process.wait(), timeout=_TERMINATE_TIMEOUT)
except subprocess.TimeoutExpired: except TimeoutError:
_LOGGER.warning("Go2rtc server didn't terminate gracefully. Killing it") _LOGGER.warning("Go2rtc server didn't terminate gracefully. Killing it")
process.kill() process.kill()
else:
_LOGGER.debug("Go2rtc server has been stopped") _LOGGER.debug("Go2rtc server has been stopped")
def stop(self) -> None:
"""Stop the server."""
self._stop_requested = True
if self.is_alive():
self.join()

View File

@ -7,6 +7,7 @@ from go2rtc_client.client import _StreamClient, _WebRTCClient
import pytest import pytest
from homeassistant.components.go2rtc.const import CONF_BINARY, DOMAIN from homeassistant.components.go2rtc.const import CONF_BINARY, DOMAIN
from homeassistant.components.go2rtc.server import Server
from homeassistant.const import CONF_HOST from homeassistant.const import CONF_HOST
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -41,9 +42,11 @@ def mock_client() -> Generator[AsyncMock]:
@pytest.fixture @pytest.fixture
def mock_server() -> Generator[Mock]: def mock_server() -> Generator[AsyncMock]:
"""Mock a go2rtc server.""" """Mock a go2rtc server."""
with patch("homeassistant.components.go2rtc.Server", autoSpec=True) as mock_server: with patch(
"homeassistant.components.go2rtc.Server", spec_set=Server
) as mock_server:
yield mock_server yield mock_server

View File

@ -184,13 +184,13 @@ async def _test_setup(
async def test_setup_go_binary( async def test_setup_go_binary(
hass: HomeAssistant, hass: HomeAssistant,
mock_client: AsyncMock, mock_client: AsyncMock,
mock_server: Mock, mock_server: AsyncMock,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
) -> None: ) -> None:
"""Test the go2rtc config entry with binary.""" """Test the go2rtc config entry with binary."""
def after_setup() -> None: def after_setup() -> None:
mock_server.assert_called_once_with("/usr/bin/go2rtc") mock_server.assert_called_once_with(hass, "/usr/bin/go2rtc")
mock_server.return_value.start.assert_called_once() mock_server.return_value.start.assert_called_once()
await _test_setup(hass, mock_client, mock_config_entry, after_setup) await _test_setup(hass, mock_client, mock_config_entry, after_setup)

View File

@ -2,20 +2,22 @@
import asyncio import asyncio
from collections.abc import Generator from collections.abc import Generator
import logging
import subprocess import subprocess
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from homeassistant.components.go2rtc.server import Server from homeassistant.components.go2rtc.server import Server
from homeassistant.core import HomeAssistant
TEST_BINARY = "/bin/go2rtc" TEST_BINARY = "/bin/go2rtc"
@pytest.fixture @pytest.fixture
def server() -> Server: def server(hass: HomeAssistant) -> Server:
"""Fixture to initialize the Server.""" """Fixture to initialize the Server."""
return Server(binary=TEST_BINARY) return Server(hass, binary=TEST_BINARY)
@pytest.fixture @pytest.fixture
@ -29,63 +31,77 @@ def mock_tempfile() -> Generator[MagicMock]:
@pytest.fixture @pytest.fixture
def mock_popen() -> Generator[MagicMock]: def mock_process() -> Generator[MagicMock]:
"""Fixture to mock subprocess.Popen.""" """Fixture to mock subprocess.Popen."""
with patch("homeassistant.components.go2rtc.server.subprocess.Popen") as mock_popen: with patch(
"homeassistant.components.go2rtc.server.asyncio.create_subprocess_exec"
) as mock_popen:
mock_popen.return_value.returncode = None
yield mock_popen yield mock_popen
@pytest.mark.usefixtures("mock_tempfile") @pytest.mark.usefixtures("mock_tempfile")
async def test_server_run_success(mock_popen: MagicMock, server: Server) -> None: async def test_server_run_success(
mock_process: MagicMock,
server: Server,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that the server runs successfully.""" """Test that the server runs successfully."""
mock_process = MagicMock()
mock_process.poll.return_value = None # Simulate process running
# Simulate process output # Simulate process output
mock_process.stdout.readline.side_effect = [ mock_process.return_value.stdout.__aiter__.return_value = iter(
[
b"log line 1\n", b"log line 1\n",
b"log line 2\n", b"log line 2\n",
b"",
] ]
mock_popen.return_value.__enter__.return_value = mock_process )
server.start() await server.start()
await asyncio.sleep(0)
# Check that Popen was called with the right arguments # Check that Popen was called with the right arguments
mock_popen.assert_called_once_with( mock_process.assert_called_once_with(
[TEST_BINARY, "-c", "webrtc.ice_servers=[]", "-c", "test.yaml"], TEST_BINARY,
"-c",
"webrtc.ice_servers=[]",
"-c",
"test.yaml",
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
) )
# Check that server read the log lines # Check that server read the log lines
assert mock_process.stdout.readline.call_count == 3 for entry in ("log line 1", "log line 2"):
assert (
"homeassistant.components.go2rtc.server",
logging.DEBUG,
entry,
) in caplog.record_tuples
server.stop() await server.stop()
mock_process.terminate.assert_called_once() mock_process.return_value.terminate.assert_called_once()
assert not server.is_alive()
@pytest.mark.usefixtures("mock_tempfile") @pytest.mark.usefixtures("mock_tempfile")
def test_server_run_process_timeout(mock_popen: MagicMock, server: Server) -> None: async def test_server_run_process_timeout(
mock_process: MagicMock, server: Server
) -> None:
"""Test server run where the process takes too long to terminate.""" """Test server run where the process takes too long to terminate."""
mock_process.return_value.stdout.__aiter__.return_value = iter(
mock_process = MagicMock() [
mock_process.poll.return_value = None # Simulate process running
# Simulate process output
mock_process.stdout.readline.side_effect = [
b"log line 1\n", b"log line 1\n",
b"",
] ]
# Simulate timeout )
mock_process.wait.side_effect = subprocess.TimeoutExpired(cmd="go2rtc", timeout=5)
mock_popen.return_value.__enter__.return_value = mock_process
async def sleep() -> None:
await asyncio.sleep(1)
# Simulate timeout
mock_process.return_value.wait.side_effect = sleep
with patch("homeassistant.components.go2rtc.server._TERMINATE_TIMEOUT", new=0.1):
# Start server thread # Start server thread
server.start() await server.start()
server.stop() await server.stop()
# Ensure terminate and kill were called due to timeout # Ensure terminate and kill were called due to timeout
mock_process.terminate.assert_called_once() mock_process.return_value.terminate.assert_called_once()
mock_process.kill.assert_called_once() mock_process.return_value.kill.assert_called_once()
assert not server.is_alive()