mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Add endpoint validation for AWS S3 (#144334)
This commit is contained in:
parent
bdf6f7f590
commit
0ec7dc5654
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from aiobotocore.session import AioSession
|
from aiobotocore.session import AioSession
|
||||||
from botocore.exceptions import ClientError, ConnectionError, ParamValidationError
|
from botocore.exceptions import ClientError, ConnectionError, ParamValidationError
|
||||||
@ -17,6 +18,7 @@ from homeassistant.helpers.selector import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
|
AWS_DOMAIN,
|
||||||
CONF_ACCESS_KEY_ID,
|
CONF_ACCESS_KEY_ID,
|
||||||
CONF_BUCKET,
|
CONF_BUCKET,
|
||||||
CONF_ENDPOINT_URL,
|
CONF_ENDPOINT_URL,
|
||||||
@ -57,28 +59,34 @@ class S3ConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
CONF_ENDPOINT_URL: user_input[CONF_ENDPOINT_URL],
|
CONF_ENDPOINT_URL: user_input[CONF_ENDPOINT_URL],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
session = AioSession()
|
if not urlparse(user_input[CONF_ENDPOINT_URL]).hostname.endswith(
|
||||||
async with session.create_client(
|
AWS_DOMAIN
|
||||||
"s3",
|
):
|
||||||
endpoint_url=user_input.get(CONF_ENDPOINT_URL),
|
|
||||||
aws_secret_access_key=user_input[CONF_SECRET_ACCESS_KEY],
|
|
||||||
aws_access_key_id=user_input[CONF_ACCESS_KEY_ID],
|
|
||||||
) as client:
|
|
||||||
await client.head_bucket(Bucket=user_input[CONF_BUCKET])
|
|
||||||
except ClientError:
|
|
||||||
errors["base"] = "invalid_credentials"
|
|
||||||
except ParamValidationError as err:
|
|
||||||
if "Invalid bucket name" in str(err):
|
|
||||||
errors[CONF_BUCKET] = "invalid_bucket_name"
|
|
||||||
except ValueError:
|
|
||||||
errors[CONF_ENDPOINT_URL] = "invalid_endpoint_url"
|
errors[CONF_ENDPOINT_URL] = "invalid_endpoint_url"
|
||||||
except ConnectionError:
|
|
||||||
errors[CONF_ENDPOINT_URL] = "cannot_connect"
|
|
||||||
else:
|
else:
|
||||||
return self.async_create_entry(
|
try:
|
||||||
title=user_input[CONF_BUCKET], data=user_input
|
session = AioSession()
|
||||||
)
|
async with session.create_client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=user_input.get(CONF_ENDPOINT_URL),
|
||||||
|
aws_secret_access_key=user_input[CONF_SECRET_ACCESS_KEY],
|
||||||
|
aws_access_key_id=user_input[CONF_ACCESS_KEY_ID],
|
||||||
|
) as client:
|
||||||
|
await client.head_bucket(Bucket=user_input[CONF_BUCKET])
|
||||||
|
except ClientError:
|
||||||
|
errors["base"] = "invalid_credentials"
|
||||||
|
except ParamValidationError as err:
|
||||||
|
if "Invalid bucket name" in str(err):
|
||||||
|
errors[CONF_BUCKET] = "invalid_bucket_name"
|
||||||
|
except ValueError:
|
||||||
|
errors[CONF_ENDPOINT_URL] = "invalid_endpoint_url"
|
||||||
|
except ConnectionError:
|
||||||
|
errors[CONF_ENDPOINT_URL] = "cannot_connect"
|
||||||
|
else:
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=user_input[CONF_BUCKET], data=user_input
|
||||||
|
)
|
||||||
|
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="user",
|
step_id="user",
|
||||||
|
@ -12,7 +12,8 @@ CONF_SECRET_ACCESS_KEY = "secret_access_key"
|
|||||||
CONF_ENDPOINT_URL = "endpoint_url"
|
CONF_ENDPOINT_URL = "endpoint_url"
|
||||||
CONF_BUCKET = "bucket"
|
CONF_BUCKET = "bucket"
|
||||||
|
|
||||||
DEFAULT_ENDPOINT_URL = "https://s3.eu-central-1.amazonaws.com/"
|
AWS_DOMAIN = "amazonaws.com"
|
||||||
|
DEFAULT_ENDPOINT_URL = f"https://s3.eu-central-1.{AWS_DOMAIN}/"
|
||||||
|
|
||||||
DATA_BACKUP_AGENT_LISTENERS: HassKey[list[Callable[[], None]]] = HassKey(
|
DATA_BACKUP_AGENT_LISTENERS: HassKey[list[Callable[[], None]]] = HassKey(
|
||||||
f"{DOMAIN}.backup_agent_listeners"
|
f"{DOMAIN}.backup_agent_listeners"
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
"cannot_connect": "[%key:component::aws_s3::exceptions::cannot_connect::message%]",
|
"cannot_connect": "[%key:component::aws_s3::exceptions::cannot_connect::message%]",
|
||||||
"invalid_bucket_name": "[%key:component::aws_s3::exceptions::invalid_bucket_name::message%]",
|
"invalid_bucket_name": "[%key:component::aws_s3::exceptions::invalid_bucket_name::message%]",
|
||||||
"invalid_credentials": "[%key:component::aws_s3::exceptions::invalid_credentials::message%]",
|
"invalid_credentials": "[%key:component::aws_s3::exceptions::invalid_credentials::message%]",
|
||||||
"invalid_endpoint_url": "Invalid endpoint URL"
|
"invalid_endpoint_url": "Invalid endpoint URL. Please make sure it's a valid AWS S3 endpoint URL."
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
||||||
|
@ -10,6 +10,6 @@ from homeassistant.components.aws_s3.const import (
|
|||||||
USER_INPUT = {
|
USER_INPUT = {
|
||||||
CONF_ACCESS_KEY_ID: "TestTestTestTestTest",
|
CONF_ACCESS_KEY_ID: "TestTestTestTestTest",
|
||||||
CONF_SECRET_ACCESS_KEY: "TestTestTestTestTestTestTestTestTestTest",
|
CONF_SECRET_ACCESS_KEY: "TestTestTestTestTestTestTestTestTestTest",
|
||||||
CONF_ENDPOINT_URL: "http://127.0.0.1:9000",
|
CONF_ENDPOINT_URL: "https://s3.eu-south-1.amazonaws.com",
|
||||||
CONF_BUCKET: "test",
|
CONF_BUCKET: "test",
|
||||||
}
|
}
|
||||||
|
@ -21,8 +21,12 @@ from tests.common import MockConfigEntry
|
|||||||
|
|
||||||
async def _async_start_flow(
|
async def _async_start_flow(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
user_input: dict[str, str] | None = None,
|
||||||
) -> FlowResultType:
|
) -> FlowResultType:
|
||||||
"""Initialize the config flow."""
|
"""Initialize the config flow."""
|
||||||
|
if user_input is None:
|
||||||
|
user_input = USER_INPUT
|
||||||
|
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
)
|
)
|
||||||
@ -30,7 +34,7 @@ async def _async_start_flow(
|
|||||||
|
|
||||||
return await hass.config_entries.flow.async_configure(
|
return await hass.config_entries.flow.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
USER_INPUT,
|
user_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -116,3 +120,24 @@ async def test_abort_if_already_configured(
|
|||||||
result = await _async_start_flow(hass)
|
result = await _async_start_flow(hass)
|
||||||
assert result["type"] is FlowResultType.ABORT
|
assert result["type"] is FlowResultType.ABORT
|
||||||
assert result["reason"] == "already_configured"
|
assert result["reason"] == "already_configured"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flow_create_not_aws_endpoint(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> None:
|
||||||
|
"""Test config flow with a not aws endpoint should raise an error."""
|
||||||
|
result = await _async_start_flow(
|
||||||
|
hass, USER_INPUT | {CONF_ENDPOINT_URL: "http://example.com"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["errors"] == {CONF_ENDPOINT_URL: "invalid_endpoint_url"}
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
USER_INPUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["title"] == "test"
|
||||||
|
assert result["data"] == USER_INPUT
|
||||||
|
Loading…
x
Reference in New Issue
Block a user