Update device_traker for async platforms (#5102)

Async DeviceScanner object, migrate to async, cleanups
This commit is contained in:
Pascal Vizeli 2017-01-02 20:50:42 +01:00 committed by Johann Kellerman
parent 9c6a985c56
commit b2371c6614
26 changed files with 124 additions and 72 deletions

View File

@ -8,7 +8,7 @@ import asyncio
from datetime import timedelta
import logging
import os
from typing import Any, Sequence, Callable
from typing import Any, List, Sequence, Callable
import aiohttp
import async_timeout
@ -142,23 +142,34 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
if platform is None:
return
_LOGGER.info("Setting up %s.%s", DOMAIN, p_type)
try:
if hasattr(platform, 'get_scanner'):
scanner = None
setup = None
if hasattr(platform, 'async_get_scanner'):
scanner = yield from platform.async_get_scanner(
hass, {DOMAIN: p_config})
elif hasattr(platform, 'get_scanner'):
scanner = yield from hass.loop.run_in_executor(
None, platform.get_scanner, hass, {DOMAIN: p_config})
elif hasattr(platform, 'async_setup_scanner'):
setup = yield from platform.setup_scanner(
hass, p_config, tracker.see)
elif hasattr(platform, 'setup_scanner'):
setup = yield from hass.loop.run_in_executor(
None, platform.setup_scanner, hass, p_config, tracker.see)
else:
raise HomeAssistantError("Invalid device_tracker platform.")
if scanner is None:
_LOGGER.error('Error setting up platform %s', p_type)
return
if scanner:
yield from async_setup_scanner_platform(
hass, p_config, scanner, tracker.async_see)
return
ret = yield from hass.loop.run_in_executor(
None, platform.setup_scanner, hass, p_config, tracker.see)
if not ret:
if not setup:
_LOGGER.error('Error setting up platform %s', p_type)
return
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error setting up platform %s', p_type)
@ -526,6 +537,34 @@ class Device(Entity):
yield from resp.release()
class DeviceScanner(object):
"""Device scanner object."""
hass = None # type: HomeAssistantType
def scan_devices(self) -> List[str]:
"""Scan for devices."""
raise NotImplementedError()
def async_scan_devices(self) -> Any:
"""Scan for devices.
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(None, self.scan_devices)
def get_device_name(self, mac: str) -> str:
"""Get device name from mac."""
raise NotImplementedError()
def async_get_device_name(self, mac: str) -> Any:
"""Get device name from mac.
This method must be run in the event loop and returns a coroutine.
"""
return self.hass.loop.run_in_executor(None, self.get_device_name, mac)
def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
"""Load devices from YAML configuration file."""
return run_coroutine_threadsafe(
@ -582,26 +621,28 @@ def async_setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
This method is a coroutine.
"""
interval = config.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
scanner.hass = hass
# Initial scan of each mac we also tell about host name for config
seen = set() # type: Any
def device_tracker_scan(now: dt_util.dt.datetime):
@asyncio.coroutine
def async_device_tracker_scan(now: dt_util.dt.datetime):
"""Called when interval matches."""
found_devices = scanner.scan_devices()
found_devices = yield from scanner.async_scan_devices()
for mac in found_devices:
if mac in seen:
host_name = None
else:
host_name = scanner.get_device_name(mac)
host_name = yield from scanner.async_get_device_name(mac)
seen.add(mac)
hass.add_job(async_see_device(mac=mac, host_name=host_name))
hass.async_add_job(async_see_device(mac=mac, host_name=host_name))
async_track_utc_time_change(
hass, device_tracker_scan, second=range(0, 60, interval))
hass, async_device_tracker_scan, second=range(0, 60, interval))
hass.async_add_job(device_tracker_scan, None)
hass.async_add_job(async_device_tracker_scan, None)
def update_config(path: str, dev_id: str, device: Device):

View File

@ -14,7 +14,8 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import (DOMAIN, PLATFORM_SCHEMA)
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -46,7 +47,7 @@ def get_scanner(hass, config):
Device = namedtuple("Device", ["mac", "ip", "last_update"])
class ActiontecDeviceScanner(object):
class ActiontecDeviceScanner(DeviceScanner):
"""This class queries a an actiontec router for connected devices."""
def __init__(self, config):

View File

@ -12,7 +12,8 @@ from datetime import timedelta
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -42,7 +43,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class ArubaDeviceScanner(object):
class ArubaDeviceScanner(DeviceScanner):
"""This class queries a Aruba Access Point for connected devices."""
def __init__(self, config):

View File

