Switch dhcp to use async sniff for faster shutdown (#45339)

* Switch dhcp to use async sniff for faster shutdown

* Do not actually start the thread since we do not know when it will finish starting
This commit is contained in:
J. Nick Koston 2021-01-19 13:49:49 -06:00 committed by GitHub
parent a9a0f8938f
commit 3ae527c158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 45 deletions

View File

@ -7,10 +7,12 @@ import logging
import os import os
import threading import threading
from scapy.config import conf
from scapy.data import ETH_P_ALL
from scapy.error import Scapy_Exception from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP from scapy.layers.dhcp import DHCP
from scapy.layers.l2 import Ether from scapy.layers.l2 import Ether
from scapy.sendrecv import sniff from scapy.sendrecv import AsyncSniffer
from homeassistant.components.device_tracker.const import ( from homeassistant.components.device_tracker.const import (
ATTR_HOST_NAME, ATTR_HOST_NAME,
@ -54,15 +56,12 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
for cls in (DHCPWatcher, DeviceTrackerWatcher): for cls in (DHCPWatcher, DeviceTrackerWatcher):
watcher = cls(hass, address_data, integration_matchers) watcher = cls(hass, address_data, integration_matchers)
watcher.async_start() await watcher.async_start()
watchers.append(watcher) watchers.append(watcher)
async def _async_stop(*_): async def _async_stop(*_):
for watcher in watchers: for watcher in watchers:
if hasattr(watcher, "async_stop"): await watcher.async_stop()
watcher.async_stop()
else:
await hass.async_add_executor_job(watcher.stop)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
@ -144,15 +143,13 @@ class DeviceTrackerWatcher(WatcherBase):
super().__init__(hass, address_data, integration_matchers) super().__init__(hass, address_data, integration_matchers)
self._unsub = None self._unsub = None
@callback async def async_stop(self):
def async_stop(self):
"""Stop watching for new device trackers.""" """Stop watching for new device trackers."""
if self._unsub: if self._unsub:
self._unsub() self._unsub()
self._unsub = None self._unsub = None
@callback async def async_start(self):
def async_start(self):
"""Stop watching for new device trackers.""" """Stop watching for new device trackers."""
self._unsub = async_track_state_added_domain( self._unsub = async_track_state_added_domain(
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
@ -190,33 +187,35 @@ class DeviceTrackerWatcher(WatcherBase):
self.hass.async_create_task(task) self.hass.async_create_task(task)
class DHCPWatcher(WatcherBase, threading.Thread): class DHCPWatcher(WatcherBase):
"""Class to watch dhcp requests.""" """Class to watch dhcp requests."""
def __init__(self, hass, address_data, integration_matchers): def __init__(self, hass, address_data, integration_matchers):
"""Initialize class.""" """Initialize class."""
super().__init__(hass, address_data, integration_matchers) super().__init__(hass, address_data, integration_matchers)
self.name = "dhcp-discovery" self._sniffer = None
self._stop_event = threading.Event() self._started = threading.Event()
def stop(self): async def async_stop(self):
"""Stop watching for new device trackers."""
await self.hass.async_add_executor_job(self._stop)
def _stop(self):
"""Stop the thread.""" """Stop the thread."""
self._stop_event.set() if self._started.is_set():
self.join() self._sniffer.stop()
@callback async def async_start(self):
def async_start(self):
"""Start the thread."""
self.start()
def run(self):
"""Start watching for dhcp packets.""" """Start watching for dhcp packets."""
try: try:
sniff( sniff_socket = conf.L2socket(type=ETH_P_ALL)
self._sniffer = AsyncSniffer(
filter=FILTER, filter=FILTER,
opened_socket=[sniff_socket],
started_callback=self._started.set,
prn=self.handle_dhcp_packet, prn=self.handle_dhcp_packet,
stop_filter=lambda _: self._stop_event.is_set(),
) )
self._sniffer.start()
except (Scapy_Exception, OSError) as ex: except (Scapy_Exception, OSError) as ex:
if os.geteuid() == 0: if os.geteuid() == 0:
_LOGGER.error("Cannot watch for dhcp packets: %s", ex) _LOGGER.error("Cannot watch for dhcp packets: %s", ex)

View File

@ -280,18 +280,14 @@ async def test_setup_and_stop(hass):
) )
await hass.async_block_till_done() await hass.async_block_till_done()
wait_event = threading.Event() with patch("homeassistant.components.dhcp.AsyncSniffer.start") as start_call:
def _sniff_wait():
wait_event.wait()
with patch("homeassistant.components.dhcp.sniff", _sniff_wait):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done() await hass.async_block_till_done()
wait_event.set()
start_call.assert_called_once()
async def test_setup_fails_as_root(hass, caplog): async def test_setup_fails_as_root(hass, caplog):
@ -307,7 +303,7 @@ async def test_setup_fails_as_root(hass, caplog):
wait_event = threading.Event() wait_event = threading.Event()
with patch("os.geteuid", return_value=0), patch( with patch("os.geteuid", return_value=0), patch(
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception "homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
): ):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -331,7 +327,7 @@ async def test_setup_fails_non_root(hass, caplog):
wait_event = threading.Event() wait_event = threading.Event()
with patch("os.geteuid", return_value=10), patch( with patch("os.geteuid", return_value=10), patch(
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception "homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
): ):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -363,9 +359,9 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
) )
device_tracker_watcher.async_start() await device_tracker_watcher.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
device_tracker_watcher.async_stop() await device_tracker_watcher.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
@ -389,7 +385,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
) )
device_tracker_watcher.async_start() await device_tracker_watcher.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"device_tracker.august_connect", "device_tracker.august_connect",
@ -402,7 +398,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
device_tracker_watcher.async_stop() await device_tracker_watcher.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1 assert len(mock_init.mock_calls) == 1
@ -426,7 +422,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
) )
device_tracker_watcher.async_start() await device_tracker_watcher.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"device_tracker.august_connect", "device_tracker.august_connect",
@ -439,7 +435,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
device_tracker_watcher.async_stop() await device_tracker_watcher.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
@ -456,7 +452,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
) )
device_tracker_watcher.async_start() await device_tracker_watcher.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"device_tracker.august_connect", "device_tracker.august_connect",
@ -469,7 +465,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
device_tracker_watcher.async_stop() await device_tracker_watcher.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
@ -488,7 +484,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
) )
device_tracker_watcher.async_start() await device_tracker_watcher.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
hass.states.async_set( hass.states.async_set(
"device_tracker.august_connect", "device_tracker.august_connect",
@ -500,7 +496,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
device_tracker_watcher.async_stop() await device_tracker_watcher.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0
@ -527,9 +523,9 @@ async def test_device_tracker_ignore_self_assigned_ips_before_start(hass):
{}, {},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}], [{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
) )
device_tracker_watcher.async_start() await device_tracker_watcher.async_start()
await hass.async_block_till_done() await hass.async_block_till_done()
device_tracker_watcher.async_stop() await device_tracker_watcher.async_stop()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0 assert len(mock_init.mock_calls) == 0