mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Async migration device_tracker (#4406)
* Async migration device_tracker * change location stuff to async * address paulus comments * fix lint & add async discovery listener * address paulus comments v2 * fix tests * fix test_mqtt * fix test_init * fix gps_acc * fix lint * change async_update_stale to callback
This commit is contained in:
parent
265232af98
commit
c56f99baaf
@ -8,13 +8,13 @@ import asyncio
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Sequence, Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.bootstrap import (
|
||||
prepare_setup_platform, log_exception)
|
||||
async_prepare_setup_platform, async_log_exception)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import group, zone
|
||||
from homeassistant.components.discovery import SERVICE_NETGEAR
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
@ -28,7 +28,7 @@ from homeassistant.util.async import run_coroutine_threadsafe
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.yaml import dump
|
||||
|
||||
from homeassistant.helpers.event import track_utc_time_change
|
||||
from homeassistant.helpers.event import async_track_utc_time_change
|
||||
from homeassistant.const import (
|
||||
ATTR_GPS_ACCURACY, ATTR_LATITUDE, ATTR_LONGITUDE,
|
||||
DEVICE_DEFAULT_NAME, STATE_HOME, STATE_NOT_HOME, ATTR_ENTITY_ID)
|
||||
@ -106,14 +106,15 @@ def see(hass: HomeAssistantType, mac: str=None, dev_id: str=None,
|
||||
hass.services.call(DOMAIN, SERVICE_SEE, data)
|
||||
|
||||
|
||||
def setup(hass: HomeAssistantType, config: ConfigType):
|
||||
@asyncio.coroutine
|
||||
def async_setup(hass: HomeAssistantType, config: ConfigType):
|
||||
"""Setup device tracker."""
|
||||
yaml_path = hass.config.path(YAML_DEVICES)
|
||||
|
||||
try:
|
||||
conf = config.get(DOMAIN, [])
|
||||
except vol.Invalid as ex:
|
||||
log_exception(ex, DOMAIN, config, hass)
|
||||
async_log_exception(ex, DOMAIN, config, hass)
|
||||
return False
|
||||
else:
|
||||
conf = conf[0] if len(conf) > 0 else {}
|
||||
@ -121,60 +122,77 @@ def setup(hass: HomeAssistantType, config: ConfigType):
|
||||
timedelta(seconds=DEFAULT_CONSIDER_HOME))
|
||||
track_new = conf.get(CONF_TRACK_NEW, DEFAULT_TRACK_NEW)
|
||||
|
||||
devices = load_config(yaml_path, hass, consider_home)
|
||||
|
||||
devices = yield from async_load_config(yaml_path, hass, consider_home)
|
||||
tracker = DeviceTracker(hass, consider_home, track_new, devices)
|
||||
|
||||
def setup_platform(p_type, p_config, disc_info=None):
|
||||
# update tracked devices
|
||||
update_tasks = [device.async_update_ha_state() for device in devices
|
||||
if device.track]
|
||||
if update_tasks:
|
||||
yield from asyncio.wait(update_tasks, loop=hass.loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup_platform(p_type, p_config, disc_info=None):
|
||||
"""Setup a device tracker platform."""
|
||||
platform = prepare_setup_platform(hass, config, DOMAIN, p_type)
|
||||
platform = yield from async_prepare_setup_platform(
|
||||
hass, config, DOMAIN, p_type)
|
||||
if platform is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if hasattr(platform, 'get_scanner'):
|
||||
scanner = platform.get_scanner(hass, {DOMAIN: p_config})
|
||||
scanner = yield from hass.loop.run_in_executor(
|
||||
None, platform.get_scanner, hass, {DOMAIN: p_config})
|
||||
|
||||
if scanner is None:
|
||||
_LOGGER.error('Error setting up platform %s', p_type)
|
||||
return
|
||||
|
||||
setup_scanner_platform(hass, p_config, scanner, tracker.see)
|
||||
yield from async_setup_scanner_platform(
|
||||
hass, p_config, scanner, tracker.async_see)
|
||||
return
|
||||
|
||||
if not platform.setup_scanner(hass, p_config, tracker.see):
|
||||
ret = yield from hass.loop.run_in_executor(
|
||||
None, platform.setup_scanner, hass, p_config, tracker.see)
|
||||
if not ret:
|
||||
_LOGGER.error('Error setting up platform %s', p_type)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error setting up platform %s', p_type)
|
||||
|
||||
for p_type, p_config in config_per_platform(config, DOMAIN):
|
||||
setup_platform(p_type, p_config)
|
||||
setup_tasks = [async_setup_platform(p_type, p_config) for p_type, p_config
|
||||
in config_per_platform(config, DOMAIN)]
|
||||
if setup_tasks:
|
||||
yield from asyncio.wait(setup_tasks, loop=hass.loop)
|
||||
|
||||
def device_tracker_discovered(service, info):
|
||||
yield from tracker.async_setup_group()
|
||||
|
||||
@callback
|
||||
def async_device_tracker_discovered(service, info):
|
||||
"""Called when a device tracker platform is discovered."""
|
||||
setup_platform(DISCOVERY_PLATFORMS[service], {}, info)
|
||||
hass.async_add_job(
|
||||
async_setup_platform(DISCOVERY_PLATFORMS[service], {}, info))
|
||||
|
||||
discovery.listen(hass, DISCOVERY_PLATFORMS.keys(),
|
||||
device_tracker_discovered)
|
||||
discovery.async_listen(
|
||||
hass, DISCOVERY_PLATFORMS.keys(), async_device_tracker_discovered)
|
||||
|
||||
def update_stale(now):
|
||||
"""Clean up stale devices."""
|
||||
tracker.update_stale(now)
|
||||
track_utc_time_change(hass, update_stale, second=range(0, 60, 5))
|
||||
# Clean up stale devices
|
||||
async_track_utc_time_change(
|
||||
hass, tracker.async_update_stale, second=range(0, 60, 5))
|
||||
|
||||
tracker.setup_group()
|
||||
|
||||
def see_service(call):
|
||||
@asyncio.coroutine
|
||||
def async_see_service(call):
|
||||
"""Service to see a device."""
|
||||
args = {key: value for key, value in call.data.items() if key in
|
||||
(ATTR_MAC, ATTR_DEV_ID, ATTR_HOST_NAME, ATTR_LOCATION_NAME,
|
||||
ATTR_GPS, ATTR_GPS_ACCURACY, ATTR_BATTERY, ATTR_ATTRIBUTES)}
|
||||
tracker.see(**args)
|
||||
yield from tracker.async_see(**args)
|
||||
|
||||
descriptions = load_yaml_config_file(
|
||||
os.path.join(os.path.dirname(__file__), 'services.yaml'))
|
||||
hass.services.register(DOMAIN, SERVICE_SEE, see_service,
|
||||
descriptions.get(SERVICE_SEE))
|
||||
descriptions = yield from hass.loop.run_in_executor(
|
||||
None, load_yaml_config_file,
|
||||
os.path.join(os.path.dirname(__file__), 'services.yaml')
|
||||
)
|
||||
hass.services.async_register(
|
||||
DOMAIN, SERVICE_SEE, async_see_service, descriptions.get(SERVICE_SEE))
|
||||
|
||||
return True
|
||||
|
||||
@ -188,94 +206,116 @@ class DeviceTracker(object):
|
||||
self.hass = hass
|
||||
self.devices = {dev.dev_id: dev for dev in devices}
|
||||
self.mac_to_dev = {dev.mac: dev for dev in devices if dev.mac}
|
||||
self.consider_home = consider_home
|
||||
self.track_new = track_new
|
||||
self.group = None # type: group.Group
|
||||
self._is_updating = asyncio.Lock(loop=hass.loop)
|
||||
|
||||
for dev in devices:
|
||||
if self.devices[dev.dev_id] is not dev:
|
||||
_LOGGER.warning('Duplicate device IDs detected %s', dev.dev_id)
|
||||
if dev.mac and self.mac_to_dev[dev.mac] is not dev:
|
||||
_LOGGER.warning('Duplicate device MAC addresses detected %s',
|
||||
dev.mac)
|
||||
self.consider_home = consider_home
|
||||
self.track_new = track_new
|
||||
self.lock = threading.Lock()
|
||||
|
||||
for device in devices:
|
||||
if device.track:
|
||||
device.update_ha_state()
|
||||
|
||||
self.group = None # type: group.Group
|
||||
|
||||
def see(self, mac: str=None, dev_id: str=None, host_name: str=None,
|
||||
location_name: str=None, gps: GPSType=None, gps_accuracy=None,
|
||||
battery: str=None, attributes: dict=None):
|
||||
"""Notify the device tracker that you see a device."""
|
||||
with self.lock:
|
||||
if mac is None and dev_id is None:
|
||||
raise HomeAssistantError('Neither mac or device id passed in')
|
||||
elif mac is not None:
|
||||
mac = str(mac).upper()
|
||||
device = self.mac_to_dev.get(mac)
|
||||
if not device:
|
||||
dev_id = util.slugify(host_name or '') or util.slugify(mac)
|
||||
else:
|
||||
dev_id = cv.slug(str(dev_id).lower())
|
||||
device = self.devices.get(dev_id)
|
||||
self.hass.add_job(
|
||||
self.async_see(mac, dev_id, host_name, location_name, gps,
|
||||
gps_accuracy, battery, attributes)
|
||||
)
|
||||
|
||||
if device:
|
||||
device.seen(host_name, location_name, gps, gps_accuracy,
|
||||
battery, attributes)
|
||||
if device.track:
|
||||
device.update_ha_state()
|
||||
return
|
||||
@asyncio.coroutine
|
||||
def async_see(self, mac: str=None, dev_id: str=None, host_name: str=None,
|
||||
location_name: str=None, gps: GPSType=None,
|
||||
gps_accuracy=None, battery: str=None, attributes: dict=None):
|
||||
"""Notify the device tracker that you see a device.
|
||||
|
||||
# If no device can be found, create it
|
||||
dev_id = util.ensure_unique_string(dev_id, self.devices.keys())
|
||||
device = Device(
|
||||
self.hass, self.consider_home, self.track_new,
|
||||
dev_id, mac, (host_name or dev_id).replace('_', ' '))
|
||||
self.devices[dev_id] = device
|
||||
if mac is not None:
|
||||
self.mac_to_dev[mac] = device
|
||||
|
||||
device.seen(host_name, location_name, gps, gps_accuracy, battery,
|
||||
attributes)
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if mac is None and dev_id is None:
|
||||
raise HomeAssistantError('Neither mac or device id passed in')
|
||||
elif mac is not None:
|
||||
mac = str(mac).upper()
|
||||
device = self.mac_to_dev.get(mac)
|
||||
if not device:
|
||||
dev_id = util.slugify(host_name or '') or util.slugify(mac)
|
||||
else:
|
||||
dev_id = cv.slug(str(dev_id).lower())
|
||||
device = self.devices.get(dev_id)
|
||||
|
||||
if device:
|
||||
yield from device.async_seen(host_name, location_name, gps,
|
||||
gps_accuracy, battery, attributes)
|
||||
if device.track:
|
||||
device.update_ha_state()
|
||||
yield from device.async_update_ha_state()
|
||||
return
|
||||
|
||||
self.hass.bus.fire(EVENT_NEW_DEVICE, {
|
||||
ATTR_ENTITY_ID: device.entity_id,
|
||||
ATTR_HOST_NAME: device.host_name,
|
||||
})
|
||||
# If no device can be found, create it
|
||||
dev_id = util.ensure_unique_string(dev_id, self.devices.keys())
|
||||
device = Device(
|
||||
self.hass, self.consider_home, self.track_new,
|
||||
dev_id, mac, (host_name or dev_id).replace('_', ' '))
|
||||
self.devices[dev_id] = device
|
||||
if mac is not None:
|
||||
self.mac_to_dev[mac] = device
|
||||
|
||||
# During init, we ignore the group
|
||||
if self.group is not None:
|
||||
self.group.update_tracked_entity_ids(
|
||||
list(self.group.tracking) + [device.entity_id])
|
||||
update_config(self.hass.config.path(YAML_DEVICES), dev_id, device)
|
||||
yield from device.async_seen(host_name, location_name, gps,
|
||||
gps_accuracy, battery, attributes)
|
||||
|
||||
def setup_group(self):
|
||||
"""Initialize group for all tracked devices."""
|
||||
run_coroutine_threadsafe(
|
||||
self.async_setup_group(), self.hass.loop).result()
|
||||
if device.track:
|
||||
yield from device.async_update_ha_state()
|
||||
|
||||
self.hass.bus.async_fire(EVENT_NEW_DEVICE, {
|
||||
ATTR_ENTITY_ID: device.entity_id,
|
||||
ATTR_HOST_NAME: device.host_name,
|
||||
})
|
||||
|
||||
# During init, we ignore the group
|
||||
if self.group is not None:
|
||||
yield from self.group.async_update_tracked_entity_ids(
|
||||
list(self.group.tracking) + [device.entity_id])
|
||||
|
||||
# update known_devices.yaml
|
||||
self.hass.async_add_job(
|
||||
self.async_update_config(self.hass.config.path(YAML_DEVICES),
|
||||
dev_id, device)
|
||||
)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_update_config(self, path, dev_id, device):
|
||||
"""Add device to YAML configuration file.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
with (yield from self._is_updating):
|
||||
self.hass.loop.run_in_executor(
|
||||
None, update_config, self.hass.config.path(YAML_DEVICES),
|
||||
dev_id, device)
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_setup_group(self):
|
||||
"""Initialize group for all tracked devices.
|
||||
|
||||
This method must be run in the event loop.
|
||||
This method is a coroutine.
|
||||
"""
|
||||
entity_ids = (dev.entity_id for dev in self.devices.values()
|
||||
if dev.track)
|
||||
self.group = yield from group.Group.async_create_group(
|
||||
self.hass, GROUP_NAME_ALL_DEVICES, entity_ids, False)
|
||||
|
||||
def update_stale(self, now: dt_util.dt.datetime):
|
||||
"""Update stale devices."""
|
||||
with self.lock:
|
||||
for device in self.devices.values():
|
||||
if (device.track and device.last_update_home and
|
||||
device.stale(now)):
|
||||
device.update_ha_state(True)
|
||||
@callback
|
||||
def async_update_stale(self, now: dt_util.dt.datetime):
|
||||
"""Update stale devices.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
for device in self.devices.values():
|
||||
if (device.track and device.last_update_home) and \
|
||||
device.stale(now):
|
||||
self.hass.async_add_job(device.async_update_ha_state(True))
|
||||
|
||||
|
||||
class Device(Entity):
|
||||
@ -362,9 +402,10 @@ class Device(Entity):
|
||||
"""If device should be hidden."""
|
||||
return self.away_hide and self.state != STATE_HOME
|
||||
|
||||
def seen(self, host_name: str=None, location_name: str=None,
|
||||
gps: GPSType=None, gps_accuracy=0, battery: str=None,
|
||||
attributes: dict=None):
|
||||
@asyncio.coroutine
|
||||
def async_seen(self, host_name: str=None, location_name: str=None,
|
||||
gps: GPSType=None, gps_accuracy=0, battery: str=None,
|
||||
attributes: dict=None):
|
||||
"""Mark the device as seen."""
|
||||
self.last_seen = dt_util.utcnow()
|
||||
self.host_name = host_name
|
||||
@ -373,28 +414,38 @@ class Device(Entity):
|
||||
self.battery = battery
|
||||
self.attributes = attributes
|
||||
self.gps = None
|
||||
|
||||
if gps is not None:
|
||||
try:
|
||||
self.gps = float(gps[0]), float(gps[1])
|
||||
except (ValueError, TypeError, IndexError):
|
||||
_LOGGER.warning('Could not parse gps value for %s: %s',
|
||||
self.dev_id, gps)
|
||||
self.update()
|
||||
|
||||
# pylint: disable=not-an-iterable
|
||||
yield from self.async_update()
|
||||
|
||||
def stale(self, now: dt_util.dt.datetime=None):
|
||||
"""Return if device state is stale."""
|
||||
"""Return if device state is stale.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return self.last_seen and \
|
||||
(now or dt_util.utcnow()) - self.last_seen > self.consider_home
|
||||
|
||||
def update(self):
|
||||
"""Update state of entity."""
|
||||
@asyncio.coroutine
|
||||
def async_update(self):
|
||||
"""Update state of entity.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
if not self.last_seen:
|
||||
return
|
||||
elif self.location_name:
|
||||
self._state = self.location_name
|
||||
elif self.gps is not None:
|
||||
zone_state = zone.active_zone(self.hass, self.gps[0], self.gps[1],
|
||||
self.gps_accuracy)
|
||||
zone_state = zone.async_active_zone(
|
||||
self.hass, self.gps[0], self.gps[1], self.gps_accuracy)
|
||||
if zone_state is None:
|
||||
self._state = STATE_NOT_HOME
|
||||
elif zone_state.entity_id == zone.ENTITY_ID_HOME:
|
||||
@ -412,6 +463,17 @@ class Device(Entity):
|
||||
|
||||
def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
|
||||
"""Load devices from YAML configuration file."""
|
||||
return run_coroutine_threadsafe(
|
||||
async_load_config(path, hass, consider_home), hass.loop).result()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_load_config(path: str, hass: HomeAssistantType,
|
||||
consider_home: timedelta):
|
||||
"""Load devices from YAML configuration file.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
dev_schema = vol.Schema({
|
||||
vol.Required('name'): cv.string,
|
||||
vol.Optional('track', default=False): cv.boolean,
|
||||
@ -426,7 +488,8 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
|
||||
try:
|
||||
result = []
|
||||
try:
|
||||
devices = load_yaml_config_file(path)
|
||||
devices = yield from hass.loop.run_in_executor(
|
||||
None, load_yaml_config_file, path)
|
||||
except HomeAssistantError as err:
|
||||
_LOGGER.error('Unable to load %s: %s', path, str(err))
|
||||
return []
|
||||
@ -436,7 +499,7 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
|
||||
device = dev_schema(device)
|
||||
device['dev_id'] = cv.slugify(dev_id)
|
||||
except vol.Invalid as exp:
|
||||
log_exception(exp, dev_id, devices, hass)
|
||||
async_log_exception(exp, dev_id, devices, hass)
|
||||
else:
|
||||
result.append(Device(hass, **device))
|
||||
return result
|
||||
@ -445,9 +508,13 @@ def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
|
||||
return []
|
||||
|
||||
|
||||
def setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
|
||||
scanner: Any, see_device: Callable):
|
||||
"""Helper method to connect scanner-based platform to device tracker."""
|
||||
@asyncio.coroutine
|
||||
def async_setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
|
||||
scanner: Any, async_see_device: Callable):
|
||||
"""Helper method to connect scanner-based platform to device tracker.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
interval = config.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
|
||||
|
||||
# Initial scan of each mac we also tell about host name for config
|
||||
@ -455,18 +522,20 @@ def setup_scanner_platform(hass: HomeAssistantType, config: ConfigType,
|
||||
|
||||
def device_tracker_scan(now: dt_util.dt.datetime):
|
||||
"""Called when interval matches."""
|
||||
for mac in scanner.scan_devices():
|
||||
found_devices = scanner.scan_devices()
|
||||
|
||||
for mac in found_devices:
|
||||
if mac in seen:
|
||||
host_name = None
|
||||
else:
|
||||
host_name = scanner.get_device_name(mac)
|
||||
seen.add(mac)
|
||||
see_device(mac=mac, host_name=host_name)
|
||||
hass.async_add_job(async_see_device(mac=mac, host_name=host_name))
|
||||
|
||||
track_utc_time_change(hass, device_tracker_scan, second=range(0, 60,
|
||||
interval))
|
||||
async_track_utc_time_change(
|
||||
hass, device_tracker_scan, second=range(0, 60, interval))
|
||||
|
||||
device_tracker_scan(None)
|
||||
hass.async_add_job(device_tracker_scan, None)
|
||||
|
||||
|
||||
def update_config(path: str, dev_id: str, device: Device):
|
||||
@ -484,7 +553,10 @@ def update_config(path: str, dev_id: str, device: Device):
|
||||
|
||||
|
||||
def get_gravatar_for_email(email: str):
|
||||
"""Return an 80px Gravatar for the given email address."""
|
||||
"""Return an 80px Gravatar for the given email address.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
import hashlib
|
||||
url = 'https://www.gravatar.com/avatar/{}.jpg?s=80&d=wavatar'
|
||||
return url.format(hashlib.md5(email.encode('utf-8').lower()).hexdigest())
|
||||
|
@ -14,6 +14,7 @@ from homeassistant.const import (
|
||||
CONF_LONGITUDE, CONF_ICON)
|
||||
from homeassistant.helpers import config_per_platform
|
||||
from homeassistant.helpers.entity import Entity, async_generate_entity_id
|
||||
from homeassistant.util.async import run_callback_threadsafe
|
||||
from homeassistant.util.location import distance
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
@ -51,9 +52,19 @@ PLATFORM_SCHEMA = vol.Schema({
|
||||
|
||||
def active_zone(hass, latitude, longitude, radius=0):
|
||||
"""Find the active zone for given latitude, longitude."""
|
||||
return run_callback_threadsafe(
|
||||
hass.loop, async_active_zone, hass, latitude, longitude, radius
|
||||
).result()
|
||||
|
||||
|
||||
def async_active_zone(hass, latitude, longitude, radius=0):
|
||||
"""Find the active zone for given latitude, longitude.
|
||||
|
||||
This method must be run in the event loop.
|
||||
"""
|
||||
# Sort entity IDs so that we are deterministic if equal distance to 2 zones
|
||||
zones = (hass.states.get(entity_id) for entity_id
|
||||
in sorted(hass.states.entity_ids(DOMAIN)))
|
||||
in sorted(hass.states.async_entity_ids(DOMAIN)))
|
||||
|
||||
min_dist = None
|
||||
closest = None
|
||||
@ -80,7 +91,10 @@ def active_zone(hass, latitude, longitude, radius=0):
|
||||
|
||||
|
||||
def in_zone(zone, latitude, longitude, radius=0):
|
||||
"""Test if given latitude, longitude is in given zone."""
|
||||
"""Test if given latitude, longitude is in given zone.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
zone_dist = distance(
|
||||
latitude, longitude,
|
||||
zone.attributes[ATTR_LATITUDE], zone.attributes[ATTR_LONGITUDE])
|
||||
|
@ -14,6 +14,16 @@ ATTR_PLATFORM = 'platform'
|
||||
def listen(hass, service, callback):
|
||||
"""Setup listener for discovery of specific service.
|
||||
|
||||
Service can be a string or a list/tuple.
|
||||
"""
|
||||
run_callback_threadsafe(
|
||||
hass.loop, async_listen, hass, service, callback).result()
|
||||
|
||||
|
||||
@core.callback
|
||||
def async_listen(hass, service, callback):
|
||||
"""Setup listener for discovery of specific service.
|
||||
|
||||
Service can be a string or a list/tuple.
|
||||
"""
|
||||
if isinstance(service, str):
|
||||
@ -21,12 +31,14 @@ def listen(hass, service, callback):
|
||||
else:
|
||||
service = tuple(service)
|
||||
|
||||
@core.callback
|
||||
def discovery_event_listener(event):
|
||||
"""Listen for discovery events."""
|
||||
if ATTR_SERVICE in event.data and event.data[ATTR_SERVICE] in service:
|
||||
callback(event.data[ATTR_SERVICE], event.data.get(ATTR_DISCOVERED))
|
||||
hass.async_add_job(callback, event.data[ATTR_SERVICE],
|
||||
event.data.get(ATTR_DISCOVERED))
|
||||
|
||||
hass.bus.listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
|
||||
hass.bus.async_listen(EVENT_PLATFORM_DISCOVERED, discovery_event_listener)
|
||||
|
||||
|
||||
def discover(hass, service, discovered=None, component=None, hass_config=None):
|
||||
|
@ -8,7 +8,10 @@ from homeassistant.util import location as loc_util
|
||||
|
||||
|
||||
def has_location(state: State) -> bool:
|
||||
"""Test if state contains a valid location."""
|
||||
"""Test if state contains a valid location.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return (isinstance(state, State) and
|
||||
isinstance(state.attributes.get(ATTR_LATITUDE), float) and
|
||||
isinstance(state.attributes.get(ATTR_LONGITUDE), float))
|
||||
@ -16,7 +19,10 @@ def has_location(state: State) -> bool:
|
||||
|
||||
def closest(latitude: float, longitude: float,
|
||||
states: Sequence[State]) -> State:
|
||||
"""Return closest state to point."""
|
||||
"""Return closest state to point.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
with_location = [state for state in states if has_location(state)]
|
||||
|
||||
if not with_location:
|
||||
|
@ -51,7 +51,10 @@ def detect_location_info():
|
||||
|
||||
|
||||
def distance(lat1, lon1, lat2, lon2):
|
||||
"""Calculate the distance in meters between two points."""
|
||||
"""Calculate the distance in meters between two points.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return vincenty((lat1, lon1), (lat2, lon2)) * 1000
|
||||
|
||||
|
||||
@ -88,6 +91,8 @@ def vincenty(point1: Tuple[float, float], point2: Tuple[float, float],
|
||||
|
||||
Result in kilometers or miles between two points on the surface of a
|
||||
spheroid.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
# short-circuit coincident points
|
||||
if point1[0] == point2[0] and point1[1] == point2[1]:
|
||||
|
@ -10,6 +10,7 @@ import os
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.bootstrap import setup_component
|
||||
from homeassistant.loader import get_component
|
||||
from homeassistant.util.async import run_coroutine_threadsafe
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID, ATTR_ENTITY_PICTURE, ATTR_FRIENDLY_NAME, ATTR_HIDDEN,
|
||||
@ -280,7 +281,7 @@ class TestComponentsDeviceTracker(unittest.TestCase):
|
||||
self.assertSequenceEqual((entity_id,),
|
||||
state.attributes.get(ATTR_ENTITY_ID))
|
||||
|
||||
@patch('homeassistant.components.device_tracker.DeviceTracker.see')
|
||||
@patch('homeassistant.components.device_tracker.DeviceTracker.async_see')
|
||||
def test_see_service(self, mock_see):
|
||||
"""Test the see service with a unicode dev_id and NO MAC."""
|
||||
self.assertTrue(setup_component(self.hass, device_tracker.DOMAIN,
|
||||
@ -375,20 +376,22 @@ class TestComponentsDeviceTracker(unittest.TestCase):
|
||||
|
||||
# No device id or MAC(not added)
|
||||
with self.assertRaises(HomeAssistantError):
|
||||
tracker.see()
|
||||
run_coroutine_threadsafe(
|
||||
tracker.async_see(), self.hass.loop).result()
|
||||
assert mock_warning.call_count == 0
|
||||
|
||||
# Ignore gps on invalid GPS (both added & warnings)
|
||||
tracker.see(mac='mac_1_bad_gps', gps=1)
|
||||
tracker.see(mac='mac_2_bad_gps', gps=[1])
|
||||
tracker.see(mac='mac_3_bad_gps', gps='gps')
|
||||
self.hass.block_till_done()
|
||||
config = device_tracker.load_config(self.yaml_devices, self.hass,
|
||||
timedelta(seconds=0))
|
||||
assert mock_warning.call_count == 3
|
||||
|
||||
assert len(config) == 4
|
||||
|
||||
@patch('homeassistant.components.device_tracker.log_exception')
|
||||
@patch('homeassistant.components.device_tracker.async_log_exception')
|
||||
def test_config_failure(self, mock_ex):
|
||||
"""Test that the device tracker see failures."""
|
||||
with assert_setup_component(0, device_tracker.DOMAIN):
|
||||
|
@ -37,7 +37,8 @@ class TestComponentsDeviceTrackerMQTT(unittest.TestCase):
|
||||
self.assertTrue('qos' in config)
|
||||
|
||||
with patch('homeassistant.components.device_tracker.mqtt.'
|
||||
'setup_scanner', side_effect=mock_setup_scanner) as mock_sp:
|
||||
'setup_scanner', autospec=True,
|
||||
side_effect=mock_setup_scanner) as mock_sp:
|
||||
|
||||
dev_id = 'paulus'
|
||||
topic = '/location/paulus'
|
||||
|
Loading…
x
Reference in New Issue
Block a user