@ -14,7 +14,8 @@ from datetime import timedelta
import voluptuous as vol
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
import homeassistant.helpers.config_validation as cv
@ -97,7 +98,7 @@ def get_scanner(hass, config):
AsusWrtResult = namedtuple('AsusWrtResult', 'neighbors leases arp nvram')
class AsusWrtDeviceScanner(object):
class AsusWrtDeviceScanner(DeviceScanner):
"""This class queries a router running ASUSWRT firmware."""
# Eighth attribute needed for mode (AP mode vs router mode)

View File

@ -11,8 +11,8 @@ import requests
import voluptuous as vol
from homeassistant.components.device_tracker import (PLATFORM_SCHEMA,
ATTR_ATTRIBUTES)
from homeassistant.components.device_tracker import (
PLATFORM_SCHEMA, ATTR_ATTRIBUTES)
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import track_utc_time_change

View File

@ -9,7 +9,7 @@ import logging
from datetime import timedelta
import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import DOMAIN
from homeassistant.components.device_tracker import DOMAIN, DeviceScanner
from homeassistant.util import Throttle
REQUIREMENTS = ['pybbox==0.0.5-alpha']
@ -29,7 +29,7 @@ def get_scanner(hass, config):
Device = namedtuple('Device', ['mac', 'name', 'ip', 'last_update'])
class BboxDeviceScanner(object):
class BboxDeviceScanner(DeviceScanner):
"""This class scans for devices connected to the bbox."""
def __init__(self, config):

View File

@ -5,13 +5,8 @@ from datetime import timedelta
import voluptuous as vol
from homeassistant.helpers.event import track_point_in_utc_time
from homeassistant.components.device_tracker import (
YAML_DEVICES,
CONF_TRACK_NEW,
CONF_SCAN_INTERVAL,
DEFAULT_SCAN_INTERVAL,
PLATFORM_SCHEMA,
load_config,
DEFAULT_TRACK_NEW
YAML_DEVICES, CONF_TRACK_NEW, CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL,
PLATFORM_SCHEMA, load_config, DEFAULT_TRACK_NEW
)
import homeassistant.util as util
import homeassistant.util.dt as dt_util

View File

@ -16,7 +16,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST
from homeassistant.util import Throttle
@ -40,7 +41,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class BTHomeHub5DeviceScanner(object):
class BTHomeHub5DeviceScanner(DeviceScanner):
"""This class queries a BT Home Hub 5."""
def __init__(self, config):

View File

@ -10,7 +10,8 @@ from datetime import timedelta
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, \
CONF_PORT
from homeassistant.util import Throttle
@ -39,7 +40,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class CiscoDeviceScanner(object):
class CiscoDeviceScanner(DeviceScanner):
"""This class queries a wireless router running Cisco IOS firmware."""
def __init__(self, config):

View File

@ -13,7 +13,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -41,7 +42,7 @@ def get_scanner(hass, config):
return None
class DdWrtDeviceScanner(object):
class DdWrtDeviceScanner(DeviceScanner):
"""This class queries a wireless router running DD-WRT firmware."""
def __init__(self, config):

View File

@ -10,7 +10,8 @@ from datetime import timedelta
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -38,7 +39,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class FritzBoxScanner(object):
class FritzBoxScanner(DeviceScanner):
"""This class queries a FRITZ!Box router."""
def __init__(self, config):

View File

@ -12,7 +12,7 @@ import voluptuous as vol
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD
from homeassistant.components.device_tracker import (
PLATFORM_SCHEMA, DOMAIN, ATTR_ATTRIBUTES, ENTITY_ID_FORMAT)
PLATFORM_SCHEMA, DOMAIN, ATTR_ATTRIBUTES, ENTITY_ID_FORMAT, DeviceScanner)
from homeassistant.components.zone import active_zone
from homeassistant.helpers.event import track_utc_time_change
import homeassistant.helpers.config_validation as cv
@ -131,7 +131,7 @@ def setup_scanner(hass, config: dict, see):
return True
class Icloud(object):
class Icloud(DeviceScanner):
"""Represent an icloud account in Home Assistant."""
def __init__(self, hass, username, password, name, see):

View File

