From 1baf0da62780f96aa02d5168a984048c3215276e Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Mon, 25 Sep 2017 09:05:09 -0700 Subject: [PATCH] Clean up OwnTracks (#9569) * Clean up OwnTracks * Address comments --- .../components/device_tracker/owntracks.py | 572 +++++++++--------- homeassistant/util/decorator.py | 14 + .../device_tracker/test_owntracks.py | 72 ++- .../device_tracker/test_upc_connect.py | 20 +- 4 files changed, 372 insertions(+), 306 deletions(-) create mode 100644 homeassistant/util/decorator.py diff --git a/homeassistant/components/device_tracker/owntracks.py b/homeassistant/components/device_tracker/owntracks.py index 5c5c3c7c92e..1c773f97692 100644 --- a/homeassistant/components/device_tracker/owntracks.py +++ b/homeassistant/components/device_tracker/owntracks.py @@ -16,7 +16,7 @@ from homeassistant.core import callback import homeassistant.helpers.config_validation as cv import homeassistant.components.mqtt as mqtt from homeassistant.const import STATE_HOME -from homeassistant.util import convert, slugify +from homeassistant.util import slugify, decorator from homeassistant.components import zone as zone_comp from homeassistant.components.device_tracker import PLATFORM_SCHEMA @@ -25,6 +25,8 @@ REQUIREMENTS = ['libnacl==1.5.2'] _LOGGER = logging.getLogger(__name__) +HANDLERS = decorator.Registry() + BEACON_DEV_ID = 'beacon' CONF_MAX_GPS_ACCURACY = 'max_gps_accuracy' @@ -32,17 +34,7 @@ CONF_SECRET = 'secret' CONF_WAYPOINT_IMPORT = 'waypoints' CONF_WAYPOINT_WHITELIST = 'waypoint_whitelist' -EVENT_TOPIC = 'owntracks/+/+/event' - -LOCATION_TOPIC = 'owntracks/+/+' - -VALIDATE_LOCATION = 'location' -VALIDATE_TRANSITION = 'transition' -VALIDATE_WAYPOINTS = 'waypoints' - -WAYPOINT_LAT_KEY = 'lat' -WAYPOINT_LON_KEY = 'lon' -WAYPOINT_TOPIC = 'owntracks/{}/{}/waypoints' +OWNTRACKS_TOPIC = 'owntracks/#' PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Optional(CONF_MAX_GPS_ACCURACY): vol.Coerce(float), @@ -77,295 +69,61 @@ def async_setup_scanner(hass, config, async_see, discovery_info=None): waypoint_whitelist = config.get(CONF_WAYPOINT_WHITELIST) secret = config.get(CONF_SECRET) - mobile_beacons_active = defaultdict(list) - regions_entered = defaultdict(list) + context = OwnTracksContext(async_see, secret, max_gps_accuracy, + waypoint_import, waypoint_whitelist) - def decrypt_payload(topic, ciphertext): - """Decrypt encrypted payload.""" + @asyncio.coroutine + def async_handle_mqtt_message(topic, payload, qos): + """Handle incoming OwnTracks message.""" try: - keylen, decrypt = get_cipher() - except OSError: - _LOGGER.warning( - "Ignoring encrypted payload because libsodium not installed") - return None - - if isinstance(secret, dict): - key = secret.get(topic) - else: - key = secret - - if key is None: - _LOGGER.warning( - "Ignoring encrypted payload because no decryption key known " - "for topic %s", topic) - return None - - key = key.encode("utf-8") - key = key[:keylen] - key = key.ljust(keylen, b'\0') - - try: - ciphertext = base64.b64decode(ciphertext) - message = decrypt(ciphertext, key) - message = message.decode("utf-8") - _LOGGER.debug("Decrypted payload: %s", message) - return message - except ValueError: - _LOGGER.warning( - "Ignoring encrypted payload because unable to decrypt using " - "key for topic %s", topic) - return None - - def validate_payload(topic, payload, data_type): - """Validate the OwnTracks payload.""" - try: - data = json.loads(payload) + message = json.loads(payload) except ValueError: # If invalid JSON _LOGGER.error("Unable to parse payload as JSON: %s", payload) - return None - if isinstance(data, dict) and \ - data.get('_type') == 'encrypted' and \ - 'data' in data: - plaintext_payload = decrypt_payload(topic, data['data']) - if plaintext_payload is None: - return None - return validate_payload(topic, plaintext_payload, data_type) + message['topic'] = topic - if not isinstance(data, dict) or data.get('_type') != data_type: - _LOGGER.debug("Skipping %s update for following data " - "because of missing or malformatted data: %s", - data_type, data) - return None - if data_type == VALIDATE_TRANSITION or data_type == VALIDATE_WAYPOINTS: - return data - if max_gps_accuracy is not None and \ - convert(data.get('acc'), float, 0.0) > max_gps_accuracy: - _LOGGER.info("Ignoring %s update because expected GPS " - "accuracy %s is not met: %s", - data_type, max_gps_accuracy, payload) - return None - if convert(data.get('acc'), float, 1.0) == 0.0: - _LOGGER.warning( - "Ignoring %s update because GPS accuracy is zero: %s", - data_type, payload) - return None - - return data - - @callback - def async_owntracks_location_update(topic, payload, qos): - """MQTT message received.""" - # Docs on available data: - # http://owntracks.org/booklet/tech/json/#_typelocation - data = validate_payload(topic, payload, VALIDATE_LOCATION) - if not data: - return - - dev_id, kwargs = _parse_see_args(topic, data) - - if regions_entered[dev_id]: - _LOGGER.debug( - "Location update ignored, inside region %s", - regions_entered[-1]) - return - - hass.async_add_job(async_see(**kwargs)) - async_see_beacons(dev_id, kwargs) - - @callback - def async_owntracks_event_update(topic, payload, qos): - """Handle MQTT event (geofences).""" - # Docs on available data: - # http://owntracks.org/booklet/tech/json/#_typetransition - data = validate_payload(topic, payload, VALIDATE_TRANSITION) - if not data: - return - - if data.get('desc') is None: - _LOGGER.error( - "Location missing from `Entering/Leaving` message - " - "please turn `Share` on in OwnTracks app") - return - # OwnTracks uses - at the start of a beacon zone - # to switch on 'hold mode' - ignore this - location = data['desc'].lstrip("-") - if location.lower() == 'home': - location = STATE_HOME - - dev_id, kwargs = _parse_see_args(topic, data) - - def enter_event(): - """Execute enter event.""" - zone = hass.states.get("zone.{}".format(slugify(location))) - if zone is None and data.get('t') == 'b': - # Not a HA zone, and a beacon so assume mobile - beacons = mobile_beacons_active[dev_id] - if location not in beacons: - beacons.append(location) - _LOGGER.info("Added beacon %s", location) - else: - # Normal region - regions = regions_entered[dev_id] - if location not in regions: - regions.append(location) - _LOGGER.info("Enter region %s", location) - _set_gps_from_zone(kwargs, location, zone) - - hass.async_add_job(async_see(**kwargs)) - async_see_beacons(dev_id, kwargs) - - def leave_event(): - """Execute leave event.""" - regions = regions_entered[dev_id] - if location in regions: - regions.remove(location) - new_region = regions[-1] if regions else None - - if new_region: - # Exit to previous region - zone = hass.states.get( - "zone.{}".format(slugify(new_region))) - _set_gps_from_zone(kwargs, new_region, zone) - _LOGGER.info("Exit to %s", new_region) - hass.async_add_job(async_see(**kwargs)) - async_see_beacons(dev_id, kwargs) - - else: - _LOGGER.info("Exit to GPS") - # Check for GPS accuracy - valid_gps = True - if 'acc' in data: - if data['acc'] == 0.0: - valid_gps = False - _LOGGER.warning( - "Ignoring GPS in region exit because accuracy" - "is zero: %s", payload) - if (max_gps_accuracy is not None and - data['acc'] > max_gps_accuracy): - valid_gps = False - _LOGGER.info( - "Ignoring GPS in region exit because expected " - "GPS accuracy %s is not met: %s", - max_gps_accuracy, payload) - if valid_gps: - hass.async_add_job(async_see(**kwargs)) - async_see_beacons(dev_id, kwargs) - - beacons = mobile_beacons_active[dev_id] - if location in beacons: - beacons.remove(location) - _LOGGER.info("Remove beacon %s", location) - - if data['event'] == 'enter': - enter_event() - elif data['event'] == 'leave': - leave_event() - else: - _LOGGER.error( - "Misformatted mqtt msgs, _type=transition, event=%s", - data['event']) - return - - @callback - def async_owntracks_waypoint_update(topic, payload, qos): - """List of waypoints published by a user.""" - # Docs on available data: - # http://owntracks.org/booklet/tech/json/#_typewaypoints - data = validate_payload(topic, payload, VALIDATE_WAYPOINTS) - if not data: - return - - wayps = data['waypoints'] - _LOGGER.info("Got %d waypoints from %s", len(wayps), topic) - for wayp in wayps: - name = wayp['desc'] - pretty_name = parse_topic(topic, True)[1] + ' - ' + name - lat = wayp[WAYPOINT_LAT_KEY] - lon = wayp[WAYPOINT_LON_KEY] - rad = wayp['rad'] - - # check zone exists - entity_id = zone_comp.ENTITY_ID_FORMAT.format(slugify(pretty_name)) - - # Check if state already exists - if hass.states.get(entity_id) is not None: - continue - - zone = zone_comp.Zone(hass, pretty_name, lat, lon, rad, - zone_comp.ICON_IMPORT, False) - zone.entity_id = entity_id - hass.async_add_job(zone.async_update_ha_state()) - - @callback - def async_see_beacons(dev_id, kwargs_param): - """Set active beacons to the current location.""" - kwargs = kwargs_param.copy() - # the battery state applies to the tracking device, not the beacon - kwargs.pop('battery', None) - for beacon in mobile_beacons_active[dev_id]: - kwargs['dev_id'] = "{}_{}".format(BEACON_DEV_ID, beacon) - kwargs['host_name'] = beacon - hass.async_add_job(async_see(**kwargs)) + yield from async_handle_message(hass, context, message) yield from mqtt.async_subscribe( - hass, LOCATION_TOPIC, async_owntracks_location_update, 1) - yield from mqtt.async_subscribe( - hass, EVENT_TOPIC, async_owntracks_event_update, 1) - - if waypoint_import: - if waypoint_whitelist is None: - yield from mqtt.async_subscribe( - hass, WAYPOINT_TOPIC.format('+', '+'), - async_owntracks_waypoint_update, 1) - else: - for whitelist_user in waypoint_whitelist: - yield from mqtt.async_subscribe( - hass, WAYPOINT_TOPIC.format(whitelist_user, '+'), - async_owntracks_waypoint_update, 1) + hass, OWNTRACKS_TOPIC, async_handle_mqtt_message, 1) return True -def parse_topic(topic, pretty=False): +def _parse_topic(topic): """Parse an MQTT topic owntracks/user/dev, return (user, dev) tuple. Async friendly. """ - parts = topic.split('/') - dev_id_format = '' - if pretty: - dev_id_format = '{} {}' - else: - dev_id_format = '{}_{}' - dev_id = slugify(dev_id_format.format(parts[1], parts[2])) - host_name = parts[1] - return (host_name, dev_id) + _, user, device, *_ = topic.split('/', 3) + + return user, device -def _parse_see_args(topic, data): +def _parse_see_args(message): """Parse the OwnTracks location parameters, into the format see expects. Async friendly. """ - (host_name, dev_id) = parse_topic(topic, False) + user, device = _parse_topic(message['topic']) + dev_id = slugify('{}_{}'.format(user, device)) kwargs = { 'dev_id': dev_id, - 'host_name': host_name, - 'gps': (data[WAYPOINT_LAT_KEY], data[WAYPOINT_LON_KEY]), + 'host_name': user, + 'gps': (message['lat'], message['lon']), 'attributes': {} } - if 'acc' in data: - kwargs['gps_accuracy'] = data['acc'] - if 'batt' in data: - kwargs['battery'] = data['batt'] - if 'vel' in data: - kwargs['attributes']['velocity'] = data['vel'] - if 'tid' in data: - kwargs['attributes']['tid'] = data['tid'] - if 'addr' in data: - kwargs['attributes']['address'] = data['addr'] + if 'acc' in message: + kwargs['gps_accuracy'] = message['acc'] + if 'batt' in message: + kwargs['battery'] = message['batt'] + if 'vel' in message: + kwargs['attributes']['velocity'] = message['vel'] + if 'tid' in message: + kwargs['attributes']['tid'] = message['tid'] + if 'addr' in message: + kwargs['attributes']['address'] = message['addr'] return dev_id, kwargs @@ -382,3 +140,269 @@ def _set_gps_from_zone(kwargs, location, zone): kwargs['gps_accuracy'] = zone.attributes['radius'] kwargs['location_name'] = location return kwargs + + +def _decrypt_payload(secret, topic, ciphertext): + """Decrypt encrypted payload.""" + try: + keylen, decrypt = get_cipher() + except OSError: + _LOGGER.warning( + "Ignoring encrypted payload because libsodium not installed") + return None + + if isinstance(secret, dict): + key = secret.get(topic) + else: + key = secret + + if key is None: + _LOGGER.warning( + "Ignoring encrypted payload because no decryption key known " + "for topic %s", topic) + return None + + key = key.encode("utf-8") + key = key[:keylen] + key = key.ljust(keylen, b'\0') + + try: + ciphertext = base64.b64decode(ciphertext) + message = decrypt(ciphertext, key) + message = message.decode("utf-8") + _LOGGER.debug("Decrypted payload: %s", message) + return message + except ValueError: + _LOGGER.warning( + "Ignoring encrypted payload because unable to decrypt using " + "key for topic %s", topic) + return None + + +class OwnTracksContext: + """Hold the current OwnTracks context.""" + + def __init__(self, async_see, secret, max_gps_accuracy, import_waypoints, + waypoint_whitelist): + """Initialize an OwnTracks context.""" + self.async_see = async_see + self.secret = secret + self.max_gps_accuracy = max_gps_accuracy + self.mobile_beacons_active = defaultdict(list) + self.regions_entered = defaultdict(list) + self.import_waypoints = import_waypoints + self.waypoint_whitelist = waypoint_whitelist + + @callback + def async_valid_accuracy(self, message): + """Check if we should ignore this message.""" + acc = message.get('acc') + + if acc is None: + return False + + try: + acc = float(acc) + except ValueError: + return False + + if acc == 0: + _LOGGER.warning( + "Ignoring %s update because GPS accuracy is zero: %s", + message['_type'], message) + return False + + if self.max_gps_accuracy is not None and \ + acc > self.max_gps_accuracy: + _LOGGER.info("Ignoring %s update because expected GPS " + "accuracy %s is not met: %s", + message['_type'], self.max_gps_accuracy, + message) + return False + + return True + + @asyncio.coroutine + def async_see_beacons(self, dev_id, kwargs_param): + """Set active beacons to the current location.""" + kwargs = kwargs_param.copy() + # the battery state applies to the tracking device, not the beacon + kwargs.pop('battery', None) + for beacon in self.mobile_beacons_active[dev_id]: + kwargs['dev_id'] = "{}_{}".format(BEACON_DEV_ID, beacon) + kwargs['host_name'] = beacon + yield from self.async_see(**kwargs) + + +@HANDLERS.register('location') +@asyncio.coroutine +def async_handle_location_message(hass, context, message): + """Handle a location message.""" + if not context.async_valid_accuracy(message): + return + + dev_id, kwargs = _parse_see_args(message) + + if context.regions_entered[dev_id]: + _LOGGER.debug( + "Location update ignored, inside region %s", + context.regions_entered[-1]) + return + + yield from context.async_see(**kwargs) + yield from context.async_see_beacons(dev_id, kwargs) + + +@asyncio.coroutine +def _async_transition_message_enter(hass, context, message, location): + """Execute enter event.""" + zone = hass.states.get("zone.{}".format(slugify(location))) + dev_id, kwargs = _parse_see_args(message) + + if zone is None and message.get('t') == 'b': + # Not a HA zone, and a beacon so assume mobile + beacons = context.mobile_beacons_active[dev_id] + if location not in beacons: + beacons.append(location) + _LOGGER.info("Added beacon %s", location) + else: + # Normal region + regions = context.regions_entered[dev_id] + if location not in regions: + regions.append(location) + _LOGGER.info("Enter region %s", location) + _set_gps_from_zone(kwargs, location, zone) + + yield from context.async_see(**kwargs) + yield from context.async_see_beacons(dev_id, kwargs) + + +@asyncio.coroutine +def _async_transition_message_leave(hass, context, message, location): + """Execute leave event.""" + dev_id, kwargs = _parse_see_args(message) + regions = context.regions_entered[dev_id] + + if location in regions: + regions.remove(location) + + new_region = regions[-1] if regions else None + + if new_region: + # Exit to previous region + zone = hass.states.get( + "zone.{}".format(slugify(new_region))) + _set_gps_from_zone(kwargs, new_region, zone) + _LOGGER.info("Exit to %s", new_region) + yield from context.async_see(**kwargs) + yield from context.async_see_beacons(dev_id, kwargs) + return + + else: + _LOGGER.info("Exit to GPS") + + # Check for GPS accuracy + if context.async_valid_accuracy(message): + yield from context.async_see(**kwargs) + yield from context.async_see_beacons(dev_id, kwargs) + + beacons = context.mobile_beacons_active[dev_id] + if location in beacons: + beacons.remove(location) + _LOGGER.info("Remove beacon %s", location) + + +@HANDLERS.register('transition') +@asyncio.coroutine +def async_handle_transition_message(hass, context, message): + """Handle a transition message.""" + if message.get('desc') is None: + _LOGGER.error( + "Location missing from `Entering/Leaving` message - " + "please turn `Share` on in OwnTracks app") + return + # OwnTracks uses - at the start of a beacon zone + # to switch on 'hold mode' - ignore this + location = message['desc'].lstrip("-") + if location.lower() == 'home': + location = STATE_HOME + + if message['event'] == 'enter': + yield from _async_transition_message_enter( + hass, context, message, location) + elif message['event'] == 'leave': + yield from _async_transition_message_leave( + hass, context, message, location) + else: + _LOGGER.error( + "Misformatted mqtt msgs, _type=transition, event=%s", + message['event']) + + +@HANDLERS.register('waypoints') +@asyncio.coroutine +def async_handle_waypoints_message(hass, context, message): + """Handle a waypoints message.""" + if not context.import_waypoints: + return + + if context.waypoint_whitelist is not None: + user = _parse_topic(message['topic'])[0] + + if user not in context.waypoint_whitelist: + return + + wayps = message['waypoints'] + + _LOGGER.info("Got %d waypoints from %s", len(wayps), message['topic']) + + name_base = ' '.join(_parse_topic(message['topic'])) + + for wayp in wayps: + name = wayp['desc'] + pretty_name = '{} - {}'.format(name_base, name) + lat = wayp['lat'] + lon = wayp['lon'] + rad = wayp['rad'] + + # check zone exists + entity_id = zone_comp.ENTITY_ID_FORMAT.format(slugify(pretty_name)) + + # Check if state already exists + if hass.states.get(entity_id) is not None: + continue + + zone = zone_comp.Zone(hass, pretty_name, lat, lon, rad, + zone_comp.ICON_IMPORT, False) + zone.entity_id = entity_id + yield from zone.async_update_ha_state() + + +@HANDLERS.register('encrypted') +@asyncio.coroutine +def async_handle_encrypted_message(hass, context, message): + """Handle an encrypted message.""" + plaintext_payload = _decrypt_payload(context.secret, message['topic'], + message['data']) + + if plaintext_payload is None: + return + + decrypted = json.loads(plaintext_payload) + decrypted['topic'] = message['topic'] + + yield from async_handle_message(hass, context, decrypted) + + +@asyncio.coroutine +def async_handle_message(hass, context, message): + """Handle an OwnTracks message.""" + msgtype = message.get('_type') + + handler = HANDLERS.get(msgtype) + + if handler is None: + error = 'Received unsupported message type: {}.'.format(msgtype) + _LOGGER.warning(error) + + yield from handler(hass, context, message) diff --git a/homeassistant/util/decorator.py b/homeassistant/util/decorator.py new file mode 100644 index 00000000000..c26606d52cf --- /dev/null +++ b/homeassistant/util/decorator.py @@ -0,0 +1,14 @@ +"""Decorator utility functions.""" + + +class Registry(dict): + """Registry of items.""" + + def register(self, name): + """Return decorator to register item with a specific name.""" + def decorator(func): + """Register decorated function.""" + self[name] = func + return func + + return decorator diff --git a/tests/components/device_tracker/test_owntracks.py b/tests/components/device_tracker/test_owntracks.py index e4944035261..3a23fe61d41 100644 --- a/tests/components/device_tracker/test_owntracks.py +++ b/tests/components/device_tracker/test_owntracks.py @@ -1,13 +1,12 @@ """The tests for the Owntracks device tracker.""" import asyncio import json -import os -from collections import defaultdict import unittest from unittest.mock import patch -from tests.common import (assert_setup_component, fire_mqtt_message, - get_test_home_assistant, mock_mqtt_component) +from tests.common import (assert_setup_component, fire_mqtt_message, mock_coro, + get_test_home_assistant, mock_mqtt_component, + mock_component) import homeassistant.components.device_tracker.owntracks as owntracks from homeassistant.setup import setup_component @@ -20,9 +19,9 @@ DEVICE = 'phone' LOCATION_TOPIC = 'owntracks/{}/{}'.format(USER, DEVICE) EVENT_TOPIC = 'owntracks/{}/{}/event'.format(USER, DEVICE) -WAYPOINT_TOPIC = owntracks.WAYPOINT_TOPIC.format(USER, DEVICE) +WAYPOINT_TOPIC = 'owntracks/{}/{}/waypoints'.format(USER, DEVICE) USER_BLACKLIST = 'ram' -WAYPOINT_TOPIC_BLOCKED = owntracks.WAYPOINT_TOPIC.format( +WAYPOINT_TOPIC_BLOCKED = 'owntracks/{}/{}/waypoints'.format( USER_BLACKLIST, DEVICE) DEVICE_TRACKER_STATE = 'device_tracker.{}_{}'.format(USER, DEVICE) @@ -252,7 +251,26 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() mock_mqtt_component(self.hass) - with assert_setup_component(1, device_tracker.DOMAIN): + mock_component(self.hass, 'group') + mock_component(self.hass, 'zone') + + patcher = patch('homeassistant.components.device_tracker.' + 'DeviceTracker.async_update_config') + patcher.start() + self.addCleanup(patcher.stop) + + orig_context = owntracks.OwnTracksContext + + def store_context(*args): + self.context = orig_context(*args) + return self.context + + with patch('homeassistant.components.device_tracker.async_load_config', + return_value=mock_coro([])), \ + patch('homeassistant.components.device_tracker.' + 'load_yaml_config_file', return_value=mock_coro({})), \ + patch.object(owntracks, 'OwnTracksContext', store_context), \ + assert_setup_component(1, device_tracker.DOMAIN): assert setup_component(self.hass, device_tracker.DOMAIN, { device_tracker.DOMAIN: { CONF_PLATFORM: 'owntracks', @@ -290,18 +308,11 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): # Clear state between teste self.hass.states.set(DEVICE_TRACKER_STATE, None) - owntracks.REGIONS_ENTERED = defaultdict(list) - owntracks.MOBILE_BEACONS_ACTIVE = defaultdict(list) def teardown_method(self, _): """Stop everything that was started.""" self.hass.stop() - try: - os.remove(self.hass.config.path(device_tracker.YAML_DEVICES)) - except FileNotFoundError: - pass - def assert_tracker_state(self, location): """Test the assertion of a tracker state.""" state = self.hass.states.get(REGION_TRACKER_STATE) @@ -372,7 +383,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): self.assert_location_state('outer') # Left clean zone state - self.assertFalse(owntracks.REGIONS_ENTERED[USER]) + self.assertFalse(self.context.regions_entered[USER]) def test_event_with_spaces(self): """Test the entry event.""" @@ -386,7 +397,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): self.send_message(EVENT_TOPIC, message) # Left clean zone state - self.assertFalse(owntracks.REGIONS_ENTERED[USER]) + self.assertFalse(self.context.regions_entered[USER]) def test_event_entry_exit_inaccurate(self): """Test the event for inaccurate exit.""" @@ -405,7 +416,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): self.assert_location_state('inner') # But does exit region correctly - self.assertFalse(owntracks.REGIONS_ENTERED[USER]) + self.assertFalse(self.context.regions_entered[USER]) def test_event_entry_exit_zero_accuracy(self): """Test entry/exit events with accuracy zero.""" @@ -424,7 +435,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): self.assert_location_state('inner') # But does exit region correctly - self.assertFalse(owntracks.REGIONS_ENTERED[USER]) + self.assertFalse(self.context.regions_entered[USER]) def test_event_exit_outside_zone_sets_away(self): """Test the event for exit zone.""" @@ -604,7 +615,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): self.hass.block_till_done() self.send_message(EVENT_TOPIC, exit_message) - self.assertEqual(owntracks.MOBILE_BEACONS_ACTIVE['greg_phone'], []) + self.assertEqual(self.context.mobile_beacons_active['greg_phone'], []) def test_mobile_multiple_enter_exit(self): """Test the multiple entering.""" @@ -618,7 +629,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): self.send_message(EVENT_TOPIC, enter_message) self.send_message(EVENT_TOPIC, exit_message) - self.assertEqual(owntracks.MOBILE_BEACONS_ACTIVE['greg_phone'], []) + self.assertEqual(self.context.mobile_beacons_active['greg_phone'], []) def test_waypoint_import_simple(self): """Test a simple import of list of waypoints.""" @@ -706,6 +717,19 @@ class TestDeviceTrackerOwnTrackConfigs(BaseMQTT): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() mock_mqtt_component(self.hass) + mock_component(self.hass, 'group') + mock_component(self.hass, 'zone') + + patch_load = patch( + 'homeassistant.components.device_tracker.async_load_config', + return_value=mock_coro([])) + patch_load.start() + self.addCleanup(patch_load.stop) + + patch_save = patch('homeassistant.components.device_tracker.' + 'DeviceTracker.async_update_config') + patch_save.start() + self.addCleanup(patch_save.stop) def teardown_method(self, method): """Tear down resources.""" @@ -749,7 +773,7 @@ class TestDeviceTrackerOwnTrackConfigs(BaseMQTT): # key missing }}) self.send_message(LOCATION_TOPIC, MOCK_ENCRYPTED_LOCATION_MESSAGE) - self.assert_location_latitude(None) + assert self.hass.states.get(DEVICE_TRACKER_STATE) is None @patch('homeassistant.components.device_tracker.owntracks.get_cipher', mock_cipher) @@ -762,7 +786,7 @@ class TestDeviceTrackerOwnTrackConfigs(BaseMQTT): CONF_SECRET: 'wrong key', }}) self.send_message(LOCATION_TOPIC, MOCK_ENCRYPTED_LOCATION_MESSAGE) - self.assert_location_latitude(None) + assert self.hass.states.get(DEVICE_TRACKER_STATE) is None @patch('homeassistant.components.device_tracker.owntracks.get_cipher', mock_cipher) @@ -776,7 +800,7 @@ class TestDeviceTrackerOwnTrackConfigs(BaseMQTT): LOCATION_TOPIC: 'wrong key' }}}) self.send_message(LOCATION_TOPIC, MOCK_ENCRYPTED_LOCATION_MESSAGE) - self.assert_location_latitude(None) + assert self.hass.states.get(DEVICE_TRACKER_STATE) is None @patch('homeassistant.components.device_tracker.owntracks.get_cipher', mock_cipher) @@ -790,7 +814,7 @@ class TestDeviceTrackerOwnTrackConfigs(BaseMQTT): 'owntracks/{}/{}'.format(USER, 'otherdevice'): 'foobar' }}}) self.send_message(LOCATION_TOPIC, MOCK_ENCRYPTED_LOCATION_MESSAGE) - self.assert_location_latitude(None) + assert self.hass.states.get(DEVICE_TRACKER_STATE) is None try: import libnacl diff --git a/tests/components/device_tracker/test_upc_connect.py b/tests/components/device_tracker/test_upc_connect.py index 1ef3aefa6a4..396d2b88b19 100644 --- a/tests/components/device_tracker/test_upc_connect.py +++ b/tests/components/device_tracker/test_upc_connect.py @@ -1,11 +1,11 @@ """The tests for the UPC ConnextBox device tracker platform.""" import asyncio -import os from unittest.mock import patch import logging +import pytest + from homeassistant.setup import setup_component -from homeassistant.components import device_tracker from homeassistant.const import ( CONF_PLATFORM, CONF_HOST) from homeassistant.components.device_tracker import DOMAIN @@ -14,7 +14,7 @@ from homeassistant.util.async import run_coroutine_threadsafe from tests.common import ( get_test_home_assistant, assert_setup_component, load_fixture, - mock_component) + mock_component, mock_coro) _LOGGER = logging.getLogger(__name__) @@ -25,6 +25,14 @@ def async_scan_devices_mock(scanner): return [] +@pytest.fixture(autouse=True) +def mock_load_config(): + """Mock device tracker loading config.""" + with patch('homeassistant.components.device_tracker.async_load_config', + return_value=mock_coro([])): + yield + + class TestUPCConnect(object): """Tests for the Ddwrt device tracker platform.""" @@ -32,16 +40,12 @@ class TestUPCConnect(object): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() mock_component(self.hass, 'zone') + mock_component(self.hass, 'group') self.host = "127.0.0.1" def teardown_method(self): """Stop everything that was started.""" - try: - os.remove(self.hass.config.path(device_tracker.YAML_DEVICES)) - except FileNotFoundError: - pass - self.hass.stop() @patch('homeassistant.components.device_tracker.upc_connect.'