mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Improve auth generic typing (#133061)
This commit is contained in:
parent
ce70cb9e33
commit
32c1b519ad
@ -115,7 +115,7 @@ class AuthManagerFlowManager(
|
||||
*,
|
||||
context: AuthFlowContext | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> LoginFlow:
|
||||
) -> LoginFlow[Any]:
|
||||
"""Create a login flow."""
|
||||
auth_provider = self.auth_manager.get_auth_provider(*handler_key)
|
||||
if not auth_provider:
|
||||
|
@ -4,8 +4,9 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import types
|
||||
from typing import Any
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
@ -34,6 +35,12 @@ DATA_REQS: HassKey[set[str]] = HassKey("mfa_auth_module_reqs_processed")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_MultiFactorAuthModuleT = TypeVar(
|
||||
"_MultiFactorAuthModuleT",
|
||||
bound="MultiFactorAuthModule",
|
||||
default="MultiFactorAuthModule",
|
||||
)
|
||||
|
||||
|
||||
class MultiFactorAuthModule:
|
||||
"""Multi-factor Auth Module of validation function."""
|
||||
@ -71,7 +78,7 @@ class MultiFactorAuthModule:
|
||||
"""Return a voluptuous schema to define mfa auth module's input."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_setup_flow(self, user_id: str) -> SetupFlow:
|
||||
async def async_setup_flow(self, user_id: str) -> SetupFlow[Any]:
|
||||
"""Return a data entry flow handler for setup module.
|
||||
|
||||
Mfa module should extend SetupFlow
|
||||
@ -95,11 +102,14 @@ class MultiFactorAuthModule:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SetupFlow(data_entry_flow.FlowHandler):
|
||||
class SetupFlow(data_entry_flow.FlowHandler, Generic[_MultiFactorAuthModuleT]):
|
||||
"""Handler for the setup flow."""
|
||||
|
||||
def __init__(
|
||||
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
|
||||
self,
|
||||
auth_module: _MultiFactorAuthModuleT,
|
||||
setup_schema: vol.Schema,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""Initialize the setup flow."""
|
||||
self._auth_module = auth_module
|
||||
|
@ -162,7 +162,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||
|
||||
return sorted(unordered_services)
|
||||
|
||||
async def async_setup_flow(self, user_id: str) -> SetupFlow:
|
||||
async def async_setup_flow(self, user_id: str) -> NotifySetupFlow:
|
||||
"""Return a data entry flow handler for setup module.
|
||||
|
||||
Mfa module should extend SetupFlow
|
||||
@ -268,7 +268,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
|
||||
await self.hass.services.async_call("notify", notify_service, data)
|
||||
|
||||
|
||||
class NotifySetupFlow(SetupFlow):
|
||||
class NotifySetupFlow(SetupFlow[NotifyAuthModule]):
|
||||
"""Handler for the setup flow."""
|
||||
|
||||
def __init__(
|
||||
@ -280,8 +280,6 @@ class NotifySetupFlow(SetupFlow):
|
||||
) -> None:
|
||||
"""Initialize the setup flow."""
|
||||
super().__init__(auth_module, setup_schema, user_id)
|
||||
# to fix typing complaint
|
||||
self._auth_module: NotifyAuthModule = auth_module
|
||||
self._available_notify_services = available_notify_services
|
||||
self._secret: str | None = None
|
||||
self._count: int | None = None
|
||||
|
@ -114,7 +114,7 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||
self._users[user_id] = ota_secret # type: ignore[index]
|
||||
return ota_secret
|
||||
|
||||
async def async_setup_flow(self, user_id: str) -> SetupFlow:
|
||||
async def async_setup_flow(self, user_id: str) -> TotpSetupFlow:
|
||||
"""Return a data entry flow handler for setup module.
|
||||
|
||||
Mfa module should extend SetupFlow
|
||||
@ -174,10 +174,9 @@ class TotpAuthModule(MultiFactorAuthModule):
|
||||
return bool(pyotp.TOTP(ota_secret).verify(code, valid_window=1))
|
||||
|
||||
|
||||
class TotpSetupFlow(SetupFlow):
|
||||
class TotpSetupFlow(SetupFlow[TotpAuthModule]):
|
||||
"""Handler for the setup flow."""
|
||||
|
||||
_auth_module: TotpAuthModule
|
||||
_ota_secret: str
|
||||
_url: str
|
||||
_image: str
|
||||
|
@ -5,8 +5,9 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
import types
|
||||
from typing import Any
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
@ -46,6 +47,8 @@ AUTH_PROVIDER_SCHEMA = vol.Schema(
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
_AuthProviderT = TypeVar("_AuthProviderT", bound="AuthProvider", default="AuthProvider")
|
||||
|
||||
|
||||
class AuthProvider:
|
||||
"""Provider of user authentication."""
|
||||
@ -105,7 +108,7 @@ class AuthProvider:
|
||||
|
||||
# Implement by extending class
|
||||
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow[Any]:
|
||||
"""Return the data flow for logging in with auth provider.
|
||||
|
||||
Auth provider should extend LoginFlow and return an instance.
|
||||
@ -192,12 +195,15 @@ async def load_auth_provider_module(
|
||||
return module
|
||||
|
||||
|
||||
class LoginFlow(FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]]):
|
||||
class LoginFlow(
|
||||
FlowHandler[AuthFlowContext, AuthFlowResult, tuple[str, str]],
|
||||
Generic[_AuthProviderT],
|
||||
):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
_flow_result = AuthFlowResult
|
||||
|
||||
def __init__(self, auth_provider: AuthProvider) -> None:
|
||||
def __init__(self, auth_provider: _AuthProviderT) -> None:
|
||||
"""Initialize the login flow."""
|
||||
self._auth_provider = auth_provider
|
||||
self._auth_module_id: str | None = None
|
||||
|
@ -6,7 +6,7 @@ import asyncio
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -59,7 +59,9 @@ class CommandLineAuthProvider(AuthProvider):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
async def async_login_flow(
|
||||
self, context: AuthFlowContext | None
|
||||
) -> CommandLineLoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return CommandLineLoginFlow(self)
|
||||
|
||||
@ -133,7 +135,7 @@ class CommandLineAuthProvider(AuthProvider):
|
||||
)
|
||||
|
||||
|
||||
class CommandLineLoginFlow(LoginFlow):
|
||||
class CommandLineLoginFlow(LoginFlow[CommandLineAuthProvider]):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
@ -145,9 +147,9 @@ class CommandLineLoginFlow(LoginFlow):
|
||||
if user_input is not None:
|
||||
user_input["username"] = user_input["username"].strip()
|
||||
try:
|
||||
await cast(
|
||||
CommandLineAuthProvider, self._auth_provider
|
||||
).async_validate_login(user_input["username"], user_input["password"])
|
||||
await self._auth_provider.async_validate_login(
|
||||
user_input["username"], user_input["password"]
|
||||
)
|
||||
except InvalidAuthError:
|
||||
errors["base"] = "invalid_auth"
|
||||
|
||||
|
@ -305,7 +305,7 @@ class HassAuthProvider(AuthProvider):
|
||||
await data.async_load()
|
||||
self.data = data
|
||||
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> HassLoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return HassLoginFlow(self)
|
||||
|
||||
@ -400,7 +400,7 @@ class HassAuthProvider(AuthProvider):
|
||||
pass
|
||||
|
||||
|
||||
class HassLoginFlow(LoginFlow):
|
||||
class HassLoginFlow(LoginFlow[HassAuthProvider]):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
@ -411,7 +411,7 @@ class HassLoginFlow(LoginFlow):
|
||||
|
||||
if user_input is not None:
|
||||
try:
|
||||
await cast(HassAuthProvider, self._auth_provider).async_validate_login(
|
||||
await self._auth_provider.async_validate_login(
|
||||
user_input["username"], user_input["password"]
|
||||
)
|
||||
except InvalidAuth:
|
||||
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import hmac
|
||||
from typing import cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -36,7 +35,9 @@ class InvalidAuthError(HomeAssistantError):
|
||||
class ExampleAuthProvider(AuthProvider):
|
||||
"""Example auth provider based on hardcoded usernames and passwords."""
|
||||
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
async def async_login_flow(
|
||||
self, context: AuthFlowContext | None
|
||||
) -> ExampleLoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return ExampleLoginFlow(self)
|
||||
|
||||
@ -93,7 +94,7 @@ class ExampleAuthProvider(AuthProvider):
|
||||
return UserMeta(name=name, is_active=True)
|
||||
|
||||
|
||||
class ExampleLoginFlow(LoginFlow):
|
||||
class ExampleLoginFlow(LoginFlow[ExampleAuthProvider]):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def async_step_init(
|
||||
@ -104,7 +105,7 @@ class ExampleLoginFlow(LoginFlow):
|
||||
|
||||
if user_input is not None:
|
||||
try:
|
||||
cast(ExampleAuthProvider, self._auth_provider).async_validate_login(
|
||||
self._auth_provider.async_validate_login(
|
||||
user_input["username"], user_input["password"]
|
||||
)
|
||||
except InvalidAuthError:
|
||||
|
@ -104,7 +104,9 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||
"""Trusted Networks auth provider does not support MFA."""
|
||||
return False
|
||||
|
||||
async def async_login_flow(self, context: AuthFlowContext | None) -> LoginFlow:
|
||||
async def async_login_flow(
|
||||
self, context: AuthFlowContext | None
|
||||
) -> TrustedNetworksLoginFlow:
|
||||
"""Return a flow to login."""
|
||||
assert context is not None
|
||||
ip_addr = cast(IPAddress, context.get("ip_address"))
|
||||
@ -214,7 +216,7 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
||||
self.async_validate_access(ip_address(remote_ip))
|
||||
|
||||
|
||||
class TrustedNetworksLoginFlow(LoginFlow):
|
||||
class TrustedNetworksLoginFlow(LoginFlow[TrustedNetworksAuthProvider]):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
def __init__(
|
||||
@ -235,9 +237,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
try:
|
||||
cast(
|
||||
TrustedNetworksAuthProvider, self._auth_provider
|
||||
).async_validate_access(self._ip_address)
|
||||
self._auth_provider.async_validate_access(self._ip_address)
|
||||
|
||||
except InvalidAuthError:
|
||||
return self.async_abort(reason="not_allowed")
|
||||
|
Loading…
x
Reference in New Issue
Block a user