From 32ffd006facea907b2013da19336cf81bd16e6d3 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 25 Nov 2016 13:04:06 -0800 Subject: [PATCH] Reorganize HTTP component (#4575) * Move HTTP to own folder * Break HTTP into middlewares * Lint * Split tests per middleware * Clean up HTTP tests * Make HomeAssistantViews more stateless * Lint * Make HTTP setup async --- homeassistant/components/alexa.py | 6 +- homeassistant/components/api.py | 55 +- homeassistant/components/camera/__init__.py | 11 +- .../components/device_tracker/gpslogger.py | 12 +- .../components/device_tracker/locative.py | 17 +- homeassistant/components/emulated_hue.py | 46 +- homeassistant/components/foursquare.py | 8 +- homeassistant/components/frontend/__init__.py | 37 +- homeassistant/components/history.py | 15 +- homeassistant/components/http.py | 641 ------------------ homeassistant/components/http/__init__.py | 407 +++++++++++ homeassistant/components/http/auth.py | 61 ++ homeassistant/components/http/ban.py | 132 ++++ homeassistant/components/http/const.py | 12 + homeassistant/components/http/static.py | 93 +++ homeassistant/components/http/util.py | 25 + homeassistant/components/ios.py | 12 +- homeassistant/components/logbook.py | 8 +- .../components/media_player/__init__.py | 11 +- homeassistant/components/notify/html5.py | 12 +- homeassistant/components/sensor/fitbit.py | 13 +- homeassistant/components/sensor/torque.py | 8 +- homeassistant/components/switch/netio.py | 3 +- homeassistant/const.py | 1 - homeassistant/util/logging.py | 17 + tests/common.py | 17 +- tests/components/camera/test_generic.py | 2 +- tests/components/http/__init__.py | 1 + tests/components/http/test_auth.py | 169 +++++ tests/components/http/test_ban.py | 118 ++++ tests/components/http/test_init.py | 111 +++ tests/components/notify/test_html5.py | 25 +- tests/components/test_frontend.py | 3 +- tests/components/test_http.py | 285 -------- tests/scripts/test_check_config.py | 8 + 35 files changed, 1318 insertions(+), 1084 deletions(-) delete mode 100644 homeassistant/components/http.py create mode 100644 homeassistant/components/http/__init__.py create mode 100644 homeassistant/components/http/auth.py create mode 100644 homeassistant/components/http/ban.py create mode 100644 homeassistant/components/http/const.py create mode 100644 homeassistant/components/http/static.py create mode 100644 homeassistant/components/http/util.py create mode 100644 homeassistant/util/logging.py create mode 100644 tests/components/http/__init__.py create mode 100644 tests/components/http/test_auth.py create mode 100644 tests/components/http/test_ban.py create mode 100644 tests/components/http/test_init.py delete mode 100644 tests/components/test_http.py diff --git a/homeassistant/components/alexa.py b/homeassistant/components/alexa.py index 72c0b2a8705..9bd0d783fee 100644 --- a/homeassistant/components/alexa.py +++ b/homeassistant/components/alexa.py @@ -118,7 +118,7 @@ class AlexaIntentsView(HomeAssistantView): def __init__(self, hass, intents): """Initialize Alexa view.""" - super().__init__(hass) + super().__init__() intents = copy.deepcopy(intents) template.attach(hass, intents) @@ -150,7 +150,7 @@ class AlexaIntentsView(HomeAssistantView): return None intent = req.get('intent') - response = AlexaResponse(self.hass, intent) + response = AlexaResponse(request.app['hass'], intent) if req_type == 'LaunchRequest': response.add_speech( @@ -282,7 +282,7 @@ class AlexaFlashBriefingView(HomeAssistantView): def __init__(self, hass, flash_briefings): """Initialize Alexa view.""" - super().__init__(hass) + super().__init__() self.flash_briefings = copy.deepcopy(flash_briefings) template.attach(hass, self.flash_briefings) diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index ae5e1de7c1b..da8ad9f88ba 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -77,8 +77,10 @@ class APIEventStream(HomeAssistantView): @asyncio.coroutine def get(self, request): """Provide a streaming interface for the event bus.""" + # pylint: disable=no-self-use + hass = request.app['hass'] stop_obj = object() - to_write = asyncio.Queue(loop=self.hass.loop) + to_write = asyncio.Queue(loop=hass.loop) restrict = request.GET.get('restrict') if restrict: @@ -106,7 +108,7 @@ class APIEventStream(HomeAssistantView): response.content_type = 'text/event-stream' yield from response.prepare(request) - unsub_stream = self.hass.bus.async_listen(MATCH_ALL, forward_events) + unsub_stream = hass.bus.async_listen(MATCH_ALL, forward_events) try: _LOGGER.debug('STREAM %s ATTACHED', id(stop_obj)) @@ -117,7 +119,7 @@ class APIEventStream(HomeAssistantView): while True: try: with async_timeout.timeout(STREAM_PING_INTERVAL, - loop=self.hass.loop): + loop=hass.loop): payload = yield from to_write.get() if payload is stop_obj: @@ -145,7 +147,7 @@ class APIConfigView(HomeAssistantView): @ha.callback def get(self, request): """Get current configuration.""" - return self.json(self.hass.config.as_dict()) + return self.json(request.app['hass'].config.as_dict()) class APIDiscoveryView(HomeAssistantView): @@ -158,10 +160,11 @@ class APIDiscoveryView(HomeAssistantView): @ha.callback def get(self, request): """Get discovery info.""" - needs_auth = self.hass.config.api.api_password is not None + hass = request.app['hass'] + needs_auth = hass.config.api.api_password is not None return self.json({ - 'base_url': self.hass.config.api.base_url, - 'location_name': self.hass.config.location_name, + 'base_url': hass.config.api.base_url, + 'location_name': hass.config.location_name, 'requires_api_password': needs_auth, 'version': __version__ }) @@ -176,7 +179,7 @@ class APIStatesView(HomeAssistantView): @ha.callback def get(self, request): """Get current states.""" - return self.json(self.hass.states.async_all()) + return self.json(request.app['hass'].states.async_all()) class APIEntityStateView(HomeAssistantView): @@ -188,7 +191,7 @@ class APIEntityStateView(HomeAssistantView): @ha.callback def get(self, request, entity_id): """Retrieve state of entity.""" - state = self.hass.states.get(entity_id) + state = request.app['hass'].states.get(entity_id) if state: return self.json(state) else: @@ -197,6 +200,7 @@ class APIEntityStateView(HomeAssistantView): @asyncio.coroutine def post(self, request, entity_id): """Update state of entity.""" + hass = request.app['hass'] try: data = yield from request.json() except ValueError: @@ -211,15 +215,14 @@ class APIEntityStateView(HomeAssistantView): attributes = data.get('attributes') force_update = data.get('force_update', False) - is_new_state = self.hass.states.get(entity_id) is None + is_new_state = hass.states.get(entity_id) is None # Write state - self.hass.states.async_set(entity_id, new_state, attributes, - force_update) + hass.states.async_set(entity_id, new_state, attributes, force_update) # Read the state back for our response status_code = HTTP_CREATED if is_new_state else 200 - resp = self.json(self.hass.states.get(entity_id), status_code) + resp = self.json(hass.states.get(entity_id), status_code) resp.headers.add('Location', URL_API_STATES_ENTITY.format(entity_id)) @@ -228,7 +231,7 @@ class APIEntityStateView(HomeAssistantView): @ha.callback def delete(self, request, entity_id): """Remove entity.""" - if self.hass.states.async_remove(entity_id): + if request.app['hass'].states.async_remove(entity_id): return self.json_message('Entity removed') else: return self.json_message('Entity not found', HTTP_NOT_FOUND) @@ -243,7 +246,7 @@ class APIEventListenersView(HomeAssistantView): @ha.callback def get(self, request): """Get event listeners.""" - return self.json(async_events_json(self.hass)) + return self.json(async_events_json(request.app['hass'])) class APIEventView(HomeAssistantView): @@ -271,7 +274,8 @@ class APIEventView(HomeAssistantView): if state: event_data[key] = state - self.hass.bus.async_fire(event_type, event_data, ha.EventOrigin.remote) + request.app['hass'].bus.async_fire(event_type, event_data, + ha.EventOrigin.remote) return self.json_message("Event {} fired.".format(event_type)) @@ -285,7 +289,7 @@ class APIServicesView(HomeAssistantView): @ha.callback def get(self, request): """Get registered services.""" - return self.json(async_services_json(self.hass)) + return self.json(async_services_json(request.app['hass'])) class APIDomainServicesView(HomeAssistantView): @@ -300,12 +304,12 @@ class APIDomainServicesView(HomeAssistantView): Returns a list of changed states. """ + hass = request.app['hass'] body = yield from request.text() data = json.loads(body) if body else None - with AsyncTrackStates(self.hass) as changed_states: - yield from self.hass.services.async_call(domain, service, data, - True) + with AsyncTrackStates(hass) as changed_states: + yield from hass.services.async_call(domain, service, data, True) return self.json(changed_states) @@ -320,6 +324,7 @@ class APIEventForwardingView(HomeAssistantView): @asyncio.coroutine def post(self, request): """Setup an event forwarder.""" + hass = request.app['hass'] try: data = yield from request.json() except ValueError: @@ -340,14 +345,14 @@ class APIEventForwardingView(HomeAssistantView): api = rem.API(host, api_password, port) - valid = yield from self.hass.loop.run_in_executor( + valid = yield from hass.loop.run_in_executor( None, api.validate_api) if not valid: return self.json_message("Unable to validate API.", HTTP_UNPROCESSABLE_ENTITY) if self.event_forwarder is None: - self.event_forwarder = rem.EventForwarder(self.hass) + self.event_forwarder = rem.EventForwarder(hass) self.event_forwarder.async_connect(api) @@ -389,7 +394,7 @@ class APIComponentsView(HomeAssistantView): @ha.callback def get(self, request): """Get current loaded components.""" - return self.json(self.hass.config.components) + return self.json(request.app['hass'].config.components) class APIErrorLogView(HomeAssistantView): @@ -402,7 +407,7 @@ class APIErrorLogView(HomeAssistantView): def get(self, request): """Serve error log.""" resp = yield from self.file( - request, self.hass.config.path(ERROR_LOG_FILENAME)) + request, request.app['hass'].config.path(ERROR_LOG_FILENAME)) return resp @@ -417,7 +422,7 @@ class APITemplateView(HomeAssistantView): """Render a template.""" try: data = yield from request.json() - tpl = template.Template(data['template'], self.hass) + tpl = template.Template(data['template'], request.app['hass']) return tpl.async_render(data.get('variables')) except (ValueError, TemplateError) as ex: return self.json_message('Error rendering template: {}'.format(ex), diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index 6724598419f..427d4535ef6 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -13,7 +13,7 @@ from aiohttp import web from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa -from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED DOMAIN = 'camera' DEPENDENCIES = ['http'] @@ -33,8 +33,8 @@ def async_setup(hass, config): component = EntityComponent( logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) - hass.http.register_view(CameraImageView(hass, component.entities)) - hass.http.register_view(CameraMjpegStream(hass, component.entities)) + hass.http.register_view(CameraImageView(component.entities)) + hass.http.register_view(CameraMjpegStream(component.entities)) yield from component.async_setup(config) return True @@ -165,9 +165,8 @@ class CameraView(HomeAssistantView): requires_auth = False - def __init__(self, hass, entities): + def __init__(self, entities): """Initialize a basic camera view.""" - super().__init__(hass) self.entities = entities @asyncio.coroutine @@ -178,7 +177,7 @@ class CameraView(HomeAssistantView): if camera is None: return web.Response(status=404) - authenticated = (request.authenticated or + authenticated = (request[KEY_AUTHENTICATED] or request.GET.get('token') == camera.access_token) if not authenticated: diff --git a/homeassistant/components/device_tracker/gpslogger.py b/homeassistant/components/device_tracker/gpslogger.py index 462ae16300c..2e897ccb10c 100644 --- a/homeassistant/components/device_tracker/gpslogger.py +++ b/homeassistant/components/device_tracker/gpslogger.py @@ -21,7 +21,7 @@ DEPENDENCIES = ['http'] def setup_scanner(hass, config, see): """Setup an endpoint for the GPSLogger application.""" - hass.http.register_view(GPSLoggerView(hass, see)) + hass.http.register_view(GPSLoggerView(see)) return True @@ -32,20 +32,18 @@ class GPSLoggerView(HomeAssistantView): url = '/api/gpslogger' name = 'api:gpslogger' - def __init__(self, hass, see): + def __init__(self, see): """Initialize GPSLogger url endpoints.""" - super().__init__(hass) self.see = see @asyncio.coroutine def get(self, request): """A GPSLogger message received as GET.""" - res = yield from self._handle(request.GET) + res = yield from self._handle(request.app['hass'], request.GET) return res @asyncio.coroutine - # pylint: disable=too-many-return-statements - def _handle(self, data): + def _handle(self, hass, data): """Handle gpslogger request.""" if 'latitude' not in data or 'longitude' not in data: return ('Latitude and longitude not specified.', @@ -66,7 +64,7 @@ class GPSLoggerView(HomeAssistantView): if 'battery' in data: battery = float(data['battery']) - yield from self.hass.loop.run_in_executor( + yield from hass.loop.run_in_executor( None, partial(self.see, dev_id=device, gps=gps_location, battery=battery, gps_accuracy=accuracy)) diff --git a/homeassistant/components/device_tracker/locative.py b/homeassistant/components/device_tracker/locative.py index e6bd74e57c9..10641b3a921 100644 --- a/homeassistant/components/device_tracker/locative.py +++ b/homeassistant/components/device_tracker/locative.py @@ -23,7 +23,7 @@ DEPENDENCIES = ['http'] def setup_scanner(hass, config, see): """Setup an endpoint for the Locative application.""" - hass.http.register_view(LocativeView(hass, see)) + hass.http.register_view(LocativeView(see)) return True @@ -34,27 +34,26 @@ class LocativeView(HomeAssistantView): url = '/api/locative' name = 'api:locative' - def __init__(self, hass, see): + def __init__(self, see): """Initialize Locative url endpoints.""" - super().__init__(hass) self.see = see @asyncio.coroutine def get(self, request): """Locative message received as GET.""" - res = yield from self._handle(request.GET) + res = yield from self._handle(request.app['hass'], request.GET) return res @asyncio.coroutine def post(self, request): """Locative message received.""" data = yield from request.post() - res = yield from self._handle(data) + res = yield from self._handle(request.app['hass'], data) return res @asyncio.coroutine # pylint: disable=too-many-return-statements - def _handle(self, data): + def _handle(self, hass, data): """Handle locative request.""" if 'latitude' not in data or 'longitude' not in data: return ('Latitude and longitude not specified.', @@ -81,19 +80,19 @@ class LocativeView(HomeAssistantView): gps_location = (data[ATTR_LATITUDE], data[ATTR_LONGITUDE]) if direction == 'enter': - yield from self.hass.loop.run_in_executor( + yield from hass.loop.run_in_executor( None, partial(self.see, dev_id=device, location_name=location_name, gps=gps_location)) return 'Setting location to {}'.format(location_name) elif direction == 'exit': - current_state = self.hass.states.get( + current_state = hass.states.get( '{}.{}'.format(DOMAIN, device)) if current_state is None or current_state.state == location_name: location_name = STATE_NOT_HOME - yield from self.hass.loop.run_in_executor( + yield from hass.loop.run_in_executor( None, partial(self.see, dev_id=device, location_name=location_name, gps=gps_location)) diff --git a/homeassistant/components/emulated_hue.py b/homeassistant/components/emulated_hue.py index afb5c63918c..dcb6bcb64b2 100644 --- a/homeassistant/components/emulated_hue.py +++ b/homeassistant/components/emulated_hue.py @@ -78,14 +78,13 @@ def setup(hass, yaml_config): cors_origins=None, use_x_forwarded_for=False, trusted_networks=None, - ip_bans=None, login_threshold=0, is_ban_enabled=False ) - server.register_view(DescriptionXmlView(hass, config)) - server.register_view(HueUsernameView(hass)) - server.register_view(HueLightsView(hass, config)) + server.register_view(DescriptionXmlView(config)) + server.register_view(HueUsernameView) + server.register_view(HueLightsView(config)) upnp_listener = UPNPResponderThread( config.host_ip_addr, config.listen_port) @@ -157,9 +156,8 @@ class DescriptionXmlView(HomeAssistantView): name = 'description:xml' requires_auth = False - def __init__(self, hass, config): + def __init__(self, config): """Initialize the instance of the view.""" - super().__init__(hass) self.config = config @core.callback @@ -201,10 +199,6 @@ class HueUsernameView(HomeAssistantView): extra_urls = ['/api/'] requires_auth = False - def __init__(self, hass): - """Initialize the instance of the view.""" - super().__init__(hass) - @asyncio.coroutine def post(self, request): """Handle a POST request.""" @@ -229,30 +223,33 @@ class HueLightsView(HomeAssistantView): '/api/{username}/lights/{entity_id}/state'] requires_auth = False - def __init__(self, hass, config): + def __init__(self, config): """Initialize the instance of the view.""" - super().__init__(hass) self.config = config self.cached_states = {} @core.callback def get(self, request, username, entity_id=None): """Handle a GET request.""" + hass = request.app['hass'] + if entity_id is None: - return self.async_get_lights_list() + return self.async_get_lights_list(hass) if not request.path.endswith('state'): - return self.async_get_light_state(entity_id) + return self.async_get_light_state(hass, entity_id) return web.Response(text="Method not allowed", status=405) @asyncio.coroutine def put(self, request, username, entity_id=None): """Handle a PUT request.""" + hass = request.app['hass'] + if not request.path.endswith('state'): return web.Response(text="Method not allowed", status=405) - if entity_id and self.hass.states.get(entity_id) is None: + if entity_id and hass.states.get(entity_id) is None: return self.json_message('Entity not found', HTTP_NOT_FOUND) try: @@ -260,24 +257,25 @@ class HueLightsView(HomeAssistantView): except ValueError: return self.json_message('Invalid JSON', HTTP_BAD_REQUEST) - result = yield from self.async_put_light_state(json_data, entity_id) + result = yield from self.async_put_light_state(hass, json_data, + entity_id) return result @core.callback - def async_get_lights_list(self): + def async_get_lights_list(self, hass): """Process a request to get the list of available lights.""" json_response = {} - for entity in self.hass.states.async_all(): + for entity in hass.states.async_all(): if self.is_entity_exposed(entity): json_response[entity.entity_id] = entity_to_json(entity) return self.json(json_response) @core.callback - def async_get_light_state(self, entity_id): + def async_get_light_state(self, hass, entity_id): """Process a request to get the state of an individual light.""" - entity = self.hass.states.get(entity_id) + entity = hass.states.get(entity_id) if entity is None or not self.is_entity_exposed(entity): return web.Response(text="Entity not found", status=404) @@ -295,12 +293,12 @@ class HueLightsView(HomeAssistantView): return self.json(json_response) @asyncio.coroutine - def async_put_light_state(self, request_json, entity_id): + def async_put_light_state(self, hass, request_json, entity_id): """Process a request to set the state of an individual light.""" config = self.config # Retrieve the entity from the state machine - entity = self.hass.states.get(entity_id) + entity = hass.states.get(entity_id) if entity is None: return web.Response(text="Entity not found", status=404) @@ -345,8 +343,8 @@ class HueLightsView(HomeAssistantView): self.cached_states[entity_id] = (result, brightness) # Perform the requested action - yield from self.hass.services.async_call(core.DOMAIN, service, data, - blocking=True) + yield from hass.services.async_call(core.DOMAIN, service, data, + blocking=True) json_response = \ [create_hue_success_response(entity_id, HUE_API_STATE_ON, result)] diff --git a/homeassistant/components/foursquare.py b/homeassistant/components/foursquare.py index bb4c66ad1f9..2afa808b502 100644 --- a/homeassistant/components/foursquare.py +++ b/homeassistant/components/foursquare.py @@ -75,8 +75,7 @@ def setup(hass, config): descriptions[DOMAIN][SERVICE_CHECKIN], schema=CHECKIN_SERVICE_SCHEMA) - hass.http.register_view(FoursquarePushReceiver( - hass, config[CONF_PUSH_SECRET])) + hass.http.register_view(FoursquarePushReceiver(config[CONF_PUSH_SECRET])) return True @@ -88,9 +87,8 @@ class FoursquarePushReceiver(HomeAssistantView): url = "/api/foursquare" name = "foursquare" - def __init__(self, hass, push_secret): + def __init__(self, push_secret): """Initialize the OAuth callback view.""" - super().__init__(hass) self.push_secret = push_secret @asyncio.coroutine @@ -110,4 +108,4 @@ class FoursquarePushReceiver(HomeAssistantView): "push secret: %s", secret) return self.json_message('Incorrect secret', HTTP_BAD_REQUEST) - self.hass.bus.async_fire(EVENT_PUSH, data) + request.app['hass'].bus.async_fire(EVENT_PUSH, data) diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index 6fde1ae388a..e19e5f6edec 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -11,6 +11,8 @@ from homeassistant.core import callback from homeassistant.const import HTTP_NOT_FOUND from homeassistant.components import api, group from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http.auth import is_trusted_ip +from homeassistant.components.http.const import KEY_DEVELOPMENT from .version import FINGERPRINTS DOMAIN = 'frontend' @@ -155,7 +157,7 @@ def setup(hass, config): if os.path.isdir(local): hass.http.register_static_path("/local", local) - index_view = hass.data[DATA_INDEX_VIEW] = IndexView(hass) + index_view = hass.data[DATA_INDEX_VIEW] = IndexView() hass.http.register_view(index_view) # Components have registered panels before frontend got setup. @@ -185,12 +187,14 @@ class BootstrapView(HomeAssistantView): @callback def get(self, request): """Return all data needed to bootstrap Home Assistant.""" + hass = request.app['hass'] + return self.json({ - 'config': self.hass.config.as_dict(), - 'states': self.hass.states.async_all(), - 'events': api.async_events_json(self.hass), - 'services': api.async_services_json(self.hass), - 'panels': self.hass.data[DATA_PANELS], + 'config': hass.config.as_dict(), + 'states': hass.states.async_all(), + 'events': api.async_events_json(hass), + 'services': api.async_services_json(hass), + 'panels': hass.data[DATA_PANELS], }) @@ -202,10 +206,8 @@ class IndexView(HomeAssistantView): requires_auth = False extra_urls = ['/states', '/states/{entity_id}'] - def __init__(self, hass): + def __init__(self): """Initialize the frontend view.""" - super().__init__(hass) - from jinja2 import FileSystemLoader, Environment self.templates = Environment( @@ -217,14 +219,16 @@ class IndexView(HomeAssistantView): @asyncio.coroutine def get(self, request, entity_id=None): """Serve the index view.""" + hass = request.app['hass'] + if entity_id is not None: - state = self.hass.states.get(entity_id) + state = hass.states.get(entity_id) if (not state or state.domain != 'group' or not state.attributes.get(group.ATTR_VIEW)): return self.json_message('Entity not found', HTTP_NOT_FOUND) - if self.hass.http.development: + if request.app[KEY_DEVELOPMENT]: core_url = '/static/home-assistant-polymer/build/core.js' ui_url = '/static/home-assistant-polymer/src/home-assistant.html' else: @@ -241,19 +245,18 @@ class IndexView(HomeAssistantView): if panel == 'states': panel_url = '' else: - panel_url = self.hass.data[DATA_PANELS][panel]['url'] + panel_url = hass.data[DATA_PANELS][panel]['url'] no_auth = 'true' - if self.hass.config.api.api_password: + if hass.config.api.api_password: # require password if set no_auth = 'false' - if self.hass.http.is_trusted_ip( - self.hass.http.get_real_ip(request)): + if is_trusted_ip(request): # bypass for trusted networks no_auth = 'true' icons_url = '/static/mdi-{}.html'.format(FINGERPRINTS['mdi.html']) - template = yield from self.hass.loop.run_in_executor( + template = yield from hass.loop.run_in_executor( None, self.templates.get_template, 'index.html') # pylint is wrong @@ -262,7 +265,7 @@ class IndexView(HomeAssistantView): resp = template.render( core_url=core_url, ui_url=ui_url, no_auth=no_auth, icons_url=icons_url, icons=FINGERPRINTS['mdi.html'], - panel_url=panel_url, panels=self.hass.data[DATA_PANELS]) + panel_url=panel_url, panels=hass.data[DATA_PANELS]) return web.Response(text=resp, content_type='text/html') diff --git a/homeassistant/components/history.py b/homeassistant/components/history.py index c3dd0bd3f5a..eee0570c9bc 100644 --- a/homeassistant/components/history.py +++ b/homeassistant/components/history.py @@ -184,8 +184,8 @@ def setup(hass, config): filters.included_entities = include[CONF_ENTITIES] filters.included_domains = include[CONF_DOMAINS] - hass.http.register_view(Last5StatesView(hass)) - hass.http.register_view(HistoryPeriodView(hass, filters)) + hass.http.register_view(Last5StatesView) + hass.http.register_view(HistoryPeriodView(filters)) register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box') return True @@ -197,14 +197,10 @@ class Last5StatesView(HomeAssistantView): url = '/api/history/entity/{entity_id}/recent_states' name = 'api:history:entity-recent-states' - def __init__(self, hass): - """Initilalize the history last 5 states view.""" - super().__init__(hass) - @asyncio.coroutine def get(self, request, entity_id): """Retrieve last 5 states of entity.""" - result = yield from self.hass.loop.run_in_executor( + result = yield from request.app['hass'].loop.run_in_executor( None, last_5_states, entity_id) return self.json(result) @@ -216,9 +212,8 @@ class HistoryPeriodView(HomeAssistantView): name = 'api:history:view-period' extra_urls = ['/api/history/period/{datetime}'] - def __init__(self, hass, filters): + def __init__(self, filters): """Initilalize the history period view.""" - super().__init__(hass) self.filters = filters @asyncio.coroutine @@ -240,7 +235,7 @@ class HistoryPeriodView(HomeAssistantView): end_time = start_time + one_day entity_id = request.GET.get('filter_entity_id') - result = yield from self.hass.loop.run_in_executor( + result = yield from request.app['hass'].loop.run_in_executor( None, get_significant_states, start_time, end_time, entity_id, self.filters) diff --git a/homeassistant/components/http.py b/homeassistant/components/http.py deleted file mode 100644 index 054d8050599..00000000000 --- a/homeassistant/components/http.py +++ /dev/null @@ -1,641 +0,0 @@ -""" -This module provides WSGI application to serve the Home Assistant API. - -For more details about this component, please refer to the documentation at -https://home-assistant.io/components/http/ -""" -import asyncio -import json -import logging -import mimetypes -import ssl -from datetime import datetime -from ipaddress import ip_address, ip_network -from pathlib import Path - -import hmac -import os -import re -import voluptuous as vol -from aiohttp import web, hdrs -from aiohttp.file_sender import FileSender -from aiohttp.web_exceptions import ( - HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified, HTTPForbidden) -from aiohttp.web_urldispatcher import StaticResource - -import homeassistant.helpers.config_validation as cv -import homeassistant.remote as rem -from homeassistant import util -from homeassistant.components import persistent_notification -from homeassistant.config import load_yaml_config_file -from homeassistant.const import ( - SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL, - CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP, - EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR) -from homeassistant.core import is_callback -from homeassistant.exceptions import HomeAssistantError -from homeassistant.util.yaml import dump - -DOMAIN = 'http' -REQUIREMENTS = ('aiohttp_cors==0.5.0',) - -CONF_API_PASSWORD = 'api_password' -CONF_SERVER_HOST = 'server_host' -CONF_SERVER_PORT = 'server_port' -CONF_DEVELOPMENT = 'development' -CONF_SSL_CERTIFICATE = 'ssl_certificate' -CONF_SSL_KEY = 'ssl_key' -CONF_CORS_ORIGINS = 'cors_allowed_origins' -CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for' -CONF_TRUSTED_NETWORKS = 'trusted_networks' -CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold' -CONF_IP_BAN_ENABLED = 'ip_ban_enabled' - -DATA_API_PASSWORD = 'api_password' -NOTIFICATION_ID_LOGIN = 'http-login' -NOTIFICATION_ID_BAN = 'ip-ban' - -IP_BANS = 'ip_bans.yaml' -ATTR_BANNED_AT = "banned_at" - - -# TLS configuation follows the best-practice guidelines specified here: -# https://wiki.mozilla.org/Security/Server_Side_TLS -# Intermediate guidelines are followed. -SSL_VERSION = ssl.PROTOCOL_SSLv23 -SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 -if hasattr(ssl, 'OP_NO_COMPRESSION'): - SSL_OPTS |= ssl.OP_NO_COMPRESSION -CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \ - "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \ - "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \ - "DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \ - "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \ - "ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \ - "ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \ - "ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \ - "DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \ - "DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \ - "ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \ - "AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \ - "AES256-SHA:DES-CBC3-SHA:!DSS" - -_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE) - -_LOGGER = logging.getLogger(__name__) - -CONFIG_SCHEMA = vol.Schema({ - DOMAIN: vol.Schema({ - vol.Optional(CONF_API_PASSWORD): cv.string, - vol.Optional(CONF_SERVER_HOST): cv.string, - vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT): - vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)), - vol.Optional(CONF_DEVELOPMENT): cv.string, - vol.Optional(CONF_SSL_CERTIFICATE): cv.isfile, - vol.Optional(CONF_SSL_KEY): cv.isfile, - vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]), - vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean, - vol.Optional(CONF_TRUSTED_NETWORKS): - vol.All(cv.ensure_list, [ip_network]), - vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD): cv.positive_int, - vol.Optional(CONF_IP_BAN_ENABLED): cv.boolean - }), -}, extra=vol.ALLOW_EXTRA) - - -# TEMP TO GET TESTS TO RUN -def request_class(): - """.""" - raise Exception('not implemented') - - -class HideSensitiveFilter(logging.Filter): - """Filter API password calls.""" - - def __init__(self, hass): - """Initialize sensitive data filter.""" - super().__init__() - self.hass = hass - - def filter(self, record): - """Hide sensitive data in messages.""" - if self.hass.http.api_password is None: - return True - - record.msg = record.msg.replace(self.hass.http.api_password, '*******') - - return True - - -def setup(hass, config): - """Set up the HTTP API and debug interface.""" - logging.getLogger('aiohttp.access').addFilter(HideSensitiveFilter(hass)) - - conf = config.get(DOMAIN, {}) - - api_password = util.convert(conf.get(CONF_API_PASSWORD), str) - server_host = conf.get(CONF_SERVER_HOST, '0.0.0.0') - server_port = conf.get(CONF_SERVER_PORT, SERVER_PORT) - development = str(conf.get(CONF_DEVELOPMENT, '')) == '1' - ssl_certificate = conf.get(CONF_SSL_CERTIFICATE) - ssl_key = conf.get(CONF_SSL_KEY) - cors_origins = conf.get(CONF_CORS_ORIGINS, []) - use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False) - trusted_networks = [ - ip_network(trusted_network) - for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])] - is_ban_enabled = bool(conf.get(CONF_IP_BAN_ENABLED, False)) - login_threshold = int(conf.get(CONF_LOGIN_ATTEMPTS_THRESHOLD, -1)) - ip_bans = load_ip_bans_config(hass.config.path(IP_BANS)) - - server = HomeAssistantWSGI( - hass, - development=development, - server_host=server_host, - server_port=server_port, - api_password=api_password, - ssl_certificate=ssl_certificate, - ssl_key=ssl_key, - cors_origins=cors_origins, - use_x_forwarded_for=use_x_forwarded_for, - trusted_networks=trusted_networks, - ip_bans=ip_bans, - login_threshold=login_threshold, - is_ban_enabled=is_ban_enabled - ) - - @asyncio.coroutine - def stop_server(event): - """Callback to stop the server.""" - yield from server.stop() - - @asyncio.coroutine - def start_server(event): - """Callback to start the server.""" - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) - yield from server.start() - - hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_server) - - hass.http = server - hass.config.api = rem.API(server_host if server_host != '0.0.0.0' - else util.get_local_ip(), - api_password, server_port, - ssl_certificate is not None) - - return True - - -class GzipFileSender(FileSender): - """FileSender class capable of sending gzip version if available.""" - - # pylint: disable=invalid-name - - development = False - - @asyncio.coroutine - def send(self, request, filepath): - """Send filepath to client using request.""" - gzip = False - if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]: - gzip_path = filepath.with_name(filepath.name + '.gz') - - if gzip_path.is_file(): - filepath = gzip_path - gzip = True - - st = filepath.stat() - - modsince = request.if_modified_since - if modsince is not None and st.st_mtime <= modsince.timestamp(): - raise HTTPNotModified() - - ct, encoding = mimetypes.guess_type(str(filepath)) - if not ct: - ct = 'application/octet-stream' - - resp = self._response_factory() - resp.content_type = ct - if encoding: - resp.headers[hdrs.CONTENT_ENCODING] = encoding - if gzip: - resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING - resp.last_modified = st.st_mtime - - # CACHE HACK - if not self.development: - cache_time = 31 * 86400 # = 1 month - resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format( - cache_time) - - file_size = st.st_size - - resp.content_length = file_size - with filepath.open('rb') as f: - yield from self._sendfile(request, resp, f, file_size) - - return resp - - -_GZIP_FILE_SENDER = GzipFileSender() - - -@asyncio.coroutine -def staticresource_enhancer(app, handler): - """Enhance StaticResourceHandler. - - Adds gzip encoding and fingerprinting matching. - """ - inst = getattr(handler, '__self__', None) - if not isinstance(inst, StaticResource): - return handler - - # pylint: disable=protected-access - inst._file_sender = _GZIP_FILE_SENDER - - @asyncio.coroutine - def middleware_handler(request): - """Strip out fingerprints from resource names.""" - fingerprinted = _FINGERPRINT.match(request.match_info['filename']) - - if fingerprinted: - request.match_info['filename'] = \ - '{}.{}'.format(*fingerprinted.groups()) - - resp = yield from handler(request) - return resp - - return middleware_handler - - -class HomeAssistantWSGI(object): - """WSGI server for Home Assistant.""" - - def __init__(self, hass, development, api_password, ssl_certificate, - ssl_key, server_host, server_port, cors_origins, - use_x_forwarded_for, trusted_networks, - ip_bans, login_threshold, is_ban_enabled): - """Initialize the WSGI Home Assistant server.""" - import aiohttp_cors - - self.app = web.Application(middlewares=[staticresource_enhancer], - loop=hass.loop) - self.hass = hass - self.development = development - self.api_password = api_password - self.ssl_certificate = ssl_certificate - self.ssl_key = ssl_key - self.server_host = server_host - self.server_port = server_port - self.use_x_forwarded_for = use_x_forwarded_for - self.trusted_networks = trusted_networks \ - if trusted_networks is not None else [] - self.event_forwarder = None - self._handler = None - self.server = None - self.login_threshold = login_threshold - self.ip_bans = ip_bans if ip_bans is not None else [] - self.failed_login_attempts = {} - self.is_ban_enabled = is_ban_enabled - - if cors_origins: - self.cors = aiohttp_cors.setup(self.app, defaults={ - host: aiohttp_cors.ResourceOptions( - allow_headers=ALLOWED_CORS_HEADERS, - allow_methods='*', - ) for host in cors_origins - }) - else: - self.cors = None - - # CACHE HACK - _GZIP_FILE_SENDER.development = development - - def register_view(self, view): - """Register a view with the WSGI server. - - The view argument must be a class that inherits from HomeAssistantView. - It is optional to instantiate it before registering; this method will - handle it either way. - """ - if isinstance(view, type): - # Instantiate the view, if needed - view = view(self.hass) - - view.register(self.app.router) - - def register_redirect(self, url, redirect_to): - """Register a redirect with the server. - - If given this must be either a string or callable. In case of a - callable it's called with the url adapter that triggered the match and - the values of the URL as keyword arguments and has to return the target - for the redirect, otherwise it has to be a string with placeholders in - rule syntax. - """ - def redirect(request): - """Redirect to location.""" - raise HTTPMovedPermanently(redirect_to) - - self.app.router.add_route('GET', url, redirect) - - def register_static_path(self, url_root, path, cache_length=31): - """Register a folder to serve as a static path. - - Specify optional cache length of asset in days. - """ - if os.path.isdir(path): - self.app.router.add_static(url_root, path) - return - - filepath = Path(path) - - @asyncio.coroutine - def serve_file(request): - """Redirect to location.""" - res = yield from _GZIP_FILE_SENDER.send(request, filepath) - return res - - # aiohttp supports regex matching for variables. Using that as temp - # to work around cache busting MD5. - # Turns something like /static/dev-panel.html into - # /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html} - base, ext = url_root.rsplit('.', 1) - base, file = base.rsplit('/', 1) - regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext) - url_pattern = "{}/{{filename:{}}}".format(base, regex) - - self.app.router.add_route('GET', url_pattern, serve_file) - - @asyncio.coroutine - def start(self): - """Start the wsgi server.""" - if self.cors is not None: - for route in list(self.app.router.routes()): - self.cors.add(route) - - if self.ssl_certificate: - context = ssl.SSLContext(SSL_VERSION) - context.options |= SSL_OPTS - context.set_ciphers(CIPHERS) - context.load_cert_chain(self.ssl_certificate, self.ssl_key) - else: - context = None - - self._handler = self.app.make_handler() - self.server = yield from self.hass.loop.create_server( - self._handler, self.server_host, self.server_port, ssl=context) - - @asyncio.coroutine - def stop(self): - """Stop the wsgi server.""" - self.server.close() - yield from self.server.wait_closed() - yield from self.app.shutdown() - yield from self._handler.finish_connections(60.0) - yield from self.app.cleanup() - - def get_real_ip(self, request): - """Return the clients correct ip address, even in proxied setups.""" - if self.use_x_forwarded_for \ - and HTTP_HEADER_X_FORWARDED_FOR in request.headers: - return request.headers.get( - HTTP_HEADER_X_FORWARDED_FOR).split(',')[0] - else: - peername = request.transport.get_extra_info('peername') - return peername[0] if peername is not None else None - - def is_trusted_ip(self, remote_addr): - """Match an ip address against trusted CIDR networks.""" - return any(ip_address(remote_addr) in trusted_network - for trusted_network in self.hass.http.trusted_networks) - - def wrong_login_attempt(self, remote_addr): - """Registering wrong login attempt.""" - if not self.is_ban_enabled or self.login_threshold < 1: - return - - if remote_addr in self.failed_login_attempts: - self.failed_login_attempts[remote_addr] += 1 - else: - self.failed_login_attempts[remote_addr] = 1 - - if self.failed_login_attempts[remote_addr] > self.login_threshold: - new_ban = IpBan(remote_addr) - self.ip_bans.append(new_ban) - update_ip_bans_config(self.hass.config.path(IP_BANS), new_ban) - _LOGGER.warning('Banned IP %s for too many login attempts', - remote_addr) - persistent_notification.async_create( - self.hass, - 'Too many login attempts from {}'.format(remote_addr), - 'Banning IP address', NOTIFICATION_ID_BAN) - - def is_banned_ip(self, remote_addr): - """Check if IP address is in a ban list.""" - if not self.is_ban_enabled: - return False - - ip_address_ = ip_address(remote_addr) - for ip_ban in self.ip_bans: - if ip_ban.ip_address == ip_address_: - return True - - return False - - -class HomeAssistantView(object): - """Base view for all views.""" - - url = None - extra_urls = [] - requires_auth = True # Views inheriting from this class can override this - - def __init__(self, hass): - """Initilalize the base view.""" - if not hasattr(self, 'url'): - class_name = self.__class__.__name__ - raise AttributeError( - '{0} missing required attribute "url"'.format(class_name) - ) - - if not hasattr(self, 'name'): - class_name = self.__class__.__name__ - raise AttributeError( - '{0} missing required attribute "name"'.format(class_name) - ) - - self.hass = hass - - # pylint: disable=no-self-use - def json(self, result, status_code=200): - """Return a JSON response.""" - msg = json.dumps( - result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8') - return web.Response( - body=msg, content_type=CONTENT_TYPE_JSON, status=status_code) - - def json_message(self, error, status_code=200): - """Return a JSON message response.""" - return self.json({'message': error}, status_code) - - @asyncio.coroutine - # pylint: disable=no-self-use - def file(self, request, fil): - """Return a file.""" - assert isinstance(fil, str), 'only string paths allowed' - response = yield from _GZIP_FILE_SENDER.send(request, Path(fil)) - return response - - def register(self, router): - """Register the view with a router.""" - assert self.url is not None, 'No url set for view' - urls = [self.url] + self.extra_urls - - for method in ('get', 'post', 'delete', 'put'): - handler = getattr(self, method, None) - - if not handler: - continue - - handler = request_handler_factory(self, handler) - - for url in urls: - router.add_route(method, url, handler) - - # aiohttp_cors does not work with class based views - # self.app.router.add_route('*', self.url, self, name=self.name) - - # for url in self.extra_urls: - # self.app.router.add_route('*', url, self) - - -def request_handler_factory(view, handler): - """Factory to wrap our handler classes. - - Eventually authentication should be managed by middleware. - """ - @asyncio.coroutine - def handle(request): - """Handle incoming request.""" - if not view.hass.is_running: - return web.Response(status=503) - - remote_addr = view.hass.http.get_real_ip(request) - - if view.hass.http.is_banned_ip(remote_addr): - raise HTTPForbidden() - - # Auth code verbose on purpose - authenticated = False - - if view.hass.http.api_password is None: - authenticated = True - - elif view.hass.http.is_trusted_ip(remote_addr): - authenticated = True - - elif hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''), - view.hass.http.api_password): - # A valid auth header has been set - authenticated = True - - elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''), - view.hass.http.api_password): - authenticated = True - - if view.requires_auth and not authenticated: - view.hass.http.wrong_login_attempt(remote_addr) - _LOGGER.warning('Login attempt or request with an invalid ' - 'password from %s', remote_addr) - persistent_notification.async_create( - view.hass, - 'Invalid password used from {}'.format(remote_addr), - 'Login attempt failed', NOTIFICATION_ID_LOGIN) - raise HTTPUnauthorized() - - request.authenticated = authenticated - - _LOGGER.info('Serving %s to %s (auth: %s)', - request.path, remote_addr, authenticated) - - assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \ - "Handler should be a coroutine or a callback." - - result = handler(request, **request.match_info) - - if asyncio.iscoroutine(result): - result = yield from result - - if isinstance(result, web.StreamResponse): - # The method handler returned a ready-made Response, how nice of it - return result - - status_code = 200 - - if isinstance(result, tuple): - result, status_code = result - - if isinstance(result, str): - result = result.encode('utf-8') - elif result is None: - result = b'' - elif not isinstance(result, bytes): - assert False, ('Result should be None, string, bytes or Response. ' - 'Got: {}').format(result) - - return web.Response(body=result, status=status_code) - - return handle - - -class IpBan(object): - """Represents banned IP address.""" - - def __init__(self, ip_ban: str, banned_at: datetime=None) -> None: - """Initializing Ip Ban object.""" - self.ip_address = ip_address(ip_ban) - self.banned_at = banned_at - if self.banned_at is None: - self.banned_at = datetime.utcnow() - - -def load_ip_bans_config(path: str): - """Loading list of banned IPs from config file.""" - ip_list = [] - ip_schema = vol.Schema({ - vol.Optional('banned_at'): vol.Any(None, cv.datetime) - }) - - try: - try: - list_ = load_yaml_config_file(path) - except HomeAssistantError as err: - _LOGGER.error('Unable to load %s: %s', path, str(err)) - return [] - - for ip_ban, ip_info in list_.items(): - try: - ip_info = ip_schema(ip_info) - ip_info['ip_ban'] = ip_address(ip_ban) - ip_list.append(IpBan(**ip_info)) - except vol.Invalid: - _LOGGER.exception('Failed to load IP ban') - continue - - except(HomeAssistantError, FileNotFoundError): - # No need to report error, file absence means - # that no bans were applied. - return [] - - return ip_list - - -def update_ip_bans_config(path: str, ip_ban: IpBan): - """Update config file with new banned IP address.""" - with open(path, 'a') as out: - ip_ = {str(ip_ban.ip_address): { - ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S") - }} - out.write('\n') - out.write(dump(ip_)) diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py new file mode 100644 index 00000000000..0404f4a0df6 --- /dev/null +++ b/homeassistant/components/http/__init__.py @@ -0,0 +1,407 @@ +""" +This module provides WSGI application to serve the Home Assistant API. + +For more details about this component, please refer to the documentation at +https://home-assistant.io/components/http/ +""" +import asyncio +import json +import logging +import ssl +from ipaddress import ip_network +from pathlib import Path + +import os +import voluptuous as vol +from aiohttp import web +from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently + +import homeassistant.helpers.config_validation as cv +import homeassistant.remote as rem +from homeassistant.util import get_local_ip +from homeassistant.components import persistent_notification +from homeassistant.const import ( + SERVER_PORT, CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, + EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START) +from homeassistant.core import is_callback +from homeassistant.util.logging import HideSensitiveDataFilter + +from .auth import auth_middleware +from .ban import ban_middleware, process_wrong_login +from .const import ( + KEY_USE_X_FORWARDED_FOR, KEY_TRUSTED_NETWORKS, + KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, + KEY_DEVELOPMENT, KEY_AUTHENTICATED) +from .static import GZIP_FILE_SENDER, staticresource_middleware +from .util import get_real_ip + +DOMAIN = 'http' +REQUIREMENTS = ('aiohttp_cors==0.5.0',) + +CONF_API_PASSWORD = 'api_password' +CONF_SERVER_HOST = 'server_host' +CONF_SERVER_PORT = 'server_port' +CONF_DEVELOPMENT = 'development' +CONF_SSL_CERTIFICATE = 'ssl_certificate' +CONF_SSL_KEY = 'ssl_key' +CONF_CORS_ORIGINS = 'cors_allowed_origins' +CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for' +CONF_TRUSTED_NETWORKS = 'trusted_networks' +CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold' +CONF_IP_BAN_ENABLED = 'ip_ban_enabled' + +NOTIFICATION_ID_LOGIN = 'http-login' + +# TLS configuation follows the best-practice guidelines specified here: +# https://wiki.mozilla.org/Security/Server_Side_TLS +# Intermediate guidelines are followed. +SSL_VERSION = ssl.PROTOCOL_SSLv23 +SSL_OPTS = ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 +if hasattr(ssl, 'OP_NO_COMPRESSION'): + SSL_OPTS |= ssl.OP_NO_COMPRESSION +CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:" \ + "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:" \ + "ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:" \ + "DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:" \ + "ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" \ + "ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:" \ + "ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:" \ + "ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:" \ + "DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:" \ + "DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:" \ + "ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:" \ + "AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:" \ + "AES256-SHA:DES-CBC3-SHA:!DSS" + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_SERVER_HOST = '0.0.0.0' +DEFAULT_DEVELOPMENT = '0' +DEFAULT_LOGIN_ATTEMPT_THRESHOLD = -1 + +HTTP_SCHEMA = vol.Schema({ + vol.Optional(CONF_API_PASSWORD, default=None): cv.string, + vol.Optional(CONF_SERVER_HOST, default=DEFAULT_SERVER_HOST): cv.string, + vol.Optional(CONF_SERVER_PORT, default=SERVER_PORT): + vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)), + vol.Optional(CONF_DEVELOPMENT, default=DEFAULT_DEVELOPMENT): cv.string, + vol.Optional(CONF_SSL_CERTIFICATE, default=None): cv.isfile, + vol.Optional(CONF_SSL_KEY, default=None): cv.isfile, + vol.Optional(CONF_CORS_ORIGINS, default=[]): vol.All(cv.ensure_list, + [cv.string]), + vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean, + vol.Optional(CONF_TRUSTED_NETWORKS, default=[]): + vol.All(cv.ensure_list, [ip_network]), + vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD, + default=DEFAULT_LOGIN_ATTEMPT_THRESHOLD): cv.positive_int, + vol.Optional(CONF_IP_BAN_ENABLED, default=True): cv.boolean +}) + +CONFIG_SCHEMA = vol.Schema({ + DOMAIN: HTTP_SCHEMA, +}, extra=vol.ALLOW_EXTRA) + + +@asyncio.coroutine +def async_setup(hass, config): + """Set up the HTTP API and debug interface.""" + conf = config.get(DOMAIN) + + if conf is None: + conf = HTTP_SCHEMA({}) + + api_password = conf[CONF_API_PASSWORD] + server_host = conf[CONF_SERVER_HOST] + server_port = conf[CONF_SERVER_PORT] + development = conf[CONF_DEVELOPMENT] == '1' + ssl_certificate = conf[CONF_SSL_CERTIFICATE] + ssl_key = conf[CONF_SSL_KEY] + cors_origins = conf[CONF_CORS_ORIGINS] + use_x_forwarded_for = conf[CONF_USE_X_FORWARDED_FOR] + trusted_networks = conf[CONF_TRUSTED_NETWORKS] + is_ban_enabled = conf[CONF_IP_BAN_ENABLED] + login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD] + + if api_password is not None: + logging.getLogger('aiohttp.access').addFilter( + HideSensitiveDataFilter(api_password)) + + server = HomeAssistantWSGI( + hass, + development=development, + server_host=server_host, + server_port=server_port, + api_password=api_password, + ssl_certificate=ssl_certificate, + ssl_key=ssl_key, + cors_origins=cors_origins, + use_x_forwarded_for=use_x_forwarded_for, + trusted_networks=trusted_networks, + login_threshold=login_threshold, + is_ban_enabled=is_ban_enabled + ) + + @asyncio.coroutine + def stop_server(event): + """Callback to stop the server.""" + yield from server.stop() + + @asyncio.coroutine + def start_server(event): + """Callback to start the server.""" + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) + yield from server.start() + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_server) + + hass.http = server + hass.config.api = rem.API(server_host if server_host != '0.0.0.0' + else get_local_ip(), + api_password, server_port, + ssl_certificate is not None) + + return True + + +class HomeAssistantWSGI(object): + """WSGI server for Home Assistant.""" + + def __init__(self, hass, development, api_password, ssl_certificate, + ssl_key, server_host, server_port, cors_origins, + use_x_forwarded_for, trusted_networks, + login_threshold, is_ban_enabled): + """Initialize the WSGI Home Assistant server.""" + import aiohttp_cors + + middlewares = [auth_middleware, staticresource_middleware] + + if is_ban_enabled: + middlewares.insert(0, ban_middleware) + + self.app = web.Application(middlewares=middlewares, loop=hass.loop) + self.app['hass'] = hass + self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for + self.app[KEY_TRUSTED_NETWORKS] = trusted_networks + self.app[KEY_BANS_ENABLED] = is_ban_enabled + self.app[KEY_LOGIN_THRESHOLD] = login_threshold + self.app[KEY_DEVELOPMENT] = development + + self.hass = hass + self.development = development + self.api_password = api_password + self.ssl_certificate = ssl_certificate + self.ssl_key = ssl_key + self.server_host = server_host + self.server_port = server_port + self._handler = None + self.server = None + + if cors_origins: + self.cors = aiohttp_cors.setup(self.app, defaults={ + host: aiohttp_cors.ResourceOptions( + allow_headers=ALLOWED_CORS_HEADERS, + allow_methods='*', + ) for host in cors_origins + }) + else: + self.cors = None + + def register_view(self, view): + """Register a view with the WSGI server. + + The view argument must be a class that inherits from HomeAssistantView. + It is optional to instantiate it before registering; this method will + handle it either way. + """ + if isinstance(view, type): + # Instantiate the view, if needed + view = view() + + if not hasattr(view, 'url'): + class_name = view.__class__.__name__ + raise AttributeError( + '{0} missing required attribute "url"'.format(class_name) + ) + + if not hasattr(view, 'name'): + class_name = view.__class__.__name__ + raise AttributeError( + '{0} missing required attribute "name"'.format(class_name) + ) + + view.register(self.app.router) + + def register_redirect(self, url, redirect_to): + """Register a redirect with the server. + + If given this must be either a string or callable. In case of a + callable it's called with the url adapter that triggered the match and + the values of the URL as keyword arguments and has to return the target + for the redirect, otherwise it has to be a string with placeholders in + rule syntax. + """ + def redirect(request): + """Redirect to location.""" + raise HTTPMovedPermanently(redirect_to) + + self.app.router.add_route('GET', url, redirect) + + def register_static_path(self, url_root, path, cache_length=31): + """Register a folder to serve as a static path. + + Specify optional cache length of asset in days. + """ + if os.path.isdir(path): + self.app.router.add_static(url_root, path) + return + + filepath = Path(path) + + @asyncio.coroutine + def serve_file(request): + """Serve file from disk.""" + res = yield from GZIP_FILE_SENDER.send(request, filepath) + return res + + # aiohttp supports regex matching for variables. Using that as temp + # to work around cache busting MD5. + # Turns something like /static/dev-panel.html into + # /static/{filename:dev-panel(-[a-z0-9]{32}|)\.html} + base, ext = url_root.rsplit('.', 1) + base, file = base.rsplit('/', 1) + regex = r"{}(-[a-z0-9]{{32}}|)\.{}".format(file, ext) + url_pattern = "{}/{{filename:{}}}".format(base, regex) + + self.app.router.add_route('GET', url_pattern, serve_file) + + @asyncio.coroutine + def start(self): + """Start the wsgi server.""" + if self.cors is not None: + for route in list(self.app.router.routes()): + self.cors.add(route) + + if self.ssl_certificate: + context = ssl.SSLContext(SSL_VERSION) + context.options |= SSL_OPTS + context.set_ciphers(CIPHERS) + context.load_cert_chain(self.ssl_certificate, self.ssl_key) + else: + context = None + + self._handler = self.app.make_handler() + self.server = yield from self.hass.loop.create_server( + self._handler, self.server_host, self.server_port, ssl=context) + + @asyncio.coroutine + def stop(self): + """Stop the wsgi server.""" + self.server.close() + yield from self.server.wait_closed() + yield from self.app.shutdown() + yield from self._handler.finish_connections(60.0) + yield from self.app.cleanup() + + +class HomeAssistantView(object): + """Base view for all views.""" + + url = None + extra_urls = [] + requires_auth = True # Views inheriting from this class can override this + + # pylint: disable=no-self-use + def json(self, result, status_code=200): + """Return a JSON response.""" + msg = json.dumps( + result, sort_keys=True, cls=rem.JSONEncoder).encode('UTF-8') + return web.Response( + body=msg, content_type=CONTENT_TYPE_JSON, status=status_code) + + def json_message(self, error, status_code=200): + """Return a JSON message response.""" + return self.json({'message': error}, status_code) + + @asyncio.coroutine + # pylint: disable=no-self-use + def file(self, request, fil): + """Return a file.""" + assert isinstance(fil, str), 'only string paths allowed' + response = yield from GZIP_FILE_SENDER.send(request, Path(fil)) + return response + + def register(self, router): + """Register the view with a router.""" + assert self.url is not None, 'No url set for view' + urls = [self.url] + self.extra_urls + + for method in ('get', 'post', 'delete', 'put'): + handler = getattr(self, method, None) + + if not handler: + continue + + handler = request_handler_factory(self, handler) + + for url in urls: + router.add_route(method, url, handler) + + # aiohttp_cors does not work with class based views + # self.app.router.add_route('*', self.url, self, name=self.name) + + # for url in self.extra_urls: + # self.app.router.add_route('*', url, self) + + +def request_handler_factory(view, handler): + """Factory to wrap our handler classes.""" + assert asyncio.iscoroutinefunction(handler) or is_callback(handler), \ + "Handler should be a coroutine or a callback." + + @asyncio.coroutine + def handle(request): + """Handle incoming request.""" + if not request.app['hass'].is_running: + return web.Response(status=503) + + remote_addr = get_real_ip(request) + authenticated = request.get(KEY_AUTHENTICATED, False) + + if view.requires_auth and not authenticated: + yield from process_wrong_login(request) + _LOGGER.warning('Login attempt or request with an invalid ' + 'password from %s', remote_addr) + persistent_notification.async_create( + request.app['hass'], + 'Invalid password used from {}'.format(remote_addr), + 'Login attempt failed', NOTIFICATION_ID_LOGIN) + raise HTTPUnauthorized() + + _LOGGER.info('Serving %s to %s (auth: %s)', + request.path, remote_addr, authenticated) + + result = handler(request, **request.match_info) + + if asyncio.iscoroutine(result): + result = yield from result + + if isinstance(result, web.StreamResponse): + # The method handler returned a ready-made Response, how nice of it + return result + + status_code = 200 + + if isinstance(result, tuple): + result, status_code = result + + if isinstance(result, str): + result = result.encode('utf-8') + elif result is None: + result = b'' + elif not isinstance(result, bytes): + assert False, ('Result should be None, string, bytes or Response. ' + 'Got: {}').format(result) + + return web.Response(body=result, status=status_code) + + return handle diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py new file mode 100644 index 00000000000..14b442e5dde --- /dev/null +++ b/homeassistant/components/http/auth.py @@ -0,0 +1,61 @@ +"""Authentication for HTTP component.""" +import asyncio +import hmac +import logging + +from homeassistant.const import HTTP_HEADER_HA_AUTH +from .util import get_real_ip +from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED + +DATA_API_PASSWORD = 'api_password' + +_LOGGER = logging.getLogger(__name__) + + +@asyncio.coroutine +def auth_middleware(app, handler): + """Authentication middleware.""" + # If no password set, just always set authenticated=True + if app['hass'].http.api_password is None: + @asyncio.coroutine + def no_auth_middleware_handler(request): + """Auth middleware to approve all requests.""" + request[KEY_AUTHENTICATED] = True + return handler(request) + + return no_auth_middleware_handler + + @asyncio.coroutine + def auth_middleware_handler(request): + """Auth middleware to check authentication.""" + hass = app['hass'] + + # Auth code verbose on purpose + authenticated = False + + if hmac.compare_digest(request.headers.get(HTTP_HEADER_HA_AUTH, ''), + hass.http.api_password): + # A valid auth header has been set + authenticated = True + + elif hmac.compare_digest(request.GET.get(DATA_API_PASSWORD, ''), + hass.http.api_password): + authenticated = True + + elif is_trusted_ip(request): + authenticated = True + + request[KEY_AUTHENTICATED] = authenticated + + return handler(request) + + return auth_middleware_handler + + +def is_trusted_ip(request): + """Test if request is from a trusted ip.""" + ip_addr = get_real_ip(request) + + return ip_addr and any( + ip_addr in trusted_network for trusted_network + in request.app[KEY_TRUSTED_NETWORKS]) diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py new file mode 100644 index 00000000000..b3f17c1dd57 --- /dev/null +++ b/homeassistant/components/http/ban.py @@ -0,0 +1,132 @@ +"""Ban logic for HTTP component.""" +import asyncio +from collections import defaultdict +from datetime import datetime +from ipaddress import ip_address +import logging + +from aiohttp.web_exceptions import HTTPForbidden +import voluptuous as vol + +from homeassistant.components import persistent_notification +from homeassistant.config import load_yaml_config_file +from homeassistant.exceptions import HomeAssistantError +import homeassistant.helpers.config_validation as cv +from homeassistant.util.yaml import dump +from .const import ( + KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD, + KEY_FAILED_LOGIN_ATTEMPTS) +from .util import get_real_ip + +NOTIFICATION_ID_BAN = 'ip-ban' + +IP_BANS_FILE = 'ip_bans.yaml' +ATTR_BANNED_AT = "banned_at" + +SCHEMA_IP_BAN_ENTRY = vol.Schema({ + vol.Optional('banned_at'): vol.Any(None, cv.datetime) +}) + +_LOGGER = logging.getLogger(__name__) + + +@asyncio.coroutine +def ban_middleware(app, handler): + """IP Ban middleware.""" + if not app[KEY_BANS_ENABLED]: + return handler + + if KEY_BANNED_IPS not in app: + hass = app['hass'] + app[KEY_BANNED_IPS] = yield from hass.loop.run_in_executor( + None, load_ip_bans_config, hass.config.path(IP_BANS_FILE)) + + @asyncio.coroutine + def ban_middleware_handler(request): + """Verify if IP is not banned.""" + ip_address_ = get_real_ip(request) + + is_banned = any(ip_ban.ip_address == ip_address_ + for ip_ban in request.app[KEY_BANNED_IPS]) + + if is_banned: + raise HTTPForbidden() + + return handler(request) + + return ban_middleware_handler + + +@asyncio.coroutine +def process_wrong_login(request): + """Process a wrong login attempt.""" + if (not request.app[KEY_BANS_ENABLED] or + request.app[KEY_LOGIN_THRESHOLD] < 1): + return + + if KEY_FAILED_LOGIN_ATTEMPTS not in request.app: + request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) + + remote_addr = get_real_ip(request) + + request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1 + + if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] > + request.app[KEY_LOGIN_THRESHOLD]): + new_ban = IpBan(remote_addr) + request.app[KEY_BANNED_IPS].append(new_ban) + + hass = request.app['hass'] + yield from hass.loop.run_in_executor( + None, update_ip_bans_config, hass.config.path(IP_BANS_FILE), + new_ban) + + _LOGGER.warning('Banned IP %s for too many login attempts', + remote_addr) + + persistent_notification.async_create( + hass, + 'Too many login attempts from {}'.format(remote_addr), + 'Banning IP address', NOTIFICATION_ID_BAN) + + +class IpBan(object): + """Represents banned IP address.""" + + def __init__(self, ip_ban: str, banned_at: datetime=None) -> None: + """Initializing Ip Ban object.""" + self.ip_address = ip_address(ip_ban) + self.banned_at = banned_at or datetime.utcnow() + + +def load_ip_bans_config(path: str): + """Loading list of banned IPs from config file.""" + ip_list = [] + + try: + list_ = load_yaml_config_file(path) + except FileNotFoundError: + return [] + except HomeAssistantError as err: + _LOGGER.error('Unable to load %s: %s', path, str(err)) + return [] + + for ip_ban, ip_info in list_.items(): + try: + ip_info = SCHEMA_IP_BAN_ENTRY(ip_info) + ip_list.append(IpBan(ip_ban, ip_info['banned_at'])) + except vol.Invalid as err: + _LOGGER.error('Failed to load IP ban %s: %s', ip_info, err) + continue + + return ip_list + + +def update_ip_bans_config(path: str, ip_ban: IpBan): + """Update config file with new banned IP address.""" + with open(path, 'a') as out: + ip_ = {str(ip_ban.ip_address): { + ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S") + }} + out.write('\n') + out.write(dump(ip_)) diff --git a/homeassistant/components/http/const.py b/homeassistant/components/http/const.py new file mode 100644 index 00000000000..625bc24c461 --- /dev/null +++ b/homeassistant/components/http/const.py @@ -0,0 +1,12 @@ +"""HTTP specific constants.""" +KEY_AUTHENTICATED = 'ha_authenticated' +KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for' +KEY_TRUSTED_NETWORKS = 'ha_trusted_networks' +KEY_REAL_IP = 'ha_real_ip' +KEY_BANS_ENABLED = 'ha_bans_enabled' +KEY_BANNED_IPS = 'ha_banned_ips' +KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts' +KEY_LOGIN_THRESHOLD = 'ha_login_treshold' +KEY_DEVELOPMENT = 'ha_development' + +HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For' diff --git a/homeassistant/components/http/static.py b/homeassistant/components/http/static.py new file mode 100644 index 00000000000..c8c55870e0f --- /dev/null +++ b/homeassistant/components/http/static.py @@ -0,0 +1,93 @@ +"""Static file handling for HTTP component.""" +import asyncio +import mimetypes +import re + +from aiohttp import hdrs +from aiohttp.file_sender import FileSender +from aiohttp.web_urldispatcher import StaticResource +from aiohttp.web_exceptions import HTTPNotModified + +from .const import KEY_DEVELOPMENT + +_FINGERPRINT = re.compile(r'^(.+)-[a-z0-9]{32}\.(\w+)$', re.IGNORECASE) + + +class GzipFileSender(FileSender): + """FileSender class capable of sending gzip version if available.""" + + # pylint: disable=invalid-name + + @asyncio.coroutine + def send(self, request, filepath): + """Send filepath to client using request.""" + gzip = False + if 'gzip' in request.headers[hdrs.ACCEPT_ENCODING]: + gzip_path = filepath.with_name(filepath.name + '.gz') + + if gzip_path.is_file(): + filepath = gzip_path + gzip = True + + st = filepath.stat() + + modsince = request.if_modified_since + if modsince is not None and st.st_mtime <= modsince.timestamp(): + raise HTTPNotModified() + + ct, encoding = mimetypes.guess_type(str(filepath)) + if not ct: + ct = 'application/octet-stream' + + resp = self._response_factory() + resp.content_type = ct + if encoding: + resp.headers[hdrs.CONTENT_ENCODING] = encoding + if gzip: + resp.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING + resp.last_modified = st.st_mtime + + # CACHE HACK + if not request.app[KEY_DEVELOPMENT]: + cache_time = 31 * 86400 # = 1 month + resp.headers[hdrs.CACHE_CONTROL] = "public, max-age={}".format( + cache_time) + + file_size = st.st_size + + resp.content_length = file_size + with filepath.open('rb') as f: + yield from self._sendfile(request, resp, f, file_size) + + return resp + + +GZIP_FILE_SENDER = GzipFileSender() + + +@asyncio.coroutine +def staticresource_middleware(app, handler): + """Enhance StaticResourceHandler middleware. + + Adds gzip encoding and fingerprinting matching. + """ + inst = getattr(handler, '__self__', None) + if not isinstance(inst, StaticResource): + return handler + + # pylint: disable=protected-access + inst._file_sender = GZIP_FILE_SENDER + + @asyncio.coroutine + def static_middleware_handler(request): + """Strip out fingerprints from resource names.""" + fingerprinted = _FINGERPRINT.match(request.match_info['filename']) + + if fingerprinted: + request.match_info['filename'] = \ + '{}.{}'.format(*fingerprinted.groups()) + + resp = yield from handler(request) + return resp + + return static_middleware_handler diff --git a/homeassistant/components/http/util.py b/homeassistant/components/http/util.py new file mode 100644 index 00000000000..1a5a3d98a22 --- /dev/null +++ b/homeassistant/components/http/util.py @@ -0,0 +1,25 @@ +"""HTTP utilities.""" +from ipaddress import ip_address + +from .const import ( + KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR) + + +def get_real_ip(request): + """Get IP address of client.""" + if KEY_REAL_IP in request: + return request[KEY_REAL_IP] + + if (request.app[KEY_USE_X_FORWARDED_FOR] and + HTTP_HEADER_X_FORWARDED_FOR in request.headers): + request[KEY_REAL_IP] = ip_address( + request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0]) + else: + peername = request.transport.get_extra_info('peername') + + if peername: + request[KEY_REAL_IP] = ip_address(peername[0]) + else: + request[KEY_REAL_IP] = None + + return request[KEY_REAL_IP] diff --git a/homeassistant/components/ios.py b/homeassistant/components/ios.py index f9b17b552de..d83bffabc91 100644 --- a/homeassistant/components/ios.py +++ b/homeassistant/components/ios.py @@ -250,11 +250,10 @@ def setup(hass, config): discovery.load_platform(hass, "sensor", DOMAIN, {}, config) - hass.http.register_view(iOSIdentifyDeviceView(hass)) + hass.http.register_view(iOSIdentifyDeviceView) app_config = config.get(DOMAIN, {}) - hass.http.register_view(iOSPushConfigView(hass, - app_config.get(CONF_PUSH, {}))) + hass.http.register_view(iOSPushConfigView(app_config.get(CONF_PUSH, {}))) return True @@ -266,9 +265,8 @@ class iOSPushConfigView(HomeAssistantView): url = "/api/ios/push" name = "api:ios:push" - def __init__(self, hass, push_config): + def __init__(self, push_config): """Init the view.""" - super().__init__(hass) self.push_config = push_config @callback @@ -283,10 +281,6 @@ class iOSIdentifyDeviceView(HomeAssistantView): url = "/api/ios/identify" name = "api:ios:identify" - def __init__(self, hass): - """Init the view.""" - super().__init__(hass) - @asyncio.coroutine def post(self, request): """Handle the POST request for device identification.""" diff --git a/homeassistant/components/logbook.py b/homeassistant/components/logbook.py index 18e80c4c761..49ab709f8f5 100644 --- a/homeassistant/components/logbook.py +++ b/homeassistant/components/logbook.py @@ -101,7 +101,7 @@ def setup(hass, config): message = message.async_render() async_log_entry(hass, name, message, domain, entity_id) - hass.http.register_view(LogbookView(hass, config)) + hass.http.register_view(LogbookView(config)) register_built_in_panel(hass, 'logbook', 'Logbook', 'mdi:format-list-bulleted-type') @@ -118,9 +118,8 @@ class LogbookView(HomeAssistantView): name = 'api:logbook' extra_urls = ['/api/logbook/{datetime}'] - def __init__(self, hass, config): + def __init__(self, config): """Initilalize the logbook view.""" - super().__init__(hass) self.config = config @asyncio.coroutine @@ -146,7 +145,8 @@ class LogbookView(HomeAssistantView): events = recorder.execute(query) return _exclude_events(events, self.config) - events = yield from self.hass.loop.run_in_executor(None, get_results) + events = yield from request.app['hass'].loop.run_in_executor( + None, get_results) return self.json(humanify(events)) diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py index c689cdbccc4..5665699d4f3 100644 --- a/homeassistant/components/media_player/__init__.py +++ b/homeassistant/components/media_player/__init__.py @@ -17,7 +17,7 @@ from homeassistant.config import load_yaml_config_file from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa -from homeassistant.components.http import HomeAssistantView +from homeassistant.components.http import HomeAssistantView, KEY_AUTHENTICATED import homeassistant.helpers.config_validation as cv from homeassistant.util.async import run_coroutine_threadsafe from homeassistant.const import ( @@ -304,7 +304,7 @@ def setup(hass, config): component = EntityComponent( logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) - hass.http.register_view(MediaPlayerImageView(hass, component.entities)) + hass.http.register_view(MediaPlayerImageView(component.entities)) component.setup(config) @@ -736,9 +736,8 @@ class MediaPlayerImageView(HomeAssistantView): url = "/api/media_player_proxy/{entity_id}" name = "api:media_player:image" - def __init__(self, hass, entities): + def __init__(self, entities): """Initialize a media player view.""" - super().__init__(hass) self.entities = entities @asyncio.coroutine @@ -748,14 +747,14 @@ class MediaPlayerImageView(HomeAssistantView): if player is None: return web.Response(status=404) - authenticated = (request.authenticated or + authenticated = (request[KEY_AUTHENTICATED] or request.GET.get('token') == player.access_token) if not authenticated: return web.Response(status=401) data, content_type = yield from _async_fetch_image( - self.hass, player.media_image_url) + request.app['hass'], player.media_image_url) if data is None: return web.Response(status=500) diff --git a/homeassistant/components/notify/html5.py b/homeassistant/components/notify/html5.py index baf887c1e6e..6621b4be6ab 100644 --- a/homeassistant/components/notify/html5.py +++ b/homeassistant/components/notify/html5.py @@ -107,8 +107,8 @@ def get_service(hass, config): return None hass.http.register_view( - HTML5PushRegistrationView(hass, registrations, json_path)) - hass.http.register_view(HTML5PushCallbackView(hass, registrations)) + HTML5PushRegistrationView(registrations, json_path)) + hass.http.register_view(HTML5PushCallbackView(registrations)) gcm_api_key = config.get(ATTR_GCM_API_KEY) gcm_sender_id = config.get(ATTR_GCM_SENDER_ID) @@ -168,9 +168,8 @@ class HTML5PushRegistrationView(HomeAssistantView): url = '/api/notify.html5' name = 'api:notify.html5' - def __init__(self, hass, registrations, json_path): + def __init__(self, registrations, json_path): """Init HTML5PushRegistrationView.""" - super().__init__(hass) self.registrations = registrations self.json_path = json_path @@ -237,9 +236,8 @@ class HTML5PushCallbackView(HomeAssistantView): url = '/api/notify.html5/callback' name = 'api:notify.html5/callback' - def __init__(self, hass, registrations): + def __init__(self, registrations): """Init HTML5PushCallbackView.""" - super().__init__(hass) self.registrations = registrations def decode_jwt(self, token): @@ -324,7 +322,7 @@ class HTML5PushCallbackView(HomeAssistantView): event_name = '{}.{}'.format(NOTIFY_CALLBACK_EVENT, event_payload[ATTR_TYPE]) - self.hass.bus.fire(event_name, event_payload) + request.app['hass'].bus.fire(event_name, event_payload) return self.json({'status': 'ok', 'event': event_payload[ATTR_TYPE]}) diff --git a/homeassistant/components/sensor/fitbit.py b/homeassistant/components/sensor/fitbit.py index cd0346c2469..697fecca077 100644 --- a/homeassistant/components/sensor/fitbit.py +++ b/homeassistant/components/sensor/fitbit.py @@ -274,7 +274,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None): hass.http.register_redirect(FITBIT_AUTH_START, fitbit_auth_start_url) hass.http.register_view(FitbitAuthCallbackView( - hass, config, add_devices, oauth)) + config, add_devices, oauth)) request_oauth_completion(hass) @@ -286,9 +286,8 @@ class FitbitAuthCallbackView(HomeAssistantView): url = '/auth/fitbit/callback' name = 'auth:fitbit:callback' - def __init__(self, hass, config, add_devices, oauth): + def __init__(self, config, add_devices, oauth): """Initialize the OAuth callback view.""" - super().__init__(hass) self.config = config self.add_devices = add_devices self.oauth = oauth @@ -299,6 +298,7 @@ class FitbitAuthCallbackView(HomeAssistantView): from oauthlib.oauth2.rfc6749.errors import MismatchingStateError from oauthlib.oauth2.rfc6749.errors import MissingTokenError + hass = request.app['hass'] data = request.GET response_message = """Fitbit has been successfully authorized! @@ -306,7 +306,7 @@ class FitbitAuthCallbackView(HomeAssistantView): if data.get('code') is not None: redirect_uri = '{}{}'.format( - self.hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH) + hass.config.api.base_url, FITBIT_AUTH_CALLBACK_PATH) try: self.oauth.fetch_access_token(data.get('code'), redirect_uri) @@ -336,12 +336,11 @@ class FitbitAuthCallbackView(HomeAssistantView): ATTR_CLIENT_ID: self.oauth.client_id, ATTR_CLIENT_SECRET: self.oauth.client_secret } - if not config_from_file(self.hass.config.path(FITBIT_CONFIG_FILE), + if not config_from_file(hass.config.path(FITBIT_CONFIG_FILE), config_contents): _LOGGER.error("Failed to save config file") - self.hass.async_add_job(setup_platform, self.hass, self.config, - self.add_devices) + hass.async_add_job(setup_platform, hass, self.config, self.add_devices) return html_response diff --git a/homeassistant/components/sensor/torque.py b/homeassistant/components/sensor/torque.py index 8c88a4e22d2..7f63ac5b4e6 100644 --- a/homeassistant/components/sensor/torque.py +++ b/homeassistant/components/sensor/torque.py @@ -59,7 +59,7 @@ def setup_platform(hass, config, add_devices, discovery_info=None): sensors = {} hass.http.register_view(TorqueReceiveDataView( - hass, email, vehicle, sensors, add_devices)) + email, vehicle, sensors, add_devices)) return True @@ -69,9 +69,8 @@ class TorqueReceiveDataView(HomeAssistantView): url = API_PATH name = 'api:torque' - def __init__(self, hass, email, vehicle, sensors, add_devices): + def __init__(self, email, vehicle, sensors, add_devices): """Initialize a Torque view.""" - super().__init__(hass) self.email = email self.vehicle = vehicle self.sensors = sensors @@ -80,6 +79,7 @@ class TorqueReceiveDataView(HomeAssistantView): @callback def get(self, request): """Handle Torque data request.""" + hass = request.app['hass'] data = request.GET if self.email is not None and self.email != data[SENSOR_EMAIL_FIELD]: @@ -108,7 +108,7 @@ class TorqueReceiveDataView(HomeAssistantView): self.sensors[pid] = TorqueSensor( ENTITY_NAME_FORMAT.format(self.vehicle, names[pid]), units.get(pid, None)) - self.hass.async_add_job(self.add_devices, [self.sensors[pid]]) + hass.async_add_job(self.add_devices, [self.sensors[pid]]) return None diff --git a/homeassistant/components/switch/netio.py b/homeassistant/components/switch/netio.py index 74505cdcdc2..95da15898b9 100644 --- a/homeassistant/components/switch/netio.py +++ b/homeassistant/components/switch/netio.py @@ -97,6 +97,7 @@ class NetioApiView(HomeAssistantView): @callback def get(self, request, host): """Request handler.""" + hass = request.app['hass'] data = request.GET states, consumptions, cumulated_consumptions, start_dates = \ [], [], [], [] @@ -119,7 +120,7 @@ class NetioApiView(HomeAssistantView): ndev.start_dates = start_dates for dev in DEVICES[host].entities: - self.hass.async_add_job(dev.async_update_ha_state()) + hass.async_add_job(dev.async_update_ha_state()) return self.json(True) diff --git a/homeassistant/const.py b/homeassistant/const.py index 1b0921ccc95..64a4e7e5c45 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -360,7 +360,6 @@ HTTP_HEADER_CONTENT_LENGTH = 'Content-Length' HTTP_HEADER_CACHE_CONTROL = 'Cache-Control' HTTP_HEADER_EXPIRES = 'Expires' HTTP_HEADER_ORIGIN = 'Origin' -HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For' HTTP_HEADER_X_REQUESTED_WITH = 'X-Requested-With' HTTP_HEADER_ACCEPT = 'Accept' HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin' diff --git a/homeassistant/util/logging.py b/homeassistant/util/logging.py new file mode 100644 index 00000000000..d324e7253b7 --- /dev/null +++ b/homeassistant/util/logging.py @@ -0,0 +1,17 @@ +"""Logging utilities.""" +import logging + + +class HideSensitiveDataFilter(logging.Filter): + """Filter API password calls.""" + + def __init__(self, text): + """Initialize sensitive data filter.""" + super().__init__() + self.text = text + + def filter(self, record): + """Hide sensitive data in messages.""" + record.msg = record.msg.replace(self.text, '*******') + + return True diff --git a/tests/common.py b/tests/common.py index 25a10783c28..fc779e120f8 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,6 +10,8 @@ import logging import threading from contextlib import contextmanager +from aiohttp import web + from homeassistant import core as ha, loader from homeassistant.bootstrap import ( setup_component, async_prepare_setup_component) @@ -22,6 +24,9 @@ from homeassistant.const import ( EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, ATTR_DISCOVERED, SERVER_PORT) from homeassistant.components import sun, mqtt +from homeassistant.components.http.auth import auth_middleware +from homeassistant.components.http.const import ( + KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED) _TEST_INSTANCE_PORT = SERVER_PORT _LOGGER = logging.getLogger(__name__) @@ -210,13 +215,23 @@ def mock_http_component(hass): """Store registered view.""" if isinstance(view, type): # Instantiate the view, if needed - view = view(hass) + view = view() hass.http.views[view.name] = view hass.http.register_view = mock_register_view +def mock_http_component_app(hass): + """Create an aiohttp.web.Application instance for testing.""" + hass.http.api_password = None + app = web.Application(middlewares=[auth_middleware], loop=hass.loop) + app['hass'] = hass + app[KEY_USE_X_FORWARDED_FOR] = False + app[KEY_BANS_ENABLED] = False + return app + + def mock_mqtt_component(hass): """Mock the MQTT component.""" with mock.patch('homeassistant.components.mqtt.MQTT') as mock_mqtt: diff --git a/tests/components/camera/test_generic.py b/tests/components/camera/test_generic.py index e2ce9c15936..ac7b0063158 100644 --- a/tests/components/camera/test_generic.py +++ b/tests/components/camera/test_generic.py @@ -27,8 +27,8 @@ def test_fetching_url(aioclient_mock, hass, test_client): resp = yield from client.get('/api/camera_proxy/camera.config_test') - assert aioclient_mock.call_count == 1 assert resp.status == 200 + assert aioclient_mock.call_count == 1 body = yield from resp.text() assert body == 'hello world' diff --git a/tests/components/http/__init__.py b/tests/components/http/__init__.py new file mode 100644 index 00000000000..869e80fff75 --- /dev/null +++ b/tests/components/http/__init__.py @@ -0,0 +1 @@ +"""Tests for the HTTP component.""" diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py new file mode 100644 index 00000000000..d41a1f03d1b --- /dev/null +++ b/tests/components/http/test_auth.py @@ -0,0 +1,169 @@ +"""The tests for the Home Assistant HTTP component.""" +# pylint: disable=protected-access +import logging +from ipaddress import ip_address, ip_network +from unittest.mock import patch + +import requests + +from homeassistant import bootstrap, const +import homeassistant.components.http as http +from homeassistant.components.http.const import ( + KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR) + +from tests.common import get_test_instance_port, get_test_home_assistant + +API_PASSWORD = 'test1234' +SERVER_PORT = get_test_instance_port() +HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT) +HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE) +HA_HEADERS = { + const.HTTP_HEADER_HA_AUTH: API_PASSWORD, + const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON, +} +# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases +TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1', + 'FD01:DB8::1'] +TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1', + '2001:DB8:ABCD::1'] +UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] + +hass = None + + +def _url(path=''): + """Helper method to generate URLs.""" + return HTTP_BASE_URL + path + + +# pylint: disable=invalid-name +def setUpModule(): + """Initialize a Home Assistant server.""" + global hass + + hass = get_test_home_assistant() + + bootstrap.setup_component( + hass, http.DOMAIN, { + http.DOMAIN: { + http.CONF_API_PASSWORD: API_PASSWORD, + http.CONF_SERVER_PORT: SERVER_PORT, + } + } + ) + + bootstrap.setup_component(hass, 'api') + + hass.http.app[KEY_TRUSTED_NETWORKS] = [ + ip_network(trusted_network) + for trusted_network in TRUSTED_NETWORKS] + + hass.start() + + +# pylint: disable=invalid-name +def tearDownModule(): + """Stop the Home Assistant server.""" + hass.stop() + + +class TestHttp: + """Test HTTP component.""" + + def test_access_denied_without_password(self): + """Test access without password.""" + req = requests.get(_url(const.URL_API)) + + assert req.status_code == 401 + + def test_access_denied_with_wrong_password_in_header(self): + """Test access with wrong password.""" + req = requests.get( + _url(const.URL_API), + headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'}) + + assert req.status_code == 401 + + def test_access_denied_with_x_forwarded_for(self, caplog): + """Test access denied through the X-Forwarded-For http header.""" + hass.http.use_x_forwarded_for = True + for remote_addr in UNTRUSTED_ADDRESSES: + req = requests.get(_url(const.URL_API), headers={ + HTTP_HEADER_X_FORWARDED_FOR: remote_addr}) + + assert req.status_code == 401, \ + "{} shouldn't be trusted".format(remote_addr) + + def test_access_denied_with_untrusted_ip(self, caplog): + """Test access with an untrusted ip address.""" + for remote_addr in UNTRUSTED_ADDRESSES: + with patch('homeassistant.components.http.' + 'util.get_real_ip', + return_value=ip_address(remote_addr)): + req = requests.get( + _url(const.URL_API), params={'api_password': ''}) + + assert req.status_code == 401, \ + "{} shouldn't be trusted".format(remote_addr) + + def test_access_with_password_in_header(self, caplog): + """Test access with password in URL.""" + # Hide logging from requests package that we use to test logging + caplog.set_level( + logging.WARNING, logger='requests.packages.urllib3.connectionpool') + + req = requests.get( + _url(const.URL_API), + headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) + + assert req.status_code == 200 + + logs = caplog.text + + assert const.URL_API in logs + assert API_PASSWORD not in logs + + def test_access_denied_with_wrong_password_in_url(self): + """Test access with wrong password.""" + req = requests.get( + _url(const.URL_API), params={'api_password': 'wrongpassword'}) + + assert req.status_code == 401 + + def test_access_with_password_in_url(self, caplog): + """Test access with password in URL.""" + # Hide logging from requests package that we use to test logging + caplog.set_level( + logging.WARNING, logger='requests.packages.urllib3.connectionpool') + + req = requests.get( + _url(const.URL_API), params={'api_password': API_PASSWORD}) + + assert req.status_code == 200 + + logs = caplog.text + + assert const.URL_API in logs + assert API_PASSWORD not in logs + + def test_access_granted_with_x_forwarded_for(self, caplog): + """Test access denied through the X-Forwarded-For http header.""" + hass.http.app[KEY_USE_X_FORWARDED_FOR] = True + for remote_addr in TRUSTED_ADDRESSES: + req = requests.get(_url(const.URL_API), headers={ + HTTP_HEADER_X_FORWARDED_FOR: remote_addr}) + + assert req.status_code == 200, \ + "{} should be trusted".format(remote_addr) + + def test_access_granted_with_trusted_ip(self, caplog): + """Test access with trusted addresses.""" + for remote_addr in TRUSTED_ADDRESSES: + with patch('homeassistant.components.http.' + 'auth.get_real_ip', + return_value=ip_address(remote_addr)): + req = requests.get( + _url(const.URL_API), params={'api_password': ''}) + + assert req.status_code == 200, \ + '{} should be trusted'.format(remote_addr) diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py new file mode 100644 index 00000000000..b2aeca2917f --- /dev/null +++ b/tests/components/http/test_ban.py @@ -0,0 +1,118 @@ +"""The tests for the Home Assistant HTTP component.""" +# pylint: disable=protected-access +from ipaddress import ip_address +from unittest.mock import patch, mock_open + +import requests + +from homeassistant import bootstrap, const +import homeassistant.components.http as http +from homeassistant.components.http.const import ( + KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS) +from homeassistant.components.http.ban import IpBan, IP_BANS_FILE + +from tests.common import get_test_instance_port, get_test_home_assistant + +API_PASSWORD = 'test1234' +SERVER_PORT = get_test_instance_port() +HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT) +HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE) +HA_HEADERS = { + const.HTTP_HEADER_HA_AUTH: API_PASSWORD, + const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON, +} +BANNED_IPS = ['200.201.202.203', '100.64.0.2'] + +hass = None + + +def _url(path=''): + """Helper method to generate URLs.""" + return HTTP_BASE_URL + path + + +# pylint: disable=invalid-name +def setUpModule(): + """Initialize a Home Assistant server.""" + global hass + + hass = get_test_home_assistant() + + bootstrap.setup_component( + hass, http.DOMAIN, { + http.DOMAIN: { + http.CONF_API_PASSWORD: API_PASSWORD, + http.CONF_SERVER_PORT: SERVER_PORT, + } + } + ) + + bootstrap.setup_component(hass, 'api') + + hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip + in BANNED_IPS] + hass.start() + + +# pylint: disable=invalid-name +def tearDownModule(): + """Stop the Home Assistant server.""" + hass.stop() + + +class TestHttp: + """Test HTTP component.""" + + def test_access_from_banned_ip(self): + """Test accessing to server from banned IP. Both trusted and not.""" + hass.http.app[KEY_BANS_ENABLED] = True + for remote_addr in BANNED_IPS: + with patch('homeassistant.components.http.' + 'ban.get_real_ip', + return_value=ip_address(remote_addr)): + req = requests.get( + _url(const.URL_API)) + assert req.status_code == 403 + + def test_access_from_banned_ip_when_ban_is_off(self): + """Test accessing to server from banned IP when feature is off""" + hass.http.app[KEY_BANS_ENABLED] = False + for remote_addr in BANNED_IPS: + with patch('homeassistant.components.http.' + 'ban.get_real_ip', + return_value=ip_address(remote_addr)): + req = requests.get( + _url(const.URL_API), + headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) + assert req.status_code == 200 + + def test_ip_bans_file_creation(self): + """Testing if banned IP file created""" + hass.http.app[KEY_BANS_ENABLED] = True + hass.http.app[KEY_LOGIN_THRESHOLD] = 1 + + m = mock_open() + + def call_server(): + with patch('homeassistant.components.http.' + 'ban.get_real_ip', + return_value=ip_address("200.201.202.204")): + print("GETTING API") + return requests.get( + _url(const.URL_API), + headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'}) + + with patch('homeassistant.components.http.ban.open', m, create=True): + req = call_server() + assert req.status_code == 401 + assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + assert m.call_count == 0 + + req = call_server() + assert req.status_code == 401 + assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 + m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a') + + req = call_server() + assert req.status_code == 403 + assert m.call_count == 1 diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py new file mode 100644 index 00000000000..a1e0532bc14 --- /dev/null +++ b/tests/components/http/test_init.py @@ -0,0 +1,111 @@ +"""The tests for the Home Assistant HTTP component.""" +import requests + +from homeassistant import bootstrap, const +import homeassistant.components.http as http + +from tests.common import get_test_instance_port, get_test_home_assistant + +API_PASSWORD = 'test1234' +SERVER_PORT = get_test_instance_port() +HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT) +HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE) +HA_HEADERS = { + const.HTTP_HEADER_HA_AUTH: API_PASSWORD, + const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON, +} +CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE] + +hass = None + + +def _url(path=''): + """Helper method to generate URLs.""" + return HTTP_BASE_URL + path + + +# pylint: disable=invalid-name +def setUpModule(): + """Initialize a Home Assistant server.""" + global hass + + hass = get_test_home_assistant() + + bootstrap.setup_component( + hass, http.DOMAIN, { + http.DOMAIN: { + http.CONF_API_PASSWORD: API_PASSWORD, + http.CONF_SERVER_PORT: SERVER_PORT, + http.CONF_CORS_ORIGINS: CORS_ORIGINS, + } + } + ) + + bootstrap.setup_component(hass, 'api') + + hass.start() + + +# pylint: disable=invalid-name +def tearDownModule(): + """Stop the Home Assistant server.""" + hass.stop() + + +class TestHttp: + """Test HTTP component.""" + + def test_cors_allowed_with_password_in_url(self): + """Test cross origin resource sharing with password in url.""" + req = requests.get(_url(const.URL_API), + params={'api_password': API_PASSWORD}, + headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL}) + + allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN + + assert req.status_code == 200 + assert req.headers.get(allow_origin) == HTTP_BASE_URL + + def test_cors_allowed_with_password_in_header(self): + """Test cross origin resource sharing with password in header.""" + headers = { + const.HTTP_HEADER_HA_AUTH: API_PASSWORD, + const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL + } + req = requests.get(_url(const.URL_API), headers=headers) + + allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN + + assert req.status_code == 200 + assert req.headers.get(allow_origin) == HTTP_BASE_URL + + def test_cors_denied_without_origin_header(self): + """Test cross origin resource sharing with password in header.""" + headers = { + const.HTTP_HEADER_HA_AUTH: API_PASSWORD + } + req = requests.get(_url(const.URL_API), headers=headers) + + allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN + allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS + + assert req.status_code == 200 + assert allow_origin not in req.headers + assert allow_headers not in req.headers + + def test_cors_preflight_allowed(self): + """Test cross origin resource sharing preflight (OPTIONS) request.""" + headers = { + const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL, + 'Access-Control-Request-Method': 'GET', + 'Access-Control-Request-Headers': 'x-ha-access' + } + req = requests.options(_url(const.URL_API), headers=headers) + + allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN + allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS + + assert req.status_code == 200 + assert req.headers.get(allow_origin) == HTTP_BASE_URL + assert req.headers.get(allow_headers) == \ + const.HTTP_HEADER_HA_AUTH.upper() diff --git a/tests/components/notify/test_html5.py b/tests/components/notify/test_html5.py index 82e43300db7..8d27a11e094 100644 --- a/tests/components/notify/test_html5.py +++ b/tests/components/notify/test_html5.py @@ -3,10 +3,10 @@ import asyncio import json from unittest.mock import patch, MagicMock, mock_open -from aiohttp import web - from homeassistant.components.notify import html5 +from tests.common import mock_http_component_app + SUBSCRIPTION_1 = { 'browser': 'chrome', 'subscription': { @@ -121,7 +121,8 @@ class TestHtml5Notify(object): assert view.json_path == hass.config.path.return_value assert view.registrations == {} - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False @@ -153,7 +154,8 @@ class TestHtml5Notify(object): view = hass.mock_calls[1][1][0] - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False @@ -208,7 +210,8 @@ class TestHtml5Notify(object): assert view.json_path == hass.config.path.return_value assert view.registrations == config - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False @@ -253,7 +256,8 @@ class TestHtml5Notify(object): assert view.json_path == hass.config.path.return_value assert view.registrations == config - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False @@ -296,7 +300,8 @@ class TestHtml5Notify(object): assert view.json_path == hass.config.path.return_value assert view.registrations == config - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False @@ -331,7 +336,8 @@ class TestHtml5Notify(object): view = hass.mock_calls[2][1][0] - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False @@ -387,7 +393,8 @@ class TestHtml5Notify(object): bearer_token = "Bearer {}".format(push_payload['data']['jwt']) - app = web.Application(loop=loop) + hass.loop = loop + app = mock_http_component_app(hass) view.register(app.router) client = yield from test_client(app) hass.http.is_banned_ip.return_value = False diff --git a/tests/components/test_frontend.py b/tests/components/test_frontend.py index 3ff366babd9..a56fac9ed5d 100644 --- a/tests/components/test_frontend.py +++ b/tests/components/test_frontend.py @@ -6,7 +6,7 @@ import unittest import requests import homeassistant.bootstrap as bootstrap -from homeassistant.components import frontend, http +from homeassistant.components import http from homeassistant.const import HTTP_HEADER_HA_AUTH from tests.common import get_test_instance_port, get_test_home_assistant @@ -45,7 +45,6 @@ def setUpModule(): def tearDownModule(): """Stop everything that was started.""" hass.stop() - frontend.PANELS = {} class TestFrontend(unittest.TestCase): diff --git a/tests/components/test_http.py b/tests/components/test_http.py deleted file mode 100644 index 83cda160ac1..00000000000 --- a/tests/components/test_http.py +++ /dev/null @@ -1,285 +0,0 @@ -"""The tests for the Home Assistant HTTP component.""" -# pylint: disable=protected-access -import logging -from ipaddress import ip_network -from unittest.mock import patch, mock_open - -import requests - -from homeassistant import bootstrap, const -import homeassistant.components.http as http - -from tests.common import get_test_instance_port, get_test_home_assistant - -API_PASSWORD = 'test1234' -SERVER_PORT = get_test_instance_port() -HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT) -HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE) -HA_HEADERS = { - const.HTTP_HEADER_HA_AUTH: API_PASSWORD, - const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON, -} -# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases -TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1', - 'FD01:DB8::1'] -TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1', - '2001:DB8:ABCD::1'] -UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] -BANNED_IPS = ['200.201.202.203', '100.64.0.1'] - -CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE] - -hass = None - - -def _url(path=''): - """Helper method to generate URLs.""" - return HTTP_BASE_URL + path - - -# pylint: disable=invalid-name -def setUpModule(): - """Initialize a Home Assistant server.""" - global hass - - hass = get_test_home_assistant() - - hass.bus.listen('test_event', lambda _: _) - hass.states.set('test.test', 'a_state') - - bootstrap.setup_component( - hass, http.DOMAIN, { - http.DOMAIN: { - http.CONF_API_PASSWORD: API_PASSWORD, - http.CONF_SERVER_PORT: SERVER_PORT, - http.CONF_CORS_ORIGINS: CORS_ORIGINS, - } - } - ) - - bootstrap.setup_component(hass, 'api') - - hass.http.trusted_networks = [ - ip_network(trusted_network) - for trusted_network in TRUSTED_NETWORKS] - - hass.http.ip_bans = [http.IpBan(banned_ip) - for banned_ip in BANNED_IPS] - - hass.start() - - -# pylint: disable=invalid-name -def tearDownModule(): - """Stop the Home Assistant server.""" - hass.stop() - - -class TestHttp: - """Test HTTP component.""" - - def test_access_denied_without_password(self): - """Test access without password.""" - req = requests.get(_url(const.URL_API)) - - assert req.status_code == 401 - - def test_access_denied_with_wrong_password_in_header(self): - """Test access with wrong password.""" - req = requests.get( - _url(const.URL_API), - headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'}) - - assert req.status_code == 401 - - def test_access_denied_with_x_forwarded_for(self, caplog): - """Test access denied through the X-Forwarded-For http header.""" - hass.http.use_x_forwarded_for = True - for remote_addr in UNTRUSTED_ADDRESSES: - req = requests.get(_url(const.URL_API), headers={ - const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr}) - - assert req.status_code == 401, \ - "{} shouldn't be trusted".format(remote_addr) - - def test_access_denied_with_untrusted_ip(self, caplog): - """Test access with an untrusted ip address.""" - for remote_addr in UNTRUSTED_ADDRESSES: - with patch('homeassistant.components.http.' - 'HomeAssistantWSGI.get_real_ip', - return_value=remote_addr): - req = requests.get( - _url(const.URL_API), params={'api_password': ''}) - - assert req.status_code == 401, \ - "{} shouldn't be trusted".format(remote_addr) - - def test_access_with_password_in_header(self, caplog): - """Test access with password in URL.""" - # Hide logging from requests package that we use to test logging - caplog.set_level( - logging.WARNING, logger='requests.packages.urllib3.connectionpool') - - req = requests.get( - _url(const.URL_API), - headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) - - assert req.status_code == 200 - - logs = caplog.text - - # assert const.URL_API in logs - assert API_PASSWORD not in logs - - def test_access_denied_with_wrong_password_in_url(self): - """Test access with wrong password.""" - req = requests.get( - _url(const.URL_API), params={'api_password': 'wrongpassword'}) - - assert req.status_code == 401 - - def test_access_with_password_in_url(self, caplog): - """Test access with password in URL.""" - # Hide logging from requests package that we use to test logging - caplog.set_level( - logging.WARNING, logger='requests.packages.urllib3.connectionpool') - - req = requests.get( - _url(const.URL_API), params={'api_password': API_PASSWORD}) - - assert req.status_code == 200 - - logs = caplog.text - - # assert const.URL_API in logs - assert API_PASSWORD not in logs - - def test_access_granted_with_x_forwarded_for(self, caplog): - """Test access denied through the X-Forwarded-For http header.""" - hass.http.use_x_forwarded_for = True - for remote_addr in TRUSTED_ADDRESSES: - req = requests.get(_url(const.URL_API), headers={ - const.HTTP_HEADER_X_FORWARDED_FOR: remote_addr}) - - assert req.status_code == 200, \ - "{} should be trusted".format(remote_addr) - - def test_access_granted_with_trusted_ip(self, caplog): - """Test access with trusted addresses.""" - for remote_addr in TRUSTED_ADDRESSES: - with patch('homeassistant.components.http.' - 'HomeAssistantWSGI.get_real_ip', - return_value=remote_addr): - req = requests.get( - _url(const.URL_API), params={'api_password': ''}) - - assert req.status_code == 200, \ - '{} should be trusted'.format(remote_addr) - - def test_cors_allowed_with_password_in_url(self): - """Test cross origin resource sharing with password in url.""" - req = requests.get(_url(const.URL_API), - params={'api_password': API_PASSWORD}, - headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL}) - - allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN - - assert req.status_code == 200 - assert req.headers.get(allow_origin) == HTTP_BASE_URL - - def test_cors_allowed_with_password_in_header(self): - """Test cross origin resource sharing with password in header.""" - headers = { - const.HTTP_HEADER_HA_AUTH: API_PASSWORD, - const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL - } - req = requests.get(_url(const.URL_API), headers=headers) - - allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN - - assert req.status_code == 200 - assert req.headers.get(allow_origin) == HTTP_BASE_URL - - def test_cors_denied_without_origin_header(self): - """Test cross origin resource sharing with password in header.""" - headers = { - const.HTTP_HEADER_HA_AUTH: API_PASSWORD - } - req = requests.get(_url(const.URL_API), headers=headers) - - allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN - allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS - - assert req.status_code == 200 - assert allow_origin not in req.headers - assert allow_headers not in req.headers - - def test_cors_preflight_allowed(self): - """Test cross origin resource sharing preflight (OPTIONS) request.""" - headers = { - const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL, - 'Access-Control-Request-Method': 'GET', - 'Access-Control-Request-Headers': 'x-ha-access' - } - req = requests.options(_url(const.URL_API), headers=headers) - - allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN - allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS - - assert req.status_code == 200 - assert req.headers.get(allow_origin) == HTTP_BASE_URL - assert req.headers.get(allow_headers) == \ - const.HTTP_HEADER_HA_AUTH.upper() - - def test_access_from_banned_ip(self): - """Test accessing to server from banned IP. Both trusted and not.""" - hass.http.is_ban_enabled = True - for remote_addr in BANNED_IPS: - with patch('homeassistant.components.http.' - 'HomeAssistantWSGI.get_real_ip', - return_value=remote_addr): - req = requests.get( - _url(const.URL_API)) - assert req.status_code == 403 - - def test_access_from_banned_ip_when_ban_is_off(self): - """Test accessing to server from banned IP when feature is off""" - hass.http.is_ban_enabled = False - for remote_addr in BANNED_IPS: - with patch('homeassistant.components.http.' - 'HomeAssistantWSGI.get_real_ip', - return_value=remote_addr): - req = requests.get( - _url(const.URL_API), - headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) - assert req.status_code == 200 - - def test_ip_bans_file_creation(self): - """Testing if banned IP file created""" - hass.http.is_ban_enabled = True - hass.http.login_threshold = 1 - - m = mock_open() - - def call_server(): - with patch('homeassistant.components.http.' - 'HomeAssistantWSGI.get_real_ip', - return_value="200.201.202.204"): - return requests.get( - _url(const.URL_API), - headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'}) - - with patch('homeassistant.components.http.open', m, create=True): - req = call_server() - assert req.status_code == 401 - assert len(hass.http.ip_bans) == len(BANNED_IPS) - assert m.call_count == 0 - - req = call_server() - assert req.status_code == 401 - assert len(hass.http.ip_bans) == len(BANNED_IPS) + 1 - m.assert_called_once_with(hass.config.path(http.IP_BANS), 'a') - - req = call_server() - assert req.status_code == 403 - assert m.call_count == 1 diff --git a/tests/scripts/test_check_config.py b/tests/scripts/test_check_config.py index e709d4693c7..b4994c5f136 100644 --- a/tests/scripts/test_check_config.py +++ b/tests/scripts/test_check_config.py @@ -165,7 +165,15 @@ class TestCheckConfig(unittest.TestCase): self.assertDictEqual({ 'components': {'http': {'api_password': 'abc123', + 'cors_allowed_origins': [], + 'development': '0', + 'ip_ban_enabled': True, + 'login_attempts_threshold': -1, + 'server_host': '0.0.0.0', 'server_port': 8123, + 'ssl_certificate': None, + 'ssl_key': None, + 'trusted_networks': [], 'use_x_forwarded_for': False}}, 'except': {}, 'secret_cache': {secrets_path: {'http_pw': 'abc123'}},