Add more type hints to helpers (#20811)

* Add type hints to helpers.aiohttp_client

* Add type hints to helpers.area_registry
This commit is contained in:
Ville Skyttä 2019-02-07 23:34:14 +02:00 committed by Paulus Schoutsen
parent 16159cc3d0
commit d45f25ce2c
3 changed files with 45 additions and 27 deletions

View File

@ -1,6 +1,9 @@
"""Helper for aiohttp webclient stuff.""" """Helper for aiohttp webclient stuff."""
import asyncio import asyncio
import sys import sys
from ssl import SSLContext # noqa: F401
from typing import Any, Awaitable, Optional, cast
from typing import Union # noqa: F401
import aiohttp import aiohttp
from aiohttp.hdrs import USER_AGENT, CONTENT_TYPE from aiohttp.hdrs import USER_AGENT, CONTENT_TYPE
@ -8,8 +11,9 @@ from aiohttp import web
from aiohttp.web_exceptions import HTTPGatewayTimeout, HTTPBadGateway from aiohttp.web_exceptions import HTTPGatewayTimeout, HTTPBadGateway
import async_timeout import async_timeout
from homeassistant.core import callback from homeassistant.core import callback, Event
from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__ from homeassistant.const import EVENT_HOMEASSISTANT_CLOSE, __version__
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util import ssl as ssl_util from homeassistant.util import ssl as ssl_util
@ -23,7 +27,8 @@ SERVER_SOFTWARE = 'HomeAssistant/{0} aiohttp/{1} Python/{2[0]}.{2[1]}'.format(
@callback @callback
@bind_hass @bind_hass
def async_get_clientsession(hass, verify_ssl=True): def async_get_clientsession(hass: HomeAssistantType,
verify_ssl: bool = True) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession. """Return default aiohttp ClientSession.
This method must be run in the event loop. This method must be run in the event loop.
@ -36,13 +41,15 @@ def async_get_clientsession(hass, verify_ssl=True):
if key not in hass.data: if key not in hass.data:
hass.data[key] = async_create_clientsession(hass, verify_ssl) hass.data[key] = async_create_clientsession(hass, verify_ssl)
return hass.data[key] return cast(aiohttp.ClientSession, hass.data[key])
@callback @callback
@bind_hass @bind_hass
def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True, def async_create_clientsession(hass: HomeAssistantType,
**kwargs): verify_ssl: bool = True,
auto_cleanup: bool = True,
**kwargs: Any) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies. """Create a new ClientSession with kwargs, i.e. for cookies.
If auto_cleanup is False, you need to call detach() after the session If auto_cleanup is False, you need to call detach() after the session
@ -67,8 +74,10 @@ def async_create_clientsession(hass, verify_ssl=True, auto_cleanup=True,
@bind_hass @bind_hass
async def async_aiohttp_proxy_web(hass, request, web_coro, async def async_aiohttp_proxy_web(
buffer_size=102400, timeout=10): hass: HomeAssistantType, request: web.BaseRequest,
web_coro: Awaitable[aiohttp.ClientResponse], buffer_size: int = 102400,
timeout: int = 10) -> Optional[web.StreamResponse]:
"""Stream websession request to aiohttp web response.""" """Stream websession request to aiohttp web response."""
try: try:
with async_timeout.timeout(timeout, loop=hass.loop): with async_timeout.timeout(timeout, loop=hass.loop):
@ -76,7 +85,7 @@ async def async_aiohttp_proxy_web(hass, request, web_coro,
except asyncio.CancelledError: except asyncio.CancelledError:
# The user cancelled the request # The user cancelled the request
return return None
except asyncio.TimeoutError as err: except asyncio.TimeoutError as err:
# Timeout trying to start the web request # Timeout trying to start the web request
@ -98,8 +107,12 @@ async def async_aiohttp_proxy_web(hass, request, web_coro,
@bind_hass @bind_hass
async def async_aiohttp_proxy_stream(hass, request, stream, content_type, async def async_aiohttp_proxy_stream(hass: HomeAssistantType,
buffer_size=102400, timeout=10): request: web.BaseRequest,
stream: aiohttp.StreamReader,
content_type: str,
buffer_size: int = 102400,
timeout: int = 10) -> web.StreamResponse:
"""Stream a stream to aiohttp web response.""" """Stream a stream to aiohttp web response."""
response = web.StreamResponse() response = web.StreamResponse()
response.content_type = content_type response.content_type = content_type
@ -122,13 +135,14 @@ async def async_aiohttp_proxy_stream(hass, request, stream, content_type,
@callback @callback
def _async_register_clientsession_shutdown(hass, clientsession): def _async_register_clientsession_shutdown(
hass: HomeAssistantType, clientsession: aiohttp.ClientSession) -> None:
"""Register ClientSession close on Home Assistant shutdown. """Register ClientSession close on Home Assistant shutdown.
This method must be run in the event loop. This method must be run in the event loop.
""" """
@callback @callback
def _async_close_websession(event): def _async_close_websession(event: Event) -> None:
"""Close websession.""" """Close websession."""
clientsession.detach() clientsession.detach()
@ -137,7 +151,8 @@ def _async_register_clientsession_shutdown(hass, clientsession):
@callback @callback
def _async_get_connector(hass, verify_ssl=True): def _async_get_connector(hass: HomeAssistantType,
verify_ssl: bool = True) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp. """Return the connector pool for aiohttp.
This method must be run in the event loop. This method must be run in the event loop.
@ -145,17 +160,18 @@ def _async_get_connector(hass, verify_ssl=True):
key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY key = DATA_CONNECTOR if verify_ssl else DATA_CONNECTOR_NOTVERIFY
if key in hass.data: if key in hass.data:
return hass.data[key] return cast(aiohttp.BaseConnector, hass.data[key])
if verify_ssl: if verify_ssl:
ssl_context = ssl_util.client_context() ssl_context = \
ssl_util.client_context() # type: Union[bool, SSLContext]
else: else:
ssl_context = False ssl_context = False
connector = aiohttp.TCPConnector(loop=hass.loop, ssl=ssl_context) connector = aiohttp.TCPConnector(loop=hass.loop, ssl=ssl_context)
hass.data[key] = connector hass.data[key] = connector
async def _async_close_connector(event): async def _async_close_connector(event: Event) -> None:
"""Close connector pool.""" """Close connector pool."""
await connector.close() await connector.close()

View File

@ -2,12 +2,14 @@
import logging import logging
import uuid import uuid
from collections import OrderedDict from collections import OrderedDict
from typing import List, Optional from typing import MutableMapping # noqa: F401
from typing import Iterable, Optional, cast
import attr import attr
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from .typing import HomeAssistantType
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -29,14 +31,14 @@ class AreaEntry:
class AreaRegistry: class AreaRegistry:
"""Class to hold a registry of areas.""" """Class to hold a registry of areas."""
def __init__(self, hass) -> None: def __init__(self, hass: HomeAssistantType) -> None:
"""Initialize the area registry.""" """Initialize the area registry."""
self.hass = hass self.hass = hass
self.areas = None self.areas = {} # type: MutableMapping[str, AreaEntry]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
@callback @callback
def async_list_areas(self) -> List[AreaEntry]: def async_list_areas(self) -> Iterable[AreaEntry]:
"""Get all areas.""" """Get all areas."""
return self.areas.values() return self.areas.values()
@ -81,18 +83,18 @@ class AreaRegistry:
return new return new
@callback @callback
def _async_is_registered(self, name) -> Optional[AreaEntry]: def _async_is_registered(self, name: str) -> Optional[AreaEntry]:
"""Check if a name is currently registered.""" """Check if a name is currently registered."""
for area in self.areas.values(): for area in self.areas.values():
if name == area.name: if name == area.name:
return area return area
return False return None
async def async_load(self) -> None: async def async_load(self) -> None:
"""Load the area registry.""" """Load the area registry."""
data = await self._store.async_load() data = await self._store.async_load()
areas = OrderedDict() areas = OrderedDict() # type: OrderedDict[str, AreaEntry]
if data is not None: if data is not None:
for area in data['areas']: for area in data['areas']:
@ -124,16 +126,16 @@ class AreaRegistry:
@bind_hass @bind_hass
async def async_get_registry(hass) -> AreaRegistry: async def async_get_registry(hass: HomeAssistantType) -> AreaRegistry:
"""Return area registry instance.""" """Return area registry instance."""
task = hass.data.get(DATA_REGISTRY) task = hass.data.get(DATA_REGISTRY)
if task is None: if task is None:
async def _load_reg(): async def _load_reg() -> AreaRegistry:
registry = AreaRegistry(hass) registry = AreaRegistry(hass)
await registry.async_load() await registry.async_load()
return registry return registry
task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg()) task = hass.data[DATA_REGISTRY] = hass.async_create_task(_load_reg())
return await task return cast(AreaRegistry, await task)

View File

@ -60,4 +60,4 @@ whitelist_externals=/bin/bash
deps = deps =
-r{toxinidir}/requirements_test.txt -r{toxinidir}/requirements_test.txt
commands = commands =
/bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py' /bin/bash -c 'mypy homeassistant/*.py homeassistant/{auth,util}/ homeassistant/helpers/{__init__,aiohttp_client,area_registry,condition,deprecation,dispatcher,entity_values,entityfilter,icon,intent,json,location,signal,state,sun,temperature,translation,typing}.py'