Enable strict typing for aruba (#106839)

This commit is contained in:
Marc Mueller 2024-01-01 20:33:15 +01:00 committed by GitHub
parent 800351287b
commit f67bae2cde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 18 deletions

View File

@ -76,6 +76,7 @@ homeassistant.components.aprs.*
homeassistant.components.aqualogic.* homeassistant.components.aqualogic.*
homeassistant.components.aquostv.* homeassistant.components.aquostv.*
homeassistant.components.aranet.* homeassistant.components.aranet.*
homeassistant.components.aruba.*
homeassistant.components.aseko_pool_live.* homeassistant.components.aseko_pool_live.*
homeassistant.components.assist_pipeline.* homeassistant.components.assist_pipeline.*
homeassistant.components.asuswrt.* homeassistant.components.asuswrt.*

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import re import re
from typing import Any
import pexpect import pexpect
import voluptuous as vol import voluptuous as vol
@ -44,33 +45,33 @@ def get_scanner(hass: HomeAssistant, config: ConfigType) -> ArubaDeviceScanner |
class ArubaDeviceScanner(DeviceScanner): class ArubaDeviceScanner(DeviceScanner):
"""Class which queries a Aruba Access Point for connected devices.""" """Class which queries a Aruba Access Point for connected devices."""
def __init__(self, config): def __init__(self, config: dict[str, Any]) -> None:
"""Initialize the scanner.""" """Initialize the scanner."""
self.host = config[CONF_HOST] self.host: str = config[CONF_HOST]
self.username = config[CONF_USERNAME] self.username: str = config[CONF_USERNAME]
self.password = config[CONF_PASSWORD] self.password: str = config[CONF_PASSWORD]
self.last_results = {} self.last_results: dict[str, dict[str, str]] = {}
# Test the router is accessible. # Test the router is accessible.
data = self.get_aruba_data() data = self.get_aruba_data()
self.success_init = data is not None self.success_init = data is not None
def scan_devices(self): def scan_devices(self) -> list[str]:
"""Scan for new devices and return a list with found device IDs.""" """Scan for new devices and return a list with found device IDs."""
self._update_info() self._update_info()
return [client["mac"] for client in self.last_results] return [client["mac"] for client in self.last_results.values()]
def get_device_name(self, device): def get_device_name(self, device: str) -> str | None:
"""Return the name of the given device or None if we don't know.""" """Return the name of the given device or None if we don't know."""
if not self.last_results: if not self.last_results:
return None return None
for client in self.last_results: for client in self.last_results.values():
if client["mac"] == device: if client["mac"] == device:
return client["name"] return client["name"]
return None return None
def _update_info(self): def _update_info(self) -> bool:
"""Ensure the information from the Aruba Access Point is up to date. """Ensure the information from the Aruba Access Point is up to date.
Return boolean if scanning successful. Return boolean if scanning successful.
@ -81,10 +82,10 @@ class ArubaDeviceScanner(DeviceScanner):
if not (data := self.get_aruba_data()): if not (data := self.get_aruba_data()):
return False return False
self.last_results = data.values() self.last_results = data
return True return True
def get_aruba_data(self): def get_aruba_data(self) -> dict[str, dict[str, str]] | None:
"""Retrieve data from Aruba Access Point and return parsed result.""" """Retrieve data from Aruba Access Point and return parsed result."""
connect = f"ssh {self.username}@{self.host} -o HostKeyAlgorithms=ssh-rsa" connect = f"ssh {self.username}@{self.host} -o HostKeyAlgorithms=ssh-rsa"
@ -103,22 +104,22 @@ class ArubaDeviceScanner(DeviceScanner):
) )
if query == 1: if query == 1:
_LOGGER.error("Timeout") _LOGGER.error("Timeout")
return return None
if query == 2: if query == 2:
_LOGGER.error("Unexpected response from router") _LOGGER.error("Unexpected response from router")
return return None
if query == 3: if query == 3:
ssh.sendline("yes") ssh.sendline("yes")
ssh.expect("password:") ssh.expect("password:")
elif query == 4: elif query == 4:
_LOGGER.error("Host key changed") _LOGGER.error("Host key changed")
return return None
elif query == 5: elif query == 5:
_LOGGER.error("Connection refused by server") _LOGGER.error("Connection refused by server")
return return None
elif query == 6: elif query == 6:
_LOGGER.error("Connection timed out") _LOGGER.error("Connection timed out")
return return None
ssh.sendline(self.password) ssh.sendline(self.password)
ssh.expect("#") ssh.expect("#")
ssh.sendline("show clients") ssh.sendline("show clients")
@ -126,7 +127,7 @@ class ArubaDeviceScanner(DeviceScanner):
devices_result = ssh.before.split(b"\r\n") devices_result = ssh.before.split(b"\r\n")
ssh.sendline("exit") ssh.sendline("exit")
devices = {} devices: dict[str, dict[str, str]] = {}
for device in devices_result: for device in devices_result:
if match := _DEVICES_REGEX.search(device.decode("utf-8")): if match := _DEVICES_REGEX.search(device.decode("utf-8")):
devices[match.group("ip")] = { devices[match.group("ip")] = {

View File

@ -520,6 +520,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.aruba.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.aseko_pool_live.*] [mypy-homeassistant.components.aseko_pool_live.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true