Add TYPE_CHECKING condition on type assertions for mqtt (#100107)

Add TYPE_CHECKING condition on type assertions
This commit is contained in:
Jan Bouwhuis 2023-09-11 10:58:33 +02:00 committed by GitHub
parent eb0099dee8
commit 20d0ebe3fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 39 additions and 26 deletions

View File

@ -5,7 +5,7 @@ import asyncio
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime from datetime import datetime
import logging import logging
from typing import Any, TypeVar, cast from typing import TYPE_CHECKING, Any, TypeVar, cast
import jinja2 import jinja2
import voluptuous as vol import voluptuous as vol
@ -313,6 +313,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
return return
if TYPE_CHECKING:
assert msg_topic is not None assert msg_topic is not None
await mqtt_data.client.async_publish(msg_topic, payload, qos, retain) await mqtt_data.client.async_publish(msg_topic, payload, qos, retain)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from base64 import b64decode from base64 import b64decode
import functools import functools
import logging import logging
from typing import TYPE_CHECKING
import voluptuous as vol import voluptuous as vol
@ -112,6 +113,7 @@ class MqttCamera(MqttEntity, Camera):
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload) self._last_image = b64decode(msg.payload)
else: else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes) assert isinstance(msg.payload, bytes)
self._last_image = msg.payload self._last_image = msg.payload

View File

@ -6,7 +6,7 @@ from collections.abc import Callable
import queue import queue
from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError from ssl import PROTOCOL_TLS_CLIENT, SSLContext, SSLError
from types import MappingProxyType from types import MappingProxyType
from typing import Any from typing import TYPE_CHECKING, Any
from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509 import load_pem_x509_certificate from cryptography.x509 import load_pem_x509_certificate
@ -224,6 +224,7 @@ class FlowHandler(ConfigFlow, domain=DOMAIN):
) -> FlowResult: ) -> FlowResult:
"""Confirm a Hass.io discovery.""" """Confirm a Hass.io discovery."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
if TYPE_CHECKING:
assert self._hassio_discovery assert self._hassio_discovery
if user_input is not None: if user_input is not None:
@ -312,6 +313,7 @@ class MQTTOptionsFlowHandler(OptionsFlow):
def _birth_will(birt_or_will: str) -> dict[str, Any]: def _birth_will(birt_or_will: str) -> dict[str, Any]:
"""Return the user input for birth or will.""" """Return the user input for birth or will."""
if TYPE_CHECKING:
assert user_input assert user_input
return { return {
ATTR_TOPIC: user_input[f"{birt_or_will}_topic"], ATTR_TOPIC: user_input[f"{birt_or_will}_topic"],

View File

@ -5,7 +5,7 @@ from collections import deque
from collections.abc import Callable from collections.abc import Callable
import datetime as dt import datetime as dt
from functools import wraps from functools import wraps
from typing import Any from typing import TYPE_CHECKING, Any
import attr import attr
@ -128,11 +128,11 @@ def update_entity_discovery_data(
hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str
) -> None: ) -> None:
"""Update discovery data.""" """Update discovery data."""
assert ( discovery_data = get_mqtt_data(hass).debug_info_entities[entity_id][
discovery_data := get_mqtt_data(hass).debug_info_entities[entity_id][
"discovery_data" "discovery_data"
] ]
) is not None if TYPE_CHECKING:
assert discovery_data is not None
discovery_data[ATTR_DISCOVERY_PAYLOAD] = discovery_payload discovery_data[ATTR_DISCOVERY_PAYLOAD] = discovery_payload

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
import attr import attr
import voluptuous as vol import voluptuous as vol
@ -269,6 +269,7 @@ async def async_setup_trigger(
config = TRIGGER_DISCOVERY_SCHEMA(config) config = TRIGGER_DISCOVERY_SCHEMA(config)
device_id = update_device(hass, config_entry, config) device_id = update_device(hass, config_entry, config)
if TYPE_CHECKING:
assert isinstance(device_id, str) assert isinstance(device_id, str)
mqtt_device_trigger = MqttDeviceTrigger( mqtt_device_trigger = MqttDeviceTrigger(
hass, config, device_id, discovery_data, config_entry hass, config, device_id, discovery_data, config_entry
@ -286,6 +287,7 @@ async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None
if device_trigger: if device_trigger:
device_trigger.detach_trigger() device_trigger.detach_trigger()
discovery_data = device_trigger.discovery_data discovery_data = device_trigger.discovery_data
if TYPE_CHECKING:
assert discovery_data is not None assert discovery_data is not None
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash = discovery_data[ATTR_DISCOVERY_HASH]
debug_info.remove_trigger_discovery_data(hass, discovery_hash) debug_info.remove_trigger_discovery_data(hass, discovery_hash)

View File

@ -1,7 +1,7 @@
"""Diagnostics support for MQTT.""" """Diagnostics support for MQTT."""
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import TYPE_CHECKING, Any
from homeassistant.components import device_tracker from homeassistant.components import device_tracker
from homeassistant.components.diagnostics import async_redact_data from homeassistant.components.diagnostics import async_redact_data
@ -45,6 +45,7 @@ def _async_get_diagnostics(
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return diagnostics for a config entry.""" """Return diagnostics for a config entry."""
mqtt_instance = get_mqtt_data(hass).client mqtt_instance = get_mqtt_data(hass).client
if TYPE_CHECKING:
assert mqtt_instance is not None assert mqtt_instance is not None
redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG) redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG)

