Google Assistant: Track if request is local (#31226)

* Track if request is local

* Cancel early if 2FA disabled

* Allow disabling 2FA for ack

* Do not mark devices with 2FA as reachable

* Add request source to GA events
This commit is contained in:
Paulus Schoutsen 2020-01-28 10:54:39 -08:00 committed by GitHub
parent 03954be12d
commit 8ceef72853
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 99 additions and 19 deletions

View File

@ -11,7 +11,7 @@ from homeassistant.components.alexa import (
errors as alexa_errors,
smart_home as alexa_sh,
)
from homeassistant.components.google_assistant import smart_home as ga
from homeassistant.components.google_assistant import const as gc, smart_home as ga
from homeassistant.core import Context, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import HomeAssistantType
@ -160,7 +160,7 @@ class CloudClient(Interface):
gconf = await self.get_google_config()
return await ga.async_handle_message(
self._hass, gconf, gconf.cloud_user, payload
self._hass, gconf, gconf.cloud_user, payload, gc.SOURCE_CLOUD
)
async def async_webhook_message(self, payload: Dict[Any, Any]) -> Dict[Any, Any]:

View File

@ -143,3 +143,6 @@ CHALLENGE_PIN_NEEDED = "pinNeeded"
CHALLENGE_FAILED_PIN_NEEDED = "challengeFailedPinNeeded"
STORE_AGENT_USER_IDS = "agent_user_ids"
SOURCE_CLOUD = "cloud"
SOURCE_LOCAL = "local"

View File

@ -28,6 +28,7 @@ from .const import (
DOMAIN,
DOMAIN_TO_GOOGLE_TYPES,
ERR_FUNCTION_NOT_SUPPORTED,
SOURCE_LOCAL,
STORE_AGENT_USER_IDS,
)
from .error import SmartHomeError
@ -232,7 +233,7 @@ class AbstractConfig(ABC):
return json_response(smart_home.turned_off_response(payload))
result = await smart_home.async_handle_message(
self.hass, self, self.local_sdk_user_id, payload
self.hass, self, self.local_sdk_user_id, payload, SOURCE_LOCAL
)
if _LOGGER.isEnabledFor(logging.DEBUG):
@ -286,15 +287,22 @@ class RequestData:
self,
config: AbstractConfig,
user_id: str,
source: str,
request_id: str,
devices: Optional[List[dict]],
):
"""Initialize the request data."""
self.config = config
self.source = source
self.request_id = request_id
self.context = Context(user_id=user_id)
self.devices = devices
@property
def is_local_request(self):
"""Return if this is a local request."""
return self.source == SOURCE_LOCAL
def get_google_type(domain, device_class):
"""Google type based on domain and device class."""
@ -354,6 +362,9 @@ class GoogleEntity:
features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
device_class = state.attributes.get(ATTR_DEVICE_CLASS)
if not self.config.should_2fa(state):
return False
return any(
trait.might_2fa(domain, features, device_class) for trait in self.traits()
)

View File

@ -30,6 +30,7 @@ from .const import (
HOMEGRAPH_TOKEN_URL,
REPORT_STATE_BASE_URL,
REQUEST_SYNC_BASE_URL,
SOURCE_CLOUD,
)
from .helpers import AbstractConfig
from .smart_home import async_handle_message
@ -238,6 +239,10 @@ class GoogleAssistantView(HomeAssistantView):
"""Handle Google Assistant requests."""
message: dict = await request.json()
result = await async_handle_message(
request.app["hass"], self.config, request["hass_user"].id, message
request.app["hass"],
self.config,
request["hass_user"].id,
message,
SOURCE_CLOUD,
)
return self.json(result)

View File

