Use defaultdict instead of setdefault in mqtt client (#118070)

This commit is contained in:
J. Nick Koston 2024-05-24 14:34:06 -10:00 committed by GitHub
parent 3031e4733b
commit 90d10dd773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections import defaultdict
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
import contextlib import contextlib
from dataclasses import dataclass from dataclasses import dataclass
@ -428,13 +429,15 @@ class MQTT:
self.config_entry = config_entry self.config_entry = config_entry
self.conf = conf self.conf = conf
self._simple_subscriptions: dict[str, list[Subscription]] = {} self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict(
list
)
self._wildcard_subscriptions: list[Subscription] = [] self._wildcard_subscriptions: list[Subscription] = []
# _retained_topics prevents a Subscription from receiving a # _retained_topics prevents a Subscription from receiving a
# retained message more than once per topic. This prevents flooding # retained message more than once per topic. This prevents flooding
# already active subscribers when new subscribers subscribe to a topic # already active subscribers when new subscribers subscribe to a topic
# which has subscribed messages. # which has subscribed messages.
self._retained_topics: dict[Subscription, set[str]] = {} self._retained_topics: defaultdict[Subscription, set[str]] = defaultdict(set)
self.connected = False self.connected = False
self._ha_started = asyncio.Event() self._ha_started = asyncio.Event()
self._cleanup_on_unload: list[Callable[[], None]] = [] self._cleanup_on_unload: list[Callable[[], None]] = []
@ -786,9 +789,7 @@ class MQTT:
The caller is responsible clearing the cache of _matching_subscriptions. The caller is responsible clearing the cache of _matching_subscriptions.
""" """
if subscription.is_simple_match: if subscription.is_simple_match:
self._simple_subscriptions.setdefault(subscription.topic, []).append( self._simple_subscriptions[subscription.topic].append(subscription)
subscription
)
else: else:
self._wildcard_subscriptions.append(subscription) self._wildcard_subscriptions.append(subscription)
@ -1108,7 +1109,7 @@ class MQTT:
for subscription in subscriptions: for subscription in subscriptions:
if msg.retain: if msg.retain:
retained_topics = self._retained_topics.setdefault(subscription, set()) retained_topics = self._retained_topics[subscription]
# Skip if the subscription already received a retained message # Skip if the subscription already received a retained message
if topic in retained_topics: if topic in retained_topics:
continue continue