Enable strict type checking on apple_tv integration (#101688)

* Enable strict type checking on apple_tv integration

* move some instance variables to class variables

* fix type of attr_value

* fix tests for description_placeholders assertion

* nits

* Apply suggestions from code review

* Update remote.py

* Apply suggestions from code review

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Update config_flow.py

* Improve test coverage

* Update test_config_flow.py

* Update __init__.py

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Stackie Jia 2024-02-15 22:17:00 +08:00 committed by GitHub
parent d555f91702
commit 636c7ce350
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 253 additions and 104 deletions

View File

@ -80,6 +80,7 @@ homeassistant.components.anthemav.*
homeassistant.components.apache_kafka.*
homeassistant.components.apcupsd.*
homeassistant.components.api.*
homeassistant.components.apple_tv.*
homeassistant.components.apprise.*
homeassistant.components.aprs.*
homeassistant.components.aqualogic.*

View File

@ -1,8 +1,10 @@
"""The Apple TV integration."""
from __future__ import annotations
import asyncio
import logging
from random import randrange
from typing import TYPE_CHECKING, cast
from typing import Any, cast
from pyatv import connect, exceptions, scan
from pyatv.conf import AppleTV
@ -25,7 +27,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -89,7 +91,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
async def on_hass_stop(event):
async def on_hass_stop(event: Event) -> None:
"""Stop push updates when hass stops."""
await manager.disconnect()
@ -120,33 +122,29 @@ class AppleTVEntity(Entity):
_attr_should_poll = False
_attr_has_entity_name = True
_attr_name = None
atv: AppleTVInterface | None = None
def __init__(
self, name: str, identifier: str | None, manager: "AppleTVManager"
) -> None:
def __init__(self, name: str, identifier: str, manager: AppleTVManager) -> None:
"""Initialize device."""
self.atv: AppleTVInterface = None # type: ignore[assignment]
self.manager = manager
if TYPE_CHECKING:
assert identifier is not None
self._attr_unique_id = identifier
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, identifier)},
name=name,
)
async def async_added_to_hass(self):
async def async_added_to_hass(self) -> None:
"""Handle when an entity is about to be added to Home Assistant."""
@callback
def _async_connected(atv):
def _async_connected(atv: AppleTVInterface) -> None:
"""Handle that a connection was made to a device."""
self.atv = atv
self.async_device_connected(atv)
self.async_write_ha_state()
@callback
def _async_disconnected():
def _async_disconnected() -> None:
"""Handle that a connection to a device was lost."""
self.async_device_disconnected()
self.atv = None
@ -169,10 +167,10 @@ class AppleTVEntity(Entity):
)
)
def async_device_connected(self, atv):
def async_device_connected(self, atv: AppleTVInterface) -> None:
"""Handle when connection is made to device."""
def async_device_disconnected(self):
def async_device_disconnected(self) -> None:
"""Handle when connection was lost to device."""
@ -184,22 +182,23 @@ class AppleTVManager(DeviceListener):
in case of problems.
"""
atv: AppleTVInterface | None = None
_connection_attempts = 0
_connection_was_lost = False
_task: asyncio.Task[None] | None = None
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize power manager."""
self.config_entry = config_entry
self.hass = hass
self.atv: AppleTVInterface | None = None
self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
self._task = None
async def init(self):
async def init(self) -> None:
"""Initialize power management."""
if self.is_on:
await self.connect()
def connection_lost(self, _):
def connection_lost(self, exception: Exception) -> None:
"""Device was unexpectedly disconnected.
This is a callback function from pyatv.interface.DeviceListener.
@ -210,14 +209,14 @@ class AppleTVManager(DeviceListener):
self._connection_was_lost = True
self._handle_disconnect()
def connection_closed(self):
def connection_closed(self) -> None:
"""Device connection was (intentionally) closed.
This is a callback function from pyatv.interface.DeviceListener.
"""
self._handle_disconnect()
def _handle_disconnect(self):
def _handle_disconnect(self) -> None:
"""Handle that the device disconnected and restart connect loop."""
if self.atv:
self.atv.close()
@ -225,12 +224,12 @@ class AppleTVManager(DeviceListener):
self._dispatch_send(SIGNAL_DISCONNECTED)
self._start_connect_loop()
async def connect(self):
async def connect(self) -> None:
"""Connect to device."""
self.is_on = True
self._start_connect_loop()
async def disconnect(self):
async def disconnect(self) -> None:
"""Disconnect from device."""
_LOGGER.debug("Disconnecting from device")
self.is_on = False
@ -244,7 +243,7 @@ class AppleTVManager(DeviceListener):
except Exception: # pylint: disable=broad-except
_LOGGER.exception("An error occurred while disconnecting")
def _start_connect_loop(self):
def _start_connect_loop(self) -> None:
"""Start background connect loop to device."""
if not self._task and self.atv is None and self.is_on:
self._task = asyncio.create_task(self._connect_loop())
@ -258,7 +257,7 @@ class AppleTVManager(DeviceListener):
if conf := await self._scan():
await self._connect(conf, raise_missing_credentials)
async def async_first_connect(self):
async def async_first_connect(self) -> None:
"""Connect to device for the first time."""
connect_ok = False
try:
@ -286,7 +285,7 @@ class AppleTVManager(DeviceListener):
_LOGGER.exception("Failed to connect")
await self.disconnect()
async def _connect_loop(self):
async def _connect_loop(self) -> None:
"""Connect loop background task function."""
_LOGGER.debug("Starting connect loop")
@ -295,7 +294,8 @@ class AppleTVManager(DeviceListener):
while self.is_on and self.atv is None:
await self.connect_once(raise_missing_credentials=False)
if self.atv is not None:
break
# Calling self.connect_once may have set self.atv
break # type: ignore[unreachable]
self._connection_attempts += 1
backoff = min(
max(
@ -392,7 +392,7 @@ class AppleTVManager(DeviceListener):
self._connection_was_lost = False
@callback
def _async_setup_device_registry(self):
def _async_setup_device_registry(self) -> None:
attrs = {
ATTR_IDENTIFIERS: {(DOMAIN, self.config_entry.unique_id)},
ATTR_MANUFACTURER: "Apple",
@ -423,18 +423,18 @@ class AppleTVManager(DeviceListener):
)
@property
def is_connecting(self):
def is_connecting(self) -> bool:
"""Return true if connection is in progress."""
return self._task is not None
def _address_updated(self, address):
def _address_updated(self, address: str) -> None:
"""Update cached address in config entry."""
_LOGGER.debug("Changing address to %s", address)
self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address}
)
def _dispatch_send(self, signal, *args):
def _dispatch_send(self, signal: str, *args: Any) -> None:
"""Dispatch a signal to all entities managed by this manager."""
async_dispatcher_send(
self.hass, f"{signal}_{self.config_entry.unique_id}", *args

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from ipaddress import ip_address
import logging
from random import randrange
@ -13,12 +13,13 @@ from pyatv import exceptions, pair, scan
from pyatv.const import DeviceModel, PairingRequirement, Protocol
from pyatv.convert import model_str, protocol_str
from pyatv.helpers import get_unique_id
from pyatv.interface import BaseConfig, PairingHandler
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.components import zeroconf
from homeassistant.const import CONF_ADDRESS, CONF_NAME, CONF_PIN
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -49,10 +50,12 @@ OPTIONS_FLOW = {
}
async def device_scan(hass, identifier, loop):
async def device_scan(
hass: HomeAssistant, identifier: str | None, loop: asyncio.AbstractEventLoop
) -> tuple[BaseConfig | None, list[str] | None]:
"""Scan for a specific device using identifier as filter."""
def _filter_device(dev):
def _filter_device(dev: BaseConfig) -> bool:
if identifier is None:
return True
if identifier == str(dev.address):
@ -61,9 +64,12 @@ async def device_scan(hass, identifier, loop):
return True
return any(service.identifier == identifier for service in dev.services)
def _host_filter():
def _host_filter() -> list[str] | None:
if identifier is None:
return None
try:
return [ip_address(identifier)]
ip_address(identifier)
return [identifier]
except ValueError:
return None
@ -84,6 +90,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1
scan_filter: str | None = None
atv: BaseConfig | None = None
atv_identifiers: list[str] | None = None
protocol: Protocol | None = None
pairing: PairingHandler | None = None
protocols_to_pair: deque[Protocol] | None = None
@staticmethod
@callback
def async_get_options_flow(
@ -92,18 +105,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Get options flow for this handler."""
return SchemaOptionsFlowHandler(config_entry, OPTIONS_FLOW)
def __init__(self):
def __init__(self) -> None:
"""Initialize a new AppleTVConfigFlow."""
self.scan_filter = None
self.atv = None
self.atv_identifiers = None
self.protocol = None
self.pairing = None
self.credentials = {} # Protocol -> credentials
self.protocols_to_pair = deque()
self.credentials: dict[int, str | None] = {} # Protocol -> credentials
@property
def device_identifier(self):
def device_identifier(self) -> str | None:
"""Return a identifier for the config entry.
A device has multiple unique identifiers, but Home Assistant only supports one
@ -118,6 +125,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
existing config entry. If that's the case, the unique_id from that entry is
re-used, otherwise the newly discovered identifier is used instead.
"""
assert self.atv
all_identifiers = set(self.atv.all_identifiers)
if unique_id := self._entry_unique_id_from_identifers(all_identifiers):
return unique_id
@ -143,7 +151,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.context["identifier"] = self.unique_id
return await self.async_step_reconfigure()
async def async_step_reconfigure(self, user_input=None):
async def async_step_reconfigure(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that reconfiguration is about to start."""
if user_input is not None:
return await self.async_find_device_wrapper(
@ -152,7 +162,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form(step_id="reconfigure")
async def async_step_user(self, user_input=None):
async def async_step_user(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle the initial step."""
errors = {}
if user_input is not None:
@ -170,6 +182,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
await self.async_set_unique_id(
self.device_identifier, raise_on_progress=False
)
assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers
return await self.async_step_confirm()
@ -275,8 +288,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
context["all_identifiers"].append(unique_id)
raise AbortFlow("already_in_progress")
async def async_found_zeroconf_device(self, user_input=None):
async def async_found_zeroconf_device(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle device found after Zeroconf discovery."""
assert self.atv
self.context["all_identifiers"] = self.atv.all_identifiers
# Also abort if an integration with this identifier already exists
await self.async_set_unique_id(self.device_identifier)
@ -288,7 +304,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.context["identifier"] = self.unique_id
return await self.async_step_confirm()
async def async_find_device_wrapper(self, next_func, allow_exist=False):
async def async_find_device_wrapper(
self,
next_func: Callable[[], Awaitable[FlowResult]],
allow_exist: bool = False,
) -> FlowResult:
"""Find a specific device and call another function when done.
This function will do error handling and bail out when an error
@ -306,7 +326,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await next_func()
async def async_find_device(self, allow_exist=False):
async def async_find_device(self, allow_exist: bool = False) -> None:
"""Scan for the selected device to discover services."""
self.atv, self.atv_identifiers = await device_scan(
self.hass, self.scan_filter, self.hass.loop
@ -357,8 +377,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if not allow_exist:
raise DeviceAlreadyConfigured()
async def async_step_confirm(self, user_input=None):
async def async_step_confirm(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle user-confirmation of discovered node."""
assert self.atv
if user_input is not None:
expected_identifier_count = len(self.context["all_identifiers"])
# If number of services found during device scan mismatch number of
@ -384,7 +407,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
},
)
async def async_pair_next_protocol(self):
async def async_pair_next_protocol(self) -> FlowResult:
"""Start pairing process for the next available protocol."""
await self._async_cleanup()
@ -393,8 +416,16 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await self._async_get_entry()
self.protocol = self.protocols_to_pair.popleft()
assert self.atv
service = self.atv.get_service(self.protocol)
if service is None:
_LOGGER.debug(
"%s does not support pairing (cannot find a corresponding service)",
self.protocol,
)
return await self.async_pair_next_protocol()
# Service requires a password
if service.requires_password:
return await self.async_step_password()
@ -413,7 +444,7 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
_LOGGER.debug("%s requires pairing", self.protocol)
# Protocol specific arguments
pair_args = {}
pair_args: dict[str, Any] = {}
if self.protocol in {Protocol.AirPlay, Protocol.Companion, Protocol.DMAP}:
pair_args["name"] = "Home Assistant"
if self.protocol == Protocol.DMAP:
@ -448,8 +479,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return await self.async_step_pair_no_pin()
async def async_step_protocol_disabled(self, user_input=None):
async def async_step_protocol_disabled(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that a protocol is disabled and cannot be paired."""
assert self.protocol
if user_input is not None:
return await self.async_pair_next_protocol()
return self.async_show_form(
@ -457,9 +491,13 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)},
)
async def async_step_pair_with_pin(self, user_input=None):
async def async_step_pair_with_pin(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle pairing step where a PIN is required from the user."""
errors = {}
assert self.pairing
assert self.protocol
if user_input is not None:
try:
self.pairing.pin(user_input[CONF_PIN])
@ -480,8 +518,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)},
)
async def async_step_pair_no_pin(self, user_input=None):
async def async_step_pair_no_pin(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Handle step where user has to enter a PIN on the device."""
assert self.pairing
assert self.protocol
if user_input is not None:
await self.pairing.finish()
if self.pairing.has_paired:
@ -497,12 +539,15 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
step_id="pair_no_pin",
description_placeholders={
"protocol": protocol_str(self.protocol),
"pin": pin,
"pin": str(pin),
},
)
async def async_step_service_problem(self, user_input=None):
async def async_step_service_problem(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that a service will not be added."""
assert self.protocol
if user_input is not None:
return await self.async_pair_next_protocol()
@ -511,8 +556,11 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)},
)
async def async_step_password(self, user_input=None):
async def async_step_password(
self, user_input: dict[str, str] | None = None
) -> FlowResult:
"""Inform user that password is not supported."""
assert self.protocol
if user_input is not None:
return await self.async_pair_next_protocol()
@ -521,18 +569,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"protocol": protocol_str(self.protocol)},
)
async def _async_cleanup(self):
async def _async_cleanup(self) -> None:
"""Clean up allocated resources."""
if self.pairing is not None:
await self.pairing.close()
self.pairing = None
async def _async_get_entry(self):
async def _async_get_entry(self) -> FlowResult:
"""Return config entry or update existing config entry."""
# Abort if no protocols were paired
if not self.credentials:
return self.async_abort(reason="setup_failed")
assert self.atv
data = {
CONF_NAME: self.atv.name,
CONF_CREDENTIALS: self.credentials,

View File

@ -16,7 +16,15 @@ from pyatv.const import (
ShuffleState,
)
from pyatv.helpers import is_streamable
from pyatv.interface import AppleTV, Playing
from pyatv.interface import (
AppleTV,
AudioListener,
OutputDevice,
Playing,
PowerListener,
PushListener,
PushUpdater,
)
from homeassistant.components import media_source
from homeassistant.components.media_player import (
@ -101,7 +109,9 @@ async def async_setup_entry(
async_add_entities([AppleTvMediaPlayer(name, config_entry.unique_id, manager)])
class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
class AppleTvMediaPlayer(
AppleTVEntity, MediaPlayerEntity, PowerListener, AudioListener, PushListener
):
"""Representation of an Apple TV media player."""
_attr_supported_features = SUPPORT_APPLE_TV
@ -116,9 +126,9 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
def async_device_connected(self, atv: AppleTV) -> None:
"""Handle when connection is made to device."""
# NB: Do not use _is_feature_available here as it only works when playing
if self.atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates):
self.atv.push_updater.listener = self
self.atv.push_updater.start()
if atv.features.in_state(FeatureState.Available, FeatureName.PushUpdates):
atv.push_updater.listener = self
atv.push_updater.start()
self._attr_supported_features = SUPPORT_BASE
@ -126,7 +136,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
# "Unsupported" are considered here as the state of such a feature can never
# change after a connection has been established, i.e. an unsupported feature
# can never change to be supported.
all_features = self.atv.features.all_features()
all_features = atv.features.all_features()
for feature_name, support_flag in SUPPORT_FEATURE_MAPPING.items():
feature_info = all_features.get(feature_name)
if feature_info and feature_info.state != FeatureState.Unsupported:
@ -136,16 +146,18 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
# metadata update arrives (sometime very soon after this callback returns)
# Listen to power updates
self.atv.power.listener = self
atv.power.listener = self
# Listen to volume updates
self.atv.audio.listener = self
atv.audio.listener = self
if self.atv.features.in_state(FeatureState.Available, FeatureName.AppList):
if atv.features.in_state(FeatureState.Available, FeatureName.AppList):
self.hass.create_task(self._update_app_list())
async def _update_app_list(self) -> None:
_LOGGER.debug("Updating app list")
if not self.atv:
return
try:
apps = await self.atv.apps.app_list()
except exceptions.NotSupportedError:
@ -189,33 +201,56 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
return None
@callback
def playstatus_update(self, _, playing: Playing) -> None:
"""Print what is currently playing when it changes."""
self._playing = playing
def playstatus_update(self, updater: PushUpdater, playstatus: Playing) -> None:
"""Print what is currently playing when it changes.
This is a callback function from pyatv.interface.PushListener.
"""
self._playing = playstatus
self.async_write_ha_state()
@callback
def playstatus_error(self, _, exception: Exception) -> None:
"""Inform about an error and restart push updates."""
def playstatus_error(self, updater: PushUpdater, exception: Exception) -> None:
"""Inform about an error and restart push updates.
This is a callback function from pyatv.interface.PushListener.
"""
_LOGGER.warning("A %s error occurred: %s", exception.__class__, exception)
self._playing = None
self.async_write_ha_state()
@callback
def powerstate_update(self, old_state: PowerState, new_state: PowerState) -> None:
"""Update power state when it changes."""
"""Update power state when it changes.
This is a callback function from pyatv.interface.PowerListener.
"""
self.async_write_ha_state()
@callback
def volume_update(self, old_level: float, new_level: float) -> None:
"""Update volume when it changes."""
"""Update volume when it changes.
This is a callback function from pyatv.interface.AudioListener.
"""
self.async_write_ha_state()
@callback
def outputdevices_update(
self, old_devices: list[OutputDevice], new_devices: list[OutputDevice]
) -> None:
"""Output devices were updated.
This is a callback function from pyatv.interface.AudioListener.
"""
@property
def app_id(self) -> str | None:
"""ID of the current running app."""
if self._is_feature_available(FeatureName.App) and (
app := self.atv.metadata.app
if (
self.atv
and self._is_feature_available(FeatureName.App)
and (app := self.atv.metadata.app) is not None
):
return app.identifier
return None
@ -223,8 +258,10 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
@property
def app_name(self) -> str | None:
"""Name of the current running app."""
if self._is_feature_available(FeatureName.App) and (
app := self.atv.metadata.app
if (
self.atv
and self._is_feature_available(FeatureName.App)
and (app := self.atv.metadata.app) is not None
):
return app.name
return None
@ -255,7 +292,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
@property
def volume_level(self) -> float | None:
"""Volume level of the media player (0..1)."""
if self._is_feature_available(FeatureName.Volume):
if self.atv and self._is_feature_available(FeatureName.Volume):
return self.atv.audio.volume / 100.0 # from percent
return None
@ -286,6 +323,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
"""Send the play_media command to the media player."""
# If input (file) has a file format supported by pyatv, then stream it with
# RAOP. Otherwise try to play it with regular AirPlay.
if not self.atv:
return
if media_type in {MediaType.APP, MediaType.URL}:
await self.atv.apps.launch_app(media_id)
return
@ -313,7 +352,8 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
"""Hash value for media image."""
state = self.state
if (
self._playing
self.atv
and self._playing
and self._is_feature_available(FeatureName.Artwork)
and state not in {None, MediaPlayerState.OFF, MediaPlayerState.IDLE}
):
@ -323,7 +363,11 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_get_media_image(self) -> tuple[bytes | None, str | None]:
"""Fetch media image of current playing image."""
state = self.state
if self._playing and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}:
if (
self.atv
and self._playing
and state not in {MediaPlayerState.OFF, MediaPlayerState.IDLE}
):
artwork = await self.atv.metadata.artwork()
if artwork:
return artwork.bytes, artwork.mimetype
@ -439,20 +483,24 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_turn_on(self) -> None:
"""Turn the media player on."""
if self._is_feature_available(FeatureName.TurnOn):
if self.atv and self._is_feature_available(FeatureName.TurnOn):
await self.atv.power.turn_on()
async def async_turn_off(self) -> None:
"""Turn the media player off."""
if (self._is_feature_available(FeatureName.TurnOff)) and (
not self._is_feature_available(FeatureName.PowerState)
or self.atv.power.power_state == PowerState.On
if (
self.atv
and (self._is_feature_available(FeatureName.TurnOff))
and (
not self._is_feature_available(FeatureName.PowerState)
or self.atv.power.power_state == PowerState.On
)
):
await self.atv.power.turn_off()
async def async_media_play_pause(self) -> None:
"""Pause media on media player."""
if self._playing:
if self.atv and self._playing:
await self.atv.remote_control.play_pause()
async def async_media_play(self) -> None:
@ -519,5 +567,6 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_select_source(self, source: str) -> None:
"""Select input source."""
if app_id := self._app_list.get(source):
await self.atv.apps.launch_app(app_id)
if self.atv:
if app_id := self._app_list.get(source):
await self.atv.apps.launch_app(app_id)

View File

@ -15,7 +15,7 @@ from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from . import AppleTVEntity
from . import AppleTVEntity, AppleTVManager
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -38,8 +38,10 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Load Apple TV remote based on a config entry."""
name = config_entry.data[CONF_NAME]
manager = hass.data[DOMAIN][config_entry.unique_id]
name: str = config_entry.data[CONF_NAME]
# apple_tv config entries always have a unique id
assert config_entry.unique_id is not None
manager: AppleTVManager = hass.data[DOMAIN][config_entry.unique_id]
async_add_entities([AppleTVRemote(name, config_entry.unique_id, manager)])
@ -47,7 +49,7 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
"""Device that sends commands to an Apple TV."""
@property
def is_on(self):
def is_on(self) -> bool:
"""Return true if device is on."""
return self.atv is not None
@ -64,13 +66,13 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
num_repeats = kwargs[ATTR_NUM_REPEATS]
delay = kwargs.get(ATTR_DELAY_SECS, DEFAULT_DELAY_SECS)
if not self.is_on:
if not self.atv:
_LOGGER.error("Unable to send commands, not connected to %s", self.name)
return
for _ in range(num_repeats):
for single_command in command:
attr_value = None
attr_value: Any = None
if attributes := COMMAND_TO_ATTRIBUTE.get(single_command):
attr_value = self.atv
for attr_name in attributes:
@ -81,5 +83,5 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
raise ValueError("Command not found. Exiting sequence")
_LOGGER.info("Sending command %s", single_command)
await attr_value() # type: ignore[operator]
await attr_value()
await asyncio.sleep(delay)

View File

@ -560,6 +560,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.apple_tv.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.apprise.*]
check_untyped_defs = true
disallow_incomplete_defs = true

View File

@ -1,6 +1,6 @@
"""Test config flow."""
from ipaddress import IPv4Address, ip_address
from unittest.mock import ANY, patch
from unittest.mock import ANY, Mock, patch
from pyatv import exceptions
from pyatv.const import PairingRequirement, Protocol
@ -125,7 +125,7 @@ async def test_user_adds_full_device(hass: HomeAssistant, full_device, pairing)
result["flow_id"], {"pin": 1111}
)
assert result4["type"] == data_entry_flow.FlowResultType.FORM
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
assert result4["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"}
result5 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result5["type"] == data_entry_flow.FlowResultType.FORM
@ -167,7 +167,7 @@ async def test_user_adds_dmap_device(
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == data_entry_flow.FlowResultType.FORM
assert result3["description_placeholders"] == {"pin": 1111, "protocol": "DMAP"}
assert result3["description_placeholders"] == {"pin": "1111", "protocol": "DMAP"}
result6 = await hass.config_entries.flow.async_configure(
result["flow_id"], {"pin": 1234}
@ -646,7 +646,7 @@ async def test_zeroconf_add_dmap_device(
{},
)
assert result2["type"] == data_entry_flow.FlowResultType.FORM
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": 1111}
assert result2["description_placeholders"] == {"protocol": "DMAP", "pin": "1111"}
result3 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result3["type"] == "create_entry"
@ -1130,6 +1130,43 @@ async def test_zeroconf_pair_additionally_found_protocols(
assert result5["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
async def test_zeroconf_mismatch(
hass: HomeAssistant, mock_scan, pairing, mock_zeroconf: None
) -> None:
"""Test the technically possible case where a protocol has no service.
This could happen in case of mDNS issues.
"""
mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
mock_scan.result[0].get_service = Mock(return_value=None)
# Find device with AirPlay service and set up flow for it
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_ZEROCONF},
data=zeroconf.ZeroconfServiceInfo(
ip_address=ip_address("127.0.0.1"),
ip_addresses=[ip_address("127.0.0.1")],
hostname="mock_hostname",
port=None,
type="_airplay._tcp.local.",
name="Kitchen",
properties={"deviceid": "airplayid"},
),
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result["type"] == data_entry_flow.FlowResultType.ABORT
assert result["reason"] == "setup_failed"
# Re-configuration