@ -8,9 +8,8 @@ import asyncio
from functools import partial
import logging
from homeassistant.const import (ATTR_LATITUDE, ATTR_LONGITUDE,
STATE_NOT_HOME,
HTTP_UNPROCESSABLE_ENTITY)
from homeassistant.const import (
ATTR_LATITUDE, ATTR_LONGITUDE, STATE_NOT_HOME, HTTP_UNPROCESSABLE_ENTITY)
from homeassistant.components.http import HomeAssistantView
# pylint: disable=unused-import
from homeassistant.components.device_tracker import ( # NOQA

View File

@ -14,7 +14,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -37,7 +38,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class LuciDeviceScanner(object):
class LuciDeviceScanner(DeviceScanner):
"""This class queries a wireless router running OpenWrt firmware.
Adapted from Tomato scanner.

View File

@ -11,7 +11,8 @@ from datetime import timedelta
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import (
CONF_HOST, CONF_PASSWORD, CONF_USERNAME, CONF_PORT)
from homeassistant.util import Throttle
@ -47,7 +48,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class NetgearDeviceScanner(object):
class NetgearDeviceScanner(DeviceScanner):
"""Queries a Netgear wireless router using the SOAP-API."""
def __init__(self, host, username, password, port):

View File

@ -14,7 +14,8 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOSTS
from homeassistant.util import Throttle
@ -63,7 +64,7 @@ def _arp(ip_address):
return None
class NmapDeviceScanner(object):
class NmapDeviceScanner(DeviceScanner):
"""This class scans for devices using nmap."""
exclude = []

View File

@ -12,7 +12,8 @@ from datetime import timedelta
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST
from homeassistant.util import Throttle
@ -46,7 +47,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class SnmpScanner(object):
class SnmpScanner(DeviceScanner):
"""Queries any SNMP capable Access Point for connected devices."""
def __init__(self, config):

View File

@ -12,7 +12,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST
from homeassistant.util import Throttle
@ -35,7 +36,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class SwisscomDeviceScanner(object):
class SwisscomDeviceScanner(DeviceScanner):
"""This class queries a router running Swisscom Internet-Box firmware."""
def __init__(self, config):

View File

@ -13,7 +13,8 @@ from datetime import timedelta
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -46,7 +47,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class ThomsonDeviceScanner(object):
class ThomsonDeviceScanner(DeviceScanner):
"""This class queries a router running THOMSON firmware."""
def __init__(self, config):

View File

@ -14,7 +14,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -38,7 +39,7 @@ def get_scanner(hass, config):
return TomatoDeviceScanner(config[DOMAIN])
class TomatoDeviceScanner(object):
class TomatoDeviceScanner(DeviceScanner):
"""This class queries a wireless router running Tomato firmware."""
def __init__(self, config):

View File

@ -15,7 +15,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -42,7 +43,7 @@ def get_scanner(hass, config):
return None
class TplinkDeviceScanner(object):
class TplinkDeviceScanner(DeviceScanner):
"""This class queries a wireless router running TP-Link firmware."""
def __init__(self, config):

View File

@ -14,7 +14,8 @@ import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle
@ -37,7 +38,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None
class UbusDeviceScanner(object):
class UbusDeviceScanner(DeviceScanner):
"""
This class queries a wireless router running OpenWrt firmware.

View File

@ -10,7 +10,8 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv
import homeassistant.loader as loader
from homeassistant.components.device_tracker import DOMAIN, PLATFORM_SCHEMA
from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner)
from homeassistant.const import CONF_HOST, CONF_USERNAME, CONF_PASSWORD
# Unifi package doesn't list urllib3 as a requirement
@ -59,7 +60,7 @@ def get_scanner(hass, config):
return UnifiScanner(ctrl)
class UnifiScanner(object):
class UnifiScanner(DeviceScanner):
"""Provide device_tracker support from Unifi WAP client data."""
def __init__(self, controller):

View File

@ -14,12 +14,9 @@ from homeassistant.helpers.event import track_point_in_utc_time
from homeassistant.util.dt import utcnow
from homeassistant.util import slugify
from homeassistant.const import (
CONF_PASSWORD,
CONF_SCAN_INTERVAL,
CONF_USERNAME)
CONF_PASSWORD, CONF_SCAN_INTERVAL, CONF_USERNAME)
from homeassistant.components.device_tracker import (
DEFAULT_SCAN_INTERVAL,
PLATFORM_SCHEMA)
DEFAULT_SCAN_INTERVAL, PLATFORM_SCHEMA)
MIN_TIME_BETWEEN_SCANS = timedelta(minutes=1)

View File

@ -315,7 +315,8 @@ class TestComponentsDeviceTracker(unittest.TestCase):
scanner = get_component('device_tracker.test').SCANNER
with patch.dict(device_tracker.DISCOVERY_PLATFORMS, {'test': 'test'}):
with patch.object(scanner, 'scan_devices') as mock_scan:
with patch.object(scanner, 'scan_devices',
autospec=True) as mock_scan:
with assert_setup_component(1, device_tracker.DOMAIN):
assert setup_component(
self.hass, device_tracker.DOMAIN, TEST_PLATFORM)

View File

@ -1,12 +1,14 @@
"""Provide a mock device scanner."""
from homeassistant.components.device_tracker import DeviceScanner
def get_scanner(hass, config):
"""Return a mock scanner."""
return SCANNER
class MockScanner(object):
class MockScanner(DeviceScanner):
"""Mock device scanner."""
def __init__(self):