More data entry flow and HTTP related type hints (#34430)

This commit is contained in:
Ville Skyttä 2020-05-26 17:28:22 +03:00 committed by GitHub
parent bc1dac80b6
commit f8416484f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 79 additions and 56 deletions

View File

@ -172,7 +172,7 @@ class CalendarEventView(http.HomeAssistantView):
url = "/api/calendars/{entity_id}" url = "/api/calendars/{entity_id}"
name = "api:calendars:calendar" name = "api:calendars:calendar"
def __init__(self, component): def __init__(self, component: EntityComponent) -> None:
"""Initialize calendar view.""" """Initialize calendar view."""
self.component = component self.component = component
@ -200,11 +200,11 @@ class CalendarListView(http.HomeAssistantView):
url = "/api/calendars" url = "/api/calendars"
name = "api:calendars" name = "api:calendars"
def __init__(self, component): def __init__(self, component: EntityComponent) -> None:
"""Initialize calendar view.""" """Initialize calendar view."""
self.component = component self.component = component
async def get(self, request): async def get(self, request: web.Request) -> web.Response:
"""Retrieve calendar list.""" """Retrieve calendar list."""
hass = request.app["hass"] hass = request.app["hass"]
calendar_list = [] calendar_list = []

View File

@ -473,11 +473,11 @@ class CameraView(HomeAssistantView):
requires_auth = False requires_auth = False
def __init__(self, component): def __init__(self, component: EntityComponent) -> None:
"""Initialize a basic camera view.""" """Initialize a basic camera view."""
self.component = component self.component = component
async def get(self, request, entity_id): async def get(self, request: web.Request, entity_id: str) -> web.Response:
"""Start a GET request.""" """Start a GET request."""
camera = self.component.get_entity(entity_id) camera = self.component.get_entity(entity_id)
@ -509,7 +509,7 @@ class CameraImageView(CameraView):
url = "/api/camera_proxy/{entity_id}" url = "/api/camera_proxy/{entity_id}"
name = "api:camera:image" name = "api:camera:image"
async def handle(self, request, camera): async def handle(self, request: web.Request, camera: Camera) -> web.Response:
"""Serve camera image.""" """Serve camera image."""
with suppress(asyncio.CancelledError, asyncio.TimeoutError): with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with async_timeout.timeout(10): async with async_timeout.timeout(10):
@ -527,7 +527,7 @@ class CameraMjpegStream(CameraView):
url = "/api/camera_proxy_stream/{entity_id}" url = "/api/camera_proxy_stream/{entity_id}"
name = "api:camera:stream" name = "api:camera:stream"
async def handle(self, request, camera): async def handle(self, request: web.Request, camera: Camera) -> web.Response:
"""Serve camera stream, possibly with interval.""" """Serve camera stream, possibly with interval."""
interval = request.query.get("interval") interval = request.query.get("interval")
if interval is None: if interval is None:

View File

@ -4,7 +4,9 @@ from datetime import timedelta
from itertools import groupby from itertools import groupby
import logging import logging
import time import time
from typing import Optional, cast
from aiohttp import web
from sqlalchemy import and_, func from sqlalchemy import and_, func
import voluptuous as vol import voluptuous as vol
@ -337,20 +339,22 @@ class HistoryPeriodView(HomeAssistantView):
self.filters = filters self.filters = filters
self.use_include_order = use_include_order self.use_include_order = use_include_order
async def get(self, request, datetime=None): async def get(
self, request: web.Request, datetime: Optional[str] = None
) -> web.Response:
"""Return history over a period of time.""" """Return history over a period of time."""
if datetime: if datetime:
datetime = dt_util.parse_datetime(datetime) datetime_ = dt_util.parse_datetime(datetime)
if datetime is None: if datetime_ is None:
return self.json_message("Invalid datetime", HTTP_BAD_REQUEST) return self.json_message("Invalid datetime", HTTP_BAD_REQUEST)
now = dt_util.utcnow() now = dt_util.utcnow()
one_day = timedelta(days=1) one_day = timedelta(days=1)
if datetime: if datetime_:
start_time = dt_util.as_utc(datetime) start_time = dt_util.as_utc(datetime_)
else: else:
start_time = now - one_day start_time = now - one_day
@ -376,14 +380,17 @@ class HistoryPeriodView(HomeAssistantView):
hass = request.app["hass"] hass = request.app["hass"]
return await hass.async_add_executor_job( return cast(
self._sorted_significant_states_json, web.Response,
hass, await hass.async_add_executor_job(
start_time, self._sorted_significant_states_json,
end_time, hass,
entity_ids, start_time,
include_start_time_state, end_time,
significant_changes_only, entity_ids,
include_start_time_state,
significant_changes_only,
),
) )
def _sorted_significant_states_json( def _sorted_significant_states_json(

View File

@ -1,12 +1,14 @@
"""Decorator for view methods to help with data validation.""" """Decorator for view methods to help with data validation."""
from functools import wraps from functools import wraps
import logging import logging
from typing import Any, Awaitable, Callable
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.const import HTTP_BAD_REQUEST from homeassistant.const import HTTP_BAD_REQUEST
# mypy: allow-untyped-defs from .view import HomeAssistantView
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -20,7 +22,7 @@ class RequestDataValidator:
Will return a 400 if no JSON provided or doesn't match schema. Will return a 400 if no JSON provided or doesn't match schema.
""" """
def __init__(self, schema, allow_empty=False): def __init__(self, schema: vol.Schema, allow_empty: bool = False) -> None:
"""Initialize the decorator.""" """Initialize the decorator."""
if isinstance(schema, dict): if isinstance(schema, dict):
schema = vol.Schema(schema) schema = vol.Schema(schema)
@ -28,11 +30,15 @@ class RequestDataValidator:
self._schema = schema self._schema = schema
self._allow_empty = allow_empty self._allow_empty = allow_empty
def __call__(self, method): def __call__(
self, method: Callable[..., Awaitable[web.StreamResponse]]
) -> Callable:
"""Decorate a function.""" """Decorate a function."""
@wraps(method) @wraps(method)
async def wrapper(view, request, *args, **kwargs): async def wrapper(
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
) -> web.StreamResponse:
"""Wrap a request handler with data validation.""" """Wrap a request handler with data validation."""
data = None data = None
try: try:

View File

@ -2,9 +2,10 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import List, Optional from typing import Any, Callable, List, Optional
from aiohttp import web from aiohttp import web
from aiohttp.typedefs import LooseHeaders
from aiohttp.web_exceptions import ( from aiohttp.web_exceptions import (
HTTPBadRequest, HTTPBadRequest,
HTTPInternalServerError, HTTPInternalServerError,
@ -22,9 +23,6 @@ from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_REAL_IP
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs, no-check-untyped-defs
class HomeAssistantView: class HomeAssistantView:
"""Base view for all views.""" """Base view for all views."""
@ -35,7 +33,7 @@ class HomeAssistantView:
cors_allowed = False cors_allowed = False
@staticmethod @staticmethod
def context(request): def context(request: web.Request) -> Context:
"""Generate a context from a request.""" """Generate a context from a request."""
user = request.get("hass_user") user = request.get("hass_user")
if user is None: if user is None:
@ -44,7 +42,9 @@ class HomeAssistantView:
return Context(user_id=user.id) return Context(user_id=user.id)
@staticmethod @staticmethod
def json(result, status_code=HTTP_OK, headers=None): def json(
result: Any, status_code: int = HTTP_OK, headers: Optional[LooseHeaders] = None,
) -> web.Response:
"""Return a JSON response.""" """Return a JSON response."""
try: try:
msg = json.dumps( msg = json.dumps(
@ -63,15 +63,19 @@ class HomeAssistantView:
return response return response
def json_message( def json_message(
self, message, status_code=HTTP_OK, message_code=None, headers=None self,
): message: str,
status_code: int = HTTP_OK,
message_code: Optional[str] = None,
headers: Optional[LooseHeaders] = None,
) -> web.Response:
"""Return a JSON message response.""" """Return a JSON message response."""
data = {"message": message} data = {"message": message}
if message_code is not None: if message_code is not None:
data["code"] = message_code data["code"] = message_code
return self.json(data, status_code, headers=headers) return self.json(data, status_code, headers=headers)
def register(self, app, router): def register(self, app: web.Application, router: web.UrlDispatcher) -> None:
"""Register the view with a router.""" """Register the view with a router."""
assert self.url is not None, "No url set for view" assert self.url is not None, "No url set for view"
urls = [self.url] + self.extra_urls urls = [self.url] + self.extra_urls
@ -95,13 +99,13 @@ class HomeAssistantView:
app["allow_cors"](route) app["allow_cors"](route)
def request_handler_factory(view, handler): def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable:
"""Wrap the handler classes.""" """Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback( assert asyncio.iscoroutinefunction(handler) or is_callback(
handler handler
), "Handler should be a coroutine or a callback." ), "Handler should be a coroutine or a callback."
async def handle(request): async def handle(request: web.Request) -> web.StreamResponse:
"""Handle incoming request.""" """Handle incoming request."""
if not request.app[KEY_HASS].is_running: if not request.app[KEY_HASS].is_running:
return web.Response(status=503) return web.Response(status=503)
@ -139,15 +143,17 @@ def request_handler_factory(view, handler):
if isinstance(result, tuple): if isinstance(result, tuple):
result, status_code = result result, status_code = result
if isinstance(result, str): if isinstance(result, bytes):
result = result.encode("utf-8") bresult = result
elif isinstance(result, str):
bresult = result.encode("utf-8")
elif result is None: elif result is None:
result = b"" bresult = b""
elif not isinstance(result, bytes): else:
assert ( assert (
False False
), f"Result should be None, string, bytes or Response. Got: {result}" ), f"Result should be None, string, bytes or Response. Got: {result}"
return web.Response(body=result, status=status_code) return web.Response(body=bresult, status=status_code)
return handle return handle

View File

@ -200,7 +200,7 @@ class MailboxPlatformsView(MailboxView):
url = "/api/mailbox/platforms" url = "/api/mailbox/platforms"
name = "api:mailbox:platforms" name = "api:mailbox:platforms"
async def get(self, request): async def get(self, request: web.Request) -> web.Response:
"""Retrieve list of platforms.""" """Retrieve list of platforms."""
platforms = [] platforms = []
for mailbox in self.mailboxes: for mailbox in self.mailboxes:

View File

@ -12,6 +12,7 @@ from urllib.parse import urlparse
from aiohttp import web from aiohttp import web
from aiohttp.hdrs import CACHE_CONTROL, CONTENT_TYPE from aiohttp.hdrs import CACHE_CONTROL, CONTENT_TYPE
from aiohttp.typedefs import LooseHeaders
import async_timeout import async_timeout
import voluptuous as vol import voluptuous as vol
@ -863,7 +864,7 @@ class MediaPlayerImageView(HomeAssistantView):
"""Initialize a media player view.""" """Initialize a media player view."""
self.component = component self.component = component
async def get(self, request, entity_id): async def get(self, request: web.Request, entity_id: str) -> web.Response:
"""Start a get request.""" """Start a get request."""
player = self.component.get_entity(entity_id) player = self.component.get_entity(entity_id)
if player is None: if player is None:
@ -883,7 +884,7 @@ class MediaPlayerImageView(HomeAssistantView):
if data is None: if data is None:
return web.Response(status=HTTP_INTERNAL_SERVER_ERROR) return web.Response(status=HTTP_INTERNAL_SERVER_ERROR)
headers = {CACHE_CONTROL: "max-age=3600"} headers: LooseHeaders = {CACHE_CONTROL: "max-age=3600"}
return web.Response(body=data, content_type=content_type, headers=headers) return web.Response(body=data, content_type=content_type, headers=headers)

View File

@ -530,7 +530,7 @@ class TextToSpeechUrlView(HomeAssistantView):
"""Initialize a tts view.""" """Initialize a tts view."""
self.tts = tts self.tts = tts
async def post(self, request): async def post(self, request: web.Request) -> web.Response:
"""Generate speech and provide url.""" """Generate speech and provide url."""
try: try:
data = await request.json() data = await request.json()
@ -570,7 +570,7 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view.""" """Initialize a tts view."""
self.tts = tts self.tts = tts
async def get(self, request, filename): async def get(self, request: web.Request, filename: str) -> web.Response:
"""Start a get request.""" """Start a get request."""
try: try:
content, data = await self.tts.async_read_tts(filename) content, data = await self.tts.async_read_tts(filename)

View File

@ -42,7 +42,7 @@ class WebsocketAPIView(HomeAssistantView):
url = URL url = URL
requires_auth = False requires_auth = False
async def get(self, request): async def get(self, request: web.Request) -> web.WebSocketResponse:
"""Handle an incoming websocket connection.""" """Handle an incoming websocket connection."""
return await WebSocketHandler(request.app["hass"], request).async_handle() return await WebSocketHandler(request.app["hass"], request).async_handle()
@ -148,7 +148,7 @@ class WebSocketHandler:
self._handle_task.cancel() self._handle_task.cancel()
self._writer_task.cancel() self._writer_task.cancel()
async def async_handle(self): async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response.""" """Handle a websocket response."""
request = self.request request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55) wsock = self.wsock = web.WebSocketResponse(heartbeat=55)

View File

@ -1,5 +1,8 @@
"""Helpers for the data entry flow.""" """Helpers for the data entry flow."""
from typing import Any, Dict
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries, data_entry_flow from homeassistant import config_entries, data_entry_flow
@ -8,18 +11,16 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import HTTP_NOT_FOUND from homeassistant.const import HTTP_NOT_FOUND
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
# mypy: allow-untyped-calls, allow-untyped-defs
class _BaseFlowManagerView(HomeAssistantView): class _BaseFlowManagerView(HomeAssistantView):
"""Foundation for flow manager views.""" """Foundation for flow manager views."""
def __init__(self, flow_mgr): def __init__(self, flow_mgr: data_entry_flow.FlowManager) -> None:
"""Initialize the flow manager index view.""" """Initialize the flow manager index view."""
self._flow_mgr = flow_mgr self._flow_mgr = flow_mgr
# pylint: disable=no-self-use # pylint: disable=no-self-use
def _prepare_result_json(self, result): def _prepare_result_json(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""Convert result to JSON.""" """Convert result to JSON."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy() data = result.copy()
@ -57,7 +58,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
) )
async def post(self, request, data): async def post(self, request: web.Request, data: Dict[str, Any]) -> web.Response:
"""Handle a POST request.""" """Handle a POST request."""
if isinstance(data["handler"], list): if isinstance(data["handler"], list):
handler = tuple(data["handler"]) handler = tuple(data["handler"])
@ -66,7 +67,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
try: try:
result = await self._flow_mgr.async_init( result = await self._flow_mgr.async_init(
handler, handler, # type: ignore
context={ context={
"source": config_entries.SOURCE_USER, "source": config_entries.SOURCE_USER,
"show_advanced_options": data["show_advanced_options"], "show_advanced_options": data["show_advanced_options"],
@ -85,7 +86,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
class FlowManagerResourceView(_BaseFlowManagerView): class FlowManagerResourceView(_BaseFlowManagerView):
"""View to interact with the flow manager.""" """View to interact with the flow manager."""
async def get(self, request, flow_id): async def get(self, request: web.Request, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow.""" """Get the current state of a data_entry_flow."""
try: try:
result = await self._flow_mgr.async_configure(flow_id) result = await self._flow_mgr.async_configure(flow_id)
@ -97,7 +98,9 @@ class FlowManagerResourceView(_BaseFlowManagerView):
return self.json(result) return self.json(result)
@RequestDataValidator(vol.Schema(dict), allow_empty=True) @RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(self, request, flow_id, data): async def post(
self, request: web.Request, flow_id: str, data: Dict[str, Any]
) -> web.Response:
"""Handle a POST request.""" """Handle a POST request."""
try: try:
result = await self._flow_mgr.async_configure(flow_id, data) result = await self._flow_mgr.async_configure(flow_id, data)
@ -110,7 +113,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
return self.json(result) return self.json(result)
async def delete(self, request, flow_id): async def delete(self, request: web.Request, flow_id: str) -> web.Response:
"""Cancel a flow in progress.""" """Cancel a flow in progress."""
try: try:
self._flow_mgr.async_abort(flow_id) self._flow_mgr.async_abort(flow_id)