From 8ceef728535005cb424a44b65a41b722d5b1ba8a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 28 Jan 2020 10:54:39 -0800 Subject: [PATCH] Google Assistant: Track if request is local (#31226) * Track if request is local * Cancel early if 2FA disabled * Allow disabling 2FA for ack * Do not mark devices with 2FA as reachable * Add request source to GA events --- homeassistant/components/cloud/client.py | 4 +- .../components/google_assistant/const.py | 3 + .../components/google_assistant/helpers.py | 13 ++++- .../components/google_assistant/http.py | 7 ++- .../components/google_assistant/smart_home.py | 21 +++++-- .../components/google_assistant/trait.py | 4 +- tests/components/google_assistant/__init__.py | 1 + .../google_assistant/test_smart_home.py | 57 ++++++++++++++++--- .../components/google_assistant/test_trait.py | 8 ++- 9 files changed, 99 insertions(+), 19 deletions(-) diff --git a/homeassistant/components/cloud/client.py b/homeassistant/components/cloud/client.py index 24947ed7952..ef73d4356d5 100644 --- a/homeassistant/components/cloud/client.py +++ b/homeassistant/components/cloud/client.py @@ -11,7 +11,7 @@ from homeassistant.components.alexa import ( errors as alexa_errors, smart_home as alexa_sh, ) -from homeassistant.components.google_assistant import smart_home as ga +from homeassistant.components.google_assistant import const as gc, smart_home as ga from homeassistant.core import Context, callback from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import HomeAssistantType @@ -160,7 +160,7 @@ class CloudClient(Interface): gconf = await self.get_google_config() return await ga.async_handle_message( - self._hass, gconf, gconf.cloud_user, payload + self._hass, gconf, gconf.cloud_user, payload, gc.SOURCE_CLOUD ) async def async_webhook_message(self, payload: Dict[Any, Any]) -> Dict[Any, Any]: diff --git a/homeassistant/components/google_assistant/const.py b/homeassistant/components/google_assistant/const.py index dcb87d1d93d..add625d2de4 100644 --- a/homeassistant/components/google_assistant/const.py +++ b/homeassistant/components/google_assistant/const.py @@ -143,3 +143,6 @@ CHALLENGE_PIN_NEEDED = "pinNeeded" CHALLENGE_FAILED_PIN_NEEDED = "challengeFailedPinNeeded" STORE_AGENT_USER_IDS = "agent_user_ids" + +SOURCE_CLOUD = "cloud" +SOURCE_LOCAL = "local" diff --git a/homeassistant/components/google_assistant/helpers.py b/homeassistant/components/google_assistant/helpers.py index 6493d759880..8444ba11c61 100644 --- a/homeassistant/components/google_assistant/helpers.py +++ b/homeassistant/components/google_assistant/helpers.py @@ -28,6 +28,7 @@ from .const import ( DOMAIN, DOMAIN_TO_GOOGLE_TYPES, ERR_FUNCTION_NOT_SUPPORTED, + SOURCE_LOCAL, STORE_AGENT_USER_IDS, ) from .error import SmartHomeError @@ -232,7 +233,7 @@ class AbstractConfig(ABC): return json_response(smart_home.turned_off_response(payload)) result = await smart_home.async_handle_message( - self.hass, self, self.local_sdk_user_id, payload + self.hass, self, self.local_sdk_user_id, payload, SOURCE_LOCAL ) if _LOGGER.isEnabledFor(logging.DEBUG): @@ -286,15 +287,22 @@ class RequestData: self, config: AbstractConfig, user_id: str, + source: str, request_id: str, devices: Optional[List[dict]], ): """Initialize the request data.""" self.config = config + self.source = source self.request_id = request_id self.context = Context(user_id=user_id) self.devices = devices + @property + def is_local_request(self): + """Return if this is a local request.""" + return self.source == SOURCE_LOCAL + def get_google_type(domain, device_class): """Google type based on domain and device class.""" @@ -354,6 +362,9 @@ class GoogleEntity: features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0) device_class = state.attributes.get(ATTR_DEVICE_CLASS) + if not self.config.should_2fa(state): + return False + return any( trait.might_2fa(domain, features, device_class) for trait in self.traits() ) diff --git a/homeassistant/components/google_assistant/http.py b/homeassistant/components/google_assistant/http.py index f8fa51da8d7..7bd3583e5c2 100644 --- a/homeassistant/components/google_assistant/http.py +++ b/homeassistant/components/google_assistant/http.py @@ -30,6 +30,7 @@ from .const import ( HOMEGRAPH_TOKEN_URL, REPORT_STATE_BASE_URL, REQUEST_SYNC_BASE_URL, + SOURCE_CLOUD, ) from .helpers import AbstractConfig from .smart_home import async_handle_message @@ -238,6 +239,10 @@ class GoogleAssistantView(HomeAssistantView): """Handle Google Assistant requests.""" message: dict = await request.json() result = await async_handle_message( - request.app["hass"], self.config, request["hass_user"].id, message + request.app["hass"], + self.config, + request["hass_user"].id, + message, + SOURCE_CLOUD, ) return self.json(result) diff --git a/homeassistant/components/google_assistant/smart_home.py b/homeassistant/components/google_assistant/smart_home.py index 8033bcec865..bf6c32505aa 100644 --- a/homeassistant/components/google_assistant/smart_home.py +++ b/homeassistant/components/google_assistant/smart_home.py @@ -21,9 +21,11 @@ HANDLERS = Registry() _LOGGER = logging.getLogger(__name__) -async def async_handle_message(hass, config, user_id, message): +async def async_handle_message(hass, config, user_id, message, source): """Handle incoming API messages.""" - data = RequestData(config, user_id, message["requestId"], message.get("devices")) + data = RequestData( + config, user_id, source, message["requestId"], message.get("devices") + ) response = await _process(hass, data, message) @@ -75,7 +77,9 @@ async def async_devices_sync(hass, data, payload): https://developers.google.com/assistant/smarthome/develop/process-intents#SYNC """ hass.bus.async_fire( - EVENT_SYNC_RECEIVED, {"request_id": data.request_id}, context=data.context + EVENT_SYNC_RECEIVED, + {"request_id": data.request_id, "source": data.source}, + context=data.context, ) agent_user_id = data.config.get_agent_user_id(data.context) @@ -108,7 +112,11 @@ async def async_devices_query(hass, data, payload): hass.bus.async_fire( EVENT_QUERY_RECEIVED, - {"request_id": data.request_id, ATTR_ENTITY_ID: devid}, + { + "request_id": data.request_id, + ATTR_ENTITY_ID: devid, + "source": data.source, + }, context=data.context, ) @@ -142,6 +150,7 @@ async def handle_devices_execute(hass, data, payload): "request_id": data.request_id, ATTR_ENTITY_ID: entity_id, "execution": execution, + "source": data.source, }, context=data.context, ) @@ -234,7 +243,9 @@ async def async_devices_reachable(hass, data: RequestData, payload): "devices": [ entity.reachable_device_serialize() for entity in async_get_entities(hass, data.config) - if entity.entity_id in google_ids and entity.should_expose() + if entity.entity_id in google_ids + and entity.should_expose() + and not entity.might_2fa() ] } diff --git a/homeassistant/components/google_assistant/trait.py b/homeassistant/components/google_assistant/trait.py index 14839066ebe..b4585ebde03 100644 --- a/homeassistant/components/google_assistant/trait.py +++ b/homeassistant/components/google_assistant/trait.py @@ -1447,6 +1447,8 @@ def _verify_pin_challenge(data, state, challenge): def _verify_ack_challenge(data, state, challenge): - """Verify a pin challenge.""" + """Verify an ack challenge.""" + if not data.config.should_2fa(state): + return if not challenge or not challenge.get("ack"): raise ChallengeNeeded(CHALLENGE_ACK_NEEDED) diff --git a/tests/components/google_assistant/__init__.py b/tests/components/google_assistant/__init__.py index edb12f06f33..9ef0599d394 100644 --- a/tests/components/google_assistant/__init__.py +++ b/tests/components/google_assistant/__init__.py @@ -22,6 +22,7 @@ class MockConfig(helpers.AbstractConfig): *, secure_devices_pin=None, should_expose=None, + should_2fa=None, entity_config=None, hass=None, local_sdk_webhook_id=None, diff --git a/tests/components/google_assistant/test_smart_home.py b/tests/components/google_assistant/test_smart_home.py index 7ffe9cda477..b3467eae326 100644 --- a/tests/components/google_assistant/test_smart_home.py +++ b/tests/components/google_assistant/test_smart_home.py @@ -82,6 +82,7 @@ async def test_sync_message(hass): config, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -115,7 +116,7 @@ async def test_sync_message(hass): assert len(events) == 1 assert events[0].event_type == EVENT_SYNC_RECEIVED - assert events[0].data == {"request_id": REQ_ID} + assert events[0].data == {"request_id": REQ_ID, "source": "cloud"} # pylint: disable=redefined-outer-name @@ -148,6 +149,7 @@ async def test_sync_in_area(hass, registries): config, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -181,7 +183,7 @@ async def test_sync_in_area(hass, registries): assert len(events) == 1 assert events[0].event_type == EVENT_SYNC_RECEIVED - assert events[0].data == {"request_id": REQ_ID} + assert events[0].data == {"request_id": REQ_ID, "source": "cloud"} async def test_query_message(hass): @@ -220,6 +222,7 @@ async def test_query_message(hass): } ], }, + const.SOURCE_CLOUD, ) assert result == { @@ -247,11 +250,23 @@ async def test_query_message(hass): assert len(events) == 3 assert events[0].event_type == EVENT_QUERY_RECEIVED - assert events[0].data == {"request_id": REQ_ID, "entity_id": "light.demo_light"} + assert events[0].data == { + "request_id": REQ_ID, + "entity_id": "light.demo_light", + "source": "cloud", + } assert events[1].event_type == EVENT_QUERY_RECEIVED - assert events[1].data == {"request_id": REQ_ID, "entity_id": "light.another_light"} + assert events[1].data == { + "request_id": REQ_ID, + "entity_id": "light.another_light", + "source": "cloud", + } assert events[2].event_type == EVENT_QUERY_RECEIVED - assert events[2].data == {"request_id": REQ_ID, "entity_id": "light.non_existing"} + assert events[2].data == { + "request_id": REQ_ID, + "entity_id": "light.non_existing", + "source": "cloud", + } async def test_execute(hass): @@ -300,6 +315,7 @@ async def test_execute(hass): } ], }, + const.SOURCE_CLOUD, ) assert result == { @@ -341,6 +357,7 @@ async def test_execute(hass): "command": "action.devices.commands.OnOff", "params": {"on": True}, }, + "source": "cloud", } assert events[1].event_type == EVENT_COMMAND_RECEIVED assert events[1].data == { @@ -350,6 +367,7 @@ async def test_execute(hass): "command": "action.devices.commands.BrightnessAbsolute", "params": {"brightness": 20}, }, + "source": "cloud", } assert events[2].event_type == EVENT_COMMAND_RECEIVED assert events[2].data == { @@ -359,6 +377,7 @@ async def test_execute(hass): "command": "action.devices.commands.OnOff", "params": {"on": True}, }, + "source": "cloud", } assert events[3].event_type == EVENT_COMMAND_RECEIVED assert events[3].data == { @@ -368,6 +387,7 @@ async def test_execute(hass): "command": "action.devices.commands.BrightnessAbsolute", "params": {"brightness": 20}, }, + "source": "cloud", } assert len(service_events) == 2 @@ -424,6 +444,7 @@ async def test_raising_error_trait(hass): } ], }, + const.SOURCE_CLOUD, ) assert result == { @@ -448,6 +469,7 @@ async def test_raising_error_trait(hass): "command": "action.devices.commands.ThermostatTemperatureSetpoint", "params": {"thermostatTemperatureSetpoint": 10}, }, + "source": "cloud", } @@ -483,6 +505,7 @@ async def test_unavailable_state_does_sync(hass): BASIC_CONFIG, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -515,7 +538,7 @@ async def test_unavailable_state_does_sync(hass): assert len(events) == 1 assert events[0].event_type == EVENT_SYNC_RECEIVED - assert events[0].data == {"request_id": REQ_ID} + assert events[0].data == {"request_id": REQ_ID, "source": "cloud"} @pytest.mark.parametrize( @@ -545,6 +568,7 @@ async def test_device_class_switch(hass, device_class, google_type): BASIC_CONFIG, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -589,6 +613,7 @@ async def test_device_class_binary_sensor(hass, device_class, google_type): BASIC_CONFIG, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -629,6 +654,7 @@ async def test_device_class_cover(hass, device_class, google_type): BASIC_CONFIG, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -669,6 +695,7 @@ async def test_device_media_player(hass, device_class, google_type): BASIC_CONFIG, "test-agent", {"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]}, + const.SOURCE_CLOUD, ) assert result == { @@ -702,6 +729,7 @@ async def test_query_disconnect(hass): config, "test-agent", {"inputs": [{"intent": "action.devices.DISCONNECT"}], "requestId": REQ_ID}, + const.SOURCE_CLOUD, ) assert result is None assert len(mock_disconnect.mock_calls) == 1 @@ -751,6 +779,7 @@ async def test_trait_execute_adding_query_data(hass): } ], }, + const.SOURCE_CLOUD, ) assert result == { @@ -817,6 +846,7 @@ async def test_identify(hass): } ], }, + const.SOURCE_CLOUD, ) assert result == { @@ -851,8 +881,11 @@ async def test_reachable_devices(hass): # Not passed in as google_id hass.states.async_set("light.not_mentioned", "on") + # Has 2FA + hass.states.async_set("lock.has_2fa", "on") + config = MockConfig( - should_expose=lambda state: state.entity_id != "light.not_expose" + should_expose=lambda state: state.entity_id != "light.not_expose", ) user_agent_id = "mock-user-id" @@ -898,9 +931,19 @@ async def test_reachable_devices(hass): "webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219", }, }, + { + "id": "lock.has_2fa", + "customData": { + "httpPort": 8123, + "httpSSL": False, + "proxyDeviceId": proxy_device_id, + "webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219", + }, + }, {"id": proxy_device_id, "customData": {}}, ], }, + const.SOURCE_CLOUD, ) assert result == { diff --git a/tests/components/google_assistant/test_trait.py b/tests/components/google_assistant/test_trait.py index f59d4006d29..232da039ea7 100644 --- a/tests/components/google_assistant/test_trait.py +++ b/tests/components/google_assistant/test_trait.py @@ -51,11 +51,15 @@ _LOGGER = logging.getLogger(__name__) REQ_ID = "ff36a3cc-ec34-11e6-b1a0-64510650abcf" -BASIC_DATA = helpers.RequestData(BASIC_CONFIG, "test-agent", REQ_ID, None) +BASIC_DATA = helpers.RequestData( + BASIC_CONFIG, "test-agent", const.SOURCE_CLOUD, REQ_ID, None +) PIN_CONFIG = MockConfig(secure_devices_pin="1234") -PIN_DATA = helpers.RequestData(PIN_CONFIG, "test-agent", REQ_ID, None) +PIN_DATA = helpers.RequestData( + PIN_CONFIG, "test-agent", const.SOURCE_CLOUD, REQ_ID, None +) async def test_brightness_light(hass):