OwnTracks Message Handling (#10489)

* Improve handling and logging of unsupported owntracks message types

Added generic handlers for message types that are valid but not
supported by the HA component (lwt, beacon, etc.) and for
message types which are invalid. Valid but not supported
messages will now be logged as DEBUG. Invalid messages will
be logged as WARNING.

Supporting single "waypoint" messages in addition to the
roll-up "waypoints" messages.

Added tests around these features.

* Style fixes
This commit is contained in:
Eric Hagan 2017-11-10 11:29:21 -06:00 committed by Paulus Schoutsen
parent 0490ca67d1
commit 7d9d299d5a
2 changed files with 107 additions and 37 deletions

View File

@ -367,6 +367,29 @@ def async_handle_transition_message(hass, context, message):
message['event']) message['event'])
@asyncio.coroutine
def async_handle_waypoint(hass, name_base, waypoint):
"""Handle a waypoint."""
name = waypoint['desc']
pretty_name = '{} - {}'.format(name_base, name)
lat = waypoint['lat']
lon = waypoint['lon']
rad = waypoint['rad']
# check zone exists
entity_id = zone_comp.ENTITY_ID_FORMAT.format(slugify(pretty_name))
# Check if state already exists
if hass.states.get(entity_id) is not None:
return
zone = zone_comp.Zone(hass, pretty_name, lat, lon, rad,
zone_comp.ICON_IMPORT, False)
zone.entity_id = entity_id
yield from zone.async_update_ha_state()
@HANDLERS.register('waypoint')
@HANDLERS.register('waypoints') @HANDLERS.register('waypoints')
@asyncio.coroutine @asyncio.coroutine
def async_handle_waypoints_message(hass, context, message): def async_handle_waypoints_message(hass, context, message):
@ -380,30 +403,17 @@ def async_handle_waypoints_message(hass, context, message):
if user not in context.waypoint_whitelist: if user not in context.waypoint_whitelist:
return return
if 'waypoints' in message:
wayps = message['waypoints'] wayps = message['waypoints']
else:
wayps = [message]
_LOGGER.info("Got %d waypoints from %s", len(wayps), message['topic']) _LOGGER.info("Got %d waypoints from %s", len(wayps), message['topic'])
name_base = ' '.join(_parse_topic(message['topic'])) name_base = ' '.join(_parse_topic(message['topic']))
for wayp in wayps: for wayp in wayps:
name = wayp['desc'] yield from async_handle_waypoint(hass, name_base, wayp)
pretty_name = '{} - {}'.format(name_base, name)
lat = wayp['lat']
lon = wayp['lon']
rad = wayp['rad']
# check zone exists
entity_id = zone_comp.ENTITY_ID_FORMAT.format(slugify(pretty_name))
# Check if state already exists
if hass.states.get(entity_id) is not None:
continue
zone = zone_comp.Zone(hass, pretty_name, lat, lon, rad,
zone_comp.ICON_IMPORT, False)
zone.entity_id = entity_id
yield from zone.async_update_ha_state()
@HANDLERS.register('encrypted') @HANDLERS.register('encrypted')
@ -423,10 +433,22 @@ def async_handle_encrypted_message(hass, context, message):
@HANDLERS.register('lwt') @HANDLERS.register('lwt')
@HANDLERS.register('configuration')
@HANDLERS.register('beacon')
@HANDLERS.register('cmd')
@HANDLERS.register('steps')
@HANDLERS.register('card')
@asyncio.coroutine @asyncio.coroutine
def async_handle_lwt_message(hass, context, message): def async_handle_not_impl_msg(hass, context, message):
"""Handle an lwt message.""" """Handle valid but not implemented message types."""
_LOGGER.debug('Not handling lwt message: %s', message) _LOGGER.debug('Not handling %s message: %s', message.get("_type"), message)
@asyncio.coroutine
def async_handle_unsupported_msg(hass, context, message):
"""Handle an unsupported or invalid message type."""
_LOGGER.warning('Received unsupported message type: %s.',
message.get('_type'))
@asyncio.coroutine @asyncio.coroutine
@ -434,11 +456,6 @@ def async_handle_message(hass, context, message):
"""Handle an OwnTracks message.""" """Handle an OwnTracks message."""
msgtype = message.get('_type') msgtype = message.get('_type')
handler = HANDLERS.get(msgtype) handler = HANDLERS.get(msgtype, async_handle_unsupported_msg)
if handler is None:
_LOGGER.warning(
'Received unsupported message type: %s.', msgtype)
return
yield from handler(hass, context, message) yield from handler(hass, context, message)

