Add strict typing for auth (#75586)

This commit is contained in:
Marc Mueller 2022-08-16 16:10:37 +02:00 committed by GitHub
parent 735dec8dde
commit 563ec67d39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 174 additions and 85 deletions

View File

@ -59,6 +59,7 @@ homeassistant.components.ampio.*
homeassistant.components.anthemav.* homeassistant.components.anthemav.*
homeassistant.components.aseko_pool_live.* homeassistant.components.aseko_pool_live.*
homeassistant.components.asuswrt.* homeassistant.components.asuswrt.*
homeassistant.components.auth.*
homeassistant.components.automation.* homeassistant.components.automation.*
homeassistant.components.backup.* homeassistant.components.backup.*
homeassistant.components.baf.* homeassistant.components.baf.*

View File

@ -124,15 +124,22 @@ as part of a config flow.
""" """
from __future__ import annotations from __future__ import annotations
from datetime import timedelta from collections.abc import Callable
from datetime import datetime, timedelta
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Optional, cast
import uuid import uuid
from aiohttp import web from aiohttp import web
from multidict import MultiDictProxy
import voluptuous as vol import voluptuous as vol
from homeassistant.auth import InvalidAuthError from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials from homeassistant.auth.models import (
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
Credentials,
User,
)
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http.auth import ( from homeassistant.components.http.auth import (
async_sign_path, async_sign_path,
@ -151,11 +158,16 @@ from . import indieauth, login_flow, mfa_setup_flow
DOMAIN = "auth" DOMAIN = "auth"
StoreResultType = Callable[[str, Credentials], str]
RetrieveResultType = Callable[[str, str], Optional[Credentials]]
@bind_hass @bind_hass
def create_auth_code(hass, client_id: str, credential: Credentials) -> str: def create_auth_code(
hass: HomeAssistant, client_id: str, credential: Credentials
) -> str:
"""Create an authorization code to fetch tokens.""" """Create an authorization code to fetch tokens."""
return hass.data[DOMAIN](client_id, credential) return cast(StoreResultType, hass.data[DOMAIN])(client_id, credential)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -188,15 +200,15 @@ class TokenView(HomeAssistantView):
requires_auth = False requires_auth = False
cors_allowed = True cors_allowed = True
def __init__(self, retrieve_auth): def __init__(self, retrieve_auth: RetrieveResultType) -> None:
"""Initialize the token view.""" """Initialize the token view."""
self._retrieve_auth = retrieve_auth self._retrieve_auth = retrieve_auth
@log_invalid_auth @log_invalid_auth
async def post(self, request): async def post(self, request: web.Request) -> web.Response:
"""Grant a token.""" """Grant a token."""
hass = request.app["hass"] hass: HomeAssistant = request.app["hass"]
data = await request.post() data = cast(MultiDictProxy[str], await request.post())
grant_type = data.get("grant_type") grant_type = data.get("grant_type")
@ -217,7 +229,11 @@ class TokenView(HomeAssistantView):
{"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST {"error": "unsupported_grant_type"}, status_code=HTTPStatus.BAD_REQUEST
) )
async def _async_handle_revoke_token(self, hass, data): async def _async_handle_revoke_token(
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
) -> web.Response:
"""Handle revoke token request.""" """Handle revoke token request."""
# OAuth 2.0 Token Revocation [RFC7009] # OAuth 2.0 Token Revocation [RFC7009]
@ -235,7 +251,12 @@ class TokenView(HomeAssistantView):
await hass.auth.async_remove_refresh_token(refresh_token) await hass.auth.async_remove_refresh_token(refresh_token)
return web.Response(status=HTTPStatus.OK) return web.Response(status=HTTPStatus.OK)
async def _async_handle_auth_code(self, hass, data, remote_addr): async def _async_handle_auth_code(
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
remote_addr: str | None,
) -> web.Response:
"""Handle authorization code request.""" """Handle authorization code request."""
client_id = data.get("client_id") client_id = data.get("client_id")
if client_id is None or not indieauth.verify_client_id(client_id): if client_id is None or not indieauth.verify_client_id(client_id):
@ -298,7 +319,12 @@ class TokenView(HomeAssistantView):
}, },
) )
async def _async_handle_refresh_token(self, hass, data, remote_addr): async def _async_handle_refresh_token(
self,
hass: HomeAssistant,
data: MultiDictProxy[str],
remote_addr: str | None,
) -> web.Response:
"""Handle authorization code request.""" """Handle authorization code request."""
client_id = data.get("client_id") client_id = data.get("client_id")
if client_id is not None and not indieauth.verify_client_id(client_id): if client_id is not None and not indieauth.verify_client_id(client_id):
@ -366,15 +392,15 @@ class LinkUserView(HomeAssistantView):
url = "/auth/link_user" url = "/auth/link_user"
name = "api:auth:link_user" name = "api:auth:link_user"
def __init__(self, retrieve_credentials): def __init__(self, retrieve_credentials: RetrieveResultType) -> None:
"""Initialize the link user view.""" """Initialize the link user view."""
self._retrieve_credentials = retrieve_credentials self._retrieve_credentials = retrieve_credentials
@RequestDataValidator(vol.Schema({"code": str, "client_id": str})) @RequestDataValidator(vol.Schema({"code": str, "client_id": str}))
async def post(self, request, data): async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Link a user.""" """Link a user."""
hass = request.app["hass"] hass: HomeAssistant = request.app["hass"]
user = request["hass_user"] user: User = request["hass_user"]
credentials = self._retrieve_credentials(data["client_id"], data["code"]) credentials = self._retrieve_credentials(data["client_id"], data["code"])
@ -394,12 +420,12 @@ class LinkUserView(HomeAssistantView):
@callback @callback
def _create_auth_code_store(): def _create_auth_code_store() -> tuple[StoreResultType, RetrieveResultType]:
"""Create an in memory store.""" """Create an in memory store."""
temp_results = {} temp_results: dict[tuple[str, str], tuple[datetime, Credentials]] = {}
@callback @callback
def store_result(client_id, result): def store_result(client_id: str, result: Credentials) -> str:
"""Store flow result and return a code to retrieve it.""" """Store flow result and return a code to retrieve it."""
if not isinstance(result, Credentials): if not isinstance(result, Credentials):
raise ValueError("result has to be a Credentials instance") raise ValueError("result has to be a Credentials instance")
@ -412,7 +438,7 @@ def _create_auth_code_store():
return code return code
@callback @callback
def retrieve_result(client_id, code): def retrieve_result(client_id: str, code: str) -> Credentials | None:
"""Retrieve flow result.""" """Retrieve flow result."""
key = (client_id, code) key = (client_id, code)
@ -437,8 +463,8 @@ def _create_auth_code_store():
@websocket_api.ws_require_user() @websocket_api.ws_require_user()
@websocket_api.async_response @websocket_api.async_response
async def websocket_current_user( async def websocket_current_user(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Return the current user.""" """Return the current user."""
user = connection.user user = connection.user
enabled_modules = await hass.auth.async_get_enabled_mfa(user) enabled_modules = await hass.auth.async_get_enabled_mfa(user)
@ -482,8 +508,8 @@ async def websocket_current_user(
@websocket_api.ws_require_user() @websocket_api.ws_require_user()
@websocket_api.async_response @websocket_api.async_response
async def websocket_create_long_lived_access_token( async def websocket_create_long_lived_access_token(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Create or a long-lived access token.""" """Create or a long-lived access token."""
refresh_token = await hass.auth.async_create_refresh_token( refresh_token = await hass.auth.async_create_refresh_token(
connection.user, connection.user,
@ -506,12 +532,12 @@ async def websocket_create_long_lived_access_token(
@websocket_api.ws_require_user() @websocket_api.ws_require_user()
@callback @callback
def websocket_refresh_tokens( def websocket_refresh_tokens(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Return metadata of users refresh tokens.""" """Return metadata of users refresh tokens."""
current_id = connection.refresh_token_id current_id = connection.refresh_token_id
tokens = [] tokens: list[dict[str, Any]] = []
for refresh in connection.user.refresh_tokens.values(): for refresh in connection.user.refresh_tokens.values():
if refresh.credential: if refresh.credential:
auth_provider_type = refresh.credential.auth_provider_type auth_provider_type = refresh.credential.auth_provider_type
@ -545,8 +571,8 @@ def websocket_refresh_tokens(
@websocket_api.ws_require_user() @websocket_api.ws_require_user()
@websocket_api.async_response @websocket_api.async_response
async def websocket_delete_refresh_token( async def websocket_delete_refresh_token(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Handle a delete refresh token request.""" """Handle a delete refresh token request."""
refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"]) refresh_token = connection.user.refresh_tokens.get(msg["refresh_token_id"])
@ -569,8 +595,8 @@ async def websocket_delete_refresh_token(
@websocket_api.ws_require_user() @websocket_api.ws_require_user()
@callback @callback
def websocket_sign_path( def websocket_sign_path(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Handle a sign path request.""" """Handle a sign path request."""
connection.send_message( connection.send_message(
websocket_api.result_message( websocket_api.result_message(

View File

@ -1,18 +1,24 @@
"""Helpers to resolve client ID/secret.""" """Helpers to resolve client ID/secret."""
from __future__ import annotations
import asyncio import asyncio
from html.parser import HTMLParser from html.parser import HTMLParser
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
from urllib.parse import urljoin, urlparse from urllib.parse import ParseResult, urljoin, urlparse
import aiohttp import aiohttp
import aiohttp.client_exceptions
from homeassistant.core import HomeAssistant
from homeassistant.util.network import is_local from homeassistant.util.network import is_local
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def verify_redirect_uri(hass, client_id, redirect_uri): async def verify_redirect_uri(
hass: HomeAssistant, client_id: str, redirect_uri: str
) -> bool:
"""Verify that the client and redirect uri match.""" """Verify that the client and redirect uri match."""
try: try:
client_id_parts = _parse_client_id(client_id) client_id_parts = _parse_client_id(client_id)
@ -47,24 +53,24 @@ async def verify_redirect_uri(hass, client_id, redirect_uri):
class LinkTagParser(HTMLParser): class LinkTagParser(HTMLParser):
"""Parser to find link tags.""" """Parser to find link tags."""
def __init__(self, rel): def __init__(self, rel: str) -> None:
"""Initialize a link tag parser.""" """Initialize a link tag parser."""
super().__init__() super().__init__()
self.rel = rel self.rel = rel
self.found = [] self.found: list[str | None] = []
def handle_starttag(self, tag, attrs): def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
"""Handle finding a start tag.""" """Handle finding a start tag."""
if tag != "link": if tag != "link":
return return
attrs = dict(attrs) attributes: dict[str, str | None] = dict(attrs)
if attrs.get("rel") == self.rel: if attributes.get("rel") == self.rel:
self.found.append(attrs.get("href")) self.found.append(attributes.get("href"))
async def fetch_redirect_uris(hass, url): async def fetch_redirect_uris(hass: HomeAssistant, url: str) -> list[str]:
"""Find link tag with redirect_uri values. """Find link tag with redirect_uri values.
IndieAuth 4.2.2 IndieAuth 4.2.2
@ -108,7 +114,7 @@ async def fetch_redirect_uris(hass, url):
return [urljoin(url, found) for found in parser.found] return [urljoin(url, found) for found in parser.found]
def verify_client_id(client_id): def verify_client_id(client_id: str) -> bool:
"""Verify that the client id is valid.""" """Verify that the client id is valid."""
try: try:
_parse_client_id(client_id) _parse_client_id(client_id)
@ -117,7 +123,7 @@ def verify_client_id(client_id):
return False return False
def _parse_url(url): def _parse_url(url: str) -> ParseResult:
"""Parse a url in parts and canonicalize according to IndieAuth.""" """Parse a url in parts and canonicalize according to IndieAuth."""
parts = urlparse(url) parts = urlparse(url)
@ -134,7 +140,7 @@ def _parse_url(url):
return parts return parts
def _parse_client_id(client_id): def _parse_client_id(client_id: str) -> ParseResult:
"""Test if client id is a valid URL according to IndieAuth section 3.2. """Test if client id is a valid URL according to IndieAuth section 3.2.
https://indieauth.spec.indieweb.org/#client-identifier https://indieauth.spec.indieweb.org/#client-identifier

View File

@ -66,14 +66,19 @@ associate with an credential if "type" set to "link_user" in
"version": 1 "version": 1
} }
""" """
from __future__ import annotations
from collections.abc import Callable
from http import HTTPStatus from http import HTTPStatus
from ipaddress import ip_address from ipaddress import ip_address
from typing import TYPE_CHECKING, Any
from aiohttp import web from aiohttp import web
import voluptuous as vol import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager
from homeassistant.auth.models import Credentials from homeassistant.auth.models import Credentials
from homeassistant.components import onboarding from homeassistant.components import onboarding
from homeassistant.components.http.auth import async_user_not_allowed_do_auth from homeassistant.components.http.auth import async_user_not_allowed_do_auth
@ -88,8 +93,13 @@ from homeassistant.core import HomeAssistant
from . import indieauth from . import indieauth
if TYPE_CHECKING:
from . import StoreResultType
async def async_setup(hass, store_result):
async def async_setup(
hass: HomeAssistant, store_result: Callable[[str, Credentials], str]
) -> None:
"""Component to allow users to login.""" """Component to allow users to login."""
hass.http.register_view(AuthProvidersView) hass.http.register_view(AuthProvidersView)
hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow, store_result)) hass.http.register_view(LoginFlowIndexView(hass.auth.login_flow, store_result))
@ -103,9 +113,9 @@ class AuthProvidersView(HomeAssistantView):
name = "api:auth:providers" name = "api:auth:providers"
requires_auth = False requires_auth = False
async def get(self, request): async def get(self, request: web.Request) -> web.Response:
"""Get available auth providers.""" """Get available auth providers."""
hass = request.app["hass"] hass: HomeAssistant = request.app["hass"]
if not onboarding.async_is_user_onboarded(hass): if not onboarding.async_is_user_onboarded(hass):
return self.json_message( return self.json_message(
message="Onboarding not finished", message="Onboarding not finished",
@ -121,7 +131,9 @@ class AuthProvidersView(HomeAssistantView):
) )
def _prepare_result_json(result): def _prepare_result_json(
result: data_entry_flow.FlowResult,
) -> data_entry_flow.FlowResult:
"""Convert result to JSON.""" """Convert result to JSON."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
data = result.copy() data = result.copy()
@ -147,12 +159,21 @@ class LoginFlowBaseView(HomeAssistantView):
requires_auth = False requires_auth = False
def __init__(self, flow_mgr, store_result): def __init__(
self,
flow_mgr: AuthManagerFlowManager,
store_result: StoreResultType,
) -> None:
"""Initialize the flow manager index view.""" """Initialize the flow manager index view."""
self._flow_mgr = flow_mgr self._flow_mgr = flow_mgr
self._store_result = store_result self._store_result = store_result
async def _async_flow_result_to_response(self, request, client_id, result): async def _async_flow_result_to_response(
self,
request: web.Request,
client_id: str,
result: data_entry_flow.FlowResult,
) -> web.Response:
"""Convert the flow result to a response.""" """Convert the flow result to a response."""
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
# @log_invalid_auth does not work here since it returns HTTP 200. # @log_invalid_auth does not work here since it returns HTTP 200.
@ -196,7 +217,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
url = "/auth/login_flow" url = "/auth/login_flow"
name = "api:auth:login_flow" name = "api:auth:login_flow"
async def get(self, request): async def get(self, request: web.Request) -> web.Response:
"""Do not allow index of flows in progress.""" """Do not allow index of flows in progress."""
return web.Response(status=HTTPStatus.METHOD_NOT_ALLOWED) return web.Response(status=HTTPStatus.METHOD_NOT_ALLOWED)
@ -211,15 +232,18 @@ class LoginFlowIndexView(LoginFlowBaseView):
) )
) )
@log_invalid_auth @log_invalid_auth
async def post(self, request, data): async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Create a new login flow.""" """Create a new login flow."""
if not await indieauth.verify_redirect_uri( hass: HomeAssistant = request.app["hass"]
request.app["hass"], data["client_id"], data["redirect_uri"] client_id: str = data["client_id"]
): redirect_uri: str = data["redirect_uri"]
if not await indieauth.verify_redirect_uri(hass, client_id, redirect_uri):
return self.json_message( return self.json_message(
"invalid client id or redirect uri", HTTPStatus.BAD_REQUEST "invalid client id or redirect uri", HTTPStatus.BAD_REQUEST
) )
handler: tuple[str, ...] | str
if isinstance(data["handler"], list): if isinstance(data["handler"], list):
handler = tuple(data["handler"]) handler = tuple(data["handler"])
else: else:
@ -227,9 +251,9 @@ class LoginFlowIndexView(LoginFlowBaseView):
try: try:
result = await self._flow_mgr.async_init( result = await self._flow_mgr.async_init(
handler, handler, # type: ignore[arg-type]
context={ context={
"ip_address": ip_address(request.remote), "ip_address": ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user", "credential_only": data.get("type") == "link_user",
}, },
) )
@ -240,9 +264,7 @@ class LoginFlowIndexView(LoginFlowBaseView):
"Handler does not support init", HTTPStatus.BAD_REQUEST "Handler does not support init", HTTPStatus.BAD_REQUEST
) )
return await self._async_flow_result_to_response( return await self._async_flow_result_to_response(request, client_id, result)
request, data["client_id"], result
)
class LoginFlowResourceView(LoginFlowBaseView): class LoginFlowResourceView(LoginFlowBaseView):
@ -251,13 +273,15 @@ class LoginFlowResourceView(LoginFlowBaseView):
url = "/auth/login_flow/{flow_id}" url = "/auth/login_flow/{flow_id}"
name = "api:auth:login_flow:resource" name = "api:auth:login_flow:resource"
async def get(self, request): async def get(self, request: web.Request) -> web.Response:
"""Do not allow getting status of a flow in progress.""" """Do not allow getting status of a flow in progress."""
return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND) return self.json_message("Invalid flow specified", HTTPStatus.NOT_FOUND)
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA)) @RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
@log_invalid_auth @log_invalid_auth
async def post(self, request, data, flow_id): async def post(
self, request: web.Request, data: dict[str, Any], flow_id: str
) -> web.Response:
"""Handle progressing a login flow request.""" """Handle progressing a login flow request."""
client_id = data.pop("client_id") client_id = data.pop("client_id")
@ -267,7 +291,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
try: try:
# do not allow change ip during login flow # do not allow change ip during login flow
flow = self._flow_mgr.async_get(flow_id) flow = self._flow_mgr.async_get(flow_id)
if flow["context"]["ip_address"] != ip_address(request.remote): if flow["context"]["ip_address"] != ip_address(request.remote): # type: ignore[arg-type]
return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST) return self.json_message("IP address changed", HTTPStatus.BAD_REQUEST)
result = await self._flow_mgr.async_configure(flow_id, data) result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow: except data_entry_flow.UnknownFlow:
@ -277,7 +301,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
return await self._async_flow_result_to_response(request, client_id, result) return await self._async_flow_result_to_response(request, client_id, result)
async def delete(self, request, flow_id): async def delete(self, request: web.Request, flow_id: str) -> web.Response:
"""Cancel a flow in progress.""" """Cancel a flow in progress."""
try: try:
self._flow_mgr.async_abort(flow_id) self._flow_mgr.async_abort(flow_id)

