From f42b98336c0878cf62f72e352020641f96f19cd2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 21 May 2024 17:11:05 -1000 Subject: [PATCH] Reduce overhead to validate mqtt topics (#117891) * Reduce overhead to validate mqtt topics valid_topic would iterate all the chars 4x, refactor to only do it 1x valid_subscribe_topic would enumerate all the chars when there was no + in the string * check if adding a cache helps * tweak lrus based on testing stats * note to future maintainers * note to future maintainers * keep standard lru_cache size as increasing makes no material difference --- homeassistant/components/mqtt/util.py | 48 +++++++++++++++++---------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index 6f9fb8316bb..07275f8d215 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -123,7 +123,16 @@ async def async_wait_for_mqtt_client(hass: HomeAssistant) -> bool: def valid_topic(topic: Any) -> str: - """Validate that this is a valid topic name/filter.""" + """Validate that this is a valid topic name/filter. + + This function is not cached and is not expected to be called + directly outside of this module. It is not marked as protected + only because its tested directly in test_util.py. + + If it gets used outside of valid_subscribe_topic and + valid_publish_topic, it may need an lru_cache decorator or + an lru_cache decorator on the function where its used. + """ validated_topic = cv.string(topic) try: raw_validated_topic = validated_topic.encode("utf-8") @@ -135,30 +144,32 @@ def valid_topic(topic: Any) -> str: raise vol.Invalid( "MQTT topic name/filter must not be longer than 65535 encoded bytes." ) - if "\0" in validated_topic: - raise vol.Invalid("MQTT topic name/filter must not contain null character.") - if any(char <= "\u001f" for char in validated_topic): - raise vol.Invalid("MQTT topic name/filter must not contain control characters.") - if any("\u007f" <= char <= "\u009f" for char in validated_topic): - raise vol.Invalid("MQTT topic name/filter must not contain control characters.") - if any("\ufdd0" <= char <= "\ufdef" for char in validated_topic): - raise vol.Invalid("MQTT topic name/filter must not contain non-characters.") - if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in validated_topic): - raise vol.Invalid("MQTT topic name/filter must not contain noncharacters.") + + for char in validated_topic: + if char == "\0": + raise vol.Invalid("MQTT topic name/filter must not contain null character.") + if char <= "\u001f" or "\u007f" <= char <= "\u009f": + raise vol.Invalid( + "MQTT topic name/filter must not contain control characters." + ) + if "\ufdd0" <= char <= "\ufdef" or (ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF): + raise vol.Invalid("MQTT topic name/filter must not contain non-characters.") return validated_topic +@lru_cache def valid_subscribe_topic(topic: Any) -> str: """Validate that we can subscribe using this MQTT topic.""" validated_topic = valid_topic(topic) - for i in (i for i, c in enumerate(validated_topic) if c == "+"): - if (i > 0 and validated_topic[i - 1] != "/") or ( - i < len(validated_topic) - 1 and validated_topic[i + 1] != "/" - ): - raise vol.Invalid( - "Single-level wildcard must occupy an entire level of the filter" - ) + if "+" in validated_topic: + for i in (i for i, c in enumerate(validated_topic) if c == "+"): + if (i > 0 and validated_topic[i - 1] != "/") or ( + i < len(validated_topic) - 1 and validated_topic[i + 1] != "/" + ): + raise vol.Invalid( + "Single-level wildcard must occupy an entire level of the filter" + ) index = validated_topic.find("#") if index != -1: @@ -185,6 +196,7 @@ def valid_subscribe_topic_template(value: Any) -> template.Template: return tpl +@lru_cache def valid_publish_topic(topic: Any) -> str: """Validate that we can publish using this MQTT topic.""" validated_topic = valid_topic(topic)