mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 13:47:35 +00:00
Improve auth generic typing (#133061)
This commit is contained in:
parent
ce70cb9e33
commit
32c1b519ad
@ -115,7 +115,7 @@ class AuthManagerFlowManager(
|
|||||||
*,
|
*,
|
||||||
context: AuthFlowContext | None = None,
|
context: AuthFlowContext | None = None,
|
||||||
data: dict[str, Any] | None = None,
|
data: dict[str, Any] | None = None,
|
||||||
) -> LoginFlow:
|
) -> LoginFlow[Any]:
|
||||||
"""Create a login flow."""
|
"""Create a login flow."""
|
||||||
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
||||||
if not auth_provider:
|
if not auth_provider:
|
||||||
|
@ -4,8 +4,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Any
|
from typing import Any, Generic
|
||||||
|
|
||||||
|
from typing_extensions import TypeVar
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
@ -34,6 +35,12 @@ DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MultiFactorAuthModuleT = TypeVar(
|
||||||
|
"_MultiFactorAuthModuleT",
|
||||||
|
bound="MultiFactorAuthModule",
|
||||||
|
default="MultiFactorAuthModule",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiFactorAuthModule:
|
class MultiFactorAuthModule:
|
||||||
"""Multi-factor Auth Module of validation function."""
|
"""Multi-factor Auth Module of validation function."""
|
||||||
@ -71,7 +78,7 @@ class MultiFactorAuthModule:
|
|||||||
"""Return a voluptuous schema to define mfa auth module's input."""
|
"""Return a voluptuous schema to define mfa auth module's input."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def async_setup_flow(self, user_id: str) -> SetupFlow:
|
async def async_setup_flow(self, user_id: str) -> SetupFlow[Any]:
|
||||||
"""Return a data entry flow handler for setup module.
|
"""Return a data entry flow handler for setup module.
|
||||||
|
|
||||||
Mfa module should extend SetupFlow
|
Mfa module should extend SetupFlow
|
||||||
@ -95,11 +102,14 @@ class MultiFactorAuthModule:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class SetupFlow(data_entry_flow.FlowHandler):
|
class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]):
|
||||||
"""Handler for the setup flow."""
|
"""Handler for the setup flow."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
|
self,
|
||||||
|
auth_module: _MultiFactorAuthModuleT,
|
||||||
|
setup_schema: vol.Schema,
|
||||||
|
user_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the setup flow."""
|
"""Initialize the setup flow."""
|
||||||
self._auth_module = auth_module
|
self._auth_module = auth_module
|
||||||
|
@ -162,7 +162,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||||||
|
|
||||||
return sorted(unordered_services)
|
return sorted(unordered_services)
|
||||||
|
|
||||||
async def async_setup_flow(self, user_id: str) -> SetupFlow:
|
async def async_setup_flow(self, user_id: str) -> NotifySetupFlow:
|
||||||
"""Return a data entry flow handler for setup module.
|
"""Return a data entry flow handler for setup module.
|
||||||
|
|
||||||
Mfa module should extend SetupFlow
|
Mfa module should extend SetupFlow
|
||||||
@ -268,7 +268,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
|||||||
await self.hass.services.async_call("notify", notify_service, data)
|
await self.hass.services.async_call("notify", notify_service, data)
|
||||||
|
|
||||||
|
|
||||||
class NotifySetupFlow(SetupFlow):
|
class NotifySetupFlow(SetupFlow[NotifyAuthModule]):
|
||||||
"""Handler for the setup flow."""
|
"""Handler for the setup flow."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -280,8 +280,6 @@ class NotifySetupFlow(SetupFlow):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the setup flow."""
|
"""Initialize the setup flow."""
|
||||||
super().__init__(auth_module, setup_schema, user_id)
|
super().__init__(auth_module, setup_schema, user_id)
|
||||||
# to fix typing complaint
|
|
||||||
self._auth_module: NotifyAuthModule = auth_module
|
|
||||||
self._available_notify_services = available_notify_services
|
self._available_notify_services = available_notify_services
|
||||||
self._secret: str | None = None
|
self._secret: str | None = None
|
||||||
self._count: int | None = None
|
self._count: int | None = None
|
||||||
|
@ -114,7 +114,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||||||
self._users[user_id] = ota_secret # type: ignore[index]
|
self._users[user_id] = ota_secret # type: ignore[index]
|
||||||
return ota_secret
|
return ota_secret
|
||||||
|
|
||||||
async def async_setup_flow(self, user_id: str) -> SetupFlow:
|
async def async_setup_flow(self, user_id: str) -> TotpSetupFlow:
|
||||||
"""Return a data entry flow handler for setup module.
|
"""Return a data entry flow handler for setup module.
|
||||||
|
|
||||||
Mfa module should extend SetupFlow
|
Mfa module should extend SetupFlow
|
||||||
@ -174,10 +174,9 @@ class TotpAuthModule(MultiFactorAuthModule):
|
|||||||
return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1))
|
return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1))
|
||||||
|
|
||||||
|
|
||||||
class TotpSetupFlow(SetupFlow):
|
class TotpSetupFlow(SetupFlow[TotpAuthModule]):
|
||||||
"""Handler for the setup flow."""
|
"""Handler for the setup flow."""
|
||||||
|
|
||||||
_auth_module: TotpAuthModule
|
|
||||||
_ota_secret: str
|
_ota_secret: str
|
||||||
_url: str
|
_url: str
|
||||||
_image: str
|
_image: str
|
||||||
|
@ -5,8 +5,9 @@ from __future__ import annotations
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Any
|
from typing import Any, Generic
|
||||||
|
|
||||||
|
from typing_extensions import TypeVar
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous.humanize import humanize_error
|
from voluptuous.humanize import humanize_error
|
||||||
|
|
||||||
@ -46,6 +47,8 @@ AUTH_PROVIDER_SCHEMA = vol.Schema(
|
|||||||
extra=vol.ALLOW_EXTRA,
|
extra=vol.ALLOW_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider")
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider:
|
class AuthProvider:
|
||||||
"""Provider of user authentication."""
|
"""Provider of user authentication."""
|
||||||
@ -105,7 +108,7 @@ class AuthProvider:
|
|||||||
|
|
||||||
# Implement by extending class
|
# Implement by extending class
|
||||||
|
|
||||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow[Any]:
|
||||||
"""Return the data flow for logging in with auth provider.
|
"""Return the data flow for logging in with auth provider.
|
||||||
|
|
||||||
Auth provider should extend LoginFlow and return an instance.
|
Auth provider should extend LoginFlow and return an instance.
|
||||||
@ -192,12 +195,15 @@ async def load_auth_provider_module(
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
|
class LoginFlow(
|
||||||
|
FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
|
||||||
|
Generic[_AuthProviderT],
|
||||||
|
):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
_flow_result = AuthFlowResult
|
_flow_result = AuthFlowResult
|
||||||
|
|
||||||
def __init__(self, auth_provider: AuthProvider) -> None:
|
def __init__(self, auth_provider: _AuthProviderT) -> None:
|
||||||
"""Initialize the login flow."""
|
"""Initialize the login flow."""
|
||||||
self._auth_provider = auth_provider
|
self._auth_provider = auth_provider
|
||||||
self._auth_module_id: str | None = None
|
self._auth_module_id: str | None = None
|
||||||
|
@ -6,7 +6,7 @@ import asyncio
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, cast
|
from typing import Any
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -59,7 +59,9 @@ class CommandLineAuthProvider(AuthProvider):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._user_meta: dict[str, dict[str, Any]] = {}
|
self._user_meta: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
async def async_login_flow(
|
||||||
|
self, context: AuthFlowContext | None
|
||||||
|
) -> CommandLineLoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return CommandLineLoginFlow(self)
|
return CommandLineLoginFlow(self)
|
||||||
|
|
||||||
@ -133,7 +135,7 @@ class CommandLineAuthProvider(AuthProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CommandLineLoginFlow(LoginFlow):
|
class CommandLineLoginFlow(LoginFlow[CommandLineAuthProvider]):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
@ -145,9 +147,9 @@ class CommandLineLoginFlow(LoginFlow):
|
|||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
user_input["username"] = user_input["username"].strip()
|
user_input["username"] = user_input["username"].strip()
|
||||||
try:
|
try:
|
||||||
await cast(
|
await self._auth_provider.async_validate_login(
|
||||||
CommandLineAuthProvider, self._auth_provider
|
user_input["username"], user_input["password"]
|
||||||
).async_validate_login(user_input["username"], user_input["password"])
|
)
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
errors["base"] = "invalid_auth"
|
errors["base"] = "invalid_auth"
|
||||||
|
|
||||||
|
@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
|
|||||||
await data.async_load()
|
await data.async_load()
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
async def async_login_flow(self, context: AuthFlowContext | None) -> HassLoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return HassLoginFlow(self)
|
return HassLoginFlow(self)
|
||||||
|
|
||||||
@ -400,7 +400,7 @@ class HassAuthProvider(AuthProvider):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HassLoginFlow(LoginFlow):
|
class HassLoginFlow(LoginFlow[HassAuthProvider]):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
@ -411,7 +411,7 @@ class HassLoginFlow(LoginFlow):
|
|||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
try:
|
||||||
await cast(HassAuthProvider, self._auth_provider).async_validate_login(
|
await self._auth_provider.async_validate_login(
|
||||||
user_input["username"], user_input["password"]
|
user_input["username"], user_input["password"]
|
||||||
)
|
)
|
||||||
except InvalidAuth:
|
except InvalidAuth:
|
||||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
import hmac
|
import hmac
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -36,7 +35,9 @@ class InvalidAuthError(HomeAssistantError):
|
|||||||
class ExampleAuthProvider(AuthProvider):
|
class ExampleAuthProvider(AuthProvider):
|
||||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||||
|
|
||||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
async def async_login_flow(
|
||||||
|
self, context: AuthFlowContext | None
|
||||||
|
) -> ExampleLoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
return ExampleLoginFlow(self)
|
return ExampleLoginFlow(self)
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ class ExampleAuthProvider(AuthProvider):
|
|||||||
return UserMeta(name=name, is_active=True)
|
return UserMeta(name=name, is_active=True)
|
||||||
|
|
||||||
|
|
||||||
class ExampleLoginFlow(LoginFlow):
|
class ExampleLoginFlow(LoginFlow[ExampleAuthProvider]):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
@ -104,7 +105,7 @@ class ExampleLoginFlow(LoginFlow):
|
|||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
try:
|
||||||
cast(ExampleAuthProvider, self._auth_provider).async_validate_login(
|
self._auth_provider.async_validate_login(
|
||||||
user_input["username"], user_input["password"]
|
user_input["username"], user_input["password"]
|
||||||
)
|
)
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
|
@ -104,7 +104,9 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||||||
"""Trusted Networks auth provider does not support MFA."""
|
"""Trusted Networks auth provider does not support MFA."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
async def async_login_flow(
|
||||||
|
self, context: AuthFlowContext | None
|
||||||
|
) -> TrustedNetworksLoginFlow:
|
||||||
"""Return a flow to login."""
|
"""Return a flow to login."""
|
||||||
assert context is not None
|
assert context is not None
|
||||||
ip_addr = cast(IPAddress, context.get("ip_address"))
|
ip_addr = cast(IPAddress, context.get("ip_address"))
|
||||||
@ -214,7 +216,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||||||
self.async_validate_access(ip_address(remote_ip))
|
self.async_validate_access(ip_address(remote_ip))
|
||||||
|
|
||||||
|
|
||||||
class TrustedNetworksLoginFlow(LoginFlow):
|
class TrustedNetworksLoginFlow(LoginFlow[TrustedNetworksAuthProvider]):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -235,9 +237,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
|||||||
) -> AuthFlowResult:
|
) -> AuthFlowResult:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
try:
|
try:
|
||||||
cast(
|
self._auth_provider.async_validate_access(self._ip_address)
|
||||||
TrustedNetworksAuthProvider, self._auth_provider
|
|
||||||
).async_validate_access(self._ip_address)
|
|
||||||
|
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
return self.async_abort(reason="not_allowed")
|
return self.async_abort(reason="not_allowed")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user