Fix nmap_tracker typing (#54858)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
J. Nick Koston 2021-08-21 14:25:28 -05:00 committed by GitHub
parent 4916016648
commit ebb8ad308e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 46 deletions

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import logging import logging
from typing import Final
import aiohttp import aiohttp
from getmac import get_mac_address from getmac import get_mac_address
@ -34,19 +35,17 @@ from .const import (
) )
# Some version of nmap will fail with 'Assertion failed: htn.toclock_running == true (Target.cc: stopTimeOutClock: 503)\n' # Some version of nmap will fail with 'Assertion failed: htn.toclock_running == true (Target.cc: stopTimeOutClock: 503)\n'
NMAP_TRANSIENT_FAILURE = "Assertion failed: htn.toclock_running == true" NMAP_TRANSIENT_FAILURE: Final = "Assertion failed: htn.toclock_running == true"
MAX_SCAN_ATTEMPTS = 16 MAX_SCAN_ATTEMPTS: Final = 16
OFFLINE_SCANS_TO_MARK_UNAVAILABLE = 3 OFFLINE_SCANS_TO_MARK_UNAVAILABLE: Final = 3
def short_hostname(hostname): def short_hostname(hostname: str) -> str:
"""Return the first part of the hostname.""" """Return the first part of the hostname."""
if hostname is None:
return None
return hostname.split(".")[0] return hostname.split(".")[0]
def human_readable_name(hostname, vendor, mac_address): def human_readable_name(hostname: str, vendor: str, mac_address: str) -> str:
"""Generate a human readable name.""" """Generate a human readable name."""
if hostname: if hostname:
return short_hostname(hostname) return short_hostname(hostname)
@ -65,7 +64,7 @@ class NmapDevice:
ipv4: str ipv4: str
manufacturer: str manufacturer: str
reason: str reason: str
last_update: datetime.datetime last_update: datetime
offline_scans: int offline_scans: int
@ -74,9 +73,9 @@ class NmapTrackedDevices:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the data.""" """Initialize the data."""
self.tracked: dict = {} self.tracked: dict[str, NmapDevice] = {}
self.ipv4_last_mac: dict = {} self.ipv4_last_mac: dict[str, str] = {}
self.config_entry_owner: dict = {} self.config_entry_owner: dict[str, str] = {}
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -132,7 +131,9 @@ def signal_device_update(mac_address) -> str:
class NmapDeviceScanner: class NmapDeviceScanner:
"""This class scans for devices using nmap.""" """This class scans for devices using nmap."""
def __init__(self, hass, entry, devices): def __init__(
self, hass: HomeAssistant, entry: ConfigEntry, devices: NmapTrackedDevices
) -> None:
"""Initialize the scanner.""" """Initialize the scanner."""
self.devices = devices self.devices = devices
self.home_interval = None self.home_interval = None
@ -150,9 +151,9 @@ class NmapDeviceScanner:
self._exclude = None self._exclude = None
self._scan_interval = None self._scan_interval = None
self._known_mac_addresses = {} self._known_mac_addresses: dict[str, str] = {}
self._finished_first_scan = False self._finished_first_scan = False
self._last_results = [] self._last_results: list[NmapDevice] = []
self._mac_vendor_lookup = None self._mac_vendor_lookup = None
async def async_setup(self): async def async_setup(self):

View File

@ -10,6 +10,7 @@ from homeassistant import config_entries
from homeassistant.components import network from homeassistant.components import network
from homeassistant.components.device_tracker.const import CONF_SCAN_INTERVAL from homeassistant.components.device_tracker.const import CONF_SCAN_INTERVAL
from homeassistant.components.network.const import MDNS_TARGET_IP from homeassistant.components.network.const import MDNS_TARGET_IP
from homeassistant.config_entries import ConfigEntry, OptionsFlow
from homeassistant.const import CONF_EXCLUDE, CONF_HOSTS from homeassistant.const import CONF_EXCLUDE, CONF_HOSTS
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
@ -40,7 +41,7 @@ async def async_get_network(hass: HomeAssistant) -> str:
return str(ip_network(f"{local_ip}/{network_prefix}", False)) return str(ip_network(f"{local_ip}/{network_prefix}", False))
def _normalize_ips_and_network(hosts_str): def _normalize_ips_and_network(hosts_str: str) -> list[str] | None:
"""Check if a list of hosts are all ips or ip networks.""" """Check if a list of hosts are all ips or ip networks."""
normalized_hosts = [] normalized_hosts = []
@ -74,7 +75,7 @@ def _normalize_ips_and_network(hosts_str):
return normalized_hosts return normalized_hosts
def normalize_input(user_input): def normalize_input(user_input: dict[str, Any]) -> dict[str, str]:
"""Validate hosts and exclude are valid.""" """Validate hosts and exclude are valid."""
errors = {} errors = {}
normalized_hosts = _normalize_ips_and_network(user_input[CONF_HOSTS]) normalized_hosts = _normalize_ips_and_network(user_input[CONF_HOSTS])
@ -92,7 +93,9 @@ def normalize_input(user_input):
return errors return errors
async def _async_build_schema_with_user_input(hass, user_input, include_options): async def _async_build_schema_with_user_input(
hass: HomeAssistant, user_input: dict[str, Any], include_options: bool
) -> vol.Schema:
hosts = user_input.get(CONF_HOSTS, await async_get_network(hass)) hosts = user_input.get(CONF_HOSTS, await async_get_network(hass))
exclude = user_input.get( exclude = user_input.get(
CONF_EXCLUDE, await network.async_get_source_ip(hass, MDNS_TARGET_IP) CONF_EXCLUDE, await network.async_get_source_ip(hass, MDNS_TARGET_IP)
@ -126,7 +129,9 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
"""Initialize options flow.""" """Initialize options flow."""
self.options = dict(config_entry.options) self.options = dict(config_entry.options)
async def async_step_init(self, user_input=None): async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle options flow.""" """Handle options flow."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
@ -152,9 +157,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
def __init__(self): def __init__(self) -> None:
"""Initialize config flow.""" """Initialize config flow."""
self.options = {} self.options: dict[str, Any] = {}
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -183,14 +188,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors=errors, errors=errors,
) )
def _async_is_unique_host_list(self, user_input): def _async_is_unique_host_list(self, user_input: dict[str, Any]) -> bool:
hosts = _normalize_ips_and_network(user_input[CONF_HOSTS]) hosts = _normalize_ips_and_network(user_input[CONF_HOSTS])
for entry in self._async_current_entries(): for entry in self._async_current_entries():
if _normalize_ips_and_network(entry.options[CONF_HOSTS]) == hosts: if _normalize_ips_and_network(entry.options[CONF_HOSTS]) == hosts:
return False return False
return True return True
async def async_step_import(self, user_input=None): async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
"""Handle import from yaml.""" """Handle import from yaml."""
if not self._async_is_unique_host_list(user_input): if not self._async_is_unique_host_list(user_input):
return self.async_abort(reason="already_configured") return self.async_abort(reason="already_configured")
@ -203,6 +208,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow(config_entry): def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
return OptionsFlowHandler(config_entry) return OptionsFlowHandler(config_entry)

View File

@ -1,14 +1,15 @@
"""The Nmap Tracker integration.""" """The Nmap Tracker integration."""
from typing import Final
DOMAIN = "nmap_tracker" DOMAIN: Final = "nmap_tracker"
PLATFORMS = ["device_tracker"] PLATFORMS: Final = ["device_tracker"]
NMAP_TRACKED_DEVICES = "nmap_tracked_devices" NMAP_TRACKED_DEVICES: Final = "nmap_tracked_devices"
# Interval in minutes to exclude devices from a scan while they are home # Interval in minutes to exclude devices from a scan while they are home
CONF_HOME_INTERVAL = "home_interval" CONF_HOME_INTERVAL: Final = "home_interval"
CONF_OPTIONS = "scan_options" CONF_OPTIONS: Final = "scan_options"
DEFAULT_OPTIONS = "-F -T4 --min-rate 10 --host-timeout 5s" DEFAULT_OPTIONS: Final = "-F -T4 --min-rate 10 --host-timeout 5s"
TRACKER_SCAN_INTERVAL = 120 TRACKER_SCAN_INTERVAL: Final = 120

View File

@ -1,13 +1,14 @@
"""Support for scanning a network with nmap.""" """Support for scanning a network with nmap."""
from __future__ import annotations
import logging import logging
from typing import Callable from typing import Any, Callable
import voluptuous as vol import voluptuous as vol
from homeassistant.components.device_tracker import ( from homeassistant.components.device_tracker import (
DOMAIN as DEVICE_TRACKER_DOMAIN, DOMAIN as DEVICE_TRACKER_DOMAIN,
PLATFORM_SCHEMA, PLATFORM_SCHEMA as DEVICE_TRACKER_PLATFORM_SCHEMA,
SOURCE_TYPE_ROUTER, SOURCE_TYPE_ROUTER,
) )
from homeassistant.components.device_tracker.config_entry import ScannerEntity from homeassistant.components.device_tracker.config_entry import ScannerEntity
@ -18,8 +19,10 @@ from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.typing import ConfigType
from . import NmapDeviceScanner, short_hostname, signal_device_update from . import NmapDevice, NmapDeviceScanner, short_hostname, signal_device_update
from .const import ( from .const import (
CONF_HOME_INTERVAL, CONF_HOME_INTERVAL,
CONF_OPTIONS, CONF_OPTIONS,
@ -30,7 +33,8 @@ from .const import (
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
PLATFORM_SCHEMA = DEVICE_TRACKER_PLATFORM_SCHEMA.extend(
{ {
vol.Required(CONF_HOSTS): cv.ensure_list, vol.Required(CONF_HOSTS): cv.ensure_list,
vol.Required(CONF_HOME_INTERVAL, default=0): cv.positive_int, vol.Required(CONF_HOME_INTERVAL, default=0): cv.positive_int,
@ -40,7 +44,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
) )
async def async_get_scanner(hass, config): async def async_get_scanner(hass: HomeAssistant, config: ConfigType) -> None:
"""Validate the configuration and return a Nmap scanner.""" """Validate the configuration and return a Nmap scanner."""
validated_config = config[DEVICE_TRACKER_DOMAIN] validated_config = config[DEVICE_TRACKER_DOMAIN]
@ -110,7 +114,7 @@ class NmapTrackerEntity(ScannerEntity):
self._active = active self._active = active
@property @property
def _device(self) -> bool: def _device(self) -> NmapDevice:
"""Get latest device state.""" """Get latest device state."""
return self._tracked[self._mac_address] return self._tracked[self._mac_address]
@ -140,8 +144,10 @@ class NmapTrackerEntity(ScannerEntity):
return self._mac_address return self._mac_address
@property @property
def hostname(self) -> str: def hostname(self) -> str | None:
"""Return hostname of the device.""" """Return hostname of the device."""
if not self._device.hostname:
return None
return short_hostname(self._device.hostname) return short_hostname(self._device.hostname)
@property @property
@ -150,7 +156,7 @@ class NmapTrackerEntity(ScannerEntity):
return SOURCE_TYPE_ROUTER return SOURCE_TYPE_ROUTER
@property @property
def device_info(self): def device_info(self) -> DeviceInfo:
"""Return the device information.""" """Return the device information."""
return { return {
"connections": {(CONNECTION_NETWORK_MAC, self._mac_address)}, "connections": {(CONNECTION_NETWORK_MAC, self._mac_address)},
@ -164,7 +170,7 @@ class NmapTrackerEntity(ScannerEntity):
return False return False
@property @property
def icon(self): def icon(self) -> str:
"""Return device icon.""" """Return device icon."""
return "mdi:lan-connect" if self._active else "mdi:lan-disconnect" return "mdi:lan-connect" if self._active else "mdi:lan-disconnect"
@ -174,7 +180,7 @@ class NmapTrackerEntity(ScannerEntity):
self._active = online self._active = online
@property @property
def extra_state_attributes(self): def extra_state_attributes(self) -> dict[str, Any]:
"""Return the attributes.""" """Return the attributes."""
return { return {
"last_time_reachable": self._device.last_update.isoformat( "last_time_reachable": self._device.last_update.isoformat(
@ -184,12 +190,12 @@ class NmapTrackerEntity(ScannerEntity):
} }
@callback @callback
def async_on_demand_update(self, online: bool): def async_on_demand_update(self, online: bool) -> None:
"""Update state.""" """Update state."""
self.async_process_update(online) self.async_process_update(online)
self.async_write_ha_state() self.async_write_ha_state()
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Register state update callback.""" """Register state update callback."""
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect( async_dispatcher_connect(

View File

@ -1499,9 +1499,6 @@ ignore_errors = true
[mypy-homeassistant.components.nilu.*] [mypy-homeassistant.components.nilu.*]
ignore_errors = true ignore_errors = true
[mypy-homeassistant.components.nmap_tracker.*]
ignore_errors = true
[mypy-homeassistant.components.nsw_fuel_station.*] [mypy-homeassistant.components.nsw_fuel_station.*]
ignore_errors = true ignore_errors = true

View File

@ -92,7 +92,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.nest.legacy.*", "homeassistant.components.nest.legacy.*",
"homeassistant.components.nightscout.*", "homeassistant.components.nightscout.*",
"homeassistant.components.nilu.*", "homeassistant.components.nilu.*",
"homeassistant.components.nmap_tracker.*",
"homeassistant.components.nsw_fuel_station.*", "homeassistant.components.nsw_fuel_station.*",
"homeassistant.components.nuki.*", "homeassistant.components.nuki.*",
"homeassistant.components.nws.*", "homeassistant.components.nws.*",