Enable strict type checking on apple_tv integration (#101688)

* Enable strict type checking on apple_tv integration

* move some instance variables to class variables

* fix type of attr_value

* fix tests for description_placeholders assertion

* nits

* Apply suggestions from code review

* Update remote.py

* Apply suggestions from code review

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Update config_flow.py

* Improve test coverage

* Update test_config_flow.py

* Update __init__.py

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Stackie Jia 2024-02-15 22:17:00 +08:00 committed by GitHub
parent d555f91702
commit 636c7ce350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 253 additions and 104 deletions

View File

@ -80,6 +80,7 @@ homeassistant.components.anthemav.*
homeassistant.components.apache_kafka.* homeassistant.components.apache_kafka.*
homeassistant.components.apcupsd.* homeassistant.components.apcupsd.*
homeassistant.components.api.* homeassistant.components.api.*
homeassistant.components.apple_tv.*
homeassistant.components.apprise.* homeassistant.components.apprise.*
homeassistant.components.aprs.* homeassistant.components.aprs.*
homeassistant.components.aqualogic.* homeassistant.components.aqualogic.*

View File

@ -1,8 +1,10 @@
"""The Apple TV integration.""" """The Apple TV integration."""
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from random import randrange from random import randrange
from typing import TYPE_CHECKING, cast from typing import Any, cast
from pyatv import connect, exceptions, scan from pyatv import connect, exceptions, scan
from pyatv.conf import AppleTV from pyatv.conf import AppleTV
@ -25,7 +27,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, callback from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -89,7 +91,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
async def on_hass_stop(event): async def on_hass_stop(event: Event) -> None:
"""Stop push updates when hass stops.""" """Stop push updates when hass stops."""
await manager.disconnect() await manager.disconnect()
@ -120,33 +122,29 @@ class AppleTVEntity(Entity):
_attr_should_poll = False _attr_should_poll = False
_attr_has_entity_name = True _attr_has_entity_name = True
_attr_name = None _attr_name = None
atv: AppleTVInterface | None = None
def __init__( def __init__(self, name: str, identifier: str, manager: AppleTVManager) -> None:
self, name: str, identifier: str | None, manager: "AppleTVManager"
) -> None:
"""Initialize device.""" """Initialize device."""
self.atv: AppleTVInterface = None # type: ignore[assignment]
self.manager = manager self.manager = manager
if TYPE_CHECKING:
assert identifier is not None
self._attr_unique_id = identifier self._attr_unique_id = identifier
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, identifier)}, identifiers={(DOMAIN, identifier)},
name=name, name=name,
) )
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Handle when an entity is about to be added to Home Assistant.""" """Handle when an entity is about to be added to Home Assistant."""
@callback @callback
def _async_connected(atv): def _async_connected(atv: AppleTVInterface) -> None:
"""Handle that a connection was made to a device.""" """Handle that a connection was made to a device."""
self.atv = atv self.atv = atv
self.async_device_connected(atv) self.async_device_connected(atv)
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def _async_disconnected(): def _async_disconnected() -> None:
"""Handle that a connection to a device was lost.""" """Handle that a connection to a device was lost."""
self.async_device_disconnected() self.async_device_disconnected()
self.atv = None self.atv = None
@ -169,10 +167,10 @@ class AppleTVEntity(Entity):
) )
) )
def async_device_connected(self, atv): def async_device_connected(self, atv: AppleTVInterface) -> None:
"""Handle when connection is made to device.""" """Handle when connection is made to device."""
def async_device_disconnected(self): def async_device_disconnected(self) -> None:
"""Handle when connection was lost to device.""" """Handle when connection was lost to device."""
@ -184,22 +182,23 @@ class AppleTVManager(DeviceListener):
in case of problems. in case of problems.
""" """
atv: AppleTVInterface | None = None
_connection_attempts = 0
_connection_was_lost = False
_task: asyncio.Task[None] | None = None
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None: def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize power manager.""" """Initialize power manager."""
self.config_entry = config_entry self.config_entry = config_entry
self.hass = hass self.hass = hass
self.atv: AppleTVInterface | None = None
self.is_on = not config_entry.options.get(CONF_START_OFF, False) self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
self._task = None
async def init(self): async def init(self) -> None:
"""Initialize power management.""" """Initialize power management."""
if self.is_on: if self.is_on:
await self.connect() await self.connect()
def connection_lost(self, _): def connection_lost(self, exception: Exception) -> None:
"""Device was unexpectedly disconnected. """Device was unexpectedly disconnected.
This is a callback function from pyatv.interface.DeviceListener. This is a callback function from pyatv.interface.DeviceListener.
@ -210,14 +209,14 @@ class AppleTVManager(DeviceListener):
self._connection_was_lost = True self._connection_was_lost = True
self._handle_disconnect() self._handle_disconnect()
def connection_closed(self): def connection_closed(self) -> None:
"""Device connection was (intentionally) closed. """Device connection was (intentionally) closed.
This is a callback function from pyatv.interface.DeviceListener. This is a callback function from pyatv.interface.DeviceListener.
""" """
self._handle_disconnect() self._handle_disconnect()
def _handle_disconnect(self): def _handle_disconnect(self) -> None:
"""Handle that the device disconnected and restart connect loop.""" """Handle that the device disconnected and restart connect loop."""
if self.atv: if self.atv:
self.atv.close() self.atv.close()
@ -225,12 +224,12 @@ class AppleTVManager(DeviceListener):
self._dispatch_send(SIGNAL_DISCONNECTED) self._dispatch_send(SIGNAL_DISCONNECTED)
self._start_connect_loop() self._start_connect_loop()
async def connect(self): async def connect(self) -> None:
"""Connect to device.""" """Connect to device."""
self.is_on = True self.is_on = True
self._start_connect_loop() self._start_connect_loop()
async def disconnect(self): async def disconnect(self) -> None:
"""Disconnect from device.""" """Disconnect from device."""
_LOGGER.debug("Disconnecting from device") _LOGGER.debug("Disconnecting from device")
self.is_on = False self.is_on = False
@ -244,7 +243,7 @@ class AppleTVManager(DeviceListener):
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("An error occurred while disconnecting") _LOGGER.exception("An error occurred while disconnecting")
def _start_connect_loop(self): def _start_connect_loop(self) -> None:
"""Start background connect loop to device.""" """Start background connect loop to device."""
if not self._task and self.atv is None and self.is_on: if not self._task and self.atv is None and self.is_on:
self._task = asyncio.create_task(self._connect_loop()) self._task = asyncio.create_task(self._connect_loop())
@ -258,7 +257,7 @@ class AppleTVManager(DeviceListener):
if conf := await self._scan(): if conf := await self._scan():
await self._connect(conf, raise_missing_credentials) await self._connect(conf, raise_missing_credentials)
async def async_first_connect(self): async def async_first_connect(self) -> None:
"""Connect to device for the first time.""" """Connect to device for the first time."""
connect_ok = False connect_ok = False
try: try:
@ -286,7 +285,7 @@ class AppleTVManager(DeviceListener):
_LOGGER.exception("Failed to connect") _LOGGER.exception("Failed to connect")
await self.disconnect() await self.disconnect()
async def _connect_loop(self): async def _connect_loop(self) -> None:
"""Connect loop background task function.""" """Connect loop background task function."""
_LOGGER.debug("Starting connect loop") _LOGGER.debug("Starting connect loop")
@ -295,7 +294,8 @@ class AppleTVManager(DeviceListener):
while self.is_on and self.atv is None: while self.is_on and self.atv is None:
await self.connect_once(raise_missing_credentials=False) await self.connect_once(raise_missing_credentials=False)
if self.atv is not None: if self.atv is not None:
break # Calling self.connect_once may have set self.atv
break # type: ignore[unreachable]
self._connection_attempts += 1 self._connection_attempts += 1
backoff = min( backoff = min(
max( max(
@ -392,7 +392,7 @@ class AppleTVManager(DeviceListener):
self._connection_was_lost = False self._connection_was_lost = False
@callback @callback
def _async_setup_device_registry(self): def _async_setup_device_registry(self) -> None:
attrs = { attrs = {
ATTR_IDENTIFIERS: {(DOMAIN, self.config_entry.unique_id)}, ATTR_IDENTIFIERS: {(DOMAIN, self.config_entry.unique_id)},
ATTR_MANUFACTURER: "Apple", ATTR_MANUFACTURER: "Apple",
@ -423,18 +423,18 @@ class AppleTVManager(DeviceListener):
) )
@property @property
def is_connecting(self): def is_connecting(self) -> bool:
"""Return true if connection is in progress.""" """Return true if connection is in progress."""
return self._task is not None return self._task is not None
def _address_updated(self, address): def _address_updated(self, address: str) -> None:
"""Update cached address in config entry.""" """Update cached address in config entry."""
_LOGGER.debug("Changing address to %s", address) _LOGGER.debug("Changing address to %s", address)
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address} self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address}
) )
def _dispatch_send(self, signal, *args): def _dispatch_send(self, signal: str, *args: Any) -> None:
"""Dispatch a signal to all entities managed by this manager.""" """Dispatch a signal to all entities managed by this manager."""
async_dispatcher_send( async_dispatcher_send(
self.hass, f"{signal}_{self.config_entry.unique_id}", *args self.hass, f"{signal}_{self.config_entry.unique_id}", *args

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Mapping from collections.abc import Awaitable, Callable, Mapping
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
from random import randrange from random import randrange
@ -13,12 +13,13 @@ from pyatv import exceptions, pair, scan
from pyatv.const import DeviceModel, PairingRequirement, Protocol from pyatv.const import DeviceModel, PairingRequirement, Protocol
from pyatv.convert import model_str, protocol_str from pyatv.convert import model_str, protocol_str
from pyatv.helpers import get_unique_id from pyatv.helpers import get_unique_id
from pyatv.interface import BaseConfig, PairingHandler
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -49,10 +50,12 @@ OPTIONS_FLOW = {
} }
async def device_scan(hass, identifier, loop): async def device_scan(
hass: HomeAssistant, identifier: str | None, loop: asyncio.AbstractEventLoop
) -> tuple[BaseConfig | None, list[str] | None]:
"""Scan for a specific device using identifier as filter.""" """Scan for a specific device using identifier as filter."""
def _filter_device(dev): def _filter_device(dev: BaseConfig) -> bool:
if identifier is None: if identifier is None:
return True return True
if identifier == str(dev.address): if identifier == str(dev.address):
@ -61,9 +64,12 @@ async def device_scan(hass, identifier, loop):
return True return True
return any(service.identifier == identifier for service in dev.services) return any(service.identifier == identifier for service in dev.services)
def _host_filter(): def _host_filter() -> list[str] | None:
if identifier is None:
return None
try: try:
return [ip_address(identifier)] ip_address(identifier)
return [identifier]
except ValueError: except ValueError:
return None return None
@ -84,6 +90,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
scan_filter: str | None = None
atv: BaseConfig | None = None
atv_identifiers: list[str] | None = None
protocol: Protocol | None = None
pairing: PairingHandler | None = None
protocols_to_pair: deque[Protocol] | None = None
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(
@ -92,18 +105,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Get options flow for this handler.""" """Get options flow for this handler."""
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW) return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
def __init__(self): def __init__(self) -> None:
"""Initialize a new AppleTVConfigFlow.""" """Initialize a new AppleTVConfigFlow."""
self.scan_filter = None self.credentials: dict[int, str | None] = {} # Protocol -> credentials
self.atv = None
self.atv_identifiers = None
self.protocol = None
self.pairing = None
self.credentials = {} # Protocol -> credentials
self.protocols_to_pair = deque()
@property @property
def device_identifier(self): def device_identifier(self) -> str | None:
"""Return a identifier for the config entry. """Return a identifier for the config entry.
A device has multiple unique identifiers, but Home Assistant only supports one A device has multiple unique identifiers, but Home Assistant only supports one
@ -118,6 +125,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
existing config entry. If that's the case, the unique_id from that entry is existing config entry. If that's the case, the unique_id from that entry is
re-used, otherwise the newly discovered identifier is used instead. re-used, otherwise the newly discovered identifier is used instead.
""" """
assert self.atv
all_identifiers = set(self.atv.all_identifiers) all_identifiers = set(self.atv.all_identifiers)
if unique_id := self._entry_unique_id_from_identifers(all_identifiers): if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
return unique_id return unique_id
@ -143,7 +151,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.context["identifier"] = self.unique_id self.context["identifier"] = self.unique_id
return await self.async_step_reconfigure() return await self.async_step_reconfigure()
async def async_step_reconfigure(self, user_input=None): async def async_step_reconfigure(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that reconfiguration is about to start.""" """Inform user that reconfiguration is about to start."""
if user_input is not None: if user_input is not None:
return await self.async_find_device_wrapper( return await self.async_find_device_wrapper(
@ -152,7 +162,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form(step_id="reconfigure") return self.async_show_form(step_id="reconfigure")
async def async_step_user(self, user_input=None): async def async_step_user(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle the initial step.""" """Handle the initial step."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
@ -170,6 +182,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
await self.async_set_unique_id( await self.async_set_unique_id(
self.device_identifier, raise_on_progress=False self.device_identifier, raise_on_progress=False
) )
assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers self.context["all_identifiers"] = self.atv.all_identifiers
return await self.async_step_confirm() return await self.async_step_confirm()
@ -275,8 +288,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
context["all_identifiers"].append(unique_id) context["all_identifiers"].append(unique_id)
raise AbortFlow("already_in_progress") raise AbortFlow("already_in_progress")
async def async_found_zeroconf_device(self, user_input=None): async def async_found_zeroconf_device(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle device found after Zeroconf discovery.""" """Handle device found after Zeroconf discovery."""
assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers self.context["all_identifiers"] = self.atv.all_identifiers
# Also abort if an integration with this identifier already exists # Also abort if an integration with this identifier already exists
await self.async_set_unique_id(self.device_identifier) await self.async_set_unique_id(self.device_identifier)
@ -288,7 +304,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.context["identifier"] = self.unique_id self.context["identifier"] = self.unique_id
return await self.async_step_confirm() return await self.async_step_confirm()
async def async_find_device_wrapper(self, next_func, allow_exist=False): async def async_find_device_wrapper(
self,
next_func: Callable[[], Awaitable[FlowResult]],
allow_exist: bool = False,
) -> FlowResult:
"""Find a specific device and call another function when done. """Find a specific device and call another function when done.
This function will do error handling and bail out when an error This function will do error handling and bail out when an error
@ -306,7 +326,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await next_func() return await next_func()
async def async_find_device(self, allow_exist=False): async def async_find_device(self, allow_exist: bool = False) -> None:
"""Scan for the selected device to discover services.""" """Scan for the selected device to discover services."""
self.atv, self.atv_identifiers = await device_scan( self.atv, self.atv_identifiers = await device_scan(
self.hass, self.scan_filter, self.hass.loop self.hass, self.scan_filter, self.hass.loop
@ -357,8 +377,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if not allow_exist: if not allow_exist:
raise DeviceAlreadyConfigured() raise DeviceAlreadyConfigured()
async def async_step_confirm(self, user_input=None): async def async_step_confirm(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle user-confirmation of discovered node.""" """Handle user-confirmation of discovered node."""
assert self.atv
if user_input is not None: if user_input is not None:
expected_identifier_count = len(self.context["all_identifiers"]) expected_identifier_count = len(self.context["all_identifiers"])
# If number of services found during device scan mismatch number of # If number of services found during device scan mismatch number of
@ -384,7 +407,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
}, },
) )
async def async_pair_next_protocol(self): async def async_pair_next_protocol(self) -> FlowResult:
"""Start pairing process for the next available protocol.""" """Start pairing process for the next available protocol."""
await self._async_cleanup() await self._async_cleanup()
@ -393,8 +416,16 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await self._async_get_entry() return await self._async_get_entry()
self.protocol = self.protocols_to_pair.popleft() self.protocol = self.protocols_to_pair.popleft()
assert self.atv
service = self.atv.get_service(self.protocol) service = self.atv.get_service(self.protocol)
if service is None:
_LOGGER.debug(
"%s does not support pairing (cannot find a corresponding service)",
self.protocol,
)
return await self.async_pair_next_protocol()
# Service requires a password # Service requires a password
if service.requires_password: if service.requires_password:
return await self.async_step_password() return await self.async_step_password()
@ -413,7 +444,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.debug("%s requires pairing", self.protocol) _LOGGER.debug("%s requires pairing", self.protocol)
# Protocol specific arguments # Protocol specific arguments
pair_args = {} pair_args: dict[str, Any] = {}
if self.protocol in {Protocol.AirPlay, Protocol.Companion, Protocol.DMAP}: if self.protocol in {Protocol.AirPlay, Protocol.Companion, Protocol.DMAP}:
pair_args["name"] = "Home Assistant" pair_args["name"] = "Home Assistant"
if self.protocol == Protocol.DMAP: if self.protocol == Protocol.DMAP:
@ -448,8 +479,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await self.async_step_pair_no_pin() return await self.async_step_pair_no_pin()
async def async_step_protocol_disabled(self, user_input=None): async def async_step_protocol_disabled(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that a protocol is disabled and cannot be paired.""" """Inform user that a protocol is disabled and cannot be paired."""
assert self.protocol
if user_input is not None: if user_input is not None:
return await self.async_pair_next_protocol() return await self.async_pair_next_protocol()
return self.async_show_form( return self.async_show_form(
@ -457,9 +491,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)}, description_placeholders={"protocol": protocol_str(self.protocol)},
) )
async def async_step_pair_with_pin(self, user_input=None): async def async_step_pair_with_pin(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle pairing step where a PIN is required from the user.""" """Handle pairing step where a PIN is required from the user."""
errors = {} errors = {}
assert self.pairing
assert self.protocol
if user_input is not None: if user_input is not None:
try: try:
self.pairing.pin(user_input[CONF_PIN]) self.pairing.pin(user_input[CONF_PIN])
@ -480,8 +518,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)}, description_placeholders={"protocol": protocol_str(self.protocol)},
) )
async def async_step_pair_no_pin(self, user_input=None): async def async_step_pair_no_pin(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle step where user has to enter a PIN on the device.""" """Handle step where user has to enter a PIN on the device."""
assert self.pairing
assert self.protocol
if user_input is not None: if user_input is not None:
await self.pairing.finish() await self.pairing.finish()
if self.pairing.has_paired: if self.pairing.has_paired:
@ -497,12 +539,15 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
step_id="pair_no_pin", step_id="pair_no_pin",
description_placeholders={ description_placeholders={
"protocol": protocol_str(self.protocol), "protocol": protocol_str(self.protocol),
"pin": pin, "pin": str(pin),
}, },
) )
async def async_step_service_problem(self, user_input=None): async def async_step_service_problem(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that a service will not be added.""" """Inform user that a service will not be added."""
assert self.protocol
if user_input is not None: if user_input is not None:
return await self.async_pair_next_protocol() return await self.async_pair_next_protocol()
@ -511,8 +556,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)}, description_placeholders={"protocol": protocol_str(self.protocol)},
) )
async def async_step_password(self, user_input=None): async def async_step_password(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that password is not supported.""" """Inform user that password is not supported."""
assert self.protocol
if user_input is not None: if user_input is not None:
return await self.async_pair_next_protocol() return await self.async_pair_next_protocol()
@ -521,18 +569,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)}, description_placeholders={"protocol": protocol_str(self.protocol)},
) )
async def _async_cleanup(self): async def _async_cleanup(self) -> None:
"""Clean up allocated resources.""" """Clean up allocated resources."""
if self.pairing is not None: if self.pairing is not None:
await self.pairing.close() await self.pairing.close()
self.pairing = None self.pairing = None
async def _async_get_entry(self): async def _async_get_entry(self) -> FlowResult:
"""Return config entry or update existing config entry.""" """Return config entry or update existing config entry."""
# Abort if no protocols were paired # Abort if no protocols were paired
if not self.credentials: if not self.credentials:
return self.async_abort(reason="setup_failed") return self.async_abort(reason="setup_failed")
assert self.atv
data = { data = {
CONF_NAME: self.atv.name, CONF_NAME: self.atv.name,
CONF_CREDENTIALS: self.credentials, CONF_CREDENTIALS: self.credentials,

View File

@ -16,7 +16,15 @@ from pyatv.const import (
ShuffleState, ShuffleState,
) )
from pyatv.helpers import is_streamable from pyatv.helpers import is_streamable
from pyatv.interface import AppleTV, Playing from pyatv.interface import (
AppleTV,
AudioListener,
OutputDevice,
Playing,
PowerListener,
PushListener,
PushUpdater,
)
from homeassistant.components import media_source from homeassistant.components import media_source
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
@ -101,7 +109,9 @@ async def async_setup_entry(
async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)]) async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)])
class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity): class AppleTvMediaPlayer(
AppleTVEntity, MediaPlayerEntity, PowerListener, AudioListener, PushListener
):
"""Representation of an Apple TV media player.""" """Representation of an Apple TV media player."""
_attr_supported_features = SUPPORT_APPLE_TV _attr_supported_features = SUPPORT_APPLE_TV
@ -116,9 +126,9 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
def async_device_connected(self, atv: AppleTV) -> None: def async_device_connected(self, atv: AppleTV) -> None:
"""Handle when connection is made to device.""" """Handle when connection is made to device."""
# NB: Do not use _is_feature_available here as it only works when playing # NB: Do not use _is_feature_available here as it only works when playing
if self.atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates): if atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates):
self.atv.push_updater.listener = self atv.push_updater.listener = self
self.atv.push_updater.start() atv.push_updater.start()
self._attr_supported_features = SUPPORT_BASE self._attr_supported_features = SUPPORT_BASE
@ -126,7 +136,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
# "Unsupported" are considered here as the state of such a feature can never # "Unsupported" are considered here as the state of such a feature can never
# change after a connection has been established, i.e. an unsupported feature # change after a connection has been established, i.e. an unsupported feature
# can never change to be supported. # can never change to be supported.
all_features = self.atv.features.all_features() all_features = atv.features.all_features()
for feature_name, support_flag in SUPPORT_FEATURE_MAPPING.items(): for feature_name, support_flag in SUPPORT_FEATURE_MAPPING.items():
feature_info = all_features.get(feature_name) feature_info = all_features.get(feature_name)
if feature_info and feature_info.state != FeatureState.Unsupported: if feature_info and feature_info.state != FeatureState.Unsupported:
@ -136,16 +146,18 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
# metadata update arrives (sometime very soon after this callback returns) # metadata update arrives (sometime very soon after this callback returns)
# Listen to power updates # Listen to power updates
self.atv.power.listener = self atv.power.listener = self
# Listen to volume updates # Listen to volume updates
self.atv.audio.listener = self atv.audio.listener = self
if self.atv.features.in_state(FeatureState.Available, FeatureName.AppList): if atv.features.in_state(FeatureState.Available, FeatureName.AppList):
self.hass.create_task(self._update_app_list()) self.hass.create_task(self._update_app_list())
async def _update_app_list(self) -> None: async def _update_app_list(self) -> None:
_LOGGER.debug("Updating app list") _LOGGER.debug("Updating app list")
if not self.atv:
return
try: try:
apps = await self.atv.apps.app_list() apps = await self.atv.apps.app_list()
except exceptions.NotSupportedError: except exceptions.NotSupportedError:
@ -189,33 +201,56 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
return None return None
@callback @callback
def playstatus_update(self, _, playing: Playing) -> None: def playstatus_update(self, updater: PushUpdater, playstatus: Playing) -> None:
"""Print what is currently playing when it changes.""" """Print what is currently playing when it changes.
self._playing = playing
This is a callback function from pyatv.interface.PushListener.
"""
self._playing = playstatus
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def playstatus_error(self, _, exception: Exception) -> None: def playstatus_error(self, updater: PushUpdater, exception: Exception) -> None:
"""Inform about an error and restart push updates.""" """Inform about an error and restart push updates.
This is a callback function from pyatv.interface.PushListener.
"""
_LOGGER.warning("A %s error occurred: %s", exception.__class__, exception) _LOGGER.warning("A %s error occurred: %s", exception.__class__, exception)
self._playing = None self._playing = None
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def powerstate_update(self, old_state: PowerState, new_state: PowerState) -> None: def powerstate_update(self, old_state: PowerState, new_state: PowerState) -> None:
"""Update power state when it changes.""" """Update power state when it changes.
This is a callback function from pyatv.interface.PowerListener.
"""
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def volume_update(self, old_level: float, new_level: float) -> None: def volume_update(self, old_level: float, new_level: float) -> None:
"""Update volume when it changes.""" """Update volume when it changes.
This is a callback function from pyatv.interface.AudioListener.
"""
self.async_write_ha_state() self.async_write_ha_state()
@callback
def outputdevices_update(
self, old_devices: list[OutputDevice], new_devices: list[OutputDevice]
) -> None:
"""Output devices were updated.
This is a callback function from pyatv.interface.AudioListener.
"""
@property @property
def app_id(self) -> str | None: def app_id(self) -> str | None:
"""ID of the current running app.""" """ID of the current running app."""
if self._is_feature_available(FeatureName.App) and ( if (
app := self.atv.metadata.app self.atv
and self._is_feature_available(FeatureName.App)
and (app := self.atv.metadata.app) is not None
): ):
return app.identifier return app.identifier
return None return None
@ -223,8 +258,10 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
@property @property
def app_name(self) -> str | None: def app_name(self) -> str | None:
"""Name of the current running app.""" """Name of the current running app."""
if self._is_feature_available(FeatureName.App) and ( if (
app := self.atv.metadata.app self.atv
and self._is_feature_available(FeatureName.App)
and (app := self.atv.metadata.app) is not None
): ):
return app.name return app.name
return None return None
@ -255,7 +292,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
@property @property
def volume_level(self) -> float | None: def volume_level(self) -> float | None:
"""Volume level of the media player (0..1).""" """Volume level of the media player (0..1)."""
if self._is_feature_available(FeatureName.Volume): if self.atv and self._is_feature_available(FeatureName.Volume):
return self.atv.audio.volume / 100.0 # from percent return self.atv.audio.volume / 100.0 # from percent
return None return None
@ -286,6 +323,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
"""Send the play_media command to the media player.""" """Send the play_media command to the media player."""
# If input (file) has a file format supported by pyatv, then stream it with # If input (file) has a file format supported by pyatv, then stream it with
# RAOP. Otherwise try to play it with regular AirPlay. # RAOP. Otherwise try to play it with regular AirPlay.
if not self.atv:
return
if media_type in {MediaType.APP, MediaType.URL}: if media_type in {MediaType.APP, MediaType.URL}:
await self.atv.apps.launch_app(media_id) await self.atv.apps.launch_app(media_id)
return return
@ -313,7 +352,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
"""Hash value for media image.""" """Hash value for media image."""
state = self.state state = self.state
if ( if (
self._playing self.atv
and self._playing
and self._is_feature_available(FeatureName.Artwork) and self._is_feature_available(FeatureName.Artwork)
and state not in {None, MediaPlayerState.OFF, MediaPlayerState.IDLE} and state not in {None, MediaPlayerState.OFF, MediaPlayerState.IDLE}
): ):
@ -323,7 +363,11 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_get_media_image(self) -> tuple[bytes | None, str | None]: async def async_get_media_image(self) -> tuple[bytes | None, str | None]:
"""Fetch media image of current playing image.""" """Fetch media image of current playing image."""
state = self.state state = self.state
if self._playing and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}: if (
self.atv
and self._playing
and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}
):
artwork = await self.atv.metadata.artwork() artwork = await self.atv.metadata.artwork()
if artwork: if artwork:
return artwork.bytes, artwork.mimetype return artwork.bytes, artwork.mimetype
@ -439,20 +483,24 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_turn_on(self) -> None: async def async_turn_on(self) -> None:
"""Turn the media player on.""" """Turn the media player on."""
if self._is_feature_available(FeatureName.TurnOn): if self.atv and self._is_feature_available(FeatureName.TurnOn):
await self.atv.power.turn_on() await self.atv.power.turn_on()
async def async_turn_off(self) -> None: async def async_turn_off(self) -> None:
"""Turn the media player off.""" """Turn the media player off."""
if (self._is_feature_available(FeatureName.TurnOff)) and ( if (
not self._is_feature_available(FeatureName.PowerState) self.atv
or self.atv.power.power_state == PowerState.On and (self._is_feature_available(FeatureName.TurnOff))
and (
not self._is_feature_available(FeatureName.PowerState)
or self.atv.power.power_state == PowerState.On
)
): ):
await self.atv.power.turn_off() await self.atv.power.turn_off()
async def async_media_play_pause(self) -> None: async def async_media_play_pause(self) -> None:
"""Pause media on media player.""" """Pause media on media player."""
if self._playing: if self.atv and self._playing:
await self.atv.remote_control.play_pause() await self.atv.remote_control.play_pause()
async def async_media_play(self) -> None: async def async_media_play(self) -> None:
@ -519,5 +567,6 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_select_source(self, source: str) -> None: async def async_select_source(self, source: str) -> None:
"""Select input source.""" """Select input source."""
if app_id := self._app_list.get(source): if self.atv:
await self.atv.apps.launch_app(app_id) if app_id := self._app_list.get(source):
await self.atv.apps.launch_app(app_id)

View File

@ -15,7 +15,7 @@ from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import AppleTVEntity from . import AppleTVEntity, AppleTVManager
from .const import DOMAIN from .const import DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -38,8 +38,10 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Load Apple TV remote based on a config entry.""" """Load Apple TV remote based on a config entry."""
name = config_entry.data[CONF_NAME] name: str = config_entry.data[CONF_NAME]
manager = hass.data[DOMAIN][config_entry.unique_id] # apple_tv config entries always have a unique id
assert config_entry.unique_id is not None
manager: AppleTVManager = hass.data[DOMAIN][config_entry.unique_id]
async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)]) async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)])
@ -47,7 +49,7 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
"""Device that sends commands to an Apple TV.""" """Device that sends commands to an Apple TV."""
@property @property
def is_on(self): def is_on(self) -> bool:
"""Return true if device is on.""" """Return true if device is on."""
return self.atv is not None return self.atv is not None
@ -64,13 +66,13 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
num_repeats = kwargs[ATTR_NUM_REPEATS] num_repeats = kwargs[ATTR_NUM_REPEATS]
delay = kwargs.get(ATTR_DELAY_SECS, DEFAULT_DELAY_SECS) delay = kwargs.get(ATTR_DELAY_SECS, DEFAULT_DELAY_SECS)
if not self.is_on: if not self.atv:
_LOGGER.error("Unable to send commands, not connected to %s", self.name) _LOGGER.error("Unable to send commands, not connected to %s", self.name)
return return
for _ in range(num_repeats): for _ in range(num_repeats):
for single_command in command: for single_command in command:
attr_value = None attr_value: Any = None
if attributes := COMMAND_TO_ATTRIBUTE.get(single_command): if attributes := COMMAND_TO_ATTRIBUTE.get(single_command):
attr_value = self.atv attr_value = self.atv
for attr_name in attributes: for attr_name in attributes:
@ -81,5 +83,5 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
raise ValueError("Command not found. Exiting sequence") raise ValueError("Command not found. Exiting sequence")
_LOGGER.info("Sending command %s", single_command) _LOGGER.info("Sending command %s", single_command)
await attr_value() # type: ignore[operator] await attr_value()
await asyncio.sleep(delay) await asyncio.sleep(delay)