View File

@ -18,10 +18,13 @@ DEVICE = 'phone'
LOCATION_TOPIC = 'owntracks/{}/{}'.format(USER, DEVICE) LOCATION_TOPIC = 'owntracks/{}/{}'.format(USER, DEVICE)
EVENT_TOPIC = 'owntracks/{}/{}/event'.format(USER, DEVICE) EVENT_TOPIC = 'owntracks/{}/{}/event'.format(USER, DEVICE)
WAYPOINT_TOPIC = 'owntracks/{}/{}/waypoints'.format(USER, DEVICE) WAYPOINTS_TOPIC = 'owntracks/{}/{}/waypoints'.format(USER, DEVICE)
WAYPOINT_TOPIC = 'owntracks/{}/{}/waypoint'.format(USER, DEVICE)
USER_BLACKLIST = 'ram' USER_BLACKLIST = 'ram'
WAYPOINT_TOPIC_BLOCKED = 'owntracks/{}/{}/waypoints'.format( WAYPOINTS_TOPIC_BLOCKED = 'owntracks/{}/{}/waypoints'.format(
USER_BLACKLIST, DEVICE) USER_BLACKLIST, DEVICE)
LWT_TOPIC = 'owntracks/{}/{}/lwt'.format(USER, DEVICE)
BAD_TOPIC = 'owntracks/{}/{}/unsupported'.format(USER, DEVICE)
DEVICE_TRACKER_STATE = 'device_tracker.{}_{}'.format(USER, DEVICE) DEVICE_TRACKER_STATE = 'device_tracker.{}_{}'.format(USER, DEVICE)
@ -232,6 +235,15 @@ WAYPOINTS_UPDATED_MESSAGE = {
] ]
} }
WAYPOINT_MESSAGE = {
"_type": "waypoint",
"tst": 4,
"lat": 9,
"lon": 47,
"rad": 50,
"desc": "exp_wayp1"
}
WAYPOINT_ENTITY_NAMES = [ WAYPOINT_ENTITY_NAMES = [
'zone.greg_phone__exp_wayp1', 'zone.greg_phone__exp_wayp1',
'zone.greg_phone__exp_wayp2', 'zone.greg_phone__exp_wayp2',
@ -239,10 +251,26 @@ WAYPOINT_ENTITY_NAMES = [
'zone.ram_phone__exp_wayp2', 'zone.ram_phone__exp_wayp2',
] ]
LWT_MESSAGE = {
"_type": "lwt",
"tst": 1
}
BAD_MESSAGE = {
"_type": "unsupported",
"tst": 1
}
BAD_JSON_PREFIX = '--$this is bad json#--' BAD_JSON_PREFIX = '--$this is bad json#--'
BAD_JSON_SUFFIX = '** and it ends here ^^' BAD_JSON_SUFFIX = '** and it ends here ^^'
# def raise_on_not_implemented(hass, context, message):
def raise_on_not_implemented():
"""Throw NotImplemented."""
raise NotImplementedError("oopsie")
class BaseMQTT(unittest.TestCase): class BaseMQTT(unittest.TestCase):
"""Base MQTT assert functions.""" """Base MQTT assert functions."""
@ -1056,7 +1084,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT):
def test_waypoint_import_simple(self): def test_waypoint_import_simple(self):
"""Test a simple import of list of waypoints.""" """Test a simple import of list of waypoints."""
waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC, waypoints_message) self.send_message(WAYPOINTS_TOPIC, waypoints_message)
# Check if it made it into states # Check if it made it into states
wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0])
self.assertTrue(wayp is not None) self.assertTrue(wayp is not None)
@ -1066,7 +1094,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT):
def test_waypoint_import_blacklist(self): def test_waypoint_import_blacklist(self):
"""Test import of list of waypoints for blacklisted user.""" """Test import of list of waypoints for blacklisted user."""
waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC_BLOCKED, waypoints_message) self.send_message(WAYPOINTS_TOPIC_BLOCKED, waypoints_message)
# Check if it made it into states # Check if it made it into states
wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2]) wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2])
self.assertTrue(wayp is None) self.assertTrue(wayp is None)
@ -1088,7 +1116,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT):
run_coroutine_threadsafe(owntracks.async_setup_scanner( run_coroutine_threadsafe(owntracks.async_setup_scanner(
self.hass, test_config, mock_see), self.hass.loop).result() self.hass, test_config, mock_see), self.hass.loop).result()
waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC_BLOCKED, waypoints_message) self.send_message(WAYPOINTS_TOPIC_BLOCKED, waypoints_message)
# Check if it made it into states # Check if it made it into states
wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2]) wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2])
self.assertTrue(wayp is not None) self.assertTrue(wayp is not None)
@ -1098,7 +1126,7 @@ class TestDeviceTrackerOwnTracks(BaseMQTT):
def test_waypoint_import_bad_json(self): def test_waypoint_import_bad_json(self):
"""Test importing a bad JSON payload.""" """Test importing a bad JSON payload."""
waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC, waypoints_message, True) self.send_message(WAYPOINTS_TOPIC, waypoints_message, True)
# Check if it made it into states # Check if it made it into states
wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2]) wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[2])
self.assertTrue(wayp is None) self.assertTrue(wayp is None)
@ -1108,15 +1136,40 @@ class TestDeviceTrackerOwnTracks(BaseMQTT):
def test_waypoint_import_existing(self): def test_waypoint_import_existing(self):
"""Test importing a zone that exists.""" """Test importing a zone that exists."""
waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy() waypoints_message = WAYPOINTS_EXPORTED_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC, waypoints_message) self.send_message(WAYPOINTS_TOPIC, waypoints_message)
# Get the first waypoint exported # Get the first waypoint exported
wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0])
# Send an update # Send an update
waypoints_message = WAYPOINTS_UPDATED_MESSAGE.copy() waypoints_message = WAYPOINTS_UPDATED_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC, waypoints_message) self.send_message(WAYPOINTS_TOPIC, waypoints_message)
new_wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0]) new_wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0])
self.assertTrue(wayp == new_wayp) self.assertTrue(wayp == new_wayp)
def test_single_waypoint_import(self):
"""Test single waypoint message."""
waypoint_message = WAYPOINT_MESSAGE.copy()
self.send_message(WAYPOINT_TOPIC, waypoint_message)
wayp = self.hass.states.get(WAYPOINT_ENTITY_NAMES[0])
self.assertTrue(wayp is not None)
def test_not_implemented_message(self):
"""Handle not implemented message type."""
patch_handler = patch('homeassistant.components.device_tracker.'
'owntracks.async_handle_not_impl_msg',
return_value=mock_coro(False))
patch_handler.start()
self.assertFalse(self.send_message(LWT_TOPIC, LWT_MESSAGE))
patch_handler.stop()
def test_unsupported_message(self):
"""Handle not implemented message type."""
patch_handler = patch('homeassistant.components.device_tracker.'
'owntracks.async_handle_unsupported_msg',
return_value=mock_coro(False))
patch_handler.start()
self.assertFalse(self.send_message(BAD_TOPIC, BAD_MESSAGE))
patch_handler.stop()
def generate_ciphers(secret): def generate_ciphers(secret):
"""Generate test ciphers for the DEFAULT_LOCATION_MESSAGE.""" """Generate test ciphers for the DEFAULT_LOCATION_MESSAGE."""
@ -1143,7 +1196,7 @@ def generate_ciphers(secret):
json.dumps(DEFAULT_LOCATION_MESSAGE).encode("utf-8")) json.dumps(DEFAULT_LOCATION_MESSAGE).encode("utf-8"))
) )
).decode("utf-8") ).decode("utf-8")
return (ctxt, mctxt) return ctxt, mctxt
TEST_SECRET_KEY = 's3cretkey' TEST_SECRET_KEY = 's3cretkey'
@ -1172,7 +1225,7 @@ def mock_cipher():
if key != mkey: if key != mkey:
raise ValueError() raise ValueError()
return plaintext return plaintext
return (len(TEST_SECRET_KEY), mock_decrypt) return len(TEST_SECRET_KEY), mock_decrypt
class TestDeviceTrackerOwnTrackConfigs(BaseMQTT): class TestDeviceTrackerOwnTrackConfigs(BaseMQTT):