mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Switch dhcp to use async sniff for faster shutdown (#45339)
* Switch dhcp to use async sniff for faster shutdown * Do not actually start the thread since we do not know when it will finish starting
This commit is contained in:
parent
a9a0f8938f
commit
3ae527c158
@ -7,10 +7,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
from scapy.config import conf
|
||||||
|
from scapy.data import ETH_P_ALL
|
||||||
from scapy.error import Scapy_Exception
|
from scapy.error import Scapy_Exception
|
||||||
from scapy.layers.dhcp import DHCP
|
from scapy.layers.dhcp import DHCP
|
||||||
from scapy.layers.l2 import Ether
|
from scapy.layers.l2 import Ether
|
||||||
from scapy.sendrecv import sniff
|
from scapy.sendrecv import AsyncSniffer
|
||||||
|
|
||||||
from homeassistant.components.device_tracker.const import (
|
from homeassistant.components.device_tracker.const import (
|
||||||
ATTR_HOST_NAME,
|
ATTR_HOST_NAME,
|
||||||
@ -54,15 +56,12 @@ async def async_setup(hass: HomeAssistant, config: dict) -> bool:
|
|||||||
|
|
||||||
for cls in (DHCPWatcher, DeviceTrackerWatcher):
|
for cls in (DHCPWatcher, DeviceTrackerWatcher):
|
||||||
watcher = cls(hass, address_data, integration_matchers)
|
watcher = cls(hass, address_data, integration_matchers)
|
||||||
watcher.async_start()
|
await watcher.async_start()
|
||||||
watchers.append(watcher)
|
watchers.append(watcher)
|
||||||
|
|
||||||
async def _async_stop(*_):
|
async def _async_stop(*_):
|
||||||
for watcher in watchers:
|
for watcher in watchers:
|
||||||
if hasattr(watcher, "async_stop"):
|
await watcher.async_stop()
|
||||||
watcher.async_stop()
|
|
||||||
else:
|
|
||||||
await hass.async_add_executor_job(watcher.stop)
|
|
||||||
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop)
|
||||||
|
|
||||||
@ -144,15 +143,13 @@ class DeviceTrackerWatcher(WatcherBase):
|
|||||||
super().__init__(hass, address_data, integration_matchers)
|
super().__init__(hass, address_data, integration_matchers)
|
||||||
self._unsub = None
|
self._unsub = None
|
||||||
|
|
||||||
@callback
|
async def async_stop(self):
|
||||||
def async_stop(self):
|
|
||||||
"""Stop watching for new device trackers."""
|
"""Stop watching for new device trackers."""
|
||||||
if self._unsub:
|
if self._unsub:
|
||||||
self._unsub()
|
self._unsub()
|
||||||
self._unsub = None
|
self._unsub = None
|
||||||
|
|
||||||
@callback
|
async def async_start(self):
|
||||||
def async_start(self):
|
|
||||||
"""Stop watching for new device trackers."""
|
"""Stop watching for new device trackers."""
|
||||||
self._unsub = async_track_state_added_domain(
|
self._unsub = async_track_state_added_domain(
|
||||||
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
|
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
|
||||||
@ -190,33 +187,35 @@ class DeviceTrackerWatcher(WatcherBase):
|
|||||||
self.hass.async_create_task(task)
|
self.hass.async_create_task(task)
|
||||||
|
|
||||||
|
|
||||||
class DHCPWatcher(WatcherBase, threading.Thread):
|
class DHCPWatcher(WatcherBase):
|
||||||
"""Class to watch dhcp requests."""
|
"""Class to watch dhcp requests."""
|
||||||
|
|
||||||
def __init__(self, hass, address_data, integration_matchers):
|
def __init__(self, hass, address_data, integration_matchers):
|
||||||
"""Initialize class."""
|
"""Initialize class."""
|
||||||
super().__init__(hass, address_data, integration_matchers)
|
super().__init__(hass, address_data, integration_matchers)
|
||||||
self.name = "dhcp-discovery"
|
self._sniffer = None
|
||||||
self._stop_event = threading.Event()
|
self._started = threading.Event()
|
||||||
|
|
||||||
def stop(self):
|
async def async_stop(self):
|
||||||
|
"""Stop watching for new device trackers."""
|
||||||
|
await self.hass.async_add_executor_job(self._stop)
|
||||||
|
|
||||||
|
def _stop(self):
|
||||||
"""Stop the thread."""
|
"""Stop the thread."""
|
||||||
self._stop_event.set()
|
if self._started.is_set():
|
||||||
self.join()
|
self._sniffer.stop()
|
||||||
|
|
||||||
@callback
|
async def async_start(self):
|
||||||
def async_start(self):
|
|
||||||
"""Start the thread."""
|
|
||||||
self.start()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""Start watching for dhcp packets."""
|
"""Start watching for dhcp packets."""
|
||||||
try:
|
try:
|
||||||
sniff(
|
sniff_socket = conf.L2socket(type=ETH_P_ALL)
|
||||||
|
self._sniffer = AsyncSniffer(
|
||||||
filter=FILTER,
|
filter=FILTER,
|
||||||
|
opened_socket=[sniff_socket],
|
||||||
|
started_callback=self._started.set,
|
||||||
prn=self.handle_dhcp_packet,
|
prn=self.handle_dhcp_packet,
|
||||||
stop_filter=lambda _: self._stop_event.is_set(),
|
|
||||||
)
|
)
|
||||||
|
self._sniffer.start()
|
||||||
except (Scapy_Exception, OSError) as ex:
|
except (Scapy_Exception, OSError) as ex:
|
||||||
if os.geteuid() == 0:
|
if os.geteuid() == 0:
|
||||||
_LOGGER.error("Cannot watch for dhcp packets: %s", ex)
|
_LOGGER.error("Cannot watch for dhcp packets: %s", ex)
|
||||||
|
@ -280,18 +280,14 @@ async def test_setup_and_stop(hass):
|
|||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
wait_event = threading.Event()
|
with patch("homeassistant.components.dhcp.AsyncSniffer.start") as start_call:
|
||||||
|
|
||||||
def _sniff_wait():
|
|
||||||
wait_event.wait()
|
|
||||||
|
|
||||||
with patch("homeassistant.components.dhcp.sniff", _sniff_wait):
|
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
wait_event.set()
|
|
||||||
|
start_call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_fails_as_root(hass, caplog):
|
async def test_setup_fails_as_root(hass, caplog):
|
||||||
@ -307,7 +303,7 @@ async def test_setup_fails_as_root(hass, caplog):
|
|||||||
wait_event = threading.Event()
|
wait_event = threading.Event()
|
||||||
|
|
||||||
with patch("os.geteuid", return_value=0), patch(
|
with patch("os.geteuid", return_value=0), patch(
|
||||||
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception
|
"homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
|
||||||
):
|
):
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
@ -331,7 +327,7 @@ async def test_setup_fails_non_root(hass, caplog):
|
|||||||
wait_event = threading.Event()
|
wait_event = threading.Event()
|
||||||
|
|
||||||
with patch("os.geteuid", return_value=10), patch(
|
with patch("os.geteuid", return_value=10), patch(
|
||||||
"homeassistant.components.dhcp.sniff", side_effect=Scapy_Exception
|
"homeassistant.components.dhcp.AsyncSniffer.start", side_effect=Scapy_Exception
|
||||||
):
|
):
|
||||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
@ -363,9 +359,9 @@ async def test_device_tracker_hostname_and_macaddress_exists_before_start(hass):
|
|||||||
{},
|
{},
|
||||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||||
)
|
)
|
||||||
device_tracker_watcher.async_start()
|
await device_tracker_watcher.async_start()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
device_tracker_watcher.async_stop()
|
await device_tracker_watcher.async_stop()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(mock_init.mock_calls) == 1
|
assert len(mock_init.mock_calls) == 1
|
||||||
@ -389,7 +385,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
|
|||||||
{},
|
{},
|
||||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||||
)
|
)
|
||||||
device_tracker_watcher.async_start()
|
await device_tracker_watcher.async_start()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
hass.states.async_set(
|
hass.states.async_set(
|
||||||
"device_tracker.august_connect",
|
"device_tracker.august_connect",
|
||||||
@ -402,7 +398,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start(hass):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
device_tracker_watcher.async_stop()
|
await device_tracker_watcher.async_stop()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(mock_init.mock_calls) == 1
|
assert len(mock_init.mock_calls) == 1
|
||||||
@ -426,7 +422,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
|
|||||||
{},
|
{},
|
||||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||||
)
|
)
|
||||||
device_tracker_watcher.async_start()
|
await device_tracker_watcher.async_start()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
hass.states.async_set(
|
hass.states.async_set(
|
||||||
"device_tracker.august_connect",
|
"device_tracker.august_connect",
|
||||||
@ -439,7 +435,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_home(hass)
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
device_tracker_watcher.async_stop()
|
await device_tracker_watcher.async_stop()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(mock_init.mock_calls) == 0
|
assert len(mock_init.mock_calls) == 0
|
||||||
@ -456,7 +452,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
|
|||||||
{},
|
{},
|
||||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||||
)
|
)
|
||||||
device_tracker_watcher.async_start()
|
await device_tracker_watcher.async_start()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
hass.states.async_set(
|
hass.states.async_set(
|
||||||
"device_tracker.august_connect",
|
"device_tracker.august_connect",
|
||||||
@ -469,7 +465,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_not_router(has
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
device_tracker_watcher.async_stop()
|
await device_tracker_watcher.async_stop()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(mock_init.mock_calls) == 0
|
assert len(mock_init.mock_calls) == 0
|
||||||
@ -488,7 +484,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
|
|||||||
{},
|
{},
|
||||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||||
)
|
)
|
||||||
device_tracker_watcher.async_start()
|
await device_tracker_watcher.async_start()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
hass.states.async_set(
|
hass.states.async_set(
|
||||||
"device_tracker.august_connect",
|
"device_tracker.august_connect",
|
||||||
@ -500,7 +496,7 @@ async def test_device_tracker_hostname_and_macaddress_after_start_hostname_missi
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
device_tracker_watcher.async_stop()
|
await device_tracker_watcher.async_stop()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(mock_init.mock_calls) == 0
|
assert len(mock_init.mock_calls) == 0
|
||||||
@ -527,9 +523,9 @@ async def test_device_tracker_ignore_self_assigned_ips_before_start(hass):
|
|||||||
{},
|
{},
|
||||||
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
[{"domain": "mock-domain", "hostname": "connect", "macaddress": "B8B7F1*"}],
|
||||||
)
|
)
|
||||||
device_tracker_watcher.async_start()
|
await device_tracker_watcher.async_start()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
device_tracker_watcher.async_stop()
|
await device_tracker_watcher.async_stop()
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert len(mock_init.mock_calls) == 0
|
assert len(mock_init.mock_calls) == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user