Improve auth generic typing (#133061)

This commit is contained in:
Marc Mueller 2024-12-12 20:14:56 +01:00 committed by GitHub
parent ce70cb9e33
commit 32c1b519ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 50 additions and 34 deletions

View File

@ -115,7 +115,7 @@ class AuthManagerFlowManager(
*,
context: AuthFlowContext | None = None,
data: dict[str, Any] | None = None,
) -> LoginFlow:
) -> LoginFlow[Any]:
"""Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider:

View File

@ -4,8 +4,9 @@ from __future__ import annotations
import logging
import types
from typing import Any
from typing import Any, Generic
from typing_extensions import TypeVar
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -34,6 +35,12 @@ DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
_LOGGER = logging.getLogger(__name__)
_MultiFactorAuthModuleT = TypeVar(
"_MultiFactorAuthModuleT",
bound="MultiFactorAuthModule",
default="MultiFactorAuthModule",
)
class MultiFactorAuthModule:
"""Multi-factor Auth Module of validation function."""
@ -71,7 +78,7 @@ class MultiFactorAuthModule:
"""Return a voluptuous schema to define mfa auth module's input."""
raise NotImplementedError
async def async_setup_flow(self, user_id: str) -> SetupFlow:
async def async_setup_flow(self, user_id: str) -> SetupFlow[Any]:
"""Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow
@ -95,11 +102,14 @@ class MultiFactorAuthModule:
raise NotImplementedError
class SetupFlow(data_entry_flow.FlowHandler):
class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]):
"""Handler for the setup flow."""
def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
self,
auth_module: _MultiFactorAuthModuleT,
setup_schema: vol.Schema,
user_id: str,
) -> None:
"""Initialize the setup flow."""
self._auth_module = auth_module

View File

@ -162,7 +162,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
return sorted(unordered_services)
async def async_setup_flow(self, user_id: str) -> SetupFlow:
async def async_setup_flow(self, user_id: str) -> NotifySetupFlow:
"""Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow
@ -268,7 +268,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
await self.hass.services.async_call("notify", notify_service, data)
class NotifySetupFlow(SetupFlow):
class NotifySetupFlow(SetupFlow[NotifyAuthModule]):
"""Handler for the setup flow."""
def __init__(
@ -280,8 +280,6 @@ class NotifySetupFlow(SetupFlow):
) -> None:
"""Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user_id)
# to fix typing complaint
self._auth_module: NotifyAuthModule = auth_module
self._available_notify_services = available_notify_services
self._secret: str | None = None
self._count: int | None = None

View File

@ -114,7 +114,7 @@ class TotpAuthModule(MultiFactorAuthModule):
self._users[user_id] = ota_secret # type: ignore[index]
return ota_secret
async def async_setup_flow(self, user_id: str) -> SetupFlow:
async def async_setup_flow(self, user_id: str) -> TotpSetupFlow:
"""Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow
@ -174,10 +174,9 @@ class TotpAuthModule(MultiFactorAuthModule):
return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1))
class TotpSetupFlow(SetupFlow):
class TotpSetupFlow(SetupFlow[TotpAuthModule]):
"""Handler for the setup flow."""
_auth_module: TotpAuthModule
_ota_secret: str
_url: str
_image: str

View File

