Improve MQTT topic validation (#14099)

* Improve MQTT topic validation

* Fix test

* Improve length check
This commit is contained in:
Otto Winter 2018-04-27 13:15:45 +02:00 committed by Pascal Vizeli
parent 4b06392442
commit 9d1f9fe204
2 changed files with 93 additions and 15 deletions

View File

@ -90,22 +90,52 @@ ATTR_RETAIN = CONF_RETAIN
MAX_RECONNECT_WAIT = 300 # seconds
def valid_subscribe_topic(value: Any, invalid_chars='\0') -> str:
"""Validate that we can subscribe using this MQTT topic."""
def valid_topic(value: Any) -> str:
"""Validate that this is a valid topic name/filter."""
value = cv.string(value)
if all(c not in value for c in invalid_chars):
return vol.Length(min=1, max=65535)(value)
raise vol.Invalid('Invalid MQTT topic name')
try:
raw_value = value.encode('utf-8')
except UnicodeError:
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
if not raw_value:
raise vol.Invalid("MQTT topic name/filter must not be empty.")
if len(raw_value) > 65535:
raise vol.Invalid("MQTT topic name/filter must not be longer than "
"65535 encoded bytes.")
if '\0' in value:
raise vol.Invalid("MQTT topic name/filter must not contain null "
"character.")
return value
def valid_subscribe_topic(value: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = valid_topic(value)
for i in (i for i, c in enumerate(value) if c == '+'):
if (i > 0 and value[i - 1] != '/') or \
(i < len(value) - 1 and value[i + 1] != '/'):
raise vol.Invalid("Single-level wildcard must occupy an entire "
"level of the filter")
index = value.find('#')
if index != -1:
if index != len(value) - 1:
# If there are multiple wildcards, this will also trigger
raise vol.Invalid("Multi-level wildcard must be the last "
"character in the topic filter.")
if len(value) > 1 and value[index - 1] != '/':
raise vol.Invalid("Multi-level wildcard must be after a topic "
"level separator.")
return value
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
return valid_subscribe_topic(value, invalid_chars='#+\0')
def valid_discovery_topic(value: Any) -> str:
"""Validate a discovery topic."""
return valid_subscribe_topic(value, invalid_chars='#+\0/')
value = valid_topic(value)
if '+' in value or '#' in value:
raise vol.Invalid("Wildcards can not be used in topic names")
return value
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
@ -143,8 +173,10 @@ CONFIG_SCHEMA = vol.Schema({
vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA,
vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA,
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
# discovery_prefix must be a valid publish topic because if no
# state topic is specified, it will be created with the given prefix.
vol.Optional(CONF_DISCOVERY_PREFIX,
default=DEFAULT_DISCOVERY_PREFIX): valid_discovery_topic,
default=DEFAULT_DISCOVERY_PREFIX): valid_publish_topic,
}),
}, extra=vol.ALLOW_EXTRA)

View File

@ -131,10 +131,56 @@ class TestMQTTComponent(unittest.TestCase):
self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
def test_invalid_mqtt_topics(self):
"""Test invalid topics."""
def test_validate_topic(self):
"""Test topic name/filter validation."""
# Invalid UTF-8, must not contain U+D800 to U+DFFF.
self.assertRaises(vol.Invalid, mqtt.valid_topic, '\ud800')
self.assertRaises(vol.Invalid, mqtt.valid_topic, '\udfff')
# Topic MUST NOT be empty
self.assertRaises(vol.Invalid, mqtt.valid_topic, '')
# Topic MUST NOT be longer than 65535 encoded bytes.
self.assertRaises(vol.Invalid, mqtt.valid_topic, 'ü' * 32768)
# UTF-8 MUST NOT include null character
self.assertRaises(vol.Invalid, mqtt.valid_topic, 'bad\0one')
# Topics "SHOULD NOT" include these special characters
# (not MUST NOT, RFC2119). The receiver MAY close the connection.
mqtt.valid_topic('\u0001')
mqtt.valid_topic('\u001F')
mqtt.valid_topic('\u009F')
mqtt.valid_topic('\u009F')
mqtt.valid_topic('\uffff')
def test_validate_subscribe_topic(self):
"""Test invalid subscribe topics."""
mqtt.valid_subscribe_topic('#')
mqtt.valid_subscribe_topic('sport/#')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/#/')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'foo/bar#')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'foo/#/bar')
mqtt.valid_subscribe_topic('+')
mqtt.valid_subscribe_topic('+/tennis/#')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport+')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport+/')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/+1')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/+#')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad+topic')
mqtt.valid_subscribe_topic('sport/+/player1')
mqtt.valid_subscribe_topic('/finance')
mqtt.valid_subscribe_topic('+/+')
mqtt.valid_subscribe_topic('$SYS/#')
def test_validate_publish_topic(self):
"""Test invalid publish topics."""
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'pub+')
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'pub/+')
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, '1#')
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic')
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one')
mqtt.valid_publish_topic('//')
# Topic names beginning with $ SHOULD NOT be used, but can
mqtt.valid_publish_topic('$SYS/')
# pylint: disable=invalid-name