Improve auth generic typing (#133061)

This commit is contained in:
Marc Mueller 2024-12-12 20:14:56 +01:00 committed by GitHub
parent ce70cb9e33
commit 32c1b519ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 50 additions and 34 deletions

View File

@ -115,7 +115,7 @@ class AuthManagerFlowManager(
*, *,
context: AuthFlowContext | None = None, context: AuthFlowContext | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> LoginFlow: ) -> LoginFlow[Any]:
"""Create a login flow.""" """Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key) auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider: if not auth_provider:

View File

@ -4,8 +4,9 @@ from __future__ import annotations
import logging import logging
import types import types
from typing import Any from typing import Any, Generic
from typing_extensions import TypeVar
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -34,6 +35,12 @@ DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_MultiFactorAuthModuleT = TypeVar(
"_MultiFactorAuthModuleT",
bound="MultiFactorAuthModule",
default="MultiFactorAuthModule",
)
class MultiFactorAuthModule: class MultiFactorAuthModule:
"""Multi-factor Auth Module of validation function.""" """Multi-factor Auth Module of validation function."""
@ -71,7 +78,7 @@ class MultiFactorAuthModule:
"""Return a voluptuous schema to define mfa auth module's input.""" """Return a voluptuous schema to define mfa auth module's input."""
raise NotImplementedError raise NotImplementedError
async def async_setup_flow(self, user_id: str) -> SetupFlow: async def async_setup_flow(self, user_id: str) -> SetupFlow[Any]:
"""Return a data entry flow handler for setup module. """Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
@ -95,11 +102,14 @@ class MultiFactorAuthModule:
raise NotImplementedError raise NotImplementedError
class SetupFlow(data_entry_flow.FlowHandler): class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]):
"""Handler for the setup flow.""" """Handler for the setup flow."""
def __init__( def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str self,
auth_module: _MultiFactorAuthModuleT,
setup_schema: vol.Schema,
user_id: str,
) -> None: ) -> None:
"""Initialize the setup flow.""" """Initialize the setup flow."""
self._auth_module = auth_module self._auth_module = auth_module

View File

@ -162,7 +162,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
return sorted(unordered_services) return sorted(unordered_services)
async def async_setup_flow(self, user_id: str) -> SetupFlow: async def async_setup_flow(self, user_id: str) -> NotifySetupFlow:
"""Return a data entry flow handler for setup module. """Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
@ -268,7 +268,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
await self.hass.services.async_call("notify", notify_service, data) await self.hass.services.async_call("notify", notify_service, data)
class NotifySetupFlow(SetupFlow): class NotifySetupFlow(SetupFlow[NotifyAuthModule]):
"""Handler for the setup flow.""" """Handler for the setup flow."""
def __init__( def __init__(
@ -280,8 +280,6 @@ class NotifySetupFlow(SetupFlow):
) -> None: ) -> None:
"""Initialize the setup flow.""" """Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user_id) super().__init__(auth_module, setup_schema, user_id)
# to fix typing complaint
self._auth_module: NotifyAuthModule = auth_module
self._available_notify_services = available_notify_services self._available_notify_services = available_notify_services
self._secret: str | None = None self._secret: str | None = None
self._count: int | None = None self._count: int | None = None

View File

@ -114,7 +114,7 @@ class TotpAuthModule(MultiFactorAuthModule):
self._users[user_id] = ota_secret # type: ignore[index] self._users[user_id] = ota_secret # type: ignore[index]
return ota_secret return ota_secret
async def async_setup_flow(self, user_id: str) -> SetupFlow: async def async_setup_flow(self, user_id: str) -> TotpSetupFlow:
"""Return a data entry flow handler for setup module. """Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
@ -174,10 +174,9 @@ class TotpAuthModule(MultiFactorAuthModule):
return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1)) return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1))
class TotpSetupFlow(SetupFlow): class TotpSetupFlow(SetupFlow[TotpAuthModule]):
"""Handler for the setup flow.""" """Handler for the setup flow."""
_auth_module: TotpAuthModule
_ota_secret: str _ota_secret: str
_url: str _url: str
_image: str _image: str

View File

@ -5,8 +5,9 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
import types import types
from typing import Any from typing import Any, Generic
from typing_extensions import TypeVar
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
@ -46,6 +47,8 @@ AUTH_PROVIDER_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider")
class AuthProvider: class AuthProvider:
"""Provider of user authentication.""" """Provider of user authentication."""
@ -105,7 +108,7 @@ class AuthProvider:
# Implement by extending class # Implement by extending class
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow[Any]:
"""Return the data flow for logging in with auth provider. """Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance. Auth provider should extend LoginFlow and return an instance.
@ -192,12 +195,15 @@ async def load_auth_provider_module(
return module return module
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]): class LoginFlow(
FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
Generic[_AuthProviderT],
):
"""Handler for the login flow.""" """Handler for the login flow."""
_flow_result = AuthFlowResult _flow_result = AuthFlowResult
def __init__(self, auth_provider: AuthProvider) -> None: def __init__(self, auth_provider: _AuthProviderT) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider self._auth_provider = auth_provider
self._auth_module_id: str | None = None self._auth_module_id: str | None = None

