Avoid catch_log_exception overhead in MQTT for simple callbacks (#118036)

This commit is contained in:
J. Nick Koston 2024-05-24 14:32:32 -10:00 committed by GitHub
parent 65a702761b
commit c7a1c59215
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,7 +27,15 @@ from homeassistant.const import (
CONF_USERNAME,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import CALLBACK_TYPE, Event, HassJob, HomeAssistant, callback
from homeassistant.core import (
CALLBACK_TYPE,
Event,
HassJob,
HassJobType,
HomeAssistant,
callback,
get_hassjob_callable_job_type,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.start import async_at_started
@ -35,7 +43,7 @@ from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import create_eager_task
from homeassistant.util.collection import chunked_or_all
from homeassistant.util.logging import catch_log_exception
from homeassistant.util.logging import catch_log_exception, log_exception
from .const import (
CONF_BIRTH_MESSAGE,
@ -202,13 +210,7 @@ async def async_subscribe(
) from exc
return await mqtt_data.client.async_subscribe(
topic,
catch_log_exception(
msg_callback,
lambda msg: (
f"Exception in {msg_callback.__name__} when handling msg on "
f"'{msg.topic}': '{msg.payload}'"
),
),
msg_callback,
qos,
encoding,
)
@ -828,6 +830,17 @@ class MQTT:
return
self._subscribe_debouncer.async_schedule()
def _exception_message(
self,
msg_callback: AsyncMessageCallbackType | MessageCallbackType,
msg: ReceiveMessage,
) -> str:
"""Return a string with the exception message."""
return (
f"Exception in {msg_callback.__name__} when handling msg on "
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
)
async def async_subscribe(
self,
topic: str,
@ -842,12 +855,21 @@ class MQTT:
if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!")
job_type = get_hassjob_callable_job_type(msg_callback)
if job_type is not HassJobType.Callback:
# Only wrap the callback with catch_log_exception
# if it is not a simple callback since we catch
# exceptions for simple callbacks inline for
# performance reasons.
msg_callback = catch_log_exception(
msg_callback, partial(self._exception_message, msg_callback)
)
job = HassJob(msg_callback, job_type=job_type)
is_simple_match = not ("+" in topic or "#" in topic)
matcher = None if is_simple_match else _matcher_for_topic(topic)
subscription = Subscription(
topic, is_simple_match, matcher, HassJob(msg_callback), qos, encoding
)
subscription = Subscription(topic, is_simple_match, matcher, job, qos, encoding)
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
@ -1126,7 +1148,18 @@ class MQTT:
msg_cache_by_subscription_topic[subscription_topic] = receive_msg
else:
receive_msg = msg_cache_by_subscription_topic[subscription_topic]
self.hass.async_run_hass_job(subscription.job, receive_msg)
job = subscription.job
if job.job_type is HassJobType.Callback:
# We do not wrap Callback jobs in catch_log_exception since
# its expensive and we have to do it 2x for every entity
try:
job.target(receive_msg)
except Exception: # noqa: BLE001
log_exception(
partial(self._exception_message, job.target, receive_msg)
)
else:
self.hass.async_run_hass_job(job, receive_msg)
self._mqtt_data.state_write_requests.process_write_state_requests(msg)
@callback