mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Significantly improve Tesla Fleet config flow (#146794)
* Improved config flow * Tests * Improvements * Dashboard url & tests * Apply suggestions from code review Co-authored-by: Norbert Rittel <norbert@rittel.de> * revert oauth change * fully restore oauth file * remove CONF_DOMAIN * Add pick_implementation back in * Use try else * Improve translation * use CONF_DOMAIN --------- Co-authored-by: Norbert Rittel <norbert@rittel.de>
This commit is contained in:
parent
e8667dfbe0
commit
b563f9078a
@ -4,14 +4,30 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
from tesla_fleet_api import TeslaFleetApi
|
||||
from tesla_fleet_api.const import SERVERS
|
||||
from tesla_fleet_api.exceptions import (
|
||||
InvalidResponse,
|
||||
PreconditionFailed,
|
||||
TeslaFleetError,
|
||||
)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.selector import (
|
||||
QrCodeSelector,
|
||||
QrCodeSelectorConfig,
|
||||
QrErrorCorrectionLevel,
|
||||
)
|
||||
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .const import CONF_DOMAIN, DOMAIN, LOGGER
|
||||
from .oauth import TeslaUserImplementation
|
||||
|
||||
|
||||
class OAuth2FlowHandler(
|
||||
@ -21,36 +37,173 @@ class OAuth2FlowHandler(
|
||||
|
||||
DOMAIN = DOMAIN
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize config flow."""
|
||||
super().__init__()
|
||||
self.domain: str | None = None
|
||||
self.registration_status: dict[str, bool] = {}
|
||||
self.tesla_apis: dict[str, TeslaFleetApi] = {}
|
||||
self.failed_regions: list[str] = []
|
||||
self.data: dict[str, Any] = {}
|
||||
self.uid: str | None = None
|
||||
self.api: TeslaFleetApi | None = None
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
return LOGGER
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle a flow start."""
|
||||
return await super().async_step_user()
|
||||
|
||||
async def async_oauth_create_entry(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
|
||||
"""Handle OAuth completion and proceed to domain registration."""
|
||||
token = jwt.decode(
|
||||
data["token"]["access_token"], options={"verify_signature": False}
|
||||
)
|
||||
uid = token["sub"]
|
||||
|
||||
await self.async_set_unique_id(uid)
|
||||
self.data = data
|
||||
self.uid = token["sub"]
|
||||
server = SERVERS[token["ou_code"].lower()]
|
||||
|
||||
await self.async_set_unique_id(self.uid)
|
||||
if self.source == SOURCE_REAUTH:
|
||||
self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch")
|
||||
return self.async_update_reload_and_abort(
|
||||
self._get_reauth_entry(), data=data
|
||||
)
|
||||
self._abort_if_unique_id_configured()
|
||||
return self.async_create_entry(title=uid, data=data)
|
||||
|
||||
# OAuth done, setup a Partner API connection
|
||||
implementation = cast(TeslaUserImplementation, self.flow_impl)
|
||||
|
||||
session = async_get_clientsession(self.hass)
|
||||
self.api = TeslaFleetApi(
|
||||
session=session,
|
||||
server=server,
|
||||
partner_scope=True,
|
||||
charging_scope=False,
|
||||
energy_scope=False,
|
||||
user_scope=False,
|
||||
vehicle_scope=False,
|
||||
)
|
||||
await self.api.get_private_key(self.hass.config.path("tesla_fleet.key"))
|
||||
await self.api.partner_login(
|
||||
implementation.client_id, implementation.client_secret
|
||||
)
|
||||
|
||||
return await self.async_step_domain_input()
|
||||
|
||||
async def async_step_domain_input(
|
||||
self,
|
||||
user_input: dict[str, Any] | None = None,
|
||||
errors: dict[str, str] | None = None,
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle domain input step."""
|
||||
|
||||
errors = errors or {}
|
||||
|
||||
if user_input is not None:
|
||||
domain = user_input[CONF_DOMAIN].strip().lower()
|
||||
|
||||
# Validate domain format
|
||||
if not self._is_valid_domain(domain):
|
||||
errors[CONF_DOMAIN] = "invalid_domain"
|
||||
else:
|
||||
self.domain = domain
|
||||
return await self.async_step_domain_registration()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="domain_input",
|
||||
description_placeholders={
|
||||
"dashboard": "https://developer.tesla.com/en_AU/dashboard/"
|
||||
},
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_DOMAIN): str,
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_domain_registration(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle domain registration for both regions."""
|
||||
|
||||
assert self.api
|
||||
assert self.api.private_key
|
||||
assert self.domain
|
||||
|
||||
errors = {}
|
||||
description_placeholders = {
|
||||
"public_key_url": f"https://{self.domain}/.well-known/appspecific/com.tesla.3p.public-key.pem",
|
||||
"pem": self.api.public_pem,
|
||||
}
|
||||
|
||||
try:
|
||||
register_response = await self.api.partner.register(self.domain)
|
||||
except PreconditionFailed:
|
||||
return await self.async_step_domain_input(
|
||||
errors={CONF_DOMAIN: "precondition_failed"}
|
||||
)
|
||||
except InvalidResponse:
|
||||
errors["base"] = "invalid_response"
|
||||
except TeslaFleetError as e:
|
||||
errors["base"] = "unknown_error"
|
||||
description_placeholders["error"] = e.message
|
||||
else:
|
||||
# Get public key from response
|
||||
registered_public_key = register_response.get("response", {}).get(
|
||||
"public_key"
|
||||
)
|
||||
|
||||
if not registered_public_key:
|
||||
errors["base"] = "public_key_not_found"
|
||||
elif (
|
||||
registered_public_key.lower()
|
||||
!= self.api.public_uncompressed_point.lower()
|
||||
):
|
||||
errors["base"] = "public_key_mismatch"
|
||||
else:
|
||||
return await self.async_step_registration_complete()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="domain_registration",
|
||||
description_placeholders=description_placeholders,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_registration_complete(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Show completion and virtual key installation."""
|
||||
if user_input is not None and self.uid and self.data:
|
||||
return self.async_create_entry(title=self.uid, data=self.data)
|
||||
|
||||
if not self.domain:
|
||||
return await self.async_step_domain_input()
|
||||
|
||||
virtual_key_url = f"https://www.tesla.com/_ak/{self.domain}"
|
||||
data_schema = vol.Schema({}).extend(
|
||||
{
|
||||
vol.Optional("qr_code"): QrCodeSelector(
|
||||
config=QrCodeSelectorConfig(
|
||||
data=virtual_key_url,
|
||||
scale=6,
|
||||
error_correction_level=QrErrorCorrectionLevel.QUARTILE,
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="registration_complete",
|
||||
data_schema=data_schema,
|
||||
description_placeholders={
|
||||
"virtual_key_url": virtual_key_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_step_reauth(
|
||||
self, entry_data: Mapping[str, Any]
|
||||
@ -67,4 +220,11 @@ class OAuth2FlowHandler(
|
||||
step_id="reauth_confirm",
|
||||
description_placeholders={"name": "Tesla Fleet"},
|
||||
)
|
||||
return await self.async_step_user()
|
||||
# For reauth, skip domain registration and go straight to OAuth
|
||||
return await super().async_step_user()
|
||||
|
||||
def _is_valid_domain(self, domain: str) -> bool:
|
||||
"""Validate domain format."""
|
||||
# Basic domain validation regex
|
||||
domain_pattern = re.compile(r"^(?:[a-zA-Z0-9]+\.)+[a-zA-Z0-9-]+$")
|
||||
return bool(domain_pattern.match(domain))
|
||||
|
@ -9,6 +9,7 @@ from tesla_fleet_api.const import Scope
|
||||
|
||||
DOMAIN = "tesla_fleet"
|
||||
|
||||
CONF_DOMAIN = "domain"
|
||||
CONF_REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
|
@ -4,6 +4,7 @@
|
||||
"authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]",
|
||||
"missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]",
|
||||
"already_configured": "Configuration updated for profile.",
|
||||
"already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]",
|
||||
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]",
|
||||
"oauth_error": "[%key:common::config_flow::abort::oauth2_error%]",
|
||||
"oauth_timeout": "[%key:common::config_flow::abort::oauth2_timeout%]",
|
||||
@ -13,7 +14,12 @@
|
||||
"reauth_account_mismatch": "The reauthentication account does not match the original account"
|
||||
},
|
||||
"error": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_account%]"
|
||||
"invalid_domain": "Invalid domain format. Please enter a valid domain name.",
|
||||
"public_key_not_found": "Public key not found.",
|
||||
"public_key_mismatch": "The public key hosted at your domain does not match the expected key. Please ensure the correct public key is hosted at the specified location.",
|
||||
"precondition_failed": "The domain does not match the application's allowed origins.",
|
||||
"invalid_response": "The registration was rejected by Tesla",
|
||||
"unknown_error": "An unknown error occurred: {error}"
|
||||
},
|
||||
"step": {
|
||||
"pick_implementation": {
|
||||
@ -25,6 +31,21 @@
|
||||
"implementation": "[%key:common::config_flow::description::implementation%]"
|
||||
}
|
||||
},
|
||||
"domain_input": {
|
||||
"title": "Tesla Fleet domain registration",
|
||||
"description": "Enter the domain that will host your public key. This is typically the domain of the origin you specified during registration at {dashboard}.",
|
||||
"data": {
|
||||
"domain": "Domain"
|
||||
}
|
||||
},
|
||||
"domain_registration": {
|
||||
"title": "Registering public key",
|
||||
"description": "You must host the public key at:\n\n{public_key_url}\n\n```\n{pem}\n```"
|
||||
},
|
||||
"registration_complete": {
|
||||
"title": "Command signing",
|
||||
"description": "To enable command signing, you must open the Tesla app, select your vehicle, and then visit the following URL to set up a virtual key. You must repeat this process for each vehicle.\n\n{virtual_key_url}"
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "[%key:common::config_flow::title::reauth%]",
|
||||
"description": "The {name} integration needs to re-authenticate your account"
|
||||
|
@ -1,16 +1,23 @@
|
||||
"""Test the Tesla Fleet config flow."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import pytest
|
||||
from tesla_fleet_api.exceptions import (
|
||||
InvalidResponse,
|
||||
PreconditionFailed,
|
||||
TeslaFleetError,
|
||||
)
|
||||
|
||||
from homeassistant.components.application_credentials import (
|
||||
ClientCredential,
|
||||
async_import_client_credential,
|
||||
)
|
||||
from homeassistant.components.tesla_fleet.config_flow import OAuth2FlowHandler
|
||||
from homeassistant.components.tesla_fleet.const import (
|
||||
AUTHORIZE_URL,
|
||||
CONF_DOMAIN,
|
||||
DOMAIN,
|
||||
SCOPES,
|
||||
TOKEN_URL,
|
||||
@ -64,15 +71,30 @@ async def create_credential(hass: HomeAssistant) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_private_key():
|
||||
"""Mock private key for testing."""
|
||||
private_key = Mock()
|
||||
public_key = Mock()
|
||||
private_key.public_key.return_value = public_key
|
||||
public_key.public_bytes.side_effect = [
|
||||
b"-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA\n-----END PUBLIC KEY-----",
|
||||
bytes.fromhex(
|
||||
"0404112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff1122"
|
||||
),
|
||||
]
|
||||
return private_key
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_full_flow_user_cred(
|
||||
async def test_full_flow_with_domain_registration(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
mock_private_key,
|
||||
) -> None:
|
||||
"""Check full flow."""
|
||||
|
||||
"""Test full flow with domain registration."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
@ -95,7 +117,7 @@ async def test_full_flow_user_cred(
|
||||
assert parsed_query["redirect_uri"][0] == REDIRECT
|
||||
assert parsed_query["state"][0] == state
|
||||
assert parsed_query["scope"][0] == " ".join(SCOPES)
|
||||
assert "code_challenge" not in parsed_query # Ensure not a PKCE flow
|
||||
assert "code_challenge" not in parsed_query
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
@ -112,21 +134,416 @@ async def test_full_flow_user_cred(
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
with patch(
|
||||
"homeassistant.components.tesla_fleet.async_setup_entry", return_value=True
|
||||
) as mock_setup:
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.config_flow.TeslaFleetApi"
|
||||
) as mock_api_class,
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.async_setup_entry", return_value=True
|
||||
),
|
||||
):
|
||||
mock_api = AsyncMock()
|
||||
mock_api.private_key = mock_private_key
|
||||
mock_api.get_private_key = AsyncMock()
|
||||
mock_api.partner_login = AsyncMock()
|
||||
mock_api.public_uncompressed_point = "0404112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff1122"
|
||||
mock_api.partner.register.return_value = {
|
||||
"response": {
|
||||
"public_key": "0404112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff1122"
|
||||
}
|
||||
}
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Complete OAuth
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_input"
|
||||
|
||||
# Enter domain - this should automatically register and go to registration_complete
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "example.com"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "registration_complete"
|
||||
|
||||
# Complete flow - provide user input to complete registration
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == UNIQUE_ID
|
||||
assert "result" in result
|
||||
assert result["result"].unique_id == UNIQUE_ID
|
||||
assert "token" in result["result"].data
|
||||
assert result["result"].data["token"]["access_token"] == access_token
|
||||
assert result["result"].data["token"]["refresh_token"] == "mock-refresh-token"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_domain_input_invalid_domain(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
mock_private_key,
|
||||
) -> None:
|
||||
"""Test domain input with invalid domain."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": access_token,
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.config_flow.TeslaFleetApi"
|
||||
) as mock_api_class,
|
||||
):
|
||||
mock_api = AsyncMock()
|
||||
mock_api.private_key = mock_private_key
|
||||
mock_api.get_private_key = AsyncMock()
|
||||
mock_api.partner_login = AsyncMock()
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Complete OAuth
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_input"
|
||||
|
||||
# Enter invalid domain
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "invalid-domain"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_input"
|
||||
assert result["errors"] == {CONF_DOMAIN: "invalid_domain"}
|
||||
|
||||
# Enter valid domain - this should automatically register and go to registration_complete
|
||||
mock_api.public_uncompressed_point = "0404112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff1122"
|
||||
mock_api.partner.register.return_value = {
|
||||
"response": {
|
||||
"public_key": "0404112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff112233445566778899aabbccddeeff1122"
|
||||
}
|
||||
}
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "example.com"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "registration_complete"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "expected_error"),
|
||||
[
|
||||
(InvalidResponse, "invalid_response"),
|
||||
(TeslaFleetError("Custom error"), "unknown_error"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_domain_registration_errors(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
mock_private_key,
|
||||
side_effect,
|
||||
expected_error,
|
||||
) -> None:
|
||||
"""Test domain registration with errors that stay on domain_registration step."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": access_token,
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.config_flow.TeslaFleetApi"
|
||||
) as mock_api_class,
|
||||
):
|
||||
mock_api = AsyncMock()
|
||||
mock_api.private_key = mock_private_key
|
||||
mock_api.get_private_key = AsyncMock()
|
||||
mock_api.partner_login = AsyncMock()
|
||||
mock_api.public_uncompressed_point = "test_point"
|
||||
mock_api.partner.register.side_effect = side_effect
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Complete OAuth
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
# Enter domain - this should fail and stay on domain_registration
|
||||
with patch(
|
||||
"homeassistant.helpers.translation.async_get_translations", return_value={}
|
||||
):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "example.com"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_registration"
|
||||
assert result["errors"] == {"base": expected_error}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_domain_registration_precondition_failed(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
mock_private_key,
|
||||
) -> None:
|
||||
"""Test domain registration with PreconditionFailed redirects to domain_input."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": access_token,
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.config_flow.TeslaFleetApi"
|
||||
) as mock_api_class,
|
||||
):
|
||||
mock_api = AsyncMock()
|
||||
mock_api.private_key = mock_private_key
|
||||
mock_api.get_private_key = AsyncMock()
|
||||
mock_api.partner_login = AsyncMock()
|
||||
mock_api.public_uncompressed_point = "test_point"
|
||||
mock_api.partner.register.side_effect = PreconditionFailed
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Complete OAuth
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
# Enter domain - this should go to domain_registration and then fail back to domain_input
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "example.com"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_input"
|
||||
assert result["errors"] == {CONF_DOMAIN: "precondition_failed"}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_domain_registration_public_key_not_found(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
mock_private_key,
|
||||
) -> None:
|
||||
"""Test domain registration with missing public key."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": access_token,
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.config_flow.TeslaFleetApi"
|
||||
) as mock_api_class,
|
||||
):
|
||||
mock_api = AsyncMock()
|
||||
mock_api.private_key = mock_private_key
|
||||
mock_api.get_private_key = AsyncMock()
|
||||
mock_api.partner_login = AsyncMock()
|
||||
mock_api.public_uncompressed_point = "test_point"
|
||||
mock_api.partner.register.return_value = {"response": {}}
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Complete OAuth
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
# Enter domain - this should fail and stay on domain_registration
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "example.com"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_registration"
|
||||
assert result["errors"] == {"base": "public_key_not_found"}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_domain_registration_public_key_mismatch(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
mock_private_key,
|
||||
) -> None:
|
||||
"""Test domain registration with public key mismatch."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": access_token,
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tesla_fleet.config_flow.TeslaFleetApi"
|
||||
) as mock_api_class,
|
||||
):
|
||||
mock_api = AsyncMock()
|
||||
mock_api.private_key = mock_private_key
|
||||
mock_api.get_private_key = AsyncMock()
|
||||
mock_api.partner_login = AsyncMock()
|
||||
mock_api.public_uncompressed_point = "expected_key"
|
||||
mock_api.partner.register.return_value = {
|
||||
"response": {"public_key": "different_key"}
|
||||
}
|
||||
mock_api_class.return_value = mock_api
|
||||
|
||||
# Complete OAuth
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
# Enter domain - this should fail and stay on domain_registration
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {CONF_DOMAIN: "example.com"}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_registration"
|
||||
assert result["errors"] == {"base": "public_key_mismatch"}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_registration_complete_no_domain(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test registration complete step without domain."""
|
||||
|
||||
flow_instance = OAuth2FlowHandler()
|
||||
flow_instance.hass = hass
|
||||
flow_instance.domain = None
|
||||
|
||||
result = await flow_instance.async_step_registration_complete({})
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "domain_input"
|
||||
|
||||
|
||||
async def test_registration_complete_with_domain_and_user_input(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test registration complete step with domain and user input."""
|
||||
|
||||
flow_instance = OAuth2FlowHandler()
|
||||
flow_instance.hass = hass
|
||||
flow_instance.domain = "example.com"
|
||||
flow_instance.uid = UNIQUE_ID
|
||||
flow_instance.data = {"token": {"access_token": "test"}}
|
||||
|
||||
result = await flow_instance.async_step_registration_complete({"complete": True})
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == UNIQUE_ID
|
||||
|
||||
|
||||
async def test_registration_complete_with_domain_no_user_input(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test registration complete step with domain but no user input."""
|
||||
|
||||
flow_instance = OAuth2FlowHandler()
|
||||
flow_instance.hass = hass
|
||||
flow_instance.domain = "example.com"
|
||||
|
||||
result = await flow_instance.async_step_registration_complete(None)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "registration_complete"
|
||||
assert (
|
||||
result["description_placeholders"]["virtual_key_url"]
|
||||
== "https://www.tesla.com/_ak/example.com"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
@ -225,3 +642,89 @@ async def test_reauth_account_mismatch(
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_account_mismatch"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("current_request_with_host")
|
||||
async def test_duplicate_unique_id_abort(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
access_token: str,
|
||||
) -> None:
|
||||
"""Test duplicate unique ID aborts flow."""
|
||||
# Create existing entry
|
||||
existing_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
unique_id=UNIQUE_ID,
|
||||
version=1,
|
||||
data={},
|
||||
)
|
||||
existing_entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT,
|
||||
},
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": access_token,
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
# Complete OAuth - should abort due to duplicate unique_id
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_reauth_confirm_form(hass: HomeAssistant) -> None:
|
||||
"""Test reauth confirm form display."""
|
||||
old_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
unique_id=UNIQUE_ID,
|
||||
version=1,
|
||||
data={},
|
||||
)
|
||||
old_entry.add_to_hass(hass)
|
||||
|
||||
result = await old_entry.start_reauth_flow(hass)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
assert result["description_placeholders"] == {"name": "Tesla Fleet"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("domain", "expected_valid"),
|
||||
[
|
||||
("example.com", True),
|
||||
("test.example.com", True),
|
||||
("sub.domain.example.org", True),
|
||||
("https://example.com", False),
|
||||
("invalid-domain", False),
|
||||
("", False),
|
||||
("example", False),
|
||||
("example.", False),
|
||||
(".example.com", False),
|
||||
("exam ple.com", False),
|
||||
],
|
||||
)
|
||||
def test_is_valid_domain(domain: str, expected_valid: bool) -> None:
|
||||
"""Test domain validation."""
|
||||
|
||||
assert OAuth2FlowHandler()._is_valid_domain(domain) == expected_valid
|
||||
|
Loading…
x
Reference in New Issue
Block a user