diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index b81a4fc16a7..55d99a0817e 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -90,22 +90,52 @@ ATTR_RETAIN = CONF_RETAIN MAX_RECONNECT_WAIT = 300 # seconds -def valid_subscribe_topic(value: Any, invalid_chars='\0') -> str: - """Validate that we can subscribe using this MQTT topic.""" +def valid_topic(value: Any) -> str: + """Validate that this is a valid topic name/filter.""" value = cv.string(value) - if all(c not in value for c in invalid_chars): - return vol.Length(min=1, max=65535)(value) - raise vol.Invalid('Invalid MQTT topic name') + try: + raw_value = value.encode('utf-8') + except UnicodeError: + raise vol.Invalid("MQTT topic name/filter must be valid UTF-8 string.") + if not raw_value: + raise vol.Invalid("MQTT topic name/filter must not be empty.") + if len(raw_value) > 65535: + raise vol.Invalid("MQTT topic name/filter must not be longer than " + "65535 encoded bytes.") + if '\0' in value: + raise vol.Invalid("MQTT topic name/filter must not contain null " + "character.") + return value + + +def valid_subscribe_topic(value: Any) -> str: + """Validate that we can subscribe using this MQTT topic.""" + value = valid_topic(value) + for i in (i for i, c in enumerate(value) if c == '+'): + if (i > 0 and value[i - 1] != '/') or \ + (i < len(value) - 1 and value[i + 1] != '/'): + raise vol.Invalid("Single-level wildcard must occupy an entire " + "level of the filter") + + index = value.find('#') + if index != -1: + if index != len(value) - 1: + # If there are multiple wildcards, this will also trigger + raise vol.Invalid("Multi-level wildcard must be the last " + "character in the topic filter.") + if len(value) > 1 and value[index - 1] != '/': + raise vol.Invalid("Multi-level wildcard must be after a topic " + "level separator.") + + return value def valid_publish_topic(value: Any) -> str: """Validate that we can publish using this MQTT topic.""" - return valid_subscribe_topic(value, invalid_chars='#+\0') - - -def valid_discovery_topic(value: Any) -> str: - """Validate a discovery topic.""" - return valid_subscribe_topic(value, invalid_chars='#+\0/') + value = valid_topic(value) + if '+' in value or '#' in value: + raise vol.Invalid("Wildcards can not be used in topic names") + return value _VALID_QOS_SCHEMA = vol.All(vol.Coerce(int), vol.In([0, 1, 2])) @@ -143,8 +173,10 @@ CONFIG_SCHEMA = vol.Schema({ vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, vol.Optional(CONF_DISCOVERY, default=DEFAULT_DISCOVERY): cv.boolean, + # discovery_prefix must be a valid publish topic because if no + # state topic is specified, it will be created with the given prefix. vol.Optional(CONF_DISCOVERY_PREFIX, - default=DEFAULT_DISCOVERY_PREFIX): valid_discovery_topic, + default=DEFAULT_DISCOVERY_PREFIX): valid_publish_topic, }), }, extra=vol.ALLOW_EXTRA) diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index b25479bb75a..05c5de71b8c 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -131,10 +131,56 @@ class TestMQTTComponent(unittest.TestCase): self.hass.data['mqtt'].async_publish.call_args[0][2], 2) self.assertFalse(self.hass.data['mqtt'].async_publish.call_args[0][3]) - def test_invalid_mqtt_topics(self): - """Test invalid topics.""" + def test_validate_topic(self): + """Test topic name/filter validation.""" + # Invalid UTF-8, must not contain U+D800 to U+DFFF. + self.assertRaises(vol.Invalid, mqtt.valid_topic, '\ud800') + self.assertRaises(vol.Invalid, mqtt.valid_topic, '\udfff') + # Topic MUST NOT be empty + self.assertRaises(vol.Invalid, mqtt.valid_topic, '') + # Topic MUST NOT be longer than 65535 encoded bytes. + self.assertRaises(vol.Invalid, mqtt.valid_topic, 'ΓΌ' * 32768) + # UTF-8 MUST NOT include null character + self.assertRaises(vol.Invalid, mqtt.valid_topic, 'bad\0one') + + # Topics "SHOULD NOT" include these special characters + # (not MUST NOT, RFC2119). The receiver MAY close the connection. + mqtt.valid_topic('\u0001') + mqtt.valid_topic('\u001F') + mqtt.valid_topic('\u009F') + mqtt.valid_topic('\u009F') + mqtt.valid_topic('\uffff') + + def test_validate_subscribe_topic(self): + """Test invalid subscribe topics.""" + mqtt.valid_subscribe_topic('#') + mqtt.valid_subscribe_topic('sport/#') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/#/') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'foo/bar#') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'foo/#/bar') + + mqtt.valid_subscribe_topic('+') + mqtt.valid_subscribe_topic('+/tennis/#') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport+') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport+/') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/+1') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'sport/+#') + self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad+topic') + mqtt.valid_subscribe_topic('sport/+/player1') + mqtt.valid_subscribe_topic('/finance') + mqtt.valid_subscribe_topic('+/+') + mqtt.valid_subscribe_topic('$SYS/#') + + def test_validate_publish_topic(self): + """Test invalid publish topics.""" + self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'pub+') + self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'pub/+') + self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, '1#') self.assertRaises(vol.Invalid, mqtt.valid_publish_topic, 'bad+topic') - self.assertRaises(vol.Invalid, mqtt.valid_subscribe_topic, 'bad\0one') + mqtt.valid_publish_topic('//') + + # Topic names beginning with $ SHOULD NOT be used, but can + mqtt.valid_publish_topic('$SYS/') # pylint: disable=invalid-name