View File

@ -7,7 +7,7 @@ import functools
import logging import logging
import re import re
import time import time
from typing import Any from typing import TYPE_CHECKING, Any
import voluptuous as vol import voluptuous as vol
@ -343,6 +343,7 @@ async def async_start( # noqa: C901
integration: str, msg: ReceiveMessage integration: str, msg: ReceiveMessage
) -> None: ) -> None:
"""Process the received message.""" """Process the received message."""
if TYPE_CHECKING:
assert mqtt_data.data_config_flow_lock assert mqtt_data.data_config_flow_lock
key = f"{integration}_{msg.subscribed_topic}" key = f"{integration}_{msg.subscribed_topic}"

View File

@ -6,7 +6,7 @@ import binascii
from collections.abc import Callable from collections.abc import Callable
import functools import functools
import logging import logging
from typing import Any from typing import TYPE_CHECKING, Any
import httpx import httpx
import voluptuous as vol import voluptuous as vol
@ -172,6 +172,7 @@ class MqttImage(MqttEntity, ImageEntity):
if CONF_IMAGE_ENCODING in self._config: if CONF_IMAGE_ENCODING in self._config:
self._last_image = b64decode(msg.payload) self._last_image = b64decode(msg.payload)
else: else:
if TYPE_CHECKING:
assert isinstance(msg.payload, bytes) assert isinstance(msg.payload, bytes)
self._last_image = msg.payload self._last_image = msg.payload
except (binascii.Error, ValueError, AssertionError) as err: except (binascii.Error, ValueError, AssertionError) as err:

View File

@ -6,7 +6,7 @@ import asyncio
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from functools import partial from functools import partial
import logging import logging
from typing import Any, Protocol, cast, final from typing import TYPE_CHECKING, Any, Protocol, cast, final
import voluptuous as vol import voluptuous as vol
@ -850,6 +850,7 @@ class MqttDiscoveryUpdate(Entity):
discovery_hash, discovery_hash,
payload, payload,
) )
if TYPE_CHECKING:
assert self._discovery_data assert self._discovery_data
old_payload: DiscoveryInfoType old_payload: DiscoveryInfoType
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD] old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
@ -877,6 +878,7 @@ class MqttDiscoveryUpdate(Entity):
send_discovery_done(self.hass, self._discovery_data) send_discovery_done(self.hass, self._discovery_data)
if discovery_hash: if discovery_hash:
if TYPE_CHECKING:
assert self._discovery_data is not None assert self._discovery_data is not None
debug_info.add_entity_discovery_data( debug_info.add_entity_discovery_data(
self.hass, self._discovery_data, self.entity_id self.hass, self._discovery_data, self.entity_id

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Coroutine from collections.abc import Callable, Coroutine
from typing import Any from typing import TYPE_CHECKING, Any
import attr import attr
@ -31,6 +31,7 @@ class EntitySubscription:
) -> None: ) -> None:
"""Re-subscribe to the new topic if necessary.""" """Re-subscribe to the new topic if necessary."""
if not self._should_resubscribe(other): if not self._should_resubscribe(other):
if TYPE_CHECKING:
assert other assert other
self.unsubscribe_callback = other.unsubscribe_callback self.unsubscribe_callback = other.unsubscribe_callback
return return