diff --git a/homeassistant/components/go2rtc/__init__.py b/homeassistant/components/go2rtc/__init__.py index eeaa35fbbb4..013c094dc23 100644 --- a/homeassistant/components/go2rtc/__init__.py +++ b/homeassistant/components/go2rtc/__init__.py @@ -38,7 +38,7 @@ from homeassistant.helpers.typing import ConfigType from homeassistant.util.hass_dict import HassKey from homeassistant.util.package import is_docker_env -from .const import CONF_DEBUG_UI, DEBUG_UI_URL_MESSAGE, DOMAIN +from .const import CONF_DEBUG_UI, DEBUG_UI_URL_MESSAGE, DEFAULT_URL, DOMAIN from .server import Server _LOGGER = logging.getLogger(__name__) @@ -121,7 +121,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop) - url = "http://localhost:1984/" + url = DEFAULT_URL hass.data[_DATA_GO2RTC] = url discovery_flow.async_create_flow( diff --git a/homeassistant/components/go2rtc/const.py b/homeassistant/components/go2rtc/const.py index b0d52e4fd39..cb03e224e52 100644 --- a/homeassistant/components/go2rtc/const.py +++ b/homeassistant/components/go2rtc/const.py @@ -4,3 +4,4 @@ DOMAIN = "go2rtc" CONF_DEBUG_UI = "debug_ui" DEBUG_UI_URL_MESSAGE = "Url and debug_ui cannot be set at the same time." +DEFAULT_URL = "http://localhost:1984/" diff --git a/homeassistant/components/go2rtc/server.py b/homeassistant/components/go2rtc/server.py index df4b5b7f13e..b2aa19d5275 100644 --- a/homeassistant/components/go2rtc/server.py +++ b/homeassistant/components/go2rtc/server.py @@ -1,17 +1,25 @@ """Go2rtc server.""" import asyncio +from contextlib import suppress import logging from tempfile import NamedTemporaryFile +from go2rtc_client import Go2RtcRestClient + from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.aiohttp_client import async_get_clientsession + +from .const import DEFAULT_URL _LOGGER = logging.getLogger(__name__) _TERMINATE_TIMEOUT = 5 _SETUP_TIMEOUT = 30 _SUCCESSFUL_BOOT_MESSAGE = "INF [api] listen addr=" _LOCALHOST_IP = "127.0.0.1" +_RESPAWN_COOLDOWN = 1 + # Default configuration for HA # - Api is listening only on localhost # - Disable rtsp listener @@ -29,6 +37,16 @@ webrtc: """ +class Go2RTCServerStartError(HomeAssistantError): + """Raised when server does not start.""" + + _message = "Go2rtc server didn't start correctly" + + +class Go2RTCWatchdogError(HomeAssistantError): + """Raised on watchdog error.""" + + def _create_temp_file(api_ip: str) -> str: """Create temporary config file.""" # Set delete=False to prevent the file from being deleted when the file is closed @@ -53,8 +71,17 @@ class Server: if enable_ui: # Listen on all interfaces for allowing access from all ips self._api_ip = "" + self._watchdog_task: asyncio.Task | None = None + self._watchdog_tasks: list[asyncio.Task] = [] async def start(self) -> None: + """Start the server.""" + await self._start() + self._watchdog_task = asyncio.create_task( + self._watchdog(), name="Go2rtc respawn" + ) + + async def _start(self) -> None: """Start the server.""" _LOGGER.debug("Starting go2rtc server") config_file = await self._hass.async_add_executor_job( @@ -82,8 +109,8 @@ class Server: except TimeoutError as err: msg = "Go2rtc server didn't start correctly" _LOGGER.exception(msg) - await self.stop() - raise HomeAssistantError("Go2rtc server didn't start correctly") from err + await self._stop() + raise Go2RTCServerStartError from err async def _log_output(self, process: asyncio.subprocess.Process) -> None: """Log the output of the process.""" @@ -95,17 +122,95 @@ class Server: if not self._startup_complete.is_set() and _SUCCESSFUL_BOOT_MESSAGE in msg: self._startup_complete.set() + async def _watchdog(self) -> None: + """Keep respawning go2rtc servers. + + A new go2rtc server is spawned if the process terminates or the API + stops responding. + """ + while True: + try: + monitor_process_task = asyncio.create_task(self._monitor_process()) + self._watchdog_tasks.append(monitor_process_task) + monitor_process_task.add_done_callback(self._watchdog_tasks.remove) + monitor_api_task = asyncio.create_task(self._monitor_api()) + self._watchdog_tasks.append(monitor_api_task) + monitor_api_task.add_done_callback(self._watchdog_tasks.remove) + try: + await asyncio.gather(monitor_process_task, monitor_api_task) + except Go2RTCWatchdogError: + _LOGGER.debug("Caught Go2RTCWatchdogError") + for task in self._watchdog_tasks: + if task.done(): + if not task.cancelled(): + task.exception() + continue + task.cancel() + await asyncio.sleep(_RESPAWN_COOLDOWN) + try: + await self._stop() + _LOGGER.debug("Spawning new go2rtc server") + with suppress(Go2RTCServerStartError): + await self._start() + except Exception: + _LOGGER.exception( + "Unexpected error when restarting go2rtc server" + ) + except Exception: + _LOGGER.exception("Unexpected error in go2rtc server watchdog") + + async def _monitor_process(self) -> None: + """Raise if the go2rtc process terminates.""" + _LOGGER.debug("Monitoring go2rtc server process") + if self._process: + await self._process.wait() + _LOGGER.debug("go2rtc server terminated") + raise Go2RTCWatchdogError("Process ended") + + async def _monitor_api(self) -> None: + """Raise if the go2rtc process terminates.""" + client = Go2RtcRestClient(async_get_clientsession(self._hass), DEFAULT_URL) + + _LOGGER.debug("Monitoring go2rtc API") + try: + while True: + await client.streams.list() + await asyncio.sleep(10) + except Exception as err: + _LOGGER.debug("go2rtc API did not reply", exc_info=True) + raise Go2RTCWatchdogError("API error") from err + + async def _stop_watchdog(self) -> None: + """Handle watchdog stop request.""" + tasks: list[asyncio.Task] = [] + if watchdog_task := self._watchdog_task: + self._watchdog_task = None + tasks.append(watchdog_task) + watchdog_task.cancel() + for task in self._watchdog_tasks: + tasks.append(task) + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + async def stop(self) -> None: + """Stop the server and abort the watchdog task.""" + _LOGGER.debug("Server stop requested") + await self._stop_watchdog() + await self._stop() + + async def _stop(self) -> None: """Stop the server.""" if self._process: _LOGGER.debug("Stopping go2rtc server") process = self._process self._process = None - process.terminate() + with suppress(ProcessLookupError): + process.terminate() try: await asyncio.wait_for(process.wait(), timeout=_TERMINATE_TIMEOUT) except TimeoutError: _LOGGER.warning("Go2rtc server didn't terminate gracefully. Killing it") - process.kill() + with suppress(ProcessLookupError): + process.kill() else: _LOGGER.debug("Go2rtc server has been stopped") diff --git a/tests/components/go2rtc/conftest.py b/tests/components/go2rtc/conftest.py index b299c28c557..495d42114f1 100644 --- a/tests/components/go2rtc/conftest.py +++ b/tests/components/go2rtc/conftest.py @@ -18,6 +18,7 @@ def rest_client() -> Generator[AsyncMock]: patch( "homeassistant.components.go2rtc.Go2RtcRestClient", ) as mock_client, + patch("homeassistant.components.go2rtc.server.Go2RtcRestClient", mock_client), ): client = mock_client.return_value client.streams = Mock(spec_set=_StreamClient) diff --git a/tests/components/go2rtc/test_server.py b/tests/components/go2rtc/test_server.py index 42f3f5e098d..1410fbeb6c3 100644 --- a/tests/components/go2rtc/test_server.py +++ b/tests/components/go2rtc/test_server.py @@ -161,3 +161,100 @@ async def test_server_failed_to_start( stderr=subprocess.STDOUT, close_fds=False, ) + + +@patch("homeassistant.components.go2rtc.server._RESPAWN_COOLDOWN", 0) +async def test_server_restart_process_exit( + hass: HomeAssistant, + mock_create_subprocess: AsyncMock, + rest_client: AsyncMock, + server: Server, +) -> None: + """Test that the server is restarted when it exits.""" + evt = asyncio.Event() + + async def wait_event() -> None: + await evt.wait() + + mock_create_subprocess.return_value.wait.side_effect = wait_event + + await server.start() + mock_create_subprocess.assert_awaited_once() + mock_create_subprocess.reset_mock() + + await asyncio.sleep(0.1) + await hass.async_block_till_done() + mock_create_subprocess.assert_not_awaited() + + evt.set() + await asyncio.sleep(0.1) + mock_create_subprocess.assert_awaited_once() + + await server.stop() + + +@patch("homeassistant.components.go2rtc.server._RESPAWN_COOLDOWN", 0) +async def test_server_restart_process_error( + hass: HomeAssistant, + mock_create_subprocess: AsyncMock, + rest_client: AsyncMock, + server: Server, +) -> None: + """Test that the server is restarted on error.""" + mock_create_subprocess.return_value.wait.side_effect = [Exception, None, None, None] + + await server.start() + mock_create_subprocess.assert_awaited_once() + mock_create_subprocess.reset_mock() + + await asyncio.sleep(0.1) + await hass.async_block_till_done() + mock_create_subprocess.assert_awaited_once() + + await server.stop() + + +@patch("homeassistant.components.go2rtc.server._RESPAWN_COOLDOWN", 0) +async def test_server_restart_api_error( + hass: HomeAssistant, + mock_create_subprocess: AsyncMock, + rest_client: AsyncMock, + server: Server, +) -> None: + """Test that the server is restarted on error.""" + rest_client.streams.list.side_effect = Exception + + await server.start() + mock_create_subprocess.assert_awaited_once() + mock_create_subprocess.reset_mock() + + await asyncio.sleep(0.1) + await hass.async_block_till_done() + mock_create_subprocess.assert_awaited_once() + + await server.stop() + + +@patch("homeassistant.components.go2rtc.server._RESPAWN_COOLDOWN", 0) +async def test_server_restart_error( + hass: HomeAssistant, + mock_create_subprocess: AsyncMock, + rest_client: AsyncMock, + server: Server, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test error handling when exception is raised during restart.""" + rest_client.streams.list.side_effect = Exception + mock_create_subprocess.return_value.terminate.side_effect = [Exception, None] + + await server.start() + mock_create_subprocess.assert_awaited_once() + mock_create_subprocess.reset_mock() + + await asyncio.sleep(0.1) + await hass.async_block_till_done() + mock_create_subprocess.assert_awaited_once() + + assert "Unexpected error when restarting go2rtc server" in caplog.text + + await server.stop()