From e5817e9445b0847257bb997b140518fa7a480c0e Mon Sep 17 00:00:00 2001 From: Stefan Agner Date: Wed, 20 Oct 2021 14:40:28 +0200 Subject: [PATCH] D-Bus signal handling improvements (#3248) * Allow to update connection information * Introduce Signal wrapper class using async context manager This allows to start monitoring signals and execute code before processing signals. With that it is possible to check for state changes in a race free manor. * Fix unit tests --- supervisor/dbus/network/connection.py | 7 ++ supervisor/host/network.py | 33 +++--- supervisor/utils/dbus.py | 144 ++++++++++++++------------ tests/conftest.py | 22 +++- 4 files changed, 120 insertions(+), 86 deletions(-) diff --git a/supervisor/dbus/network/connection.py b/supervisor/dbus/network/connection.py index a9fce0584..636945b08 100644 --- a/supervisor/dbus/network/connection.py +++ b/supervisor/dbus/network/connection.py @@ -2,6 +2,8 @@ from ipaddress import ip_address, ip_interface from typing import Optional +from supervisor.dbus.utils import dbus_connected + from ...const import ATTR_ADDRESS, ATTR_PREFIX from ...utils.dbus import DBus from ..const import ( @@ -79,6 +81,11 @@ class NetworkConnection(DBusInterfaceProxy): async def connect(self) -> None: """Get connection information.""" self.dbus = await DBus.connect(DBUS_NAME_NM, self.object_path) + await self.update() + + @dbus_connected + async def update(self): + """Update connection information.""" self.properties = await self.dbus.get_properties(DBUS_IFACE_CONNECTION_ACTIVE) # IPv4 diff --git a/supervisor/host/network.py b/supervisor/host/network.py index 89075d10f..2c9ad8e81 100644 --- a/supervisor/host/network.py +++ b/supervisor/host/network.py @@ -191,25 +191,24 @@ class NetworkManager(CoreSysAttributes): ) if con: - # Only consider activated or deactivated signals, continue waiting on others - def message_filter(msg_body): - state: ConnectionStateType = msg_body[0] - if state == ConnectionStateType.DEACTIVATED: - return True - elif state == ConnectionStateType.ACTIVATED: - return True - return False + async with con.dbus.signal( + DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED + ) as signal: + # From this point we monitor signals. However, it might be that + # the state change before this point. Get the state currently to + # avoid any race condition. + await con.update() + state: ConnectionStateType = con.state - result = await con.dbus.wait_signal( - DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED, message_filter - ) + while state != ConnectionStateType.ACTIVATED: + if state == ConnectionStateType.DEACTIVATED: + raise HostNetworkError( + "Activating connection failed, check connection settings." + ) - _LOGGER.debug("StateChanged signal received, result: %s", str(result)) - state: ConnectionStateType = result[0] - if state != ConnectionStateType.ACTIVATED: - raise HostNetworkError( - "Activating connection failed, check connection settings." - ) + msg = await signal.wait_for_signal() + state = msg[0] + _LOGGER.debug("Active connection state changed to %s", state) await self.update() diff --git a/supervisor/utils/dbus.py b/supervisor/utils/dbus.py index fb38196d4..02597503c 100644 --- a/supervisor/utils/dbus.py +++ b/supervisor/utils/dbus.py @@ -193,72 +193,14 @@ class DBus: _LOGGER.error("No Set attribute %s for %s", name, interface) raise DBusFatalError() from err - async def wait_signal(self, signal_member, message_filter=None) -> Any: + def signal(self, signal_member) -> DBusSignalWrapper: + """Get signal context manager for this object.""" + return DBusSignalWrapper(self, signal_member) + + async def wait_signal(self, signal_member) -> Any: """Wait for signal on this object.""" - signal_parts = signal_member.split(".") - interface = ".".join(signal_parts[:-1]) - member = signal_parts[-1] - match = f"type='signal',interface={interface},member={member},path={self.object_path}" - - _LOGGER.debug("Install match for signal %s", signal_member) - await self._bus.call( - Message( - destination="org.freedesktop.DBus", - interface="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - member="AddMatch", - signature="s", - body=[match], - ) - ) - - loop = asyncio.get_event_loop() - future = loop.create_future() - - def message_handler(msg: Message): - if msg.message_type != MessageType.SIGNAL: - return - - _LOGGER.debug( - "Signal message received %s, %s.%s object %s", - msg.body, - msg.interface, - msg.member, - msg.path, - ) - if ( - msg.interface != interface - or msg.member != member - or msg.path != self.object_path - ): - return - - # Avoid race condition: We already received signal but handler not yet removed. - if future.done(): - return - - msg_body = _remove_dbus_signature(msg.body) - if message_filter and not message_filter(msg_body): - return - - future.set_result(msg_body) - - self._bus.add_message_handler(message_handler) - result = await future - self._bus.remove_message_handler(message_handler) - - await self._bus.call( - Message( - destination="org.freedesktop.DBus", - interface="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - member="RemoveMatch", - signature="s", - body=[match], - ) - ) - - return result + async with self.signal(signal_member) as signal: + return await signal.wait_for_signal() def __getattr__(self, name: str) -> DBusCallWrapper: """Map to dbus method.""" @@ -293,3 +235,75 @@ class DBusCallWrapper: return self.dbus.call_dbus(interface, *args) return _method_wrapper + + +class DBusSignalWrapper: + """Wrapper for D-Bus Signal.""" + + def __init__(self, dbus: DBus, signal_member: str) -> None: + """Initialize wrapper.""" + self._dbus: DBus = dbus + signal_parts = signal_member.split(".") + self._interface = ".".join(signal_parts[:-1]) + self._member = signal_parts[-1] + self._match: str = f"type='signal',interface={self._interface},member={self._member},path={self._dbus.object_path}" + self._messages: asyncio.Queue[Message] = asyncio.Queue() + + def _message_handler(self, msg: Message): + if msg.message_type != MessageType.SIGNAL: + return + + _LOGGER.debug( + "Signal message received %s, %s.%s object %s", + msg.body, + msg.interface, + msg.member, + msg.path, + ) + if ( + msg.interface != self._interface + or msg.member != self._member + or msg.path != self._dbus.object_path + ): + return + + self._messages.put_nowait(msg) + + async def __aenter__(self): + """Install match for signals and start collecting signal messages.""" + + _LOGGER.debug("Install match for signal %s.%s", self._interface, self._member) + await self._dbus._bus.call( + Message( + destination="org.freedesktop.DBus", + interface="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + member="AddMatch", + signature="s", + body=[self._match], + ) + ) + + self._dbus._bus.add_message_handler(self._message_handler) + return self + + async def wait_for_signal(self) -> Message: + """Wait for signal and returns signal payload.""" + msg = await self._messages.get() + return msg.body + + async def __aexit__(self, exc_t, exc_v, exc_tb): + """Stop collecting signal messages and remove match for signals.""" + + self._dbus._bus.remove_message_handler(self._message_handler) + + await self._dbus._bus.call( + Message( + destination="org.freedesktop.DBus", + interface="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + member="RemoveMatch", + signature="s", + body=[self._match], + ) + ) diff --git a/tests/conftest.py b/tests/conftest.py index b6894e9a2..751ba43ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,10 +80,19 @@ def dbus() -> DBus: return load_json_fixture(f"{fixture}.json") - async def mock_wait_signal(_, signal_method, ___): - if signal_method == DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED: + async def mock_wait_for_signal(self): + if ( + self._interface + "." + self._method + == DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED + ): return [2, 0] + async def mock_signal___aenter__(self): + return self + + async def mock_signal___aexit__(self, exc_t, exc_v, exc_tb): + pass + async def mock_init_proxy(self): filetype = "xml" @@ -108,14 +117,19 @@ def dbus() -> DBus: return load_json_fixture(f"{fixture}.json") with patch("supervisor.utils.dbus.DBus.call_dbus", new=mock_call_dbus), patch( - "supervisor.utils.dbus.DBus.wait_signal", new=mock_wait_signal - ), patch( "supervisor.dbus.interface.DBusInterface.is_connected", return_value=True, ), patch( "supervisor.utils.dbus.DBus.get_properties", new=mock_get_properties ), patch( "supervisor.utils.dbus.DBus._init_proxy", new=mock_init_proxy + ), patch( + "supervisor.utils.dbus.DBusSignalWrapper.__aenter__", new=mock_signal___aenter__ + ), patch( + "supervisor.utils.dbus.DBusSignalWrapper.__aexit__", new=mock_signal___aexit__ + ), patch( + "supervisor.utils.dbus.DBusSignalWrapper.wait_for_signal", + new=mock_wait_for_signal, ): yield dbus_commands