View File

@ -6,7 +6,7 @@ import asyncio
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
import os import os
from typing import Any, cast from typing import Any
import voluptuous as vol import voluptuous as vol
@ -59,7 +59,9 @@ class CommandLineAuthProvider(AuthProvider):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._user_meta: dict[str, dict[str, Any]] = {} self._user_meta: dict[str, dict[str, Any]] = {}
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: async def async_login_flow(
self, context: AuthFlowContext | None
) -> CommandLineLoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return CommandLineLoginFlow(self) return CommandLineLoginFlow(self)
@ -133,7 +135,7 @@ class CommandLineAuthProvider(AuthProvider):
) )
class CommandLineLoginFlow(LoginFlow): class CommandLineLoginFlow(LoginFlow[CommandLineAuthProvider]):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
@ -145,9 +147,9 @@ class CommandLineLoginFlow(LoginFlow):
if user_input is not None: if user_input is not None:
user_input["username"] = user_input["username"].strip() user_input["username"] = user_input["username"].strip()
try: try:
await cast( await self._auth_provider.async_validate_login(
CommandLineAuthProvider, self._auth_provider user_input["username"], user_input["password"]
).async_validate_login(user_input["username"], user_input["password"]) )
except InvalidAuthError: except InvalidAuthError:
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"

View File

@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load() await data.async_load()
self.data = data self.data = data
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: async def async_login_flow(self, context: AuthFlowContext | None) -> HassLoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return HassLoginFlow(self) return HassLoginFlow(self)
@ -400,7 +400,7 @@ class HassAuthProvider(AuthProvider):
pass pass
class HassLoginFlow(LoginFlow): class HassLoginFlow(LoginFlow[HassAuthProvider]):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
@ -411,7 +411,7 @@ class HassLoginFlow(LoginFlow):
if user_input is not None: if user_input is not None:
try: try:
await cast(HassAuthProvider, self._auth_provider).async_validate_login( await self._auth_provider.async_validate_login(
user_input["username"], user_input["password"] user_input["username"], user_input["password"]
) )
except InvalidAuth: except InvalidAuth:

View File

@ -4,7 +4,6 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import hmac import hmac
from typing import cast
import voluptuous as vol import voluptuous as vol
@ -36,7 +35,9 @@ class InvalidAuthError(HomeAssistantError):
class ExampleAuthProvider(AuthProvider): class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords.""" """Example auth provider based on hardcoded usernames and passwords."""
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: async def async_login_flow(
self, context: AuthFlowContext | None
) -> ExampleLoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return ExampleLoginFlow(self) return ExampleLoginFlow(self)
@ -93,7 +94,7 @@ class ExampleAuthProvider(AuthProvider):
return UserMeta(name=name, is_active=True) return UserMeta(name=name, is_active=True)
class ExampleLoginFlow(LoginFlow): class ExampleLoginFlow(LoginFlow[ExampleAuthProvider]):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
@ -104,7 +105,7 @@ class ExampleLoginFlow(LoginFlow):
if user_input is not None: if user_input is not None:
try: try:
cast(ExampleAuthProvider, self._auth_provider).async_validate_login( self._auth_provider.async_validate_login(
user_input["username"], user_input["password"] user_input["username"], user_input["password"]
) )
except InvalidAuthError: except InvalidAuthError:

View File

@ -104,7 +104,9 @@ class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider does not support MFA.""" """Trusted Networks auth provider does not support MFA."""
return False return False
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow: async def async_login_flow(
self, context: AuthFlowContext | None
) -> TrustedNetworksLoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
assert context is not None assert context is not None
ip_addr = cast(IPAddress, context.get("ip_address")) ip_addr = cast(IPAddress, context.get("ip_address"))
@ -214,7 +216,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
self.async_validate_access(ip_address(remote_ip)) self.async_validate_access(ip_address(remote_ip))
class TrustedNetworksLoginFlow(LoginFlow): class TrustedNetworksLoginFlow(LoginFlow[TrustedNetworksAuthProvider]):
"""Handler for the login flow.""" """Handler for the login flow."""
def __init__( def __init__(
@ -235,9 +237,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
) -> AuthFlowResult: ) -> AuthFlowResult:
"""Handle the step of the form.""" """Handle the step of the form."""
try: try:
cast( self._auth_provider.async_validate_access(self._ip_address)
TrustedNetworksAuthProvider, self._auth_provider
).async_validate_access(self._ip_address)
except InvalidAuthError: except InvalidAuthError:
return self.async_abort(reason="not_allowed") return self.async_abort(reason="not_allowed")