mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Switch dhcp to use async sniff for faster shutdown (#45339)
* Switch dhcp to use async sniff for faster shutdown * Do not actually start the thread since we do not know when it will finish starting
This commit is contained in:
parent
a9a0f8938f
commit
3ae527c158
@ -7,10 +7,12 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from scapy.config import conf
|
||||
from scapy.data import ETH_P_ALL
|
||||
from scapy.error import Scapy_Exception
|
||||
from scapy.layers.dhcp import DHCP
|
||||
from scapy.layers.l2 import Ether
|
||||
from scapy.sendrecv import sniff
|
||||
from scapy.sendrecv import AsyncSniffer
|
||||
|
||||
from homeassistant.components.device_tracker.const import (
|
||||
ATTR_HOST_NAME,
|
||||
@ -54,15 +56,12 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
|
||||
|
||||
for cls in (DHCPWatcher, DeviceTrackerWatcher):
|
||||
watcher = cls(hass, address_data, integration_matchers)
|
||||
watcher.async_start()
|
||||
await watcher.async_start()
|
||||
watchers.append(watcher)
|
||||
|
||||
async def _async_stop(*_):
|
||||
for watcher in watchers:
|
||||
if hasattr(watcher, "async_stop"):
|
||||
watcher.async_stop()
|
||||
else:
|
||||
await hass.async_add_executor_job(watcher.stop)
|
||||
await watcher.async_stop()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
|
||||
|
||||
@ -144,15 +143,13 @@ class DeviceTrackerWatcher(WatcherBase):
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self._unsub = None
|
||||
|
||||
@callback
|
||||
def async_stop(self):
|
||||
async def async_stop(self):
|
||||
"""Stop watching for new device trackers."""
|
||||
if self._unsub:
|
||||
self._unsub()
|
||||
self._unsub = None
|
||||
|
||||
@callback
|
||||
def async_start(self):
|
||||
async def async_start(self):
|
||||
"""Stop watching for new device trackers."""
|
||||
self._unsub = async_track_state_added_domain(
|
||||
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
|
||||
@ -190,33 +187,35 @@ class DeviceTrackerWatcher(WatcherBase):
|
||||
self.hass.async_create_task(task)
|
||||
|
||||
|
||||
class DHCPWatcher(WatcherBase, threading.Thread):
|
||||
class DHCPWatcher(WatcherBase):
|
||||
"""Class to watch dhcp requests."""
|
||||
|
||||
def __init__(self, hass, address_data, integration_matchers):
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
self.name = "dhcp-discovery"
|
||||
self._stop_event = threading.Event()
|
||||
self._sniffer = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def stop(self):
|
||||
async def async_stop(self):
|
||||
"""Stop watching for new device trackers."""
|
||||
await self.hass.async_add_executor_job(self._stop)
|
||||
|
||||
def _stop(self):
|
||||
"""Stop the thread."""
|
||||
self._stop_event.set()
|
||||
self.join()
|
||||
if self._started.is_set():
|
||||
self._sniffer.stop()
|
||||
|
||||
@callback
|
||||
def async_start(self):
|
||||
"""Start the thread."""
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
async def async_start(self):
|
||||
"""Start watching for dhcp packets."""
|
||||
try:
|
||||
sniff(
|
||||
sniff_socket = conf.L2socket(type=ETH_P_ALL)
|
||||
self._sniffer = AsyncSniffer(
|
||||
filter=FILTER,
|
||||
opened_socket=[sniff_socket],
|
||||
started_callback=self._started.set,
|
||||
prn=self.handle_dhcp_packet,
|
||||
stop_filter=lambda _: self._stop_event.is_set(),
|
||||
)
|
||||
self._sniffer.start()
|
||||
except (Scapy_Exception, OSError) as ex:
|
||||
if os.geteuid() == 0:
|
||||
_LOGGER.error("Cannot watch for dhcp packets: %s", ex)
|
||||
|
@ -280,18 +280,14 @@ async def test_setup_and_stop(hass):
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
wait_event = threading.Event()
|
||||
|
||||
def _sniff_wait():
|
||||
wait_event.wait()
|
||||
|
||||
with patch("homeassistant.components.dhcp.sniff", _sniff_wait):
|
||||
with patch("homeassistant.components.dhcp.AsyncSniffer.start") as start_call:
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
wait_event.set()
|
||||
|
||||
start_call.assert_called_once()
|
||||
|
||||
|
||||
async def test_setup_fails_as_root(hass, caplog):
|
||||
@ -307,7 +303,7 @@ async def test_setup_fails_as_root(hass, caplog):
|
||||
wait_event = threading.Event()
|
||||
|
||||
with patch("os.geteuid", return_value=0), patch(
|
||||
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception
|
||||
"homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
|
||||
):
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
@ -331,7 +327,7 @@ async def test_setup_fails_non_root(hass, caplog):
|
||||
wait_event = threading.Event()
|
||||
|
||||
with patch("os.geteuid", return_value=10), patch(
|
||||
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception
|
||||
"homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
|
||||
):
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
@ -363,9 +359,9 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
device_tracker_watcher.async_start()
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
device_tracker_watcher.async_stop()
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
@ -389,7 +385,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
device_tracker_watcher.async_start()
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set(
|
||||
"device_tracker.august_connect",
|
||||
@ -402,7 +398,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
device_tracker_watcher.async_stop()
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 1
|
||||
@ -426,7 +422,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
device_tracker_watcher.async_start()
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set(
|
||||
"device_tracker.august_connect",
|
||||
@ -439,7 +435,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
device_tracker_watcher.async_stop()
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
@ -456,7 +452,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
device_tracker_watcher.async_start()
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set(
|
||||
"device_tracker.august_connect",
|
||||
@ -469,7 +465,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
device_tracker_watcher.async_stop()
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
@ -488,7 +484,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
device_tracker_watcher.async_start()
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
hass.states.async_set(
|
||||
"device_tracker.august_connect",
|
||||
@ -500,7 +496,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
device_tracker_watcher.async_stop()
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
@ -527,9 +523,9 @@ async def test_device_tracker_ignore_self_assigned_ips_before_start(hass):
|
||||
{},
|
||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||
)
|
||||
device_tracker_watcher.async_start()
|
||||
await device_tracker_watcher.async_start()
|
||||
await hass.async_block_till_done()
|
||||
device_tracker_watcher.async_stop()
|
||||
await device_tracker_watcher.async_stop()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_init.mock_calls) == 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user