From 3b34594aa3097bfe39070151d06f622160ea9754 Mon Sep 17 00:00:00 2001 From: escoand Date: Fri, 15 Mar 2019 08:43:54 +0100 Subject: [PATCH] Add HTTP auth and SSL verification to REST notify (#22016) * add HTTP auth and SSL verification * use internal import * fix long line * avoid extra import --- homeassistant/components/notify/rest.py | 43 ++++++++++++++++++++----- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/notify/rest.py b/homeassistant/components/notify/rest.py index df25045c6ec..bc4341d7b47 100644 --- a/homeassistant/components/notify/rest.py +++ b/homeassistant/components/notify/rest.py @@ -12,8 +12,11 @@ import voluptuous as vol from homeassistant.components.notify import ( ATTR_TARGET, ATTR_TITLE, ATTR_TITLE_DEFAULT, BaseNotificationService, PLATFORM_SCHEMA) -from homeassistant.const import (CONF_RESOURCE, CONF_METHOD, CONF_NAME, - CONF_HEADERS) +from homeassistant.const import (CONF_AUTHENTICATION, CONF_HEADERS, + CONF_METHOD, CONF_NAME, CONF_PASSWORD, + CONF_RESOURCE, CONF_USERNAME, CONF_VERIFY_SSL, + HTTP_BASIC_AUTHENTICATION, + HTTP_DIGEST_AUTHENTICATION) import homeassistant.helpers.config_validation as cv CONF_DATA = 'data' @@ -23,6 +26,7 @@ CONF_TARGET_PARAMETER_NAME = 'target_param_name' CONF_TITLE_PARAMETER_NAME = 'title_param_name' DEFAULT_MESSAGE_PARAM_NAME = 'message' DEFAULT_METHOD = 'GET' +DEFAULT_VERIFY_SSL = True PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_RESOURCE): cv.url, @@ -35,7 +39,12 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Optional(CONF_TARGET_PARAMETER_NAME): cv.string, vol.Optional(CONF_TITLE_PARAMETER_NAME): cv.string, vol.Optional(CONF_DATA): dict, - vol.Optional(CONF_DATA_TEMPLATE): {cv.match_all: cv.template_complex} + vol.Optional(CONF_DATA_TEMPLATE): {cv.match_all: cv.template_complex}, + vol.Optional(CONF_AUTHENTICATION): + vol.In([HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION]), + vol.Optional(CONF_PASSWORD): cv.string, + vol.Optional(CONF_USERNAME): cv.string, + vol.Optional(CONF_VERIFY_SSL, default=DEFAULT_VERIFY_SSL): cv.boolean, }) _LOGGER = logging.getLogger(__name__) @@ -51,17 +60,30 @@ def get_service(hass, config, discovery_info=None): target_param_name = config.get(CONF_TARGET_PARAMETER_NAME) data = config.get(CONF_DATA) data_template = config.get(CONF_DATA_TEMPLATE) + username = config.get(CONF_USERNAME) + password = config.get(CONF_PASSWORD) + verify_ssl = config.get(CONF_VERIFY_SSL) + + if username and password: + if config.get(CONF_AUTHENTICATION) == HTTP_DIGEST_AUTHENTICATION: + auth = requests.auth.HTTPDigestAuth(username, password) + else: + auth = requests.auth.HTTPBasicAuth(username, password) + else: + auth = None return RestNotificationService( hass, resource, method, headers, message_param_name, - title_param_name, target_param_name, data, data_template) + title_param_name, target_param_name, data, data_template, auth, + verify_ssl) class RestNotificationService(BaseNotificationService): """Implementation of a notification service for REST.""" def __init__(self, hass, resource, method, headers, message_param_name, - title_param_name, target_param_name, data, data_template): + title_param_name, target_param_name, data, data_template, + auth, verify_ssl): """Initialize the service.""" self._resource = resource self._hass = hass @@ -72,6 +94,8 @@ class RestNotificationService(BaseNotificationService): self._target_param_name = target_param_name self._data = data self._data_template = data_template + self._auth = auth + self._verify_ssl = verify_ssl def send_message(self, message="", **kwargs): """Send a message to a user.""" @@ -104,13 +128,16 @@ class RestNotificationService(BaseNotificationService): if self._method == 'POST': response = requests.post(self._resource, headers=self._headers, - data=data, timeout=10) + data=data, timeout=10, auth=self._auth, + verify=self._verify_ssl) elif self._method == 'POST_JSON': response = requests.post(self._resource, headers=self._headers, - json=data, timeout=10) + json=data, timeout=10, auth=self._auth, + verify=self._verify_ssl) else: # default GET response = requests.get(self._resource, headers=self._headers, - params=data, timeout=10) + params=data, timeout=10, auth=self._auth, + verify=self._verify_ssl) success_codes = (200, 201, 202, 203, 204, 205, 206, 207, 208, 226) if response.status_code not in success_codes: