Switch dhcp to use async sniff for faster shutdown (#45339)

* Switch dhcp to use async sniff for faster shutdown

* Do not actually start the thread since we do not know when it will finish starting
This commit is contained in:
J. Nick Koston 2021-01-19 13:49:49 -06:00 committed by GitHub
parent a9a0f8938f
commit 3ae527c158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 45 deletions

View File

@ -7,10 +7,12 @@ import logging
import os
import threading
from scapy.config import conf
from scapy.data import ETH_P_ALL
from scapy.error import Scapy_Exception
from scapy.layers.dhcp import DHCP
from scapy.layers.l2 import Ether
from scapy.sendrecv import sniff
from scapy.sendrecv import AsyncSniffer
from homeassistant.components.device_tracker.const import (
ATTR_HOST_NAME,
@ -54,15 +56,12 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
for cls in (DHCPWatcher, DeviceTrackerWatcher):
watcher = cls(hass, address_data, integration_matchers)
watcher.async_start()
await watcher.async_start()
watchers.append(watcher)
async def _async_stop(*_):
for watcher in watchers:
if hasattr(watcher, "async_stop"):
watcher.async_stop()
else:
await hass.async_add_executor_job(watcher.stop)
await watcher.async_stop()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
@ -144,15 +143,13 @@ class DeviceTrackerWatcher(WatcherBase):
super().__init__(hass, address_data, integration_matchers)
self._unsub = None
@callback
def async_stop(self):
async def async_stop(self):
"""Stop watching for new device trackers."""
if self._unsub:
self._unsub()
self._unsub = None
@callback
def async_start(self):
async def async_start(self):
"""Stop watching for new device trackers."""
self._unsub = async_track_state_added_domain(
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
@ -190,33 +187,35 @@ class DeviceTrackerWatcher(WatcherBase):
self.hass.async_create_task(task)
class DHCPWatcher(WatcherBase, threading.Thread):
class DHCPWatcher(WatcherBase):
"""Class to watch dhcp requests."""
def __init__(self, hass, address_data, integration_matchers):
"""Initialize class."""
super().__init__(hass, address_data, integration_matchers)
self.name = "dhcp-discovery"
self._stop_event = threading.Event()
self._sniffer = None
self._started = threading.Event()
def stop(self):
async def async_stop(self):
"""Stop watching for new device trackers."""
await self.hass.async_add_executor_job(self._stop)
def _stop(self):
"""Stop the thread."""
self._stop_event.set()
self.join()
if self._started.is_set():
self._sniffer.stop()
@callback
def async_start(self):
"""Start the thread."""
self.start()
def run(self):
async def async_start(self):
"""Start watching for dhcp packets."""
try:
sniff(
sniff_socket = conf.L2socket(type=ETH_P_ALL)
self._sniffer = AsyncSniffer(
filter=FILTER,
opened_socket=[sniff_socket],
started_callback=self._started.set,
prn=self.handle_dhcp_packet,
stop_filter=lambda _: self._stop_event.is_set(),
)
self._sniffer.start()
except (Scapy_Exception, OSError) as ex:
if os.geteuid() == 0:
_LOGGER.error("Cannot watch for dhcp packets: %s", ex)

View File

@ -280,18 +280,14 @@ async def test_setup_and_stop(hass):
)
await hass.async_block_till_done()
wait_event = threading.Event()
def _sniff_wait():
wait_event.wait()
with patch("homeassistant.components.dhcp.sniff", _sniff_wait):
with patch("homeassistant.components.dhcp.AsyncSniffer.start") as start_call:
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
wait_event.set()
start_call.assert_called_once()
async def test_setup_fails_as_root(hass, caplog):
@ -307,7 +303,7 @@ async def test_setup_fails_as_root(hass, caplog):
wait_event = threading.Event()
with patch("os.geteuid", return_value=0), patch(
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception
"homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
@ -331,7 +327,7 @@ async def test_setup_fails_non_root(hass, caplog):
wait_event = threading.Event()
with patch("os.geteuid", return_value=10), patch(
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception
"homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
@ -363,9 +359,9 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1
@ -389,7 +385,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
@ -402,7 +398,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 1
@ -426,7 +422,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
@ -439,7 +435,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0
@ -456,7 +452,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
@ -469,7 +465,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0
@ -488,7 +484,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
hass.states.async_set(
"device_tracker.august_connect",
@ -500,7 +496,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
},
)
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0
@ -527,9 +523,9 @@ async def test_device_tracker_ignore_self_assigned_ips_before_start(hass):
{},
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
)
device_tracker_watcher.async_start()
await device_tracker_watcher.async_start()
await hass.async_block_till_done()
device_tracker_watcher.async_stop()
await device_tracker_watcher.async_stop()
await hass.async_block_till_done()
assert len(mock_init.mock_calls) == 0