mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Improve MQTT topic validation (#14099)
* Improve MQTT topic validation * Fix test * Improve length check
This commit is contained in:
parent
4b06392442
commit
9d1f9fe204
@ -90,22 +90,52 @@ ATTR_RETAIN = CONF_RETAIN
|
||||
MAX_RECONNECT_WAIT = 300 # seconds
|
||||
|
||||
|
||||
def valid_subscribe_topic(value: Any, invalid_chars='\0') -> str:
|
||||
"""Validate that we can subscribe using this MQTT topic."""
|
||||
def valid_topic(value: Any) -> str:
|
||||
"""Validate that this is a valid topic name/filter."""
|
||||
value = cv.string(value)
|
||||
if all(c not in value for c in invalid_chars):
|
||||
return vol.Length(min=1, max=65535)(value)
|
||||
raise vol.Invalid('Invalid MQTT topic name')
|
||||
try:
|
||||
raw_value = value.encode('utf-8')
|
||||
except UnicodeError:
|
||||
raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.")
|
||||
if not raw_value:
|
||||
raise vol.Invalid("MQTT topic name/filter must not be empty.")
|
||||
if len(raw_value) > 65535:
|
||||
raise vol.Invalid("MQTT topic name/filter must not be longer than "
|
||||
"65535 encoded bytes.")
|
||||
if '\0' in value:
|
||||
raise vol.Invalid("MQTT topic name/filter must not contain null "
|
||||
"character.")
|
||||
return value
|
||||
|
||||
|
||||
def valid_subscribe_topic(value: Any) -> str:
|
||||
"""Validate that we can subscribe using this MQTT topic."""
|
||||
value = valid_topic(value)
|
||||
for i in (i for i, c in enumerate(value) if c == '+'):
|
||||
if (i > 0 and value[i - 1] != '/') or \
|
||||
(i < len(value) - 1 and value[i + 1] != '/'):
|
||||
raise vol.Invalid("Single-level wildcard must occupy an entire "
|
||||
"level of the filter")
|
||||
|
||||
index = value.find('#')
|
||||
if index != -1:
|
||||
if index != len(value) - 1:
|
||||
# If there are multiple wildcards, this will also trigger
|
||||
raise vol.Invalid("Multi-level wildcard must be the last "
|
||||
"character in the topic filter.")
|
||||
if len(value) > 1 and value[index - 1] != '/':
|
||||
raise vol.Invalid("Multi-level wildcard must be after a topic "
|
||||
"level separator.")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def valid_publish_topic(value: Any) -> str:
|
||||
"""Validate that we can publish using this MQTT topic."""
|
||||
return valid_subscribe_topic(value, invalid_chars='#+\0')
|
||||
|
||||
|
||||
def valid_discovery_topic(value: Any) -> str:
|
||||
"""Validate a discovery topic."""
|
||||
return valid_subscribe_topic(value, invalid_chars='#+\0/')
|
||||
value = valid_topic(value)
|
||||
if '+' in value or '#' in value:
|
||||
raise vol.Invalid("Wildcards can not be used in topic names")
|
||||
return value
|
||||
|
||||
|
||||
_VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2]))
|
||||
@ -143,8 +173,10 @@ CONFIG_SCHEMA = vol.Schema({
|
||||
vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA,
|
||||
vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA,
|
||||
vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean,
|
||||
# discovery_prefix must be a valid publish topic because if no
|
||||
# state topic is specified, it will be created with the given prefix.
|
||||
vol.Optional(CONF_DISCOVERY_PREFIX,
|
||||
default=DEFAULT_DISCOVERY_PREFIX): valid_discovery_topic,
|
||||
default=DEFAULT_DISCOVERY_PREFIX): valid_publish_topic,
|
||||
}),
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
|
@ -131,10 +131,56 @@ class TestMQTTComponent(unittest.TestCase):
|
||||
self.hass.data['mqtt'].async_publish.call_args[0][2], 2)
|
||||
self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3])
|
||||
|
||||
def test_invalid_mqtt_topics(self):
|
||||
"""Test invalid topics."""
|
||||
def test_validate_topic(self):
|
||||
"""Test topic name/filter validation."""
|
||||
# Invalid UTF-8, must not contain U+D800 to U+DFFF.
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_topic, '\ud800')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_topic, '\udfff')
|
||||
# Topic MUST NOT be empty
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_topic, '')
|
||||
# Topic MUST NOT be longer than 65535 encoded bytes.
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_topic, 'ü' * 32768)
|
||||
# UTF-8 MUST NOT include null character
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_topic, 'bad\0one')
|
||||
|
||||
# Topics "SHOULD NOT" include these special characters
|
||||
# (not MUST NOT, RFC2119). The receiver MAY close the connection.
|
||||
mqtt.valid_topic('\u0001')
|
||||
mqtt.valid_topic('\u001F')
|
||||
mqtt.valid_topic('\u009F')
|
||||
mqtt.valid_topic('\u009F')
|
||||
mqtt.valid_topic('\uffff')
|
||||
|
||||
def test_validate_subscribe_topic(self):
|
||||
"""Test invalid subscribe topics."""
|
||||
mqtt.valid_subscribe_topic('#')
|
||||
mqtt.valid_subscribe_topic('sport/#')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/#/')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'foo/bar#')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'foo/#/bar')
|
||||
|
||||
mqtt.valid_subscribe_topic('+')
|
||||
mqtt.valid_subscribe_topic('+/tennis/#')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport+')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport+/')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/+1')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/+#')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad+topic')
|
||||
mqtt.valid_subscribe_topic('sport/+/player1')
|
||||
mqtt.valid_subscribe_topic('/finance')
|
||||
mqtt.valid_subscribe_topic('+/+')
|
||||
mqtt.valid_subscribe_topic('$SYS/#')
|
||||
|
||||
def test_validate_publish_topic(self):
|
||||
"""Test invalid publish topics."""
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'pub+')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'pub/+')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, '1#')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic')
|
||||
self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one')
|
||||
mqtt.valid_publish_topic('//')
|
||||
|
||||
# Topic names beginning with $ SHOULD NOT be used, but can
|
||||
mqtt.valid_publish_topic('$SYS/')
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
Loading…
x
Reference in New Issue
Block a user