Fix aws.notify platform schema (#22374)

* Fix aws component notify platform schema

* Address code review comment

* Do not allow load aws.notify from notify component

* Revert unrelated translation update

* Review comment
This commit is contained in:
Jason Hu 2019-03-27 14:53:06 -07:00 committed by GitHub
parent 24c7c2aa6e
commit f795d03503
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 92 deletions

View File

@ -13,14 +13,18 @@ from homeassistant.helpers import config_validation as cv, discovery
from . import config_flow # noqa from . import config_flow # noqa
from .const import ( from .const import (
CONF_ACCESS_KEY_ID, CONF_ACCESS_KEY_ID,
CONF_CONTEXT,
CONF_CREDENTIAL_NAME,
CONF_CREDENTIALS,
CONF_NOTIFY,
CONF_REGION,
CONF_SECRET_ACCESS_KEY, CONF_SECRET_ACCESS_KEY,
CONF_SERVICE,
DATA_CONFIG, DATA_CONFIG,
DATA_HASS_CONFIG, DATA_HASS_CONFIG,
DATA_SESSIONS, DATA_SESSIONS,
DOMAIN, DOMAIN,
CONF_NOTIFY,
) )
from .notify import PLATFORM_SCHEMA as NOTIFY_PLATFORM_SCHEMA
REQUIREMENTS = ["aiobotocore==0.10.2"] REQUIREMENTS = ["aiobotocore==0.10.2"]
@ -37,14 +41,31 @@ AWS_CREDENTIAL_SCHEMA = vol.Schema(
DEFAULT_CREDENTIAL = [{CONF_NAME: "default", CONF_PROFILE_NAME: "default"}] DEFAULT_CREDENTIAL = [{CONF_NAME: "default", CONF_PROFILE_NAME: "default"}]
SUPPORTED_SERVICES = ["lambda", "sns", "sqs"]
NOTIFY_PLATFORM_SCHEMA = vol.Schema(
{
vol.Optional(CONF_NAME): cv.string,
vol.Required(CONF_SERVICE): vol.All(
cv.string, vol.Lower, vol.In(SUPPORTED_SERVICES)
),
vol.Required(CONF_REGION): vol.All(cv.string, vol.Lower),
vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string,
vol.Inclusive(CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS): cv.string,
vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string,
vol.Exclusive(CONF_CREDENTIAL_NAME, ATTR_CREDENTIALS): cv.string,
vol.Optional(CONF_CONTEXT): vol.Coerce(dict),
}
)
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.Schema( DOMAIN: vol.Schema(
{ {
vol.Optional( vol.Optional(
ATTR_CREDENTIALS, default=DEFAULT_CREDENTIAL CONF_CREDENTIALS, default=DEFAULT_CREDENTIAL
): vol.All(cv.ensure_list, [AWS_CREDENTIAL_SCHEMA]), ): vol.All(cv.ensure_list, [AWS_CREDENTIAL_SCHEMA]),
vol.Optional(CONF_NOTIFY): vol.All( vol.Optional(CONF_NOTIFY, default=[]): vol.All(
cv.ensure_list, [NOTIFY_PLATFORM_SCHEMA] cv.ensure_list, [NOTIFY_PLATFORM_SCHEMA]
), ),
} }
@ -98,9 +119,10 @@ async def async_setup_entry(hass, entry):
if conf is None: if conf is None:
conf = CONFIG_SCHEMA({DOMAIN: entry.data})[DOMAIN] conf = CONFIG_SCHEMA({DOMAIN: entry.data})[DOMAIN]
# validate credentials and create sessions
validation = True validation = True
tasks = [] tasks = []
for cred in conf.get(ATTR_CREDENTIALS): for cred in conf[ATTR_CREDENTIALS]:
tasks.append(_validate_aws_credentials(hass, cred)) tasks.append(_validate_aws_credentials(hass, cred))
if tasks: if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
@ -109,15 +131,22 @@ async def async_setup_entry(hass, entry):
if isinstance(result, Exception): if isinstance(result, Exception):
_LOGGER.error( _LOGGER.error(
"Validating credential [%s] failed: %s", "Validating credential [%s] failed: %s",
name, result, exc_info=result name,
result,
exc_info=result,
) )
validation = False validation = False
else: else:
hass.data[DATA_SESSIONS][name] = result hass.data[DATA_SESSIONS][name] = result
# No entry support for notify component yet # set up notify platform, no entry support for notify component yet,
for notify_config in conf.get(CONF_NOTIFY, []): # have to use discovery to load platform.
discovery.load_platform(hass, "notify", DOMAIN, notify_config, config) for notify_config in conf[CONF_NOTIFY]:
hass.async_create_task(
discovery.async_load_platform(
hass, "notify", DOMAIN, notify_config, config
)
)
return validation return validation

View File

@ -1,13 +1,16 @@
"""Constant for AWS component.""" """Constant for AWS component."""
DOMAIN = "aws" DOMAIN = "aws"
DATA_KEY = DOMAIN
DATA_CONFIG = "aws_config" DATA_CONFIG = "aws_config"
DATA_HASS_CONFIG = "aws_hass_config" DATA_HASS_CONFIG = "aws_hass_config"
DATA_SESSIONS = "aws_sessions" DATA_SESSIONS = "aws_sessions"
CONF_REGION = "region_name"
CONF_ACCESS_KEY_ID = "aws_access_key_id" CONF_ACCESS_KEY_ID = "aws_access_key_id"
CONF_SECRET_ACCESS_KEY = "aws_secret_access_key" CONF_CONTEXT = "context"
CONF_PROFILE_NAME = "profile_name"
CONF_CREDENTIAL_NAME = "credential_name" CONF_CREDENTIAL_NAME = "credential_name"
CONF_CREDENTIALS = 'credentials'
CONF_NOTIFY = "notify" CONF_NOTIFY = "notify"
CONF_PROFILE_NAME = "profile_name"
CONF_REGION = "region_name"
CONF_SECRET_ACCESS_KEY = "aws_secret_access_key"
CONF_SERVICE = "service"

View File

@ -1,29 +1,23 @@
"""AWS platform for notify component.""" """AWS platform for notify component."""
import asyncio import asyncio
import logging
import json
import base64 import base64
import json
import logging
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
from homeassistant.const import CONF_PLATFORM, CONF_NAME, ATTR_CREDENTIALS
from homeassistant.components.notify import ( from homeassistant.components.notify import (
ATTR_TARGET, ATTR_TARGET,
ATTR_TITLE, ATTR_TITLE,
ATTR_TITLE_DEFAULT, ATTR_TITLE_DEFAULT,
BaseNotificationService, BaseNotificationService,
PLATFORM_SCHEMA,
) )
from homeassistant.exceptions import HomeAssistantError from homeassistant.const import CONF_PLATFORM, CONF_NAME
from homeassistant.helpers.json import JSONEncoder from homeassistant.helpers.json import JSONEncoder
from .const import ( from .const import (
CONF_ACCESS_KEY_ID, CONF_CONTEXT,
CONF_CREDENTIAL_NAME, CONF_CREDENTIAL_NAME,
CONF_PROFILE_NAME, CONF_PROFILE_NAME,
CONF_REGION, CONF_REGION,
CONF_SECRET_ACCESS_KEY, CONF_SERVICE,
DATA_SESSIONS, DATA_SESSIONS,
) )
@ -31,69 +25,43 @@ DEPENDENCIES = ["aws"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_CONTEXT = "context"
CONF_SERVICE = "service"
SUPPORTED_SERVICES = ["lambda", "sns", "sqs"] async def get_available_regions(hass, service):
"""Get available regions for a service."""
def _in_avilable_region(config):
"""Check if region is available."""
import aiobotocore import aiobotocore
session = aiobotocore.get_session() session = aiobotocore.get_session()
available_regions = session.get_available_regions(config[CONF_SERVICE]) # get_available_regions is not a coroutine since it does not perform
if config[CONF_REGION] not in available_regions: # network I/O. But it still perform file I/O heavily, so put it into
raise vol.Invalid( # an executor thread to unblock event loop
"Region {} is not available for {} service, mustin {}".format( return await hass.async_add_executor_job(
config[CONF_REGION], config[CONF_SERVICE], available_regions session.get_available_regions, service
) )
)
return config
PLATFORM_SCHEMA = vol.Schema(
vol.All(
PLATFORM_SCHEMA.extend(
{
# override notify.PLATFORM_SCHEMA.CONF_PLATFORM to Optional
# we don't need this field when we use discovery
vol.Optional(CONF_PLATFORM): cv.string,
vol.Required(CONF_SERVICE): vol.All(
cv.string, vol.Lower, vol.In(SUPPORTED_SERVICES)
),
vol.Required(CONF_REGION): vol.All(cv.string, vol.Lower),
vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string,
vol.Inclusive(
CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS
): cv.string,
vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string,
vol.Exclusive(
CONF_CREDENTIAL_NAME, ATTR_CREDENTIALS
): cv.string,
vol.Optional(CONF_CONTEXT): vol.Coerce(dict),
},
extra=vol.PREVENT_EXTRA,
),
_in_avilable_region,
)
)
async def async_get_service(hass, config, discovery_info=None): async def async_get_service(hass, config, discovery_info=None):
"""Get the AWS notification service.""" """Get the AWS notification service."""
if discovery_info is None:
_LOGGER.error('Please config aws notify platform in aws component')
return None
import aiobotocore import aiobotocore
session = None session = None
if discovery_info is not None:
conf = discovery_info conf = discovery_info
else:
conf = config
service = conf[CONF_SERVICE] service = conf[CONF_SERVICE]
region_name = conf[CONF_REGION] region_name = conf[CONF_REGION]
available_regions = await get_available_regions(hass, service)
if region_name not in available_regions:
_LOGGER.error(
"Region %s is not available for %s service, must in %s",
region_name, service, available_regions
)
return None
aws_config = conf.copy() aws_config = conf.copy()
del aws_config[CONF_SERVICE] del aws_config[CONF_SERVICE]
@ -106,13 +74,14 @@ async def async_get_service(hass, config, discovery_info=None):
del aws_config[CONF_CONTEXT] del aws_config[CONF_CONTEXT]
if not aws_config: if not aws_config:
# no platform config, use aws component config instead # no platform config, use the first aws component credential instead
if hass.data[DATA_SESSIONS]: if hass.data[DATA_SESSIONS]:
session = list(hass.data[DATA_SESSIONS].values())[0] session = next(iter(hass.data[DATA_SESSIONS].values()))
else: else:
raise ValueError( _LOGGER.error(
"No available aws session for {}".format(config[CONF_NAME]) "Missing aws credential for %s", config[CONF_NAME]
) )
return None
if session is None: if session is None:
credential_name = aws_config.get(CONF_CREDENTIAL_NAME) credential_name = aws_config.get(CONF_CREDENTIAL_NAME)
@ -148,7 +117,8 @@ async def async_get_service(hass, config, discovery_info=None):
if service == "sqs": if service == "sqs":
return AWSSQS(session, aws_config) return AWSSQS(session, aws_config)
raise ValueError("Unsupported service {}".format(service)) # should not reach here since service was checked in schema
return None
class AWSNotify(BaseNotificationService): class AWSNotify(BaseNotificationService):
@ -159,17 +129,6 @@ class AWSNotify(BaseNotificationService):
self.session = session self.session = session
self.aws_config = aws_config self.aws_config = aws_config
def send_message(self, message, **kwargs):
"""Send notification."""
raise NotImplementedError("Please call async_send_message()")
async def async_send_message(self, message="", **kwargs):
"""Send notification."""
targets = kwargs.get(ATTR_TARGET)
if not targets:
raise HomeAssistantError("At least one target is required")
class AWSLambda(AWSNotify): class AWSLambda(AWSNotify):
"""Implement the notification service for the AWS Lambda service.""" """Implement the notification service for the AWS Lambda service."""
@ -183,9 +142,11 @@ class AWSLambda(AWSNotify):
async def async_send_message(self, message="", **kwargs): async def async_send_message(self, message="", **kwargs):
"""Send notification to specified LAMBDA ARN.""" """Send notification to specified LAMBDA ARN."""
await super().async_send_message(message, **kwargs) if not kwargs.get(ATTR_TARGET):
_LOGGER.error("At least one target is required")
return
cleaned_kwargs = dict((k, v) for k, v in kwargs.items() if v) cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None}
payload = {"message": message} payload = {"message": message}
payload.update(cleaned_kwargs) payload.update(cleaned_kwargs)
json_payload = json.dumps(payload) json_payload = json.dumps(payload)
@ -214,12 +175,14 @@ class AWSSNS(AWSNotify):
async def async_send_message(self, message="", **kwargs): async def async_send_message(self, message="", **kwargs):
"""Send notification to specified SNS ARN.""" """Send notification to specified SNS ARN."""
await super().async_send_message(message, **kwargs) if not kwargs.get(ATTR_TARGET):
_LOGGER.error("At least one target is required")
return
message_attributes = { message_attributes = {
k: {"StringValue": json.dumps(v), "DataType": "String"} k: {"StringValue": json.dumps(v), "DataType": "String"}
for k, v in kwargs.items() for k, v in kwargs.items()
if v if v is not None
} }
subject = kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT) subject = kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT)
@ -248,9 +211,11 @@ class AWSSQS(AWSNotify):
async def async_send_message(self, message="", **kwargs): async def async_send_message(self, message="", **kwargs):
"""Send notification to specified SQS ARN.""" """Send notification to specified SQS ARN."""
await super().async_send_message(message, **kwargs) if not kwargs.get(ATTR_TARGET):
_LOGGER.error("At least one target is required")
return
cleaned_kwargs = dict((k, v) for k, v in kwargs.items() if v) cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None}
message_body = {"message": message} message_body = {"message": message}
message_body.update(cleaned_kwargs) message_body.update(cleaned_kwargs)
json_body = json.dumps(message_body) json_body = json.dumps(message_body)