From 2ed35debdceced266c94477c9edabb360e1e6bdc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 29 Sep 2021 23:50:21 -0500 Subject: [PATCH] Fix dhcp discovery matching due to deferred imports (#56814) --- homeassistant/components/dhcp/__init__.py | 50 +++++----- tests/components/dhcp/test_init.py | 114 ++++++++++++---------- 2 files changed, 83 insertions(+), 81 deletions(-) diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index e6debfea2eb..61208ac6423 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -282,6 +282,9 @@ class DHCPWatcher(WatcherBase): from scapy import ( # pylint: disable=import-outside-toplevel,unused-import # noqa: F401 arch, ) + from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel + from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel + from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel # # Importing scapy.sendrecv will cause a scapy resync which will @@ -294,6 +297,24 @@ class DHCPWatcher(WatcherBase): AsyncSniffer, ) + def _handle_dhcp_packet(packet): + """Process a dhcp packet.""" + if DHCP not in packet: + return + + options = packet[DHCP].options + request_type = _decode_dhcp_option(options, MESSAGE_TYPE) + if request_type != DHCP_REQUEST: + # Not a DHCP request + return + + ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src + hostname = _decode_dhcp_option(options, HOSTNAME) or "" + mac_address = _format_mac(packet[Ether].src) + + if ip_address is not None and mac_address is not None: + self.process_client(ip_address, hostname, mac_address) + # disable scapy promiscuous mode as we do not need it conf.sniff_promisc = 0 @@ -320,7 +341,7 @@ class DHCPWatcher(WatcherBase): self._sniffer = AsyncSniffer( filter=FILTER, started_callback=self._started.set, - prn=self.handle_dhcp_packet, + prn=_handle_dhcp_packet, store=0, ) @@ -328,33 +349,6 @@ class DHCPWatcher(WatcherBase): if self._sniffer.thread: self._sniffer.thread.name = self.__class__.__name__ - def handle_dhcp_packet(self, packet): - """Process a dhcp packet.""" - # Local import because importing from scapy has side effects such as opening - # sockets - from scapy.layers.dhcp import DHCP # pylint: disable=import-outside-toplevel - from scapy.layers.inet import IP # pylint: disable=import-outside-toplevel - from scapy.layers.l2 import Ether # pylint: disable=import-outside-toplevel - - if DHCP not in packet: - return - - options = packet[DHCP].options - - request_type = _decode_dhcp_option(options, MESSAGE_TYPE) - if request_type != DHCP_REQUEST: - # DHCP request - return - - ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src - hostname = _decode_dhcp_option(options, HOSTNAME) or "" - mac_address = _format_mac(packet[Ether].src) - - if ip_address is None or mac_address is None: - return - - self.process_client(ip_address, hostname, mac_address) - def create_task(self, task): """Pass a task to hass.add_job since we are in a thread.""" return self.hass.add_job(task) diff --git a/tests/components/dhcp/test_init.py b/tests/components/dhcp/test_init.py index 90ce1ebbf20..f00a0135e8d 100644 --- a/tests/components/dhcp/test_init.py +++ b/tests/components/dhcp/test_init.py @@ -1,7 +1,7 @@ """Test the DHCP discovery integration.""" import datetime import threading -from unittest.mock import patch +from unittest.mock import MagicMock, patch from scapy.error import Scapy_Exception from scapy.layers.dhcp import DHCP @@ -123,20 +123,39 @@ RAW_DHCP_REQUEST_WITHOUT_HOSTNAME = ( ) -async def test_dhcp_match_hostname_and_macaddress(hass): - """Test matching based on hostname and macaddress.""" +async def _async_get_handle_dhcp_packet(hass, integration_matchers): dhcp_watcher = dhcp.DHCPWatcher( hass, {}, - [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], + integration_matchers, ) + handle_dhcp_packet = None + def _mock_sniffer(*args, **kwargs): + nonlocal handle_dhcp_packet + handle_dhcp_packet = kwargs["prn"] + return MagicMock() + + with patch("homeassistant.components.dhcp._verify_l2socket_setup",), patch( + "scapy.arch.common.compile_filter" + ), patch("scapy.sendrecv.AsyncSniffer", _mock_sniffer): + await dhcp_watcher.async_start() + + return handle_dhcp_packet + + +async def test_dhcp_match_hostname_and_macaddress(hass): + """Test matching based on hostname and macaddress.""" + integration_matchers = [ + {"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"} + ] packet = Ether(RAW_DHCP_REQUEST) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) # Ensure no change is ignored - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -152,18 +171,17 @@ async def test_dhcp_match_hostname_and_macaddress(hass): async def test_dhcp_renewal_match_hostname_and_macaddress(hass): """Test renewal matching based on hostname and macaddress.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, - {}, - [{"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"}], - ) + integration_matchers = [ + {"domain": "mock-domain", "hostname": "irobot-*", "macaddress": "501479*"} + ] packet = Ether(RAW_DHCP_RENEWAL) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) # Ensure no change is ignored - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -179,14 +197,13 @@ async def test_dhcp_renewal_match_hostname_and_macaddress(hass): async def test_dhcp_match_hostname(hass): """Test matching based on hostname only.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "connect"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "connect"}] packet = Ether(RAW_DHCP_REQUEST) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -202,14 +219,13 @@ async def test_dhcp_match_hostname(hass): async def test_dhcp_match_macaddress(hass): """Test matching based on macaddress only.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "macaddress": "B8B7F1*"}] - ) + integration_matchers = [{"domain": "mock-domain", "macaddress": "B8B7F1*"}] packet = Ether(RAW_DHCP_REQUEST) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -225,14 +241,13 @@ async def test_dhcp_match_macaddress(hass): async def test_dhcp_match_macaddress_without_hostname(hass): """Test matching based on macaddress only.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "macaddress": "606BBD*"}] - ) + integration_matchers = [{"domain": "mock-domain", "macaddress": "606BBD*"}] packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -248,51 +263,46 @@ async def test_dhcp_match_macaddress_without_hostname(hass): async def test_dhcp_nomatch(hass): """Test not matching based on macaddress only.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "macaddress": "ABC123*"}] - ) + integration_matchers = [{"domain": "mock-domain", "macaddress": "ABC123*"}] packet = Ether(RAW_DHCP_REQUEST) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 async def test_dhcp_nomatch_hostname(hass): """Test not matching based on hostname only.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}] packet = Ether(RAW_DHCP_REQUEST) + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 async def test_dhcp_nomatch_non_dhcp_packet(hass): """Test matching does not throw on a non-dhcp packet.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}] packet = Ether(b"") + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 async def test_dhcp_nomatch_non_dhcp_request_packet(hass): """Test nothing happens with the wrong message-type.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}] packet = Ether(RAW_DHCP_REQUEST) @@ -305,17 +315,16 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass): ("hostname", b"connect"), ] + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 async def test_dhcp_invalid_hostname(hass): """Test we ignore invalid hostnames.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}] packet = Ether(RAW_DHCP_REQUEST) @@ -328,17 +337,16 @@ async def test_dhcp_invalid_hostname(hass): ("hostname", "connect"), ] + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 async def test_dhcp_missing_hostname(hass): """Test we ignore missing hostnames.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}] packet = Ether(RAW_DHCP_REQUEST) @@ -351,17 +359,16 @@ async def test_dhcp_missing_hostname(hass): ("hostname", None), ] + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 async def test_dhcp_invalid_option(hass): """Test we ignore invalid hostname option.""" - dhcp_watcher = dhcp.DHCPWatcher( - hass, {}, [{"domain": "mock-domain", "hostname": "nomatch*"}] - ) + integration_matchers = [{"domain": "mock-domain", "hostname": "nomatch*"}] packet = Ether(RAW_DHCP_REQUEST) @@ -374,8 +381,9 @@ async def test_dhcp_invalid_option(hass): ("hostname"), ] + handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - dhcp_watcher.handle_dhcp_packet(packet) + handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0