From 455c629f478bb966dd9a8ebae7f7562a4fc83fdf Mon Sep 17 00:00:00 2001 From: Christopher Viel Date: Fri, 5 Jan 2018 17:29:27 -0500 Subject: [PATCH] Don't duplicate html5 registrations (#11451) * Don't duplicate html5 registrations If a registration is posted and another registration with the same endpoint URL exists, update that one instead. That way, we preserve the device name that has been configured. The previous behavior used to append 'unnamed device' registrations over and over, leading to multiple copies of the same registration. The endpoint URL is unique per service worker so it is safe to update matching registrations. * Refactor html5 registration view to not write json in the event loop --- homeassistant/components/notify/html5.py | 33 +++- tests/components/notify/test_html5.py | 212 +++++++++++++++-------- 2 files changed, 172 insertions(+), 73 deletions(-) diff --git a/homeassistant/components/notify/html5.py b/homeassistant/components/notify/html5.py index 2314722a2ab..a979ab5fb2f 100644 --- a/homeassistant/components/notify/html5.py +++ b/homeassistant/components/notify/html5.py @@ -169,15 +169,35 @@ class HTML5PushRegistrationView(HomeAssistantView): return self.json_message( humanize_error(data, ex), HTTP_BAD_REQUEST) - name = ensure_unique_string('unnamed device', self.registrations) + name = self.find_registration_name(data) + previous_registration = self.registrations.get(name) self.registrations[name] = data - if not save_json(self.json_path, self.registrations): + try: + hass = request.app['hass'] + + yield from hass.async_add_job(save_json, self.json_path, + self.registrations) + return self.json_message( + 'Push notification subscriber registered.') + except HomeAssistantError: + if previous_registration is not None: + self.registrations[name] = previous_registration + else: + self.registrations.pop(name) + return self.json_message( 'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR) - return self.json_message('Push notification subscriber registered.') + def find_registration_name(self, data): + """Find a registration name matching data or generate a unique one.""" + endpoint = data.get(ATTR_SUBSCRIPTION).get(ATTR_ENDPOINT) + for key, registration in self.registrations.items(): + subscription = registration.get(ATTR_SUBSCRIPTION) + if subscription.get(ATTR_ENDPOINT) == endpoint: + return key + return ensure_unique_string('unnamed device', self.registrations) @asyncio.coroutine def delete(self, request): @@ -202,7 +222,12 @@ class HTML5PushRegistrationView(HomeAssistantView): reg = self.registrations.pop(found) - if not save_json(self.json_path, self.registrations): + try: + hass = request.app['hass'] + + yield from hass.async_add_job(save_json, self.json_path, + self.registrations) + except HomeAssistantError: self.registrations[found] = reg return self.json_message( 'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR) diff --git a/tests/components/notify/test_html5.py b/tests/components/notify/test_html5.py index c3998b6db64..6fb2e6454de 100644 --- a/tests/components/notify/test_html5.py +++ b/tests/components/notify/test_html5.py @@ -4,10 +4,14 @@ import json from unittest.mock import patch, MagicMock, mock_open from aiohttp.hdrs import AUTHORIZATION +from homeassistant.exceptions import HomeAssistantError +from homeassistant.util.json import save_json from homeassistant.components.notify import html5 from tests.common import mock_http_component_app +CONFIG_FILE = 'file.conf' + SUBSCRIPTION_1 = { 'browser': 'chrome', 'subscription': { @@ -108,36 +112,30 @@ class TestHtml5Notify(object): 'unnamed device': SUBSCRIPTION_1, } - m = mock_open() - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = 'file.conf' - service = html5.get_service(hass, {}) + hass.config.path.return_value = CONFIG_FILE + service = html5.get_service(hass, {}) - assert service is not None + assert service is not None - # assert hass.called - assert len(hass.mock_calls) == 3 + assert len(hass.mock_calls) == 3 - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == {} + view = hass.mock_calls[1][1][0] + assert view.json_path == hass.config.path.return_value + assert view.registrations == {} - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - resp = yield from client.post(REGISTER_URL, - data=json.dumps(SUBSCRIPTION_1)) + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected - handle = m() - assert json.loads(handle.write.call_args[0][0]) == expected + content = yield from resp.text() + assert resp.status == 200, content + assert view.registrations == expected + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) @asyncio.coroutine def test_registering_new_device_expiration_view(self, loop, test_client): @@ -147,36 +145,114 @@ class TestHtml5Notify(object): 'unnamed device': SUBSCRIPTION_4, } - m = mock_open() - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = 'file.conf' - service = html5.get_service(hass, {}) + hass.config.path.return_value = CONFIG_FILE + service = html5.get_service(hass, {}) - assert service is not None + assert service is not None - # assert hass.called - assert len(hass.mock_calls) == 3 + # assert hass.called + assert len(hass.mock_calls) == 3 - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == {} + view = hass.mock_calls[1][1][0] + assert view.json_path == hass.config.path.return_value + assert view.registrations == {} - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - resp = yield from client.post(REGISTER_URL, - data=json.dumps(SUBSCRIPTION_4)) + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_4)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected - handle = m() - assert json.loads(handle.write.call_args[0][0]) == expected + content = yield from resp.text() + assert resp.status == 200, content + assert view.registrations == expected + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) + + @asyncio.coroutine + def test_registering_new_device_fails_view(self, loop, test_client): + """Test subs. are not altered when registering a new device fails.""" + hass = MagicMock() + expected = {} + + hass.config.path.return_value = CONFIG_FILE + html5.get_service(hass, {}) + view = hass.mock_calls[1][1][0] + + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + + hass.async_add_job.side_effect = HomeAssistantError() + + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) + + content = yield from resp.text() + assert resp.status == 500, content + assert view.registrations == expected + + @asyncio.coroutine + def test_registering_existing_device_view(self, loop, test_client): + """Test subscription is updated when registering existing device.""" + hass = MagicMock() + expected = { + 'unnamed device': SUBSCRIPTION_4, + } + + hass.config.path.return_value = CONFIG_FILE + html5.get_service(hass, {}) + view = hass.mock_calls[1][1][0] + + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + + yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_4)) + + content = yield from resp.text() + assert resp.status == 200, content + assert view.registrations == expected + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) + + @asyncio.coroutine + def test_registering_existing_device_fails_view(self, loop, test_client): + """Test sub. is not updated when registering existing device fails.""" + hass = MagicMock() + expected = { + 'unnamed device': SUBSCRIPTION_1, + } + + hass.config.path.return_value = CONFIG_FILE + html5.get_service(hass, {}) + view = hass.mock_calls[1][1][0] + + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + + yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) + + hass.async_add_job.side_effect = HomeAssistantError() + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_4)) + + content = yield from resp.text() + assert resp.status == 500, content + assert view.registrations == expected @asyncio.coroutine def test_registering_new_device_validation(self, loop, test_client): @@ -188,7 +264,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -240,7 +316,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -266,8 +342,9 @@ class TestHtml5Notify(object): assert resp.status == 200, resp.response assert view.registrations == config - handle = m() - assert json.loads(handle.write.call_args[0][0]) == config + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, + config) @asyncio.coroutine def test_unregister_device_view_handle_unknown_subscription( @@ -285,7 +362,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -309,13 +386,13 @@ class TestHtml5Notify(object): assert resp.status == 200, resp.response assert view.registrations == config - handle = m() - assert handle.write.call_count == 0 + + hass.async_add_job.assert_not_called() @asyncio.coroutine - def test_unregistering_device_view_handles_json_safe_error( + def test_unregistering_device_view_handles_save_error( self, loop, test_client): - """Test that the HTML unregister view handles JSON write errors.""" + """Test that the HTML unregister view handles save errors.""" hass = MagicMock() config = { @@ -328,7 +405,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -346,16 +423,13 @@ class TestHtml5Notify(object): client = yield from test_client(app) hass.http.is_banned_ip.return_value = False - with patch('homeassistant.components.notify.html5.save_json', - return_value=False): - resp = yield from client.delete(REGISTER_URL, data=json.dumps({ - 'subscription': SUBSCRIPTION_1['subscription'], - })) + hass.async_add_job.side_effect = HomeAssistantError() + resp = yield from client.delete(REGISTER_URL, data=json.dumps({ + 'subscription': SUBSCRIPTION_1['subscription'], + })) assert resp.status == 500, resp.response assert view.registrations == config - handle = m() - assert handle.write.call_count == 0 @asyncio.coroutine def test_callback_view_no_jwt(self, loop, test_client): @@ -367,7 +441,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -404,7 +478,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {'gcm_sender_id': '100'}) assert service is not None