Add types package for paho-mqtt (#83599)

This commit is contained in:
Marc Mueller 2022-12-09 15:27:46 +01:00 committed by GitHub
parent a77d9af989
commit 9a97784168
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 16 deletions

View File

@ -317,8 +317,8 @@ class MqttClientSetup:
client_cert = get_file_path(CONF_CLIENT_CERT, config.get(CONF_CLIENT_CERT)) client_cert = get_file_path(CONF_CLIENT_CERT, config.get(CONF_CLIENT_CERT))
tls_insecure = config.get(CONF_TLS_INSECURE) tls_insecure = config.get(CONF_TLS_INSECURE)
if transport == TRANSPORT_WEBSOCKETS: if transport == TRANSPORT_WEBSOCKETS:
ws_path = config.get(CONF_WS_PATH) ws_path: str = config[CONF_WS_PATH]
ws_headers = config.get(CONF_WS_HEADERS) ws_headers: dict[str, str] = config[CONF_WS_HEADERS]
self._client.ws_set_options(ws_path, ws_headers) self._client.ws_set_options(ws_path, ws_headers)
if certificate is not None: if certificate is not None:
self._client.tls_set( self._client.tls_set(
@ -340,6 +340,8 @@ class MqttClientSetup:
class MQTT: class MQTT:
"""Home Assistant MQTT client.""" """Home Assistant MQTT client."""
_mqttc: mqtt.Client
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@ -347,10 +349,6 @@ class MQTT:
conf: ConfigType, conf: ConfigType,
) -> None: ) -> None:
"""Initialize Home Assistant MQTT client.""" """Initialize Home Assistant MQTT client."""
# We don't import on the top because some integrations
# should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel
self._mqtt_data = get_mqtt_data(hass) self._mqtt_data = get_mqtt_data(hass)
self.hass = hass self.hass = hass
@ -360,7 +358,6 @@ class MQTT:
self.connected = False self.connected = False
self._ha_started = asyncio.Event() self._ha_started = asyncio.Event()
self._last_subscribe = time.time() self._last_subscribe = time.time()
self._mqttc: mqtt.Client = None
self._cleanup_on_unload: list[Callable[[], None]] = [] self._cleanup_on_unload: list[Callable[[], None]] = []
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
@ -526,12 +523,9 @@ class MQTT:
""" """
def _client_unsubscribe(topic: str) -> int: def _client_unsubscribe(topic: str) -> int:
result: int | None = None
mid: int | None = None
result, mid = self._mqttc.unsubscribe(topic) result, mid = self._mqttc.unsubscribe(topic)
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
_raise_on_error(result) _raise_on_error(result)
assert mid
return mid return mid
if any(other.topic == topic for other in self.subscriptions): if any(other.topic == topic for other in self.subscriptions):
@ -563,8 +557,8 @@ class MQTT:
_process_client_subscriptions _process_client_subscriptions
) )
tasks = [] tasks: list[Coroutine[Any, Any, None]] = []
errors = [] errors: list[int] = []
for result, mid in results: for result, mid in results:
if result == 0: if result == 0:
tasks.append(self._wait_for_mid(mid)) tasks.append(self._wait_for_mid(mid))
@ -777,7 +771,7 @@ class MQTT:
) )
def _raise_on_errors(result_codes: Iterable[int | None]) -> None: def _raise_on_errors(result_codes: Iterable[int]) -> None:
"""Raise error if error result.""" """Raise error if error result."""
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
@ -790,7 +784,7 @@ def _raise_on_errors(result_codes: Iterable[int | None]) -> None:
raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}") raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}")
def _raise_on_error(result_code: int | None) -> None: def _raise_on_error(result_code: int) -> None:
"""Raise error if error result.""" """Raise error if error result."""
_raise_on_errors((result_code,)) _raise_on_errors((result_code,))

View File

@ -177,7 +177,7 @@ async def async_create_certificate_temp_files(
await hass.async_add_executor_job(_create_temp_dir_and_files) await hass.async_add_executor_job(_create_temp_dir_and_files)
def get_file_path(option: str, default: str | None = None) -> Path | str | None: def get_file_path(option: str, default: str | None = None) -> str | None:
"""Get file path of a certificate file.""" """Get file path of a certificate file."""
temp_dir = Path(tempfile.gettempdir()) / TEMP_DIR_NAME temp_dir = Path(tempfile.gettempdir()) / TEMP_DIR_NAME
if not temp_dir.exists(): if not temp_dir.exists():
@ -187,7 +187,7 @@ def get_file_path(option: str, default: str | None = None) -> Path | str | None:
if not file_path.exists(): if not file_path.exists():
return default return default
return temp_dir / option return str(temp_dir / option)
def migrate_certificate_file_to_content(file_name_or_auto: str) -> str | None: def migrate_certificate_file_to_content(file_name_or_auto: str) -> str | None:

View File

@ -39,6 +39,7 @@ types-chardet==0.1.5
types-decorator==0.1.7 types-decorator==0.1.7
types-enum34==0.1.8 types-enum34==0.1.8
types-ipaddress==0.1.5 types-ipaddress==0.1.5
types-paho-mqtt==1.6.0.1
types-pkg-resources==0.1.3 types-pkg-resources==0.1.3
types-python-dateutil==2.8.19.2 types-python-dateutil==2.8.19.2
types-python-slugify==0.1.2 types-python-slugify==0.1.2