View File

@ -560,6 +560,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.apple_tv.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.apprise.*] [mypy-homeassistant.components.apprise.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true

View File

@ -1,6 +1,6 @@
"""Test config flow.""" """Test config flow."""
from ipaddress import IPv4Address, ip_address from ipaddress import IPv4Address, ip_address
from unittest.mock import ANY, patch from unittest.mock import ANY, Mock, patch
from pyatv import exceptions from pyatv import exceptions
from pyatv.const import PairingRequirement, Protocol from pyatv.const import PairingRequirement, Protocol
@ -125,7 +125,7 @@ async def test_user_adds_full_device(hass: HomeAssistant, full_device, pairing)
result["flow_id"], {"pin": 1111} result["flow_id"], {"pin": 1111}
) )
assert result4["type"] == data_entry_flow.FlowResultType.FORM assert result4["type"] == data_entry_flow.FlowResultType.FORM
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": 1111} assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"}
result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result5["type"] == data_entry_flow.FlowResultType.FORM assert result5["type"] == data_entry_flow.FlowResultType.FORM
@ -167,7 +167,7 @@ async def test_user_adds_dmap_device(
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == data_entry_flow.FlowResultType.FORM assert result3["type"] == data_entry_flow.FlowResultType.FORM
assert result3["description_placeholders"] == {"pin": 1111, "protocol": "DMAP"} assert result3["description_placeholders"] == {"pin": "1111", "protocol": "DMAP"}
result6 = await hass.config_entries.flow.async_configure( result6 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1234} result["flow_id"], {"pin": 1234}
@ -646,7 +646,7 @@ async def test_zeroconf_add_dmap_device(
{}, {},
) )
assert result2["type"] == data_entry_flow.FlowResultType.FORM assert result2["type"] == data_entry_flow.FlowResultType.FORM
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": 1111} assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {}) result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == "create_entry" assert result3["type"] == "create_entry"
@ -1130,6 +1130,43 @@ async def test_zeroconf_pair_additionally_found_protocols(
assert result5["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert result5["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
async def test_zeroconf_mismatch(
hass: HomeAssistant, mock_scan, pairing, mock_zeroconf: None
) -> None:
"""Test the technically possible case where a protocol has no service.
This could happen in case of mDNS issues.
"""
mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
mock_scan.result[0].get_service = Mock(return_value=None)
# Find device with AirPlay service and set up flow for it
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
ip_address=ip_address("127.0.0.1"),
ip_addresses=[ip_address("127.0.0.1")],
hostname="mock_hostname",
port=None,
type="_airplay._tcp.local.",
name="Kitchen",
properties={"deviceid": "airplayid"},
),
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == "setup_failed"
# Re-configuration # Re-configuration