diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index d926e6f7925..1933aeafbad 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -7,10 +7,12 @@ import logging import os import threading +from scapy.config import conf +from scapy.data import ETH_P_ALL from scapy.error import Scapy_Exception from scapy.layers.dhcp import DHCP from scapy.layers.l2 import Ether -from scapy.sendrecv import sniff +from scapy.sendrecv import AsyncSniffer from homeassistant.components.device_tracker.const import ( ATTR_HOST_NAME, @@ -54,15 +56,12 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool: for cls in (DHCPWatcher, DeviceTrackerWatcher): watcher = cls(hass, address_data, integration_matchers) - watcher.async_start() + await watcher.async_start() watchers.append(watcher) async def _async_stop(*_): for watcher in watchers: - if hasattr(watcher, "async_stop"): - watcher.async_stop() - else: - await hass.async_add_executor_job(watcher.stop) + await watcher.async_stop() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop) @@ -144,15 +143,13 @@ class DeviceTrackerWatcher(WatcherBase): super().__init__(hass, address_data, integration_matchers) self._unsub = None - @callback - def async_stop(self): + async def async_stop(self): """Stop watching for new device trackers.""" if self._unsub: self._unsub() self._unsub = None - @callback - def async_start(self): + async def async_start(self): """Stop watching for new device trackers.""" self._unsub = async_track_state_added_domain( self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event @@ -190,33 +187,35 @@ class DeviceTrackerWatcher(WatcherBase): self.hass.async_create_task(task) -class DHCPWatcher(WatcherBase, threading.Thread): +class DHCPWatcher(WatcherBase): """Class to watch dhcp requests.""" def __init__(self, hass, address_data, integration_matchers): """Initialize class.""" super().__init__(hass, address_data, integration_matchers) - self.name = "dhcp-discovery" - self._stop_event = threading.Event() + self._sniffer = None + self._started = threading.Event() - def stop(self): + async def async_stop(self): + """Stop watching for new device trackers.""" + await self.hass.async_add_executor_job(self._stop) + + def _stop(self): """Stop the thread.""" - self._stop_event.set() - self.join() + if self._started.is_set(): + self._sniffer.stop() - @callback - def async_start(self): - """Start the thread.""" - self.start() - - def run(self): + async def async_start(self): """Start watching for dhcp packets.""" try: - sniff( + sniff_socket = conf.L2socket(type=ETH_P_ALL) + self._sniffer = AsyncSniffer( filter=FILTER, + opened_socket=[sniff_socket], + started_callback=self._started.set, prn=self.handle_dhcp_packet, - stop_filter=lambda _: self._stop_event.is_set(), ) + self._sniffer.start() except (Scapy_Exception, OSError) as ex: if os.geteuid() == 0: _LOGGER.error("Cannot watch for dhcp packets: %s", ex) diff --git a/tests/components/dhcp/test_init.py b/tests/components/dhcp/test_init.py index e1c2d988096..eda54ea08d4 100644 --- a/tests/components/dhcp/test_init.py +++ b/tests/components/dhcp/test_init.py @@ -280,18 +280,14 @@ async def test_setup_and_stop(hass): ) await hass.async_block_till_done() - wait_event = threading.Event() - - def _sniff_wait(): - wait_event.wait() - - with patch("homeassistant.components.dhcp.sniff", _sniff_wait): + with patch("homeassistant.components.dhcp.AsyncSniffer.start") as start_call: hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) await hass.async_block_till_done() hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) await hass.async_block_till_done() - wait_event.set() + + start_call.assert_called_once() async def test_setup_fails_as_root(hass, caplog): @@ -307,7 +303,7 @@ async def test_setup_fails_as_root(hass, caplog): wait_event = threading.Event() with patch("os.geteuid", return_value=0), patch( - "homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception + "homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception ): hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) await hass.async_block_till_done() @@ -331,7 +327,7 @@ async def test_setup_fails_non_root(hass, caplog): wait_event = threading.Event() with patch("os.geteuid", return_value=10), patch( - "homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception + "homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception ): hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) await hass.async_block_till_done() @@ -363,9 +359,9 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass): {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) - device_tracker_watcher.async_start() + await device_tracker_watcher.async_start() await hass.async_block_till_done() - device_tracker_watcher.async_stop() + await device_tracker_watcher.async_stop() await hass.async_block_till_done() assert len(mock_init.mock_calls) == 1 @@ -389,7 +385,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass): {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) - device_tracker_watcher.async_start() + await device_tracker_watcher.async_start() await hass.async_block_till_done() hass.states.async_set( "device_tracker.august_connect", @@ -402,7 +398,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass): }, ) await hass.async_block_till_done() - device_tracker_watcher.async_stop() + await device_tracker_watcher.async_stop() await hass.async_block_till_done() assert len(mock_init.mock_calls) == 1 @@ -426,7 +422,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass) {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) - device_tracker_watcher.async_start() + await device_tracker_watcher.async_start() await hass.async_block_till_done() hass.states.async_set( "device_tracker.august_connect", @@ -439,7 +435,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass) }, ) await hass.async_block_till_done() - device_tracker_watcher.async_stop() + await device_tracker_watcher.async_stop() await hass.async_block_till_done() assert len(mock_init.mock_calls) == 0 @@ -456,7 +452,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) - device_tracker_watcher.async_start() + await device_tracker_watcher.async_start() await hass.async_block_till_done() hass.states.async_set( "device_tracker.august_connect", @@ -469,7 +465,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has }, ) await hass.async_block_till_done() - device_tracker_watcher.async_stop() + await device_tracker_watcher.async_stop() await hass.async_block_till_done() assert len(mock_init.mock_calls) == 0 @@ -488,7 +484,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) - device_tracker_watcher.async_start() + await device_tracker_watcher.async_start() await hass.async_block_till_done() hass.states.async_set( "device_tracker.august_connect", @@ -500,7 +496,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi }, ) await hass.async_block_till_done() - device_tracker_watcher.async_stop() + await device_tracker_watcher.async_stop() await hass.async_block_till_done() assert len(mock_init.mock_calls) == 0 @@ -527,9 +523,9 @@ async def test_device_tracker_ignore_self_assigned_ips_before_start(hass): {}, [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], ) - device_tracker_watcher.async_start() + await device_tracker_watcher.async_start() await hass.async_block_till_done() - device_tracker_watcher.async_stop() + await device_tracker_watcher.async_stop() await hass.async_block_till_done() assert len(mock_init.mock_calls) == 0