Improve apple_tv typing (#107694)

This commit is contained in:
J. Nick Koston 2024-01-13 22:37:04 -10:00 committed by GitHub
parent 4b8d8baa69
commit 93d363ea57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 30 deletions

View File

@ -2,10 +2,13 @@
import asyncio import asyncio
import logging import logging
from random import randrange from random import randrange
from typing import TYPE_CHECKING, cast
from pyatv import connect, exceptions, scan from pyatv import connect, exceptions, scan
from pyatv.conf import AppleTV
from pyatv.const import DeviceModel, Protocol from pyatv.const import DeviceModel, Protocol
from pyatv.convert import model_str from pyatv.convert import model_str
from pyatv.interface import AppleTV as AppleTVInterface, DeviceListener
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -92,10 +95,14 @@ class AppleTVEntity(Entity):
_attr_has_entity_name = True _attr_has_entity_name = True
_attr_name = None _attr_name = None
def __init__(self, name, identifier, manager): def __init__(
self, name: str, identifier: str | None, manager: "AppleTVManager"
) -> None:
"""Initialize device.""" """Initialize device."""
self.atv = None self.atv: AppleTVInterface = None # type: ignore[assignment]
self.manager = manager self.manager = manager
if TYPE_CHECKING:
assert identifier is not None
self._attr_unique_id = identifier self._attr_unique_id = identifier
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, identifier)}, identifiers={(DOMAIN, identifier)},
@ -143,7 +150,7 @@ class AppleTVEntity(Entity):
"""Handle when connection was lost to device.""" """Handle when connection was lost to device."""
class AppleTVManager: class AppleTVManager(DeviceListener):
"""Connection and power manager for an Apple TV. """Connection and power manager for an Apple TV.
An instance is used per device to share the same power state between An instance is used per device to share the same power state between
@ -151,11 +158,11 @@ class AppleTVManager:
in case of problems. in case of problems.
""" """
def __init__(self, hass, config_entry): def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize power manager.""" """Initialize power manager."""
self.config_entry = config_entry self.config_entry = config_entry
self.hass = hass self.hass = hass
self.atv = None self.atv: AppleTVInterface | None = None
self.is_on = not config_entry.options.get(CONF_START_OFF, False) self.is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0 self._connection_attempts = 0
self._connection_was_lost = False self._connection_was_lost = False
@ -220,7 +227,7 @@ class AppleTVManager:
"Not starting connect loop (%s, %s)", self.atv is None, self.is_on "Not starting connect loop (%s, %s)", self.atv is None, self.is_on
) )
async def connect_once(self, raise_missing_credentials): async def connect_once(self, raise_missing_credentials: bool) -> None:
"""Try to connect once.""" """Try to connect once."""
try: try:
if conf := await self._scan(): if conf := await self._scan():
@ -264,49 +271,51 @@ class AppleTVManager:
_LOGGER.debug("Connect loop ended") _LOGGER.debug("Connect loop ended")
self._task = None self._task = None
async def _scan(self): async def _scan(self) -> AppleTV | None:
"""Try to find device by scanning for it.""" """Try to find device by scanning for it."""
identifiers = set( config_entry = self.config_entry
self.config_entry.data.get(CONF_IDENTIFIERS, [self.config_entry.unique_id]) identifiers: set[str] = set(
config_entry.data.get(CONF_IDENTIFIERS, [config_entry.unique_id])
) )
address = self.config_entry.data[CONF_ADDRESS] address: str = config_entry.data[CONF_ADDRESS]
hass = self.hass
# Only scan for and set up protocols that was successfully paired # Only scan for and set up protocols that was successfully paired
protocols = { protocols = {
Protocol(int(protocol)) Protocol(int(protocol)) for protocol in config_entry.data[CONF_CREDENTIALS]
for protocol in self.config_entry.data[CONF_CREDENTIALS]
} }
_LOGGER.debug("Discovering device %s", self.config_entry.title) _LOGGER.debug("Discovering device %s", config_entry.title)
aiozc = await zeroconf.async_get_async_instance(self.hass) aiozc = await zeroconf.async_get_async_instance(hass)
atvs = await scan( atvs = await scan(
self.hass.loop, hass.loop,
identifier=identifiers, identifier=identifiers,
protocol=protocols, protocol=protocols,
hosts=[address], hosts=[address],
aiozc=aiozc, aiozc=aiozc,
) )
if atvs: if atvs:
return atvs[0] return cast(AppleTV, atvs[0])
_LOGGER.debug( _LOGGER.debug(
"Failed to find device %s with address %s", "Failed to find device %s with address %s",
self.config_entry.title, config_entry.title,
address, address,
) )
# We no longer multicast scan for the device since as soon as async_step_zeroconf runs, # We no longer multicast scan for the device since as soon as async_step_zeroconf runs,
# it will update the address and reload the config entry when the device is found. # it will update the address and reload the config entry when the device is found.
return None return None
async def _connect(self, conf, raise_missing_credentials): async def _connect(self, conf: AppleTV, raise_missing_credentials: bool) -> None:
"""Connect to device.""" """Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS] config_entry = self.config_entry
name = self.config_entry.data[CONF_NAME] credentials: dict[int, str | None] = config_entry.data[CONF_CREDENTIALS]
name: str = config_entry.data[CONF_NAME]
missing_protocols = [] missing_protocols = []
for protocol_int, creds in credentials.items(): for protocol_int, creds in credentials.items():
protocol = Protocol(int(protocol_int)) protocol = Protocol(int(protocol_int))
if conf.get_service(protocol) is not None: if conf.get_service(protocol) is not None:
conf.set_credentials(protocol, creds) conf.set_credentials(protocol, creds) # type: ignore[arg-type]
else: else:
missing_protocols.append(protocol.name) missing_protocols.append(protocol.name)