View File

@ -1,5 +1,8 @@
"""Helpers to setup multi-factor auth module.""" """Helpers to setup multi-factor auth module."""
from __future__ import annotations
import logging import logging
from typing import Any
import voluptuous as vol import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
@ -7,15 +10,19 @@ import voluptuous_serialize
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv
WS_TYPE_SETUP_MFA = "auth/setup_mfa" WS_TYPE_SETUP_MFA = "auth/setup_mfa"
SCHEMA_WS_SETUP_MFA = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( SCHEMA_WS_SETUP_MFA = vol.All(
{ websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend(
vol.Required("type"): WS_TYPE_SETUP_MFA, {
vol.Exclusive("mfa_module_id", "module_or_flow_id"): str, vol.Required("type"): WS_TYPE_SETUP_MFA,
vol.Exclusive("flow_id", "module_or_flow_id"): str, vol.Exclusive("mfa_module_id", "module_or_flow_id"): str,
vol.Optional("user_input"): object, vol.Exclusive("flow_id", "module_or_flow_id"): str,
} vol.Optional("user_input"): object,
}
),
cv.has_at_least_one_key("mfa_module_id", "flow_id"),
) )
WS_TYPE_DEPOSE_MFA = "auth/depose_mfa" WS_TYPE_DEPOSE_MFA = "auth/depose_mfa"
@ -31,7 +38,13 @@ _LOGGER = logging.getLogger(__name__)
class MfaFlowManager(data_entry_flow.FlowManager): class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows.""" """Manage multi factor authentication flows."""
async def async_create_flow(self, handler_key, *, context, data): async def async_create_flow( # type: ignore[override]
self,
handler_key: Any,
*,
context: dict[str, Any],
data: dict[str, Any],
) -> data_entry_flow.FlowHandler:
"""Create a setup flow. handler is a mfa module.""" """Create a setup flow. handler is a mfa module."""
mfa_module = self.hass.auth.get_auth_mfa_module(handler_key) mfa_module = self.hass.auth.get_auth_mfa_module(handler_key)
if mfa_module is None: if mfa_module is None:
@ -40,13 +53,15 @@ class MfaFlowManager(data_entry_flow.FlowManager):
user_id = data.pop("user_id") user_id = data.pop("user_id")
return await mfa_module.async_setup_flow(user_id) return await mfa_module.async_setup_flow(user_id)
async def async_finish_flow(self, flow, result): async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Complete an mfs setup flow.""" """Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result) _LOGGER.debug("flow_result: %s", result)
return result return result
async def async_setup(hass): async def async_setup(hass: HomeAssistant) -> None:
"""Init mfa setup flow manager.""" """Init mfa setup flow manager."""
hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass) hass.data[DATA_SETUP_FLOW_MGR] = MfaFlowManager(hass)
@ -62,13 +77,13 @@ async def async_setup(hass):
@callback @callback
@websocket_api.ws_require_user(allow_system_user=False) @websocket_api.ws_require_user(allow_system_user=False)
def websocket_setup_mfa( def websocket_setup_mfa(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Return a setup flow for mfa auth module.""" """Return a setup flow for mfa auth module."""
async def async_setup_flow(msg): async def async_setup_flow(msg: dict[str, Any]) -> None:
"""Return a setup flow for mfa auth module.""" """Return a setup flow for mfa auth module."""
flow_manager = hass.data[DATA_SETUP_FLOW_MGR] flow_manager: MfaFlowManager = hass.data[DATA_SETUP_FLOW_MGR]
if (flow_id := msg.get("flow_id")) is not None: if (flow_id := msg.get("flow_id")) is not None:
result = await flow_manager.async_configure(flow_id, msg.get("user_input")) result = await flow_manager.async_configure(flow_id, msg.get("user_input"))
@ -77,9 +92,8 @@ def websocket_setup_mfa(
) )
return return
mfa_module_id = msg.get("mfa_module_id") mfa_module_id = msg["mfa_module_id"]
mfa_module = hass.auth.get_auth_mfa_module(mfa_module_id) if hass.auth.get_auth_mfa_module(mfa_module_id) is None:
if mfa_module is None:
connection.send_message( connection.send_message(
websocket_api.error_message( websocket_api.error_message(
msg["id"], "no_module", f"MFA module {mfa_module_id} is not found" msg["id"], "no_module", f"MFA module {mfa_module_id} is not found"
@ -101,11 +115,11 @@ def websocket_setup_mfa(
@callback @callback
@websocket_api.ws_require_user(allow_system_user=False) @websocket_api.ws_require_user(allow_system_user=False)
def websocket_depose_mfa( def websocket_depose_mfa(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
): ) -> None:
"""Remove user from mfa module.""" """Remove user from mfa module."""
async def async_depose(msg): async def async_depose(msg: dict[str, Any]) -> None:
"""Remove user from mfa auth module.""" """Remove user from mfa auth module."""
mfa_module_id = msg["mfa_module_id"] mfa_module_id = msg["mfa_module_id"]
try: try:
@ -127,7 +141,9 @@ def websocket_depose_mfa(
hass.async_create_task(async_depose(msg)) hass.async_create_task(async_depose(msg))
def _prepare_result_json(result): def _prepare_result_json(
result: data_entry_flow.FlowResult,
) -> data_entry_flow.FlowResult:
"""Convert result to JSON.""" """Convert result to JSON."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
data = result.copy() data = result.copy()

View File

@ -175,7 +175,7 @@ class FlowManager(abc.ABC):
) )
@callback @callback
def async_get(self, flow_id: str) -> FlowResult | None: def async_get(self, flow_id: str) -> FlowResult:
"""Return a flow in progress as a partial FlowResult.""" """Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None: if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow raise UnknownFlow

View File

@ -349,6 +349,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.auth.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.automation.*] [mypy-homeassistant.components.automation.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true

View File

@ -44,7 +44,13 @@ async def test_ws_setup_depose_mfa(hass, hass_ws_client):
client = await hass_ws_client(hass, access_token) client = await hass_ws_client(hass, access_token)
await client.send_json({"id": 10, "type": mfa_setup_flow.WS_TYPE_SETUP_MFA}) await client.send_json(
{
"id": 10,
"type": mfa_setup_flow.WS_TYPE_SETUP_MFA,
"mfa_module_id": "invalid_module",
}
)
result = await client.receive_json() result = await client.receive_json()
assert result["id"] == 10 assert result["id"] == 10