@ -5,8 +5,9 @@ from __future__ import annotations
from collections.abc import Mapping
import logging
import types
from typing import Any
from typing import Any, Generic
from typing_extensions import TypeVar
import voluptuous as vol
from voluptuous.humanize import humanize_error
@ -46,6 +47,8 @@ AUTH_PROVIDER_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA,
)
_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider")
class AuthProvider:
"""Provider of user authentication."""
@ -105,7 +108,7 @@ class AuthProvider:
# Implement by extending class
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow[Any]:
"""Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance.
@ -192,12 +195,15 @@ async def load_auth_provider_module(
return module
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
class LoginFlow(
FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
Generic[_AuthProviderT],
):
"""Handler for the login flow."""
_flow_result = AuthFlowResult
def __init__(self, auth_provider: AuthProvider) -> None:
def __init__(self, auth_provider: _AuthProviderT) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider
self._auth_module_id: str | None = None

View File

@ -6,7 +6,7 @@ import asyncio
from collections.abc import Mapping
import logging
import os
from typing import Any, cast
from typing import Any
import voluptuous as vol
@ -59,7 +59,9 @@ class CommandLineAuthProvider(AuthProvider):
super().__init__(*args, **kwargs)
self._user_meta: dict[str, dict[str, Any]] = {}
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(
self, context: AuthFlowContext | None
) -> CommandLineLoginFlow:
"""Return a flow to login."""
return CommandLineLoginFlow(self)
@ -133,7 +135,7 @@ class CommandLineAuthProvider(AuthProvider):
)
class CommandLineLoginFlow(LoginFlow):
class CommandLineLoginFlow(LoginFlow[CommandLineAuthProvider]):
"""Handler for the login flow."""
async def async_step_init(
@ -145,9 +147,9 @@ class CommandLineLoginFlow(LoginFlow):
if user_input is not None:
user_input["username"] = user_input["username"].strip()
try:
await cast(
CommandLineAuthProvider, self._auth_provider
).async_validate_login(user_input["username"], user_input["password"])
await self._auth_provider.async_validate_login(
user_input["username"], user_input["password"]
)
except InvalidAuthError:
errors["base"] = "invalid_auth"

View File

@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load()
self.data = data
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(self, context: AuthFlowContext | None) -> HassLoginFlow:
"""Return a flow to login."""
return HassLoginFlow(self)
@ -400,7 +400,7 @@ class HassAuthProvider(AuthProvider):
pass
class HassLoginFlow(LoginFlow):
class HassLoginFlow(LoginFlow[HassAuthProvider]):
"""Handler for the login flow."""
async def async_step_init(
@ -411,7 +411,7 @@ class HassLoginFlow(LoginFlow):
if user_input is not None:
try:
await cast(HassAuthProvider, self._auth_provider).async_validate_login(
await self._auth_provider.async_validate_login(
user_input["username"], user_input["password"]
)
except InvalidAuth:

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from collections.abc import Mapping
import hmac
from typing import cast
import voluptuous as vol
@ -36,7 +35,9 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords."""
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(
self, context: AuthFlowContext | None
) -> ExampleLoginFlow:
"""Return a flow to login."""
return ExampleLoginFlow(self)
@ -93,7 +94,7 @@ class ExampleAuthProvider(AuthProvider):
return UserMeta(name=name, is_active=True)
class ExampleLoginFlow(LoginFlow):
class ExampleLoginFlow(LoginFlow[ExampleAuthProvider]):
"""Handler for the login flow."""
async def async_step_init(
@ -104,7 +105,7 @@ class ExampleLoginFlow(LoginFlow):
if user_input is not None:
try:
cast(ExampleAuthProvider, self._auth_provider).async_validate_login(
self._auth_provider.async_validate_login(
user_input["username"], user_input["password"]
)
except InvalidAuthError:

View File

@ -104,7 +104,9 @@ class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider does not support MFA."""
return False
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
async def async_login_flow(
self, context: AuthFlowContext | None
) -> TrustedNetworksLoginFlow:
"""Return a flow to login."""
assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address"))
@ -214,7 +216,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
self.async_validate_access(ip_address(remote_ip))
class TrustedNetworksLoginFlow(LoginFlow):
class TrustedNetworksLoginFlow(LoginFlow[TrustedNetworksAuthProvider]):
"""Handler for the login flow."""
def __init__(
@ -235,9 +237,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
) -> AuthFlowResult:
"""Handle the step of the form."""
try:
cast(
TrustedNetworksAuthProvider, self._auth_provider
).async_validate_access(self._ip_address)
self._auth_provider.async_validate_access(self._ip_address)
except InvalidAuthError:
return self.async_abort(reason="not_allowed")