Reduce overhead to validate mqtt topics (#117891)

* Reduce overhead to validate mqtt topics

valid_topic would iterate all the chars 4x, refactor to only
do it 1x

valid_subscribe_topic would enumerate all the chars when there was
no + in the string

* check if adding a cache helps

* tweak lrus based on testing stats

* note to future maintainers

* note to future maintainers

* keep standard lru_cache size as increasing makes no material difference
This commit is contained in:
J. Nick Koston 2024-05-21 17:11:05 -10:00 committed by GitHub
parent 2f0215b034
commit f42b98336c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -123,7 +123,16 @@ async def async_wait_for_mqtt_client(hass: HomeAssistant) -> bool:
def valid_topic(topic: Any) -> str:
"""Validate that this is a valid topic name/filter."""
"""Validate that this is a valid topic name/filter.
This function is not cached and is not expected to be called
directly outside of this module. It is not marked as protected
only because its tested directly in test_util.py.
If it gets used outside of valid_subscribe_topic and
valid_publish_topic, it may need an lru_cache decorator or
an lru_cache decorator on the function where its used.
"""
validated_topic = cv.string(topic)
try:
raw_validated_topic = validated_topic.encode("utf-8")
@ -135,23 +144,25 @@ def valid_topic(topic: Any) -> str:
raise vol.Invalid(
"MQTT topic name/filter must not be longer than 65535 encoded bytes."
)
if "\0" in validated_topic:
for char in validated_topic:
if char == "\0":
raise vol.Invalid("MQTT topic name/filter must not contain null character.")
if any(char <= "\u001f" for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
if any("\u007f" <= char <= "\u009f" for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain control characters.")
if any("\ufdd0" <= char <= "\ufdef" for char in validated_topic):
if char <= "\u001f" or "\u007f" <= char <= "\u009f":
raise vol.Invalid(
"MQTT topic name/filter must not contain control characters."
)
if "\ufdd0" <= char <= "\ufdef" or (ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF):
raise vol.Invalid("MQTT topic name/filter must not contain non-characters.")
if any((ord(char) & 0xFFFF) in (0xFFFE, 0xFFFF) for char in validated_topic):
raise vol.Invalid("MQTT topic name/filter must not contain noncharacters.")
return validated_topic
@lru_cache
def valid_subscribe_topic(topic: Any) -> str:
"""Validate that we can subscribe using this MQTT topic."""
validated_topic = valid_topic(topic)
if "+" in validated_topic:
for i in (i for i, c in enumerate(validated_topic) if c == "+"):
if (i > 0 and validated_topic[i - 1] != "/") or (
i < len(validated_topic) - 1 and validated_topic[i + 1] != "/"
@ -185,6 +196,7 @@ def valid_subscribe_topic_template(value: Any) -> template.Template:
return tpl
@lru_cache
def valid_publish_topic(topic: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
validated_topic = valid_topic(topic)