@ -21,9 +21,11 @@ HANDLERS = Registry()
_LOGGER = logging.getLogger(__name__)
async def async_handle_message(hass, config, user_id, message):
async def async_handle_message(hass, config, user_id, message, source):
"""Handle incoming API messages."""
data = RequestData(config, user_id, message["requestId"], message.get("devices"))
data = RequestData(
config, user_id, source, message["requestId"], message.get("devices")
)
response = await _process(hass, data, message)
@ -75,7 +77,9 @@ async def async_devices_sync(hass, data, payload):
https://developers.google.com/assistant/smarthome/develop/process-intents#SYNC
"""
hass.bus.async_fire(
EVENT_SYNC_RECEIVED, {"request_id": data.request_id}, context=data.context
EVENT_SYNC_RECEIVED,
{"request_id": data.request_id, "source": data.source},
context=data.context,
)
agent_user_id = data.config.get_agent_user_id(data.context)
@ -108,7 +112,11 @@ async def async_devices_query(hass, data, payload):
hass.bus.async_fire(
EVENT_QUERY_RECEIVED,
{"request_id": data.request_id, ATTR_ENTITY_ID: devid},
{
"request_id": data.request_id,
ATTR_ENTITY_ID: devid,
"source": data.source,
},
context=data.context,
)
@ -142,6 +150,7 @@ async def handle_devices_execute(hass, data, payload):
"request_id": data.request_id,
ATTR_ENTITY_ID: entity_id,
"execution": execution,
"source": data.source,
},
context=data.context,
)
@ -234,7 +243,9 @@ async def async_devices_reachable(hass, data: RequestData, payload):
"devices": [
entity.reachable_device_serialize()
for entity in async_get_entities(hass, data.config)
if entity.entity_id in google_ids and entity.should_expose()
if entity.entity_id in google_ids
and entity.should_expose()
and not entity.might_2fa()
]
}

View File

@ -1447,6 +1447,8 @@ def _verify_pin_challenge(data, state, challenge):
def _verify_ack_challenge(data, state, challenge):
"""Verify a pin challenge."""
"""Verify an ack challenge."""
if not data.config.should_2fa(state):
return
if not challenge or not challenge.get("ack"):
raise ChallengeNeeded(CHALLENGE_ACK_NEEDED)

View File

@ -22,6 +22,7 @@ class MockConfig(helpers.AbstractConfig):
*,
secure_devices_pin=None,
should_expose=None,
should_2fa=None,
entity_config=None,
hass=None,
local_sdk_webhook_id=None,

View File

@ -82,6 +82,7 @@ async def test_sync_message(hass):
config,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -115,7 +116,7 @@ async def test_sync_message(hass):
assert len(events) == 1
assert events[0].event_type == EVENT_SYNC_RECEIVED
assert events[0].data == {"request_id": REQ_ID}
assert events[0].data == {"request_id": REQ_ID, "source": "cloud"}
# pylint: disable=redefined-outer-name
@ -148,6 +149,7 @@ async def test_sync_in_area(hass, registries):
config,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -181,7 +183,7 @@ async def test_sync_in_area(hass, registries):
assert len(events) == 1
assert events[0].event_type == EVENT_SYNC_RECEIVED
assert events[0].data == {"request_id": REQ_ID}
assert events[0].data == {"request_id": REQ_ID, "source": "cloud"}
async def test_query_message(hass):
@ -220,6 +222,7 @@ async def test_query_message(hass):
}
],
},
const.SOURCE_CLOUD,
)
assert result == {
@ -247,11 +250,23 @@ async def test_query_message(hass):
assert len(events) == 3
assert events[0].event_type == EVENT_QUERY_RECEIVED
assert events[0].data == {"request_id": REQ_ID, "entity_id": "light.demo_light"}
assert events[0].data == {
"request_id": REQ_ID,
"entity_id": "light.demo_light",
"source": "cloud",
}
assert events[1].event_type == EVENT_QUERY_RECEIVED
assert events[1].data == {"request_id": REQ_ID, "entity_id": "light.another_light"}
assert events[1].data == {
"request_id": REQ_ID,
"entity_id": "light.another_light",
"source": "cloud",
}
assert events[2].event_type == EVENT_QUERY_RECEIVED
assert events[2].data == {"request_id": REQ_ID, "entity_id": "light.non_existing"}
assert events[2].data == {
"request_id": REQ_ID,
"entity_id": "light.non_existing",
"source": "cloud",
}
async def test_execute(hass):
@ -300,6 +315,7 @@ async def test_execute(hass):
}
],
},
const.SOURCE_CLOUD,
)
assert result == {
@ -341,6 +357,7 @@ async def test_execute(hass):
"command": "action.devices.commands.OnOff",
"params": {"on": True},
},
"source": "cloud",
}
assert events[1].event_type == EVENT_COMMAND_RECEIVED
assert events[1].data == {
@ -350,6 +367,7 @@ async def test_execute(hass):
"command": "action.devices.commands.BrightnessAbsolute",
"params": {"brightness": 20},
},
"source": "cloud",
}
assert events[2].event_type == EVENT_COMMAND_RECEIVED
assert events[2].data == {
@ -359,6 +377,7 @@ async def test_execute(hass):
"command": "action.devices.commands.OnOff",
"params": {"on": True},
},
"source": "cloud",
}
assert events[3].event_type == EVENT_COMMAND_RECEIVED
assert events[3].data == {
@ -368,6 +387,7 @@ async def test_execute(hass):
"command": "action.devices.commands.BrightnessAbsolute",
"params": {"brightness": 20},
},
"source": "cloud",
}
assert len(service_events) == 2
@ -424,6 +444,7 @@ async def test_raising_error_trait(hass):
}
],
},
const.SOURCE_CLOUD,
)
assert result == {
@ -448,6 +469,7 @@ async def test_raising_error_trait(hass):
"command": "action.devices.commands.ThermostatTemperatureSetpoint",
"params": {"thermostatTemperatureSetpoint": 10},
},
"source": "cloud",
}
@ -483,6 +505,7 @@ async def test_unavailable_state_does_sync(hass):
BASIC_CONFIG,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -515,7 +538,7 @@ async def test_unavailable_state_does_sync(hass):
assert len(events) == 1
assert events[0].event_type == EVENT_SYNC_RECEIVED
assert events[0].data == {"request_id": REQ_ID}
assert events[0].data == {"request_id": REQ_ID, "source": "cloud"}
@pytest.mark.parametrize(
@ -545,6 +568,7 @@ async def test_device_class_switch(hass, device_class, google_type):
BASIC_CONFIG,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -589,6 +613,7 @@ async def test_device_class_binary_sensor(hass, device_class, google_type):
BASIC_CONFIG,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -629,6 +654,7 @@ async def test_device_class_cover(hass, device_class, google_type):
BASIC_CONFIG,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -669,6 +695,7 @@ async def test_device_media_player(hass, device_class, google_type):
BASIC_CONFIG,
"test-agent",
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
const.SOURCE_CLOUD,
)
assert result == {
@ -702,6 +729,7 @@ async def test_query_disconnect(hass):
config,
"test-agent",
{"inputs": [{"intent": "action.devices.DISCONNECT"}], "requestId": REQ_ID},
const.SOURCE_CLOUD,
)
assert result is None
assert len(mock_disconnect.mock_calls) == 1
@ -751,6 +779,7 @@ async def test_trait_execute_adding_query_data(hass):
}
],
},
const.SOURCE_CLOUD,
)
assert result == {
@ -817,6 +846,7 @@ async def test_identify(hass):
}
],
},
const.SOURCE_CLOUD,
)
assert result == {
@ -851,8 +881,11 @@ async def test_reachable_devices(hass):
# Not passed in as google_id
hass.states.async_set("light.not_mentioned", "on")
# Has 2FA
hass.states.async_set("lock.has_2fa", "on")
config = MockConfig(
should_expose=lambda state: state.entity_id != "light.not_expose"
should_expose=lambda state: state.entity_id != "light.not_expose",
)
user_agent_id = "mock-user-id"
@ -898,9 +931,19 @@ async def test_reachable_devices(hass):
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
},
},
{
"id": "lock.has_2fa",
"customData": {
"httpPort": 8123,
"httpSSL": False,
"proxyDeviceId": proxy_device_id,
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
},
},
{"id": proxy_device_id, "customData": {}},
],
},
const.SOURCE_CLOUD,
)
assert result == {

View File

@ -51,11 +51,15 @@ _LOGGER = logging.getLogger(__name__)
REQ_ID = "ff36a3cc-ec34-11e6-b1a0-64510650abcf"
BASIC_DATA = helpers.RequestData(BASIC_CONFIG, "test-agent", REQ_ID, None)
BASIC_DATA = helpers.RequestData(
BASIC_CONFIG, "test-agent", const.SOURCE_CLOUD, REQ_ID, None
)
PIN_CONFIG = MockConfig(secure_devices_pin="1234")
PIN_DATA = helpers.RequestData(PIN_CONFIG, "test-agent", REQ_ID, None)
PIN_DATA = helpers.RequestData(
PIN_CONFIG, "test-agent", const.SOURCE_CLOUD, REQ_ID, None
)
async def test_brightness_light(hass):