diff --git a/homeassistant/components/device_tracker/owntracks.py b/homeassistant/components/device_tracker/owntracks.py index 77241e1a8ab..0c869dd4b57 100644 --- a/homeassistant/components/device_tracker/owntracks.py +++ b/homeassistant/components/device_tracker/owntracks.py @@ -367,6 +367,29 @@ def async_handle_transition_message(hass, context, message): message['event']) +@asyncio.coroutine +def async_handle_waypoint(hass, name_base, waypoint): + """Handle a waypoint.""" + name = waypoint['desc'] + pretty_name = '{} - {}'.format(name_base, name) + lat = waypoint['lat'] + lon = waypoint['lon'] + rad = waypoint['rad'] + + # check zone exists + entity_id = zone_comp.ENTITY_ID_FORMAT.format(slugify(pretty_name)) + + # Check if state already exists + if hass.states.get(entity_id) is not None: + return + + zone = zone_comp.Zone(hass, pretty_name, lat, lon, rad, + zone_comp.ICON_IMPORT, False) + zone.entity_id = entity_id + yield from zone.async_update_ha_state() + + +@HANDLERS.register('waypoint') @HANDLERS.register('waypoints') @asyncio.coroutine def async_handle_waypoints_message(hass, context, message): @@ -380,30 +403,17 @@ def async_handle_waypoints_message(hass, context, message): if user not in context.waypoint_whitelist: return - wayps = message['waypoints'] + if 'waypoints' in message: + wayps = message['waypoints'] + else: + wayps = [message] _LOGGER.info("Got %d waypoints from %s", len(wayps), message['topic']) name_base = ' '.join(_parse_topic(message['topic'])) for wayp in wayps: - name = wayp['desc'] - pretty_name = '{} - {}'.format(name_base, name) - lat = wayp['lat'] - lon = wayp['lon'] - rad = wayp['rad'] - - # check zone exists - entity_id = zone_comp.ENTITY_ID_FORMAT.format(slugify(pretty_name)) - - # Check if state already exists - if hass.states.get(entity_id) is not None: - continue - - zone = zone_comp.Zone(hass, pretty_name, lat, lon, rad, - zone_comp.ICON_IMPORT, False) - zone.entity_id = entity_id - yield from zone.async_update_ha_state() + yield from async_handle_waypoint(hass, name_base, wayp) @HANDLERS.register('encrypted') @@ -423,10 +433,22 @@ def async_handle_encrypted_message(hass, context, message): @HANDLERS.register('lwt') +@HANDLERS.register('configuration') +@HANDLERS.register('beacon') +@HANDLERS.register('cmd') +@HANDLERS.register('steps') +@HANDLERS.register('card') @asyncio.coroutine -def async_handle_lwt_message(hass, context, message): - """Handle an lwt message.""" - _LOGGER.debug('Not handling lwt message: %s', message) +def async_handle_not_impl_msg(hass, context, message): + """Handle valid but not implemented message types.""" + _LOGGER.debug('Not handling %s message: %s', message.get("_type"), message) + + +@asyncio.coroutine +def async_handle_unsupported_msg(hass, context, message): + """Handle an unsupported or invalid message type.""" + _LOGGER.warning('Received unsupported message type: %s.', + message.get('_type')) @asyncio.coroutine @@ -434,11 +456,6 @@ def async_handle_message(hass, context, message): """Handle an OwnTracks message.""" msgtype = message.get('_type') - handler = HANDLERS.get(msgtype) - - if handler is None: - _LOGGER.warning( - 'Received unsupported message type: %s.', msgtype) - return + handler = HANDLERS.get(msgtype, async_handle_unsupported_msg) yield from handler(hass, context, message) diff --git a/tests/components/device_tracker/test_owntracks.py b/tests/components/device_tracker/test_owntracks.py index a06adcb286a..4f5efb9d09d 100644 --- a/tests/components/device_tracker/test_owntracks.py +++ b/tests/components/device_tracker/test_owntracks.py @@ -18,10 +18,13 @@ DEVICE = 'phone' LOCATION_TOPIC = 'owntracks/{}/{}'.format(USER, DEVICE) EVENT_TOPIC = 'owntracks/{}/{}/event'.format(USER, DEVICE) -WAYPOINT_TOPIC = 'owntracks/{}/{}/waypoints'.format(USER, DEVICE) +WAYPOINTS_TOPIC = 'owntracks/{}/{}/waypoints'.format(USER, DEVICE) +WAYPOINT_TOPIC = 'owntracks/{}/{}/waypoint'.format(USER, DEVICE) USER_BLACKLIST = 'ram' -WAYPOINT_TOPIC_BLOCKED = 'owntracks/{}/{}/waypoints'.format( +WAYPOINTS_TOPIC_BLOCKED = 'owntracks/{}/{}/waypoints'.format( USER_BLACKLIST, DEVICE) +LWT_TOPIC = 'owntracks/{}/{}/lwt'.format(USER, DEVICE) +BAD_TOPIC = 'owntracks/{}/{}/unsupported'.format(USER, DEVICE) DEVICE_TRACKER_STATE = 'device_tracker.{}_{}'.format(USER, DEVICE) @@ -232,6 +235,15 @@ WAYPOINTS_UPDATED_MESSAGE = { ] } +WAYPOINT_MESSAGE = { + "_type": "waypoint", + "tst": 4, + "lat": 9, + "lon": 47, + "rad": 50, + "desc": "exp_wayp1" +} + WAYPOINT_ENTITY_NAMES = [ 'zone.greg_phone__exp_wayp1', 'zone.greg_phone__exp_wayp2', @@ -239,10 +251,26 @@ WAYPOINT_ENTITY_NAMES = [ 'zone.ram_phone__exp_wayp2', ] +LWT_MESSAGE = { + "_type": "lwt", + "tst": 1 +} + +BAD_MESSAGE = { + "_type": "unsupported", + "tst": 1 +} + BAD_JSON_PREFIX = '--$this is bad json#--' BAD_JSON_SUFFIX = '** and it ends here ^^' +# def raise_on_not_implemented(hass, context, message): +def raise_on_not_implemented(): + """Throw NotImplemented.""" + raise NotImplementedError("oopsie") + + class BaseMQTT(unittest.TestCase): """Base MQTT assert functions.""" @@ -1056,7 +1084,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): def test_waypoint_import_simple(self): """Test a simple import of list of waypoints.""" waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() - self.send_message(WAYPOINT_TOPIC, waypoints_message) + self.send_message(WAYPOINTS_TOPIC, waypoints_message) # Check if it made it into states wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) self.assertTrue(wayp is not None) @@ -1066,7 +1094,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): def test_waypoint_import_blacklist(self): """Test import of list of waypoints for blacklisted user.""" waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() - self.send_message(WAYPOINT_TOPIC_BLOCKED, waypoints_message) + self.send_message(WAYPOINTS_TOPIC_BLOCKED, waypoints_message) # Check if it made it into states wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2]) self.assertTrue(wayp is None) @@ -1088,7 +1116,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): run_coroutine_threadsafe(owntracks.async_setup_scanner( self.hass, test_config, mock_see), self.hass.loop).result() waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() - self.send_message(WAYPOINT_TOPIC_BLOCKED, waypoints_message) + self.send_message(WAYPOINTS_TOPIC_BLOCKED, waypoints_message) # Check if it made it into states wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2]) self.assertTrue(wayp is not None) @@ -1098,7 +1126,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): def test_waypoint_import_bad_json(self): """Test importing a bad JSON payload.""" waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() - self.send_message(WAYPOINT_TOPIC, waypoints_message, True) + self.send_message(WAYPOINTS_TOPIC, waypoints_message, True) # Check if it made it into states wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2]) self.assertTrue(wayp is None) @@ -1108,15 +1136,40 @@ class TestDeviceTrackerOwnTracks(BaseMQTT): def test_waypoint_import_existing(self): """Test importing a zone that exists.""" waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() - self.send_message(WAYPOINT_TOPIC, waypoints_message) + self.send_message(WAYPOINTS_TOPIC, waypoints_message) # Get the first waypoint exported wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) # Send an update waypoints_message = WAYPOINTS_UPDATED_MESSAGE.copy() - self.send_message(WAYPOINT_TOPIC, waypoints_message) + self.send_message(WAYPOINTS_TOPIC, waypoints_message) new_wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) self.assertTrue(wayp == new_wayp) + def test_single_waypoint_import(self): + """Test single waypoint message.""" + waypoint_message = WAYPOINT_MESSAGE.copy() + self.send_message(WAYPOINT_TOPIC, waypoint_message) + wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) + self.assertTrue(wayp is not None) + + def test_not_implemented_message(self): + """Handle not implemented message type.""" + patch_handler = patch('homeassistant.components.device_tracker.' + 'owntracks.async_handle_not_impl_msg', + return_value=mock_coro(False)) + patch_handler.start() + self.assertFalse(self.send_message(LWT_TOPIC, LWT_MESSAGE)) + patch_handler.stop() + + def test_unsupported_message(self): + """Handle not implemented message type.""" + patch_handler = patch('homeassistant.components.device_tracker.' + 'owntracks.async_handle_unsupported_msg', + return_value=mock_coro(False)) + patch_handler.start() + self.assertFalse(self.send_message(BAD_TOPIC, BAD_MESSAGE)) + patch_handler.stop() + def generate_ciphers(secret): """Generate test ciphers for the DEFAULT_LOCATION_MESSAGE.""" @@ -1143,7 +1196,7 @@ def generate_ciphers(secret): json.dumps(DEFAULT_LOCATION_MESSAGE).encode("utf-8")) ) ).decode("utf-8") - return (ctxt, mctxt) + return ctxt, mctxt TEST_SECRET_KEY = 's3cretkey' @@ -1172,7 +1225,7 @@ def mock_cipher(): if key != mkey: raise ValueError() return plaintext - return (len(TEST_SECRET_KEY), mock_decrypt) + return len(TEST_SECRET_KEY), mock_decrypt class TestDeviceTrackerOwnTrackConfigs(BaseMQTT):