diff --git a/homeassistant/components/notify/html5.py b/homeassistant/components/notify/html5.py index 2314722a2ab..a979ab5fb2f 100644 --- a/homeassistant/components/notify/html5.py +++ b/homeassistant/components/notify/html5.py @@ -169,15 +169,35 @@ class HTML5PushRegistrationView(HomeAssistantView): return self.json_message( humanize_error(data, ex), HTTP_BAD_REQUEST) - name = ensure_unique_string('unnamed device', self.registrations) + name = self.find_registration_name(data) + previous_registration = self.registrations.get(name) self.registrations[name] = data - if not save_json(self.json_path, self.registrations): + try: + hass = request.app['hass'] + + yield from hass.async_add_job(save_json, self.json_path, + self.registrations) + return self.json_message( + 'Push notification subscriber registered.') + except HomeAssistantError: + if previous_registration is not None: + self.registrations[name] = previous_registration + else: + self.registrations.pop(name) + return self.json_message( 'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR) - return self.json_message('Push notification subscriber registered.') + def find_registration_name(self, data): + """Find a registration name matching data or generate a unique one.""" + endpoint = data.get(ATTR_SUBSCRIPTION).get(ATTR_ENDPOINT) + for key, registration in self.registrations.items(): + subscription = registration.get(ATTR_SUBSCRIPTION) + if subscription.get(ATTR_ENDPOINT) == endpoint: + return key + return ensure_unique_string('unnamed device', self.registrations) @asyncio.coroutine def delete(self, request): @@ -202,7 +222,12 @@ class HTML5PushRegistrationView(HomeAssistantView): reg = self.registrations.pop(found) - if not save_json(self.json_path, self.registrations): + try: + hass = request.app['hass'] + + yield from hass.async_add_job(save_json, self.json_path, + self.registrations) + except HomeAssistantError: self.registrations[found] = reg return self.json_message( 'Error saving registration.', HTTP_INTERNAL_SERVER_ERROR) diff --git a/tests/components/notify/test_html5.py b/tests/components/notify/test_html5.py index c3998b6db64..6fb2e6454de 100644 --- a/tests/components/notify/test_html5.py +++ b/tests/components/notify/test_html5.py @@ -4,10 +4,14 @@ import json from unittest.mock import patch, MagicMock, mock_open from aiohttp.hdrs import AUTHORIZATION +from homeassistant.exceptions import HomeAssistantError +from homeassistant.util.json import save_json from homeassistant.components.notify import html5 from tests.common import mock_http_component_app +CONFIG_FILE = 'file.conf' + SUBSCRIPTION_1 = { 'browser': 'chrome', 'subscription': { @@ -108,36 +112,30 @@ class TestHtml5Notify(object): 'unnamed device': SUBSCRIPTION_1, } - m = mock_open() - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = 'file.conf' - service = html5.get_service(hass, {}) + hass.config.path.return_value = CONFIG_FILE + service = html5.get_service(hass, {}) - assert service is not None + assert service is not None - # assert hass.called - assert len(hass.mock_calls) == 3 + assert len(hass.mock_calls) == 3 - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == {} + view = hass.mock_calls[1][1][0] + assert view.json_path == hass.config.path.return_value + assert view.registrations == {} - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - resp = yield from client.post(REGISTER_URL, - data=json.dumps(SUBSCRIPTION_1)) + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected - handle = m() - assert json.loads(handle.write.call_args[0][0]) == expected + content = yield from resp.text() + assert resp.status == 200, content + assert view.registrations == expected + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) @asyncio.coroutine def test_registering_new_device_expiration_view(self, loop, test_client): @@ -147,36 +145,114 @@ class TestHtml5Notify(object): 'unnamed device': SUBSCRIPTION_4, } - m = mock_open() - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = 'file.conf' - service = html5.get_service(hass, {}) + hass.config.path.return_value = CONFIG_FILE + service = html5.get_service(hass, {}) - assert service is not None + assert service is not None - # assert hass.called - assert len(hass.mock_calls) == 3 + # assert hass.called + assert len(hass.mock_calls) == 3 - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == {} + view = hass.mock_calls[1][1][0] + assert view.json_path == hass.config.path.return_value + assert view.registrations == {} - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - resp = yield from client.post(REGISTER_URL, - data=json.dumps(SUBSCRIPTION_4)) + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_4)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected - handle = m() - assert json.loads(handle.write.call_args[0][0]) == expected + content = yield from resp.text() + assert resp.status == 200, content + assert view.registrations == expected + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) + + @asyncio.coroutine + def test_registering_new_device_fails_view(self, loop, test_client): + """Test subs. are not altered when registering a new device fails.""" + hass = MagicMock() + expected = {} + + hass.config.path.return_value = CONFIG_FILE + html5.get_service(hass, {}) + view = hass.mock_calls[1][1][0] + + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + + hass.async_add_job.side_effect = HomeAssistantError() + + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) + + content = yield from resp.text() + assert resp.status == 500, content + assert view.registrations == expected + + @asyncio.coroutine + def test_registering_existing_device_view(self, loop, test_client): + """Test subscription is updated when registering existing device.""" + hass = MagicMock() + expected = { + 'unnamed device': SUBSCRIPTION_4, + } + + hass.config.path.return_value = CONFIG_FILE + html5.get_service(hass, {}) + view = hass.mock_calls[1][1][0] + + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + + yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_4)) + + content = yield from resp.text() + assert resp.status == 200, content + assert view.registrations == expected + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) + + @asyncio.coroutine + def test_registering_existing_device_fails_view(self, loop, test_client): + """Test sub. is not updated when registering existing device fails.""" + hass = MagicMock() + expected = { + 'unnamed device': SUBSCRIPTION_1, + } + + hass.config.path.return_value = CONFIG_FILE + html5.get_service(hass, {}) + view = hass.mock_calls[1][1][0] + + hass.loop = loop + app = mock_http_component_app(hass) + view.register(app.router) + client = yield from test_client(app) + hass.http.is_banned_ip.return_value = False + + yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_1)) + + hass.async_add_job.side_effect = HomeAssistantError() + resp = yield from client.post(REGISTER_URL, + data=json.dumps(SUBSCRIPTION_4)) + + content = yield from resp.text() + assert resp.status == 500, content + assert view.registrations == expected @asyncio.coroutine def test_registering_new_device_validation(self, loop, test_client): @@ -188,7 +264,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -240,7 +316,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -266,8 +342,9 @@ class TestHtml5Notify(object): assert resp.status == 200, resp.response assert view.registrations == config - handle = m() - assert json.loads(handle.write.call_args[0][0]) == config + + hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, + config) @asyncio.coroutine def test_unregister_device_view_handle_unknown_subscription( @@ -285,7 +362,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -309,13 +386,13 @@ class TestHtml5Notify(object): assert resp.status == 200, resp.response assert view.registrations == config - handle = m() - assert handle.write.call_count == 0 + + hass.async_add_job.assert_not_called() @asyncio.coroutine - def test_unregistering_device_view_handles_json_safe_error( + def test_unregistering_device_view_handles_save_error( self, loop, test_client): - """Test that the HTML unregister view handles JSON write errors.""" + """Test that the HTML unregister view handles save errors.""" hass = MagicMock() config = { @@ -328,7 +405,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -346,16 +423,13 @@ class TestHtml5Notify(object): client = yield from test_client(app) hass.http.is_banned_ip.return_value = False - with patch('homeassistant.components.notify.html5.save_json', - return_value=False): - resp = yield from client.delete(REGISTER_URL, data=json.dumps({ - 'subscription': SUBSCRIPTION_1['subscription'], - })) + hass.async_add_job.side_effect = HomeAssistantError() + resp = yield from client.delete(REGISTER_URL, data=json.dumps({ + 'subscription': SUBSCRIPTION_1['subscription'], + })) assert resp.status == 500, resp.response assert view.registrations == config - handle = m() - assert handle.write.call_count == 0 @asyncio.coroutine def test_callback_view_no_jwt(self, loop, test_client): @@ -367,7 +441,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {}) assert service is not None @@ -404,7 +478,7 @@ class TestHtml5Notify(object): 'homeassistant.util.json.open', m, create=True ): - hass.config.path.return_value = 'file.conf' + hass.config.path.return_value = CONFIG_FILE service = html5.get_service(hass, {'gcm_sender_id': '100'}) assert service is not None