Fix nmap_tracker typing (#54858)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
J. Nick Koston 2021-08-21 14:25:28 -05:00 committed by GitHub
parent 4916016648
commit ebb8ad308e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 46 deletions

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import partial
import logging
from typing import Final
import aiohttp
from getmac import get_mac_address
@ -34,19 +35,17 @@ from .const import (
)
# Some version of nmap will fail with 'Assertion failed: htn.toclock_running == true (Target.cc: stopTimeOutClock: 503)\n'
NMAP_TRANSIENT_FAILURE = "Assertion failed: htn.toclock_running == true"
MAX_SCAN_ATTEMPTS = 16
OFFLINE_SCANS_TO_MARK_UNAVAILABLE = 3
NMAP_TRANSIENT_FAILURE: Final = "Assertion failed: htn.toclock_running == true"
MAX_SCAN_ATTEMPTS: Final = 16
OFFLINE_SCANS_TO_MARK_UNAVAILABLE: Final = 3
def short_hostname(hostname):
def short_hostname(hostname: str) -> str:
"""Return the first part of the hostname."""
if hostname is None:
return None
return hostname.split(".")[0]
def human_readable_name(hostname, vendor, mac_address):
def human_readable_name(hostname: str, vendor: str, mac_address: str) -> str:
"""Generate a human readable name."""
if hostname:
return short_hostname(hostname)
@ -65,7 +64,7 @@ class NmapDevice:
ipv4: str
manufacturer: str
reason: str
last_update: datetime.datetime
last_update: datetime
offline_scans: int
@ -74,9 +73,9 @@ class NmapTrackedDevices:
def __init__(self) -> None:
"""Initialize the data."""
self.tracked: dict = {}
self.ipv4_last_mac: dict = {}
self.config_entry_owner: dict = {}
self.tracked: dict[str, NmapDevice] = {}
self.ipv4_last_mac: dict[str, str] = {}
self.config_entry_owner: dict[str, str] = {}
_LOGGER = logging.getLogger(__name__)
@ -132,7 +131,9 @@ def signal_device_update(mac_address) -> str:
class NmapDeviceScanner:
"""This class scans for devices using nmap."""
def __init__(self, hass, entry, devices):
def __init__(
self, hass: HomeAssistant, entry: ConfigEntry, devices: NmapTrackedDevices
) -> None:
"""Initialize the scanner."""
self.devices = devices
self.home_interval = None
@ -150,9 +151,9 @@ class NmapDeviceScanner:
self._exclude = None
self._scan_interval = None
self._known_mac_addresses = {}
self._known_mac_addresses: dict[str, str] = {}
self._finished_first_scan = False
self._last_results = []
self._last_results: list[NmapDevice] = []
self._mac_vendor_lookup = None
async def async_setup(self):

View File

@ -10,6 +10,7 @@ from homeassistant import config_entries
from homeassistant.components import network
from homeassistant.components.device_tracker.const import CONF_SCAN_INTERVAL
from homeassistant.components.network.const import MDNS_TARGET_IP
from homeassistant.config_entries import ConfigEntry, OptionsFlow
from homeassistant.const import CONF_EXCLUDE, CONF_HOSTS
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
@ -40,7 +41,7 @@ async def async_get_network(hass: HomeAssistant) -> str:
return str(ip_network(f"{local_ip}/{network_prefix}", False))
def _normalize_ips_and_network(hosts_str):
def _normalize_ips_and_network(hosts_str: str) -> list[str] | None:
"""Check if a list of hosts are all ips or ip networks."""
normalized_hosts = []
@ -74,7 +75,7 @@ def _normalize_ips_and_network(hosts_str):
return normalized_hosts
def normalize_input(user_input):
def normalize_input(user_input: dict[str, Any]) -> dict[str, str]:
"""Validate hosts and exclude are valid."""
errors = {}
normalized_hosts = _normalize_ips_and_network(user_input[CONF_HOSTS])
@ -92,7 +93,9 @@ def normalize_input(user_input):
return errors
async def _async_build_schema_with_user_input(hass, user_input, include_options):
async def _async_build_schema_with_user_input(
hass: HomeAssistant, user_input: dict[str, Any], include_options: bool
) -> vol.Schema:
hosts = user_input.get(CONF_HOSTS, await async_get_network(hass))
exclude = user_input.get(
CONF_EXCLUDE, await network.async_get_source_ip(hass, MDNS_TARGET_IP)
@ -126,7 +129,9 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
"""Initialize options flow."""
self.options = dict(config_entry.options)
async def async_step_init(self, user_input=None):
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle options flow."""
errors = {}
if user_input is not None:
@ -152,9 +157,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1
def __init__(self):
def __init__(self) -> None:
"""Initialize config flow."""
self.options = {}
self.options: dict[str, Any] = {}
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@ -183,14 +188,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors=errors,
)
def _async_is_unique_host_list(self, user_input):
def _async_is_unique_host_list(self, user_input: dict[str, Any]) -> bool:
hosts = _normalize_ips_and_network(user_input[CONF_HOSTS])
for entry in self._async_current_entries():
if _normalize_ips_and_network(entry.options[CONF_HOSTS]) == hosts:
return False
return True
async def async_step_import(self, user_input=None):
async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
"""Handle import from yaml."""
if not self._async_is_unique_host_list(user_input):
return self.async_abort(reason="already_configured")
@ -203,6 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@staticmethod
@callback
def async_get_options_flow(config_entry):
def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
"""Get the options flow for this handler."""
return OptionsFlowHandler(config_entry)

View File

@ -1,14 +1,15 @@
"""The Nmap Tracker integration."""
from typing import Final
DOMAIN = "nmap_tracker"
DOMAIN: Final = "nmap_tracker"
PLATFORMS = ["device_tracker"]
PLATFORMS: Final = ["device_tracker"]
NMAP_TRACKED_DEVICES = "nmap_tracked_devices"
NMAP_TRACKED_DEVICES: Final = "nmap_tracked_devices"
# Interval in minutes to exclude devices from a scan while they are home
CONF_HOME_INTERVAL = "home_interval"
CONF_OPTIONS = "scan_options"
DEFAULT_OPTIONS = "-F -T4 --min-rate 10 --host-timeout 5s"
CONF_HOME_INTERVAL: Final = "home_interval"
CONF_OPTIONS: Final = "scan_options"
DEFAULT_OPTIONS: Final = "-F -T4 --min-rate 10 --host-timeout 5s"
TRACKER_SCAN_INTERVAL = 120
TRACKER_SCAN_INTERVAL: Final = 120

View File

@ -1,13 +1,14 @@
"""Support for scanning a network with nmap."""
from __future__ import annotations
import logging
from typing import Callable
from typing import Any, Callable
import voluptuous as vol
from homeassistant.components.device_tracker import (
DOMAIN as DEVICE_TRACKER_DOMAIN,
PLATFORM_SCHEMA,
PLATFORM_SCHEMA as DEVICE_TRACKER_PLATFORM_SCHEMA,
SOURCE_TYPE_ROUTER,
)
from homeassistant.components.device_tracker.config_entry import ScannerEntity
@ -18,8 +19,10 @@ from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.typing import ConfigType
from . import NmapDeviceScanner, short_hostname, signal_device_update
from . import NmapDevice, NmapDeviceScanner, short_hostname, signal_device_update
from .const import (
CONF_HOME_INTERVAL,
CONF_OPTIONS,
@ -30,7 +33,8 @@ from .const import (
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
PLATFORM_SCHEMA = DEVICE_TRACKER_PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_HOSTS): cv.ensure_list,
vol.Required(CONF_HOME_INTERVAL, default=0): cv.positive_int,
@ -40,7 +44,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
async def async_get_scanner(hass, config):
async def async_get_scanner(hass: HomeAssistant, config: ConfigType) -> None:
"""Validate the configuration and return a Nmap scanner."""
validated_config = config[DEVICE_TRACKER_DOMAIN]
@ -110,7 +114,7 @@ class NmapTrackerEntity(ScannerEntity):
self._active = active
@property
def _device(self) -> bool:
def _device(self) -> NmapDevice:
"""Get latest device state."""
return self._tracked[self._mac_address]
@ -140,8 +144,10 @@ class NmapTrackerEntity(ScannerEntity):
return self._mac_address
@property
def hostname(self) -> str:
def hostname(self) -> str | None:
"""Return hostname of the device."""
if not self._device.hostname:
return None
return short_hostname(self._device.hostname)
@property
@ -150,7 +156,7 @@ class NmapTrackerEntity(ScannerEntity):
return SOURCE_TYPE_ROUTER
@property
def device_info(self):
def device_info(self) -> DeviceInfo:
"""Return the device information."""
return {
"connections": {(CONNECTION_NETWORK_MAC, self._mac_address)},
@ -164,7 +170,7 @@ class NmapTrackerEntity(ScannerEntity):
return False
@property
def icon(self):
def icon(self) -> str:
"""Return device icon."""
return "mdi:lan-connect" if self._active else "mdi:lan-disconnect"
@ -174,7 +180,7 @@ class NmapTrackerEntity(ScannerEntity):
self._active = online
@property
def extra_state_attributes(self):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return the attributes."""
return {
"last_time_reachable": self._device.last_update.isoformat(
@ -184,12 +190,12 @@ class NmapTrackerEntity(ScannerEntity):
}
@callback
def async_on_demand_update(self, online: bool):
def async_on_demand_update(self, online: bool) -> None:
"""Update state."""
self.async_process_update(online)
self.async_write_ha_state()
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Register state update callback."""
self.async_on_remove(
async_dispatcher_connect(

View File

@ -1499,9 +1499,6 @@ ignore_errors = true
[mypy-homeassistant.components.nilu.*]
ignore_errors = true
[mypy-homeassistant.components.nmap_tracker.*]
ignore_errors = true
[mypy-homeassistant.components.nsw_fuel_station.*]
ignore_errors = true

View File

@ -92,7 +92,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.nest.legacy.*",
"homeassistant.components.nightscout.*",
"homeassistant.components.nilu.*",
"homeassistant.components.nmap_tracker.*",
"homeassistant.components.nsw_fuel_station.*",
"homeassistant.components.nuki.*",
"homeassistant.components.nws.*",