Async migration device_tracker (#4406)

* Async migration device_tracker

* change location stuff to async

* address paulus comments

* fix lint & add async discovery listener

* address paulus comments v2

* fix tests

* fix test_mqtt

* fix test_init

* fix gps_acc

* fix lint

* change async_update_stale to callback
This commit is contained in:
Pascal Vizeli 2016-11-18 23:35:08 +01:00 committed by GitHub
parent 265232af98
commit c56f99baaf
7 changed files with 233 additions and 120 deletions

View File

@ -8,13 +8,13 @@ import asyncio
from datetime import timedelta
import logging
import os
import threading
from typing import Any, Sequence, Callable
import voluptuous as vol
from homeassistant.bootstrap import (
prepare_setup_platform, log_exception)
async_prepare_setup_platform, async_log_exception)
from homeassistant.core import callback
from homeassistant.components import group, zone
from homeassistant.components.discovery import SERVICE_NETGEAR
from homeassistant.config import load_yaml_config_file
@ -28,7 +28,7 @@ from homeassistant.util.async import run_coroutine_threadsafe
import homeassistant.util.dt as dt_util
from homeassistant.util.yaml import dump
from homeassistant.helpers.event import track_utc_time_change
from homeassistant.helpers.event import async_track_utc_time_change
from homeassistant.const import (
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
DEVICE_DEFAULT_NAME, STATE_HOME, STATE_NOT_HOME, ATTR_ENTITY_ID)
@ -106,14 +106,15 @@ def see(hass: HomeAssistantType, mac: str=None, dev_id: str=None,
hass.services.call(DOMAIN, SERVICE_SEE, data)
def setup(hass: HomeAssistantType, config: ConfigType):
@asyncio.coroutine
def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Setup device tracker."""
yaml_path = hass.config.path(YAML_DEVICES)
try:
conf = config.get(DOMAIN, [])
except vol.Invalid as ex:
log_exception(ex, DOMAIN, config, hass)
async_log_exception(ex, DOMAIN, config, hass)
return False
else:
conf = conf[0] if len(conf) > 0 else {}
@ -121,60 +122,77 @@ def setup(hass: HomeAssistantType, config: ConfigType):
timedelta(seconds=DEFAULT_CONSIDER_HOME))
track_new = conf.get(CONF_TRACK_NEW, DEFAULT_TRACK_NEW)
devices = load_config(yaml_path, hass, consider_home)
devices = yield from async_load_config(yaml_path, hass, consider_home)
tracker = DeviceTracker(hass, consider_home, track_new, devices)
def setup_platform(p_type, p_config, disc_info=None):
# update tracked devices
update_tasks = [device.async_update_ha_state() for device in devices
if device.track]
if update_tasks:
yield from asyncio.wait(update_tasks, loop=hass.loop)
@asyncio.coroutine
def async_setup_platform(p_type, p_config, disc_info=None):
"""Setup a device tracker platform."""
platform = prepare_setup_platform(hass, config, DOMAIN, p_type)
platform = yield from async_prepare_setup_platform(
hass, config, DOMAIN, p_type)
if platform is None:
return
try:
if hasattr(platform, 'get_scanner'):
scanner = platform.get_scanner(hass, {DOMAIN: p_config})
scanner = yield from hass.loop.run_in_executor(
None, platform.get_scanner, hass, {DOMAIN: p_config})
if scanner is None:
_LOGGER.error('Error setting up platform %s', p_type)
return
setup_scanner_platform(hass, p_config, scanner, tracker.see)
yield from async_setup_scanner_platform(
hass, p_config, scanner, tracker.async_see)
return
if not platform.setup_scanner(hass, p_config, tracker.see):
ret = yield from hass.loop.run_in_executor(
None, platform.setup_scanner, hass, p_config, tracker.see)
if not ret:
_LOGGER.error('Error setting up platform %s', p_type)
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error setting up platform %s', p_type)
for p_type, p_config in config_per_platform(config, DOMAIN):
setup_platform(p_type, p_config)
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
in config_per_platform(config, DOMAIN)]
if setup_tasks:
yield from asyncio.wait(setup_tasks, loop=hass.loop)
def device_tracker_discovered(service, info):
yield from tracker.async_setup_group()
@callback
def async_device_tracker_discovered(service, info):
"""Called when a device tracker platform is discovered."""
setup_platform(DISCOVERY_PLATFORMS[service], {}, info)
hass.async_add_job(
async_setup_platform(DISCOVERY_PLATFORMS[service], {}, info))
discovery.listen(hass, DISCOVERY_PLATFORMS.keys(),
device_tracker_discovered)
discovery.async_listen(
hass, DISCOVERY_PLATFORMS.keys(), async_device_tracker_discovered)
def update_stale(now):
"""Clean up stale devices."""
tracker.update_stale(now)
track_utc_time_change(hass, update_stale, second=range(0, 60, 5))
# Clean up stale devices
async_track_utc_time_change(
hass, tracker.async_update_stale, second=range(0, 60, 5))
tracker.setup_group()
def see_service(call):
@asyncio.coroutine
def async_see_service(call):
"""Service to see a device."""
args = {key: value for key, value in call.data.items() if key in
(ATTR_MAC, ATTR_DEV_ID, ATTR_HOST_NAME, ATTR_LOCATION_NAME,
ATTR_GPS, ATTR_GPS_ACCURACY, ATTR_BATTERY, ATTR_ATTRIBUTES)}
tracker.see(**args)
yield from tracker.async_see(**args)
descriptions = load_yaml_config_file(
os.path.join(os.path.dirname(__file__), 'services.yaml'))
hass.services.register(DOMAIN, SERVICE_SEE, see_service,
descriptions.get(SERVICE_SEE))
descriptions = yield from hass.loop.run_in_executor(
None, load_yaml_config_file,
os.path.join(os.path.dirname(__file__), 'services.yaml')
)
hass.services.async_register(
DOMAIN, SERVICE_SEE, async_see_service, descriptions.get(SERVICE_SEE))
return True
@ -188,94 +206,116 @@ class DeviceTracker(object):
self.hass = hass
self.devices = {dev.dev_id: dev for dev in devices}
self.mac_to_dev = {dev.mac: dev for dev in devices if dev.mac}
self.consider_home = consider_home
self.track_new = track_new
self.group = None # type: group.Group
self._is_updating = asyncio.Lock(loop=hass.loop)
for dev in devices:
if self.devices[dev.dev_id] is not dev:
_LOGGER.warning('Duplicate device IDs detected %s', dev.dev_id)
if dev.mac and self.mac_to_dev[dev.mac] is not dev:
_LOGGER.warning('Duplicate device MAC addresses detected %s',
dev.mac)
self.consider_home = consider_home
self.track_new = track_new
self.lock = threading.Lock()
for device in devices:
if device.track:
device.update_ha_state()
self.group = None # type: group.Group
def see(self, mac: str=None, dev_id: str=None, host_name: str=None,
location_name: str=None, gps: GPSType=None, gps_accuracy=None,
battery: str=None, attributes: dict=None):
"""Notify the device tracker that you see a device."""
with self.lock:
if mac is None and dev_id is None:
raise HomeAssistantError('Neither mac or device id passed in')
elif mac is not None:
mac = str(mac).upper()
device = self.mac_to_dev.get(mac)
if not device:
dev_id = util.slugify(host_name or '') or util.slugify(mac)
else:
dev_id = cv.slug(str(dev_id).lower())
device = self.devices.get(dev_id)
self.hass.add_job(
self.async_see(mac, dev_id, host_name, location_name, gps,
gps_accuracy, battery, attributes)
)
if device:
device.seen(host_name, location_name, gps, gps_accuracy,
battery, attributes)
if device.track:
device.update_ha_state()
return
@asyncio.coroutine
def async_see(self, mac: str=None, dev_id: str=None, host_name: str=None,
location_name: str=None, gps: GPSType=None,
gps_accuracy=None, battery: str=None, attributes: dict=None):
"""Notify the device tracker that you see a device.
# If no device can be found, create it
dev_id = util.ensure_unique_string(dev_id, self.devices.keys())
device = Device(
self.hass, self.consider_home, self.track_new,
dev_id, mac, (host_name or dev_id).replace('_', ' '))
self.devices[dev_id] = device
if mac is not None:
self.mac_to_dev[mac] = device
device.seen(host_name, location_name, gps, gps_accuracy, battery,
attributes)
This method is a coroutine.
"""
if mac is None and dev_id is None:
raise HomeAssistantError('Neither mac or device id passed in')
elif mac is not None:
mac = str(mac).upper()
device = self.mac_to_dev.get(mac)
if not device:
dev_id = util.slugify(host_name or '') or util.slugify(mac)
else:
dev_id = cv.slug(str(dev_id).lower())
device = self.devices.get(dev_id)
if device:
yield from device.async_seen(host_name, location_name, gps,
gps_accuracy, battery, attributes)
if device.track:
device.update_ha_state()
yield from device.async_update_ha_state()
return
self.hass.bus.fire(EVENT_NEW_DEVICE, {
ATTR_ENTITY_ID: device.entity_id,
ATTR_HOST_NAME: device.host_name,
})
# If no device can be found, create it
dev_id = util.ensure_unique_string(dev_id, self.devices.keys())
device = Device(
self.hass, self.consider_home, self.track_new,
dev_id, mac, (host_name or dev_id).replace('_', ' '))
self.devices[dev_id] = device
if mac is not None:
self.mac_to_dev[mac] = device
# During init, we ignore the group
if self.group is not None:
self.group.update_tracked_entity_ids(
list(self.group.tracking) + [device.entity_id])
update_config(self.hass.config.path(YAML_DEVICES), dev_id, device)
yield from device.async_seen(host_name, location_name, gps,
gps_accuracy, battery, attributes)
def setup_group(self):
"""Initialize group for all tracked devices."""
run_coroutine_threadsafe(
self.async_setup_group(), self.hass.loop).result()
if device.track:
yield from device.async_update_ha_state()
self.hass.bus.async_fire(EVENT_NEW_DEVICE, {
ATTR_ENTITY_ID: device.entity_id,
ATTR_HOST_NAME: device.host_name,
})
# During init, we ignore the group
if self.group is not None:
yield from self.group.async_update_tracked_entity_ids(
list(self.group.tracking) + [device.entity_id])
# update known_devices.yaml
self.hass.async_add_job(
self.async_update_config(self.hass.config.path(YAML_DEVICES),
dev_id, device)
)
@asyncio.coroutine
def async_update_config(self, path, dev_id, device):
"""Add device to YAML configuration file.
This method is a coroutine.
"""
with (yield from self._is_updating):
self.hass.loop.run_in_executor(
None, update_config, self.hass.config.path(YAML_DEVICES),
dev_id, device)
@asyncio.coroutine
def async_setup_group(self):
"""Initialize group for all tracked devices.
This method must be run in the event loop.
This method is a coroutine.
"""
entity_ids = (dev.entity_id for dev in self.devices.values()
if dev.track)
self.group = yield from group.Group.async_create_group(
self.hass, GROUP_NAME_ALL_DEVICES, entity_ids, False)
def update_stale(self, now: dt_util.dt.datetime):
"""Update stale devices."""
with self.lock:
for device in self.devices.values():
if (device.track and device.last_update_home and
device.stale(now)):
device.update_ha_state(True)
@callback
def async_update_stale(self, now: dt_util.dt.datetime):
"""Update stale devices.
This method must be run in the event loop.
"""
for device in self.devices.values():
if (device.track and device.last_update_home) and \
device.stale(now):
self.hass.async_add_job(device.async_update_ha_state(True))
class Device(Entity):
@ -362,9 +402,10 @@ class Device(Entity):
"""If device should be hidden."""
return self.away_hide and self.state != STATE_HOME
def seen(self, host_name: str=None, location_name: str=None,
gps: GPSType=None, gps_accuracy=0, battery: str=None,
attributes: dict=None):
@asyncio.coroutine
def async_seen(self, host_name: str=None, location_name: str=None,
gps: GPSType=None, gps_accuracy=0, battery: str=None,
attributes: dict=None):
"""Mark the device as seen."""
self.last_seen = dt_util.utcnow()
self.host_name = host_name
@ -373,28 +414,38 @@ class Device(Entity):
self.battery = battery
self.attributes = attributes
self.gps = None
if gps is not None:
try:
self.gps = float(gps[0]), float(gps[1])
except (ValueError, TypeError, IndexError):
_LOGGER.warning('Could not parse gps value for %s: %s',
self.dev_id, gps)
self.update()
# pylint: disable=not-an-iterable
yield from self.async_update()
def stale(self, now: dt_util.dt.datetime=None):
"""Return if device state is stale."""
"""Return if device state is stale.
Async friendly.
"""
return self.last_seen and \
(now or dt_util.utcnow()) - self.last_seen > self.consider_home
def update(self):
"""Update state of entity."""
@asyncio.coroutine
def async_update(self):
"""Update state of entity.
This method is a coroutine.
"""
if not self.last_seen:
return
elif self.location_name:
self._state = self.location_name
elif self.gps is not None:
zone_state = zone.active_zone(self.hass, self.gps[0], self.gps[1],
self.gps_accuracy)
zone_state = zone.async_active_zone(
self.hass, self.gps[0], self.gps[1], self.gps_accuracy)
if zone_state is None:
self._state = STATE_NOT_HOME
elif zone_state.entity_id == zone.ENTITY_ID_HOME:
@ -412,6 +463,17 @@ class Device(Entity):
def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
"""Load devices from YAML configuration file."""
return run_coroutine_threadsafe(
async_load_config(path, hass, consider_home), hass.loop).result()
@asyncio.coroutine
def async_load_config(path: str, hass: HomeAssistantType,
consider_home: timedelta):
"""Load devices from YAML configuration file.
This method is a coroutine.
"""
dev_schema = vol.Schema({
vol.Required('name'): cv.string,
vol.Optional('track', default=False): cv.boolean,
@ -426,7 +488,8 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
try:
result = []
try:
devices = load_yaml_config_file(path)
devices = yield from hass.loop.run_in_executor(
None, load_yaml_config_file, path)
except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err))
return []
@ -436,7 +499,7 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
device = dev_schema(device)
device['dev_id'] = cv.slugify(dev_id)
except vol.Invalid as exp:
log_exception(exp, dev_id, devices, hass)
async_log_exception(exp, dev_id, devices, hass)
else:
result.append(Device(hass, **device))
return result
@ -445,9 +508,13 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
return []
def setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
scanner: Any, see_device: Callable):
"""Helper method to connect scanner-based platform to device tracker."""
@asyncio.coroutine
def async_setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
scanner: Any, async_see_device: Callable):
"""Helper method to connect scanner-based platform to device tracker.
This method is a coroutine.
"""
interval = config.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
# Initial scan of each mac we also tell about host name for config
@ -455,18 +522,20 @@ def setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
def device_tracker_scan(now: dt_util.dt.datetime):
"""Called when interval matches."""
for mac in scanner.scan_devices():
found_devices = scanner.scan_devices()
for mac in found_devices:
if mac in seen:
host_name = None
else:
host_name = scanner.get_device_name(mac)
seen.add(mac)
see_device(mac=mac, host_name=host_name)
hass.async_add_job(async_see_device(mac=mac, host_name=host_name))
track_utc_time_change(hass, device_tracker_scan, second=range(0, 60,
interval))
async_track_utc_time_change(
hass, device_tracker_scan, second=range(0, 60, interval))
device_tracker_scan(None)
hass.async_add_job(device_tracker_scan, None)
def update_config(path: str, dev_id: str, device: Device):
@ -484,7 +553,10 @@ def update_config(path: str, dev_id: str, device: Device):
def get_gravatar_for_email(email: str):
"""Return an 80px Gravatar for the given email address."""
"""Return an 80px Gravatar for the given email address.
Async friendly.
"""
import hashlib
url = 'https://www.gravatar.com/avatar/{}.jpg?s=80&d=wavatar'
return url.format(hashlib.md5(email.encode('utf-8').lower()).hexdigest())

View File

@ -14,6 +14,7 @@ from homeassistant.const import (
CONF_LONGITUDE, CONF_ICON)
from homeassistant.helpers import config_per_platform
from homeassistant.helpers.entity import Entity, async_generate_entity_id
from homeassistant.util.async import run_callback_threadsafe
from homeassistant.util.location import distance
import homeassistant.helpers.config_validation as cv
@ -51,9 +52,19 @@ PLATFORM_SCHEMA = vol.Schema({
def active_zone(hass, latitude, longitude, radius=0):
"""Find the active zone for given latitude, longitude."""
return run_callback_threadsafe(
hass.loop, async_active_zone, hass, latitude, longitude, radius
).result()
def async_active_zone(hass, latitude, longitude, radius=0):
"""Find the active zone for given latitude, longitude.
This method must be run in the event loop.
"""
# Sort entity IDs so that we are deterministic if equal distance to 2 zones
zones = (hass.states.get(entity_id) for entity_id
in sorted(hass.states.entity_ids(DOMAIN)))
in sorted(hass.states.async_entity_ids(DOMAIN)))
min_dist = None
closest = None
@ -80,7 +91,10 @@ def active_zone(hass, latitude, longitude, radius=0):
def in_zone(zone, latitude, longitude, radius=0):
"""Test if given latitude, longitude is in given zone."""
"""Test if given latitude, longitude is in given zone.
Async friendly.
"""
zone_dist = distance(
latitude, longitude,
zone.attributes[ATTR_LATITUDE], zone.attributes[ATTR_LONGITUDE])

View File

@ -14,6 +14,16 @@ ATTR_PLATFORM = 'platform'
def listen(hass, service, callback):
"""Setup listener for discovery of specific service.
Service can be a string or a list/tuple.
"""
run_callback_threadsafe(
hass.loop, async_listen, hass, service, callback).result()
@core.callback
def async_listen(hass, service, callback):
"""Setup listener for discovery of specific service.
Service can be a string or a list/tuple.
"""
if isinstance(service, str):
@ -21,12 +31,14 @@ def listen(hass, service, callback):
else:
service = tuple(service)
@core.callback
def discovery_event_listener(event):
"""Listen for discovery events."""
if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service:
callback(event.data[ATTR_SERVICE], event.data.get(ATTR_DISCOVERED))
hass.async_add_job(callback, event.data[ATTR_SERVICE],
event.data.get(ATTR_DISCOVERED))
hass.bus.listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
def discover(hass, service, discovered=None, component=None, hass_config=None):

View File

@ -8,7 +8,10 @@ from homeassistant.util import location as loc_util
def has_location(state: State) -> bool:
"""Test if state contains a valid location."""
"""Test if state contains a valid location.
Async friendly.
"""
return (isinstance(state, State) and
isinstance(state.attributes.get(ATTR_LATITUDE), float) and
isinstance(state.attributes.get(ATTR_LONGITUDE), float))
@ -16,7 +19,10 @@ def has_location(state: State) -> bool:
def closest(latitude: float, longitude: float,
states: Sequence[State]) -> State:
"""Return closest state to point."""
"""Return closest state to point.
Async friendly.
"""
with_location = [state for state in states if has_location(state)]
if not with_location:

View File

@ -51,7 +51,10 @@ def detect_location_info():
def distance(lat1, lon1, lat2, lon2):
"""Calculate the distance in meters between two points."""
"""Calculate the distance in meters between two points.
Async friendly.
"""
return vincenty((lat1, lon1), (lat2, lon2)) * 1000
@ -88,6 +91,8 @@ def vincenty(point1: Tuple[float, float], point2: Tuple[float, float],
Result in kilometers or miles between two points on the surface of a
spheroid.
Async friendly.
"""
# short-circuit coincident points
if point1[0] == point2[0] and point1[1] == point2[1]:

View File

@ -10,6 +10,7 @@ import os
from homeassistant.core import callback
from homeassistant.bootstrap import setup_component
from homeassistant.loader import get_component
from homeassistant.util.async import run_coroutine_threadsafe
import homeassistant.util.dt as dt_util
from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN,
@ -280,7 +281,7 @@ class TestComponentsDeviceTracker(unittest.TestCase):
self.assertSequenceEqual((entity_id,),
state.attributes.get(ATTR_ENTITY_ID))
@patch('homeassistant.components.device_tracker.DeviceTracker.see')
@patch('homeassistant.components.device_tracker.DeviceTracker.async_see')
def test_see_service(self, mock_see):
"""Test the see service with a unicode dev_id and NO MAC."""
self.assertTrue(setup_component(self.hass, device_tracker.DOMAIN,
@ -375,20 +376,22 @@ class TestComponentsDeviceTracker(unittest.TestCase):
# No device id or MAC(not added)
with self.assertRaises(HomeAssistantError):
tracker.see()
run_coroutine_threadsafe(
tracker.async_see(), self.hass.loop).result()
assert mock_warning.call_count == 0
# Ignore gps on invalid GPS (both added & warnings)
tracker.see(mac='mac_1_bad_gps', gps=1)
tracker.see(mac='mac_2_bad_gps', gps=[1])
tracker.see(mac='mac_3_bad_gps', gps='gps')
self.hass.block_till_done()
config = device_tracker.load_config(self.yaml_devices, self.hass,
timedelta(seconds=0))
assert mock_warning.call_count == 3
assert len(config) == 4
@patch('homeassistant.components.device_tracker.log_exception')
@patch('homeassistant.components.device_tracker.async_log_exception')
def test_config_failure(self, mock_ex):
"""Test that the device tracker see failures."""
with assert_setup_component(0, device_tracker.DOMAIN):

View File

@ -37,7 +37,8 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase):
self.assertTrue('qos' in config)
with patch('homeassistant.components.device_tracker.mqtt.'
'setup_scanner', side_effect=mock_setup_scanner) as mock_sp:
'setup_scanner', autospec=True,
side_effect=mock_setup_scanner) as mock_sp:
dev_id = 'paulus'
topic = '/location/paulus'