diff --git a/homeassistant/components/aws/__init__.py b/homeassistant/components/aws/__init__.py index bd1f6b55090..a15e56e9de8 100644 --- a/homeassistant/components/aws/__init__.py +++ b/homeassistant/components/aws/__init__.py @@ -13,14 +13,18 @@ from homeassistant.helpers import config_validation as cv, discovery from . import config_flow # noqa from .const import ( CONF_ACCESS_KEY_ID, + CONF_CONTEXT, + CONF_CREDENTIAL_NAME, + CONF_CREDENTIALS, + CONF_NOTIFY, + CONF_REGION, CONF_SECRET_ACCESS_KEY, + CONF_SERVICE, DATA_CONFIG, DATA_HASS_CONFIG, DATA_SESSIONS, DOMAIN, - CONF_NOTIFY, ) -from .notify import PLATFORM_SCHEMA as NOTIFY_PLATFORM_SCHEMA REQUIREMENTS = ["aiobotocore==0.10.2"] @@ -37,14 +41,31 @@ AWS_CREDENTIAL_SCHEMA = vol.Schema( DEFAULT_CREDENTIAL = [{CONF_NAME: "default", CONF_PROFILE_NAME: "default"}] +SUPPORTED_SERVICES = ["lambda", "sns", "sqs"] + +NOTIFY_PLATFORM_SCHEMA = vol.Schema( + { + vol.Optional(CONF_NAME): cv.string, + vol.Required(CONF_SERVICE): vol.All( + cv.string, vol.Lower, vol.In(SUPPORTED_SERVICES) + ), + vol.Required(CONF_REGION): vol.All(cv.string, vol.Lower), + vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string, + vol.Inclusive(CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS): cv.string, + vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string, + vol.Exclusive(CONF_CREDENTIAL_NAME, ATTR_CREDENTIALS): cv.string, + vol.Optional(CONF_CONTEXT): vol.Coerce(dict), + } +) + CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.Schema( { vol.Optional( - ATTR_CREDENTIALS, default=DEFAULT_CREDENTIAL + CONF_CREDENTIALS, default=DEFAULT_CREDENTIAL ): vol.All(cv.ensure_list, [AWS_CREDENTIAL_SCHEMA]), - vol.Optional(CONF_NOTIFY): vol.All( + vol.Optional(CONF_NOTIFY, default=[]): vol.All( cv.ensure_list, [NOTIFY_PLATFORM_SCHEMA] ), } @@ -98,9 +119,10 @@ async def async_setup_entry(hass, entry): if conf is None: conf = CONFIG_SCHEMA({DOMAIN: entry.data})[DOMAIN] + # validate credentials and create sessions validation = True tasks = [] - for cred in conf.get(ATTR_CREDENTIALS): + for cred in conf[ATTR_CREDENTIALS]: tasks.append(_validate_aws_credentials(hass, cred)) if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) @@ -109,15 +131,22 @@ async def async_setup_entry(hass, entry): if isinstance(result, Exception): _LOGGER.error( "Validating credential [%s] failed: %s", - name, result, exc_info=result + name, + result, + exc_info=result, ) validation = False else: hass.data[DATA_SESSIONS][name] = result - # No entry support for notify component yet - for notify_config in conf.get(CONF_NOTIFY, []): - discovery.load_platform(hass, "notify", DOMAIN, notify_config, config) + # set up notify platform, no entry support for notify component yet, + # have to use discovery to load platform. + for notify_config in conf[CONF_NOTIFY]: + hass.async_create_task( + discovery.async_load_platform( + hass, "notify", DOMAIN, notify_config, config + ) + ) return validation diff --git a/homeassistant/components/aws/const.py b/homeassistant/components/aws/const.py index c8b0eed8b6b..4fa88566934 100644 --- a/homeassistant/components/aws/const.py +++ b/homeassistant/components/aws/const.py @@ -1,13 +1,16 @@ """Constant for AWS component.""" DOMAIN = "aws" -DATA_KEY = DOMAIN + DATA_CONFIG = "aws_config" DATA_HASS_CONFIG = "aws_hass_config" DATA_SESSIONS = "aws_sessions" -CONF_REGION = "region_name" CONF_ACCESS_KEY_ID = "aws_access_key_id" -CONF_SECRET_ACCESS_KEY = "aws_secret_access_key" -CONF_PROFILE_NAME = "profile_name" +CONF_CONTEXT = "context" CONF_CREDENTIAL_NAME = "credential_name" +CONF_CREDENTIALS = 'credentials' CONF_NOTIFY = "notify" +CONF_PROFILE_NAME = "profile_name" +CONF_REGION = "region_name" +CONF_SECRET_ACCESS_KEY = "aws_secret_access_key" +CONF_SERVICE = "service" diff --git a/homeassistant/components/aws/notify.py b/homeassistant/components/aws/notify.py index 020d92200b9..48b80b64ce2 100644 --- a/homeassistant/components/aws/notify.py +++ b/homeassistant/components/aws/notify.py @@ -1,29 +1,23 @@ """AWS platform for notify component.""" import asyncio -import logging -import json import base64 +import json +import logging -import voluptuous as vol - -import homeassistant.helpers.config_validation as cv -from homeassistant.const import CONF_PLATFORM, CONF_NAME, ATTR_CREDENTIALS from homeassistant.components.notify import ( ATTR_TARGET, ATTR_TITLE, ATTR_TITLE_DEFAULT, BaseNotificationService, - PLATFORM_SCHEMA, ) -from homeassistant.exceptions import HomeAssistantError +from homeassistant.const import CONF_PLATFORM, CONF_NAME from homeassistant.helpers.json import JSONEncoder - from .const import ( - CONF_ACCESS_KEY_ID, + CONF_CONTEXT, CONF_CREDENTIAL_NAME, CONF_PROFILE_NAME, CONF_REGION, - CONF_SECRET_ACCESS_KEY, + CONF_SERVICE, DATA_SESSIONS, ) @@ -31,69 +25,43 @@ DEPENDENCIES = ["aws"] _LOGGER = logging.getLogger(__name__) -CONF_CONTEXT = "context" -CONF_SERVICE = "service" -SUPPORTED_SERVICES = ["lambda", "sns", "sqs"] - - -def _in_avilable_region(config): - """Check if region is available.""" +async def get_available_regions(hass, service): + """Get available regions for a service.""" import aiobotocore session = aiobotocore.get_session() - available_regions = session.get_available_regions(config[CONF_SERVICE]) - if config[CONF_REGION] not in available_regions: - raise vol.Invalid( - "Region {} is not available for {} service, mustin {}".format( - config[CONF_REGION], config[CONF_SERVICE], available_regions - ) - ) - return config - - -PLATFORM_SCHEMA = vol.Schema( - vol.All( - PLATFORM_SCHEMA.extend( - { - # override notify.PLATFORM_SCHEMA.CONF_PLATFORM to Optional - # we don't need this field when we use discovery - vol.Optional(CONF_PLATFORM): cv.string, - vol.Required(CONF_SERVICE): vol.All( - cv.string, vol.Lower, vol.In(SUPPORTED_SERVICES) - ), - vol.Required(CONF_REGION): vol.All(cv.string, vol.Lower), - vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string, - vol.Inclusive( - CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS - ): cv.string, - vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string, - vol.Exclusive( - CONF_CREDENTIAL_NAME, ATTR_CREDENTIALS - ): cv.string, - vol.Optional(CONF_CONTEXT): vol.Coerce(dict), - }, - extra=vol.PREVENT_EXTRA, - ), - _in_avilable_region, + # get_available_regions is not a coroutine since it does not perform + # network I/O. But it still perform file I/O heavily, so put it into + # an executor thread to unblock event loop + return await hass.async_add_executor_job( + session.get_available_regions, service ) -) async def async_get_service(hass, config, discovery_info=None): """Get the AWS notification service.""" + if discovery_info is None: + _LOGGER.error('Please config aws notify platform in aws component') + return None + import aiobotocore session = None - if discovery_info is not None: - conf = discovery_info - else: - conf = config + conf = discovery_info service = conf[CONF_SERVICE] region_name = conf[CONF_REGION] + available_regions = await get_available_regions(hass, service) + if region_name not in available_regions: + _LOGGER.error( + "Region %s is not available for %s service, must in %s", + region_name, service, available_regions + ) + return None + aws_config = conf.copy() del aws_config[CONF_SERVICE] @@ -106,13 +74,14 @@ async def async_get_service(hass, config, discovery_info=None): del aws_config[CONF_CONTEXT] if not aws_config: - # no platform config, use aws component config instead + # no platform config, use the first aws component credential instead if hass.data[DATA_SESSIONS]: - session = list(hass.data[DATA_SESSIONS].values())[0] + session = next(iter(hass.data[DATA_SESSIONS].values())) else: - raise ValueError( - "No available aws session for {}".format(config[CONF_NAME]) + _LOGGER.error( + "Missing aws credential for %s", config[CONF_NAME] ) + return None if session is None: credential_name = aws_config.get(CONF_CREDENTIAL_NAME) @@ -148,7 +117,8 @@ async def async_get_service(hass, config, discovery_info=None): if service == "sqs": return AWSSQS(session, aws_config) - raise ValueError("Unsupported service {}".format(service)) + # should not reach here since service was checked in schema + return None class AWSNotify(BaseNotificationService): @@ -159,17 +129,6 @@ class AWSNotify(BaseNotificationService): self.session = session self.aws_config = aws_config - def send_message(self, message, **kwargs): - """Send notification.""" - raise NotImplementedError("Please call async_send_message()") - - async def async_send_message(self, message="", **kwargs): - """Send notification.""" - targets = kwargs.get(ATTR_TARGET) - - if not targets: - raise HomeAssistantError("At least one target is required") - class AWSLambda(AWSNotify): """Implement the notification service for the AWS Lambda service.""" @@ -183,9 +142,11 @@ class AWSLambda(AWSNotify): async def async_send_message(self, message="", **kwargs): """Send notification to specified LAMBDA ARN.""" - await super().async_send_message(message, **kwargs) + if not kwargs.get(ATTR_TARGET): + _LOGGER.error("At least one target is required") + return - cleaned_kwargs = dict((k, v) for k, v in kwargs.items() if v) + cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None} payload = {"message": message} payload.update(cleaned_kwargs) json_payload = json.dumps(payload) @@ -214,12 +175,14 @@ class AWSSNS(AWSNotify): async def async_send_message(self, message="", **kwargs): """Send notification to specified SNS ARN.""" - await super().async_send_message(message, **kwargs) + if not kwargs.get(ATTR_TARGET): + _LOGGER.error("At least one target is required") + return message_attributes = { k: {"StringValue": json.dumps(v), "DataType": "String"} for k, v in kwargs.items() - if v + if v is not None } subject = kwargs.get(ATTR_TITLE, ATTR_TITLE_DEFAULT) @@ -248,9 +211,11 @@ class AWSSQS(AWSNotify): async def async_send_message(self, message="", **kwargs): """Send notification to specified SQS ARN.""" - await super().async_send_message(message, **kwargs) + if not kwargs.get(ATTR_TARGET): + _LOGGER.error("At least one target is required") + return - cleaned_kwargs = dict((k, v) for k, v in kwargs.items() if v) + cleaned_kwargs = {k: v for k, v in kwargs.items() if v is not None} message_body = {"message": message} message_body.update(cleaned_kwargs) json_body = json.dumps(message_body)