View File

@ -154,9 +154,9 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
_LOGGER.exception("Failed to update app list") _LOGGER.exception("Failed to update app list")
else: else:
self._app_list = { self._app_list = {
app.name: app.identifier app_name: app.identifier
for app in sorted(apps, key=lambda app: app.name.lower()) for app in sorted(apps, key=lambda app: app_name.lower())
if app.name is not None if (app_name := app.name) is not None
} }
self.async_write_ha_state() self.async_write_ha_state()
@ -214,15 +214,19 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
@property @property
def app_id(self) -> str | None: def app_id(self) -> str | None:
"""ID of the current running app.""" """ID of the current running app."""
if self._is_feature_available(FeatureName.App): if self._is_feature_available(FeatureName.App) and (
return self.atv.metadata.app.identifier app := self.atv.metadata.app
):
return app.identifier
return None return None
@property @property
def app_name(self) -> str | None: def app_name(self) -> str | None:
"""Name of the current running app.""" """Name of the current running app."""
if self._is_feature_available(FeatureName.App): if self._is_feature_available(FeatureName.App) and (
return self.atv.metadata.app.name app := self.atv.metadata.app
):
return app.name
return None return None
@property @property
@ -479,7 +483,7 @@ class AppleTvMediaPlayer(AppleTVEntity, MediaPlayerEntity):
async def async_media_seek(self, position: float) -> None: async def async_media_seek(self, position: float) -> None:
"""Send seek command.""" """Send seek command."""
if self.atv: if self.atv:
await self.atv.remote_control.set_position(position) await self.atv.remote_control.set_position(round(position))
async def async_volume_up(self) -> None: async def async_volume_up(self) -> None:
"""Turn volume up for media player.""" """Turn volume up for media player."""

View File

@ -81,5 +81,5 @@ class AppleTVRemote(AppleTVEntity, RemoteEntity):
raise ValueError("Command not found. Exiting sequence") raise ValueError("Command not found. Exiting sequence")
_LOGGER.info("Sending command %s", single_command) _LOGGER.info("Sending command %s", single_command)
await attr_value() await attr_value() # type: ignore[operator]
await asyncio.sleep(delay) await asyncio.sleep(delay)