diff --git a/supervisor/dbus/network/connection.py b/supervisor/dbus/network/connection.py index a9fce0584..636945b08 100644 --- a/supervisor/dbus/network/connection.py +++ b/supervisor/dbus/network/connection.py @@ -2,6 +2,8 @@ from ipaddress import ip_address, ip_interface from typing import Optional +from supervisor.dbus.utils import dbus_connected + from ...const import ATTR_ADDRESS, ATTR_PREFIX from ...utils.dbus import DBus from ..const import ( @@ -79,6 +81,11 @@ class NetworkConnection(DBusInterfaceProxy): async def connect(self) -> None: """Get connection information.""" self.dbus = await DBus.connect(DBUS_NAME_NM, self.object_path) + await self.update() + + @dbus_connected + async def update(self): + """Update connection information.""" self.properties = await self.dbus.get_properties(DBUS_IFACE_CONNECTION_ACTIVE) # IPv4 diff --git a/supervisor/host/network.py b/supervisor/host/network.py index 89075d10f..2c9ad8e81 100644 --- a/supervisor/host/network.py +++ b/supervisor/host/network.py @@ -191,25 +191,24 @@ class NetworkManager(CoreSysAttributes): ) if con: - # Only consider activated or deactivated signals, continue waiting on others - def message_filter(msg_body): - state: ConnectionStateType = msg_body[0] - if state == ConnectionStateType.DEACTIVATED: - return True - elif state == ConnectionStateType.ACTIVATED: - return True - return False + async with con.dbus.signal( + DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED + ) as signal: + # From this point we monitor signals. However, it might be that + # the state change before this point. Get the state currently to + # avoid any race condition. + await con.update() + state: ConnectionStateType = con.state - result = await con.dbus.wait_signal( - DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED, message_filter - ) + while state != ConnectionStateType.ACTIVATED: + if state == ConnectionStateType.DEACTIVATED: + raise HostNetworkError( + "Activating connection failed, check connection settings." + ) - _LOGGER.debug("StateChanged signal received, result: %s", str(result)) - state: ConnectionStateType = result[0] - if state != ConnectionStateType.ACTIVATED: - raise HostNetworkError( - "Activating connection failed, check connection settings." - ) + msg = await signal.wait_for_signal() + state = msg[0] + _LOGGER.debug("Active connection state changed to %s", state) await self.update() diff --git a/supervisor/utils/dbus.py b/supervisor/utils/dbus.py index fb38196d4..02597503c 100644 --- a/supervisor/utils/dbus.py +++ b/supervisor/utils/dbus.py @@ -193,72 +193,14 @@ class DBus: _LOGGER.error("No Set attribute %s for %s", name, interface) raise DBusFatalError() from err - async def wait_signal(self, signal_member, message_filter=None) -> Any: + def signal(self, signal_member) -> DBusSignalWrapper: + """Get signal context manager for this object.""" + return DBusSignalWrapper(self, signal_member) + + async def wait_signal(self, signal_member) -> Any: """Wait for signal on this object.""" - signal_parts = signal_member.split(".") - interface = ".".join(signal_parts[:-1]) - member = signal_parts[-1] - match = f"type='signal',interface={interface},member={member},path={self.object_path}" - - _LOGGER.debug("Install match for signal %s", signal_member) - await self._bus.call( - Message( - destination="org.freedesktop.DBus", - interface="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - member="AddMatch", - signature="s", - body=[match], - ) - ) - - loop = asyncio.get_event_loop() - future = loop.create_future() - - def message_handler(msg: Message): - if msg.message_type != MessageType.SIGNAL: - return - - _LOGGER.debug( - "Signal message received %s, %s.%s object %s", - msg.body, - msg.interface, - msg.member, - msg.path, - ) - if ( - msg.interface != interface - or msg.member != member - or msg.path != self.object_path - ): - return - - # Avoid race condition: We already received signal but handler not yet removed. - if future.done(): - return - - msg_body = _remove_dbus_signature(msg.body) - if message_filter and not message_filter(msg_body): - return - - future.set_result(msg_body) - - self._bus.add_message_handler(message_handler) - result = await future - self._bus.remove_message_handler(message_handler) - - await self._bus.call( - Message( - destination="org.freedesktop.DBus", - interface="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - member="RemoveMatch", - signature="s", - body=[match], - ) - ) - - return result + async with self.signal(signal_member) as signal: + return await signal.wait_for_signal() def __getattr__(self, name: str) -> DBusCallWrapper: """Map to dbus method.""" @@ -293,3 +235,75 @@ class DBusCallWrapper: return self.dbus.call_dbus(interface, *args) return _method_wrapper + + +class DBusSignalWrapper: + """Wrapper for D-Bus Signal.""" + + def __init__(self, dbus: DBus, signal_member: str) -> None: + """Initialize wrapper.""" + self._dbus: DBus = dbus + signal_parts = signal_member.split(".") + self._interface = ".".join(signal_parts[:-1]) + self._member = signal_parts[-1] + self._match: str = f"type='signal',interface={self._interface},member={self._member},path={self._dbus.object_path}" + self._messages: asyncio.Queue[Message] = asyncio.Queue() + + def _message_handler(self, msg: Message): + if msg.message_type != MessageType.SIGNAL: + return + + _LOGGER.debug( + "Signal message received %s, %s.%s object %s", + msg.body, + msg.interface, + msg.member, + msg.path, + ) + if ( + msg.interface != self._interface + or msg.member != self._member + or msg.path != self._dbus.object_path + ): + return + + self._messages.put_nowait(msg) + + async def __aenter__(self): + """Install match for signals and start collecting signal messages.""" + + _LOGGER.debug("Install match for signal %s.%s", self._interface, self._member) + await self._dbus._bus.call( + Message( + destination="org.freedesktop.DBus", + interface="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + member="AddMatch", + signature="s", + body=[self._match], + ) + ) + + self._dbus._bus.add_message_handler(self._message_handler) + return self + + async def wait_for_signal(self) -> Message: + """Wait for signal and returns signal payload.""" + msg = await self._messages.get() + return msg.body + + async def __aexit__(self, exc_t, exc_v, exc_tb): + """Stop collecting signal messages and remove match for signals.""" + + self._dbus._bus.remove_message_handler(self._message_handler) + + await self._dbus._bus.call( + Message( + destination="org.freedesktop.DBus", + interface="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + member="RemoveMatch", + signature="s", + body=[self._match], + ) + ) diff --git a/tests/conftest.py b/tests/conftest.py index b6894e9a2..751ba43ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,10 +80,19 @@ def dbus() -> DBus: return load_json_fixture(f"{fixture}.json") - async def mock_wait_signal(_, signal_method, ___): - if signal_method == DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED: + async def mock_wait_for_signal(self): + if ( + self._interface + "." + self._method + == DBUS_SIGNAL_NM_CONNECTION_ACTIVE_CHANGED + ): return [2, 0] + async def mock_signal___aenter__(self): + return self + + async def mock_signal___aexit__(self, exc_t, exc_v, exc_tb): + pass + async def mock_init_proxy(self): filetype = "xml" @@ -108,14 +117,19 @@ def dbus() -> DBus: return load_json_fixture(f"{fixture}.json") with patch("supervisor.utils.dbus.DBus.call_dbus", new=mock_call_dbus), patch( - "supervisor.utils.dbus.DBus.wait_signal", new=mock_wait_signal - ), patch( "supervisor.dbus.interface.DBusInterface.is_connected", return_value=True, ), patch( "supervisor.utils.dbus.DBus.get_properties", new=mock_get_properties ), patch( "supervisor.utils.dbus.DBus._init_proxy", new=mock_init_proxy + ), patch( + "supervisor.utils.dbus.DBusSignalWrapper.__aenter__", new=mock_signal___aenter__ + ), patch( + "supervisor.utils.dbus.DBusSignalWrapper.__aexit__", new=mock_signal___aexit__ + ), patch( + "supervisor.utils.dbus.DBusSignalWrapper.wait_for_signal", + new=mock_wait_for_signal, ): yield dbus_commands