mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Google Assistant: Track if request is local (#31226)
* Track if request is local * Cancel early if 2FA disabled * Allow disabling 2FA for ack * Do not mark devices with 2FA as reachable * Add request source to GA events
This commit is contained in:
parent
03954be12d
commit
8ceef72853
@ -11,7 +11,7 @@ from homeassistant.components.alexa import (
|
||||
errors as alexa_errors,
|
||||
smart_home as alexa_sh,
|
||||
)
|
||||
from homeassistant.components.google_assistant import smart_home as ga
|
||||
from homeassistant.components.google_assistant import const as gc, smart_home as ga
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
@ -160,7 +160,7 @@ class CloudClient(Interface):
|
||||
gconf = await self.get_google_config()
|
||||
|
||||
return await ga.async_handle_message(
|
||||
self._hass, gconf, gconf.cloud_user, payload
|
||||
self._hass, gconf, gconf.cloud_user, payload, gc.SOURCE_CLOUD
|
||||
)
|
||||
|
||||
async def async_webhook_message(self, payload: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
|
@ -143,3 +143,6 @@ CHALLENGE_PIN_NEEDED = "pinNeeded"
|
||||
CHALLENGE_FAILED_PIN_NEEDED = "challengeFailedPinNeeded"
|
||||
|
||||
STORE_AGENT_USER_IDS = "agent_user_ids"
|
||||
|
||||
SOURCE_CLOUD = "cloud"
|
||||
SOURCE_LOCAL = "local"
|
||||
|
@ -28,6 +28,7 @@ from .const import (
|
||||
DOMAIN,
|
||||
DOMAIN_TO_GOOGLE_TYPES,
|
||||
ERR_FUNCTION_NOT_SUPPORTED,
|
||||
SOURCE_LOCAL,
|
||||
STORE_AGENT_USER_IDS,
|
||||
)
|
||||
from .error import SmartHomeError
|
||||
@ -232,7 +233,7 @@ class AbstractConfig(ABC):
|
||||
return json_response(smart_home.turned_off_response(payload))
|
||||
|
||||
result = await smart_home.async_handle_message(
|
||||
self.hass, self, self.local_sdk_user_id, payload
|
||||
self.hass, self, self.local_sdk_user_id, payload, SOURCE_LOCAL
|
||||
)
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
@ -286,15 +287,22 @@ class RequestData:
|
||||
self,
|
||||
config: AbstractConfig,
|
||||
user_id: str,
|
||||
source: str,
|
||||
request_id: str,
|
||||
devices: Optional[List[dict]],
|
||||
):
|
||||
"""Initialize the request data."""
|
||||
self.config = config
|
||||
self.source = source
|
||||
self.request_id = request_id
|
||||
self.context = Context(user_id=user_id)
|
||||
self.devices = devices
|
||||
|
||||
@property
|
||||
def is_local_request(self):
|
||||
"""Return if this is a local request."""
|
||||
return self.source == SOURCE_LOCAL
|
||||
|
||||
|
||||
def get_google_type(domain, device_class):
|
||||
"""Google type based on domain and device class."""
|
||||
@ -354,6 +362,9 @@ class GoogleEntity:
|
||||
features = state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
|
||||
device_class = state.attributes.get(ATTR_DEVICE_CLASS)
|
||||
|
||||
if not self.config.should_2fa(state):
|
||||
return False
|
||||
|
||||
return any(
|
||||
trait.might_2fa(domain, features, device_class) for trait in self.traits()
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ from .const import (
|
||||
HOMEGRAPH_TOKEN_URL,
|
||||
REPORT_STATE_BASE_URL,
|
||||
REQUEST_SYNC_BASE_URL,
|
||||
SOURCE_CLOUD,
|
||||
)
|
||||
from .helpers import AbstractConfig
|
||||
from .smart_home import async_handle_message
|
||||
@ -238,6 +239,10 @@ class GoogleAssistantView(HomeAssistantView):
|
||||
"""Handle Google Assistant requests."""
|
||||
message: dict = await request.json()
|
||||
result = await async_handle_message(
|
||||
request.app["hass"], self.config, request["hass_user"].id, message
|
||||
request.app["hass"],
|
||||
self.config,
|
||||
request["hass_user"].id,
|
||||
message,
|
||||
SOURCE_CLOUD,
|
||||
)
|
||||
return self.json(result)
|
||||
|
@ -21,9 +21,11 @@ HANDLERS = Registry()
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_handle_message(hass, config, user_id, message):
|
||||
async def async_handle_message(hass, config, user_id, message, source):
|
||||
"""Handle incoming API messages."""
|
||||
data = RequestData(config, user_id, message["requestId"], message.get("devices"))
|
||||
data = RequestData(
|
||||
config, user_id, source, message["requestId"], message.get("devices")
|
||||
)
|
||||
|
||||
response = await _process(hass, data, message)
|
||||
|
||||
@ -75,7 +77,9 @@ async def async_devices_sync(hass, data, payload):
|
||||
https://developers.google.com/assistant/smarthome/develop/process-intents#SYNC
|
||||
"""
|
||||
hass.bus.async_fire(
|
||||
EVENT_SYNC_RECEIVED, {"request_id": data.request_id}, context=data.context
|
||||
EVENT_SYNC_RECEIVED,
|
||||
{"request_id": data.request_id, "source": data.source},
|
||||
context=data.context,
|
||||
)
|
||||
|
||||
agent_user_id = data.config.get_agent_user_id(data.context)
|
||||
@ -108,7 +112,11 @@ async def async_devices_query(hass, data, payload):
|
||||
|
||||
hass.bus.async_fire(
|
||||
EVENT_QUERY_RECEIVED,
|
||||
{"request_id": data.request_id, ATTR_ENTITY_ID: devid},
|
||||
{
|
||||
"request_id": data.request_id,
|
||||
ATTR_ENTITY_ID: devid,
|
||||
"source": data.source,
|
||||
},
|
||||
context=data.context,
|
||||
)
|
||||
|
||||
@ -142,6 +150,7 @@ async def handle_devices_execute(hass, data, payload):
|
||||
"request_id": data.request_id,
|
||||
ATTR_ENTITY_ID: entity_id,
|
||||
"execution": execution,
|
||||
"source": data.source,
|
||||
},
|
||||
context=data.context,
|
||||
)
|
||||
@ -234,7 +243,9 @@ async def async_devices_reachable(hass, data: RequestData, payload):
|
||||
"devices": [
|
||||
entity.reachable_device_serialize()
|
||||
for entity in async_get_entities(hass, data.config)
|
||||
if entity.entity_id in google_ids and entity.should_expose()
|
||||
if entity.entity_id in google_ids
|
||||
and entity.should_expose()
|
||||
and not entity.might_2fa()
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -1447,6 +1447,8 @@ def _verify_pin_challenge(data, state, challenge):
|
||||
|
||||
|
||||
def _verify_ack_challenge(data, state, challenge):
|
||||
"""Verify a pin challenge."""
|
||||
"""Verify an ack challenge."""
|
||||
if not data.config.should_2fa(state):
|
||||
return
|
||||
if not challenge or not challenge.get("ack"):
|
||||
raise ChallengeNeeded(CHALLENGE_ACK_NEEDED)
|
||||
|
@ -22,6 +22,7 @@ class MockConfig(helpers.AbstractConfig):
|
||||
*,
|
||||
secure_devices_pin=None,
|
||||
should_expose=None,
|
||||
should_2fa=None,
|
||||
entity_config=None,
|
||||
hass=None,
|
||||
local_sdk_webhook_id=None,
|
||||
|
@ -82,6 +82,7 @@ async def test_sync_message(hass):
|
||||
config,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -115,7 +116,7 @@ async def test_sync_message(hass):
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == EVENT_SYNC_RECEIVED
|
||||
assert events[0].data == {"request_id": REQ_ID}
|
||||
assert events[0].data == {"request_id": REQ_ID, "source": "cloud"}
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
@ -148,6 +149,7 @@ async def test_sync_in_area(hass, registries):
|
||||
config,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -181,7 +183,7 @@ async def test_sync_in_area(hass, registries):
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == EVENT_SYNC_RECEIVED
|
||||
assert events[0].data == {"request_id": REQ_ID}
|
||||
assert events[0].data == {"request_id": REQ_ID, "source": "cloud"}
|
||||
|
||||
|
||||
async def test_query_message(hass):
|
||||
@ -220,6 +222,7 @@ async def test_query_message(hass):
|
||||
}
|
||||
],
|
||||
},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -247,11 +250,23 @@ async def test_query_message(hass):
|
||||
|
||||
assert len(events) == 3
|
||||
assert events[0].event_type == EVENT_QUERY_RECEIVED
|
||||
assert events[0].data == {"request_id": REQ_ID, "entity_id": "light.demo_light"}
|
||||
assert events[0].data == {
|
||||
"request_id": REQ_ID,
|
||||
"entity_id": "light.demo_light",
|
||||
"source": "cloud",
|
||||
}
|
||||
assert events[1].event_type == EVENT_QUERY_RECEIVED
|
||||
assert events[1].data == {"request_id": REQ_ID, "entity_id": "light.another_light"}
|
||||
assert events[1].data == {
|
||||
"request_id": REQ_ID,
|
||||
"entity_id": "light.another_light",
|
||||
"source": "cloud",
|
||||
}
|
||||
assert events[2].event_type == EVENT_QUERY_RECEIVED
|
||||
assert events[2].data == {"request_id": REQ_ID, "entity_id": "light.non_existing"}
|
||||
assert events[2].data == {
|
||||
"request_id": REQ_ID,
|
||||
"entity_id": "light.non_existing",
|
||||
"source": "cloud",
|
||||
}
|
||||
|
||||
|
||||
async def test_execute(hass):
|
||||
@ -300,6 +315,7 @@ async def test_execute(hass):
|
||||
}
|
||||
],
|
||||
},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -341,6 +357,7 @@ async def test_execute(hass):
|
||||
"command": "action.devices.commands.OnOff",
|
||||
"params": {"on": True},
|
||||
},
|
||||
"source": "cloud",
|
||||
}
|
||||
assert events[1].event_type == EVENT_COMMAND_RECEIVED
|
||||
assert events[1].data == {
|
||||
@ -350,6 +367,7 @@ async def test_execute(hass):
|
||||
"command": "action.devices.commands.BrightnessAbsolute",
|
||||
"params": {"brightness": 20},
|
||||
},
|
||||
"source": "cloud",
|
||||
}
|
||||
assert events[2].event_type == EVENT_COMMAND_RECEIVED
|
||||
assert events[2].data == {
|
||||
@ -359,6 +377,7 @@ async def test_execute(hass):
|
||||
"command": "action.devices.commands.OnOff",
|
||||
"params": {"on": True},
|
||||
},
|
||||
"source": "cloud",
|
||||
}
|
||||
assert events[3].event_type == EVENT_COMMAND_RECEIVED
|
||||
assert events[3].data == {
|
||||
@ -368,6 +387,7 @@ async def test_execute(hass):
|
||||
"command": "action.devices.commands.BrightnessAbsolute",
|
||||
"params": {"brightness": 20},
|
||||
},
|
||||
"source": "cloud",
|
||||
}
|
||||
|
||||
assert len(service_events) == 2
|
||||
@ -424,6 +444,7 @@ async def test_raising_error_trait(hass):
|
||||
}
|
||||
],
|
||||
},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -448,6 +469,7 @@ async def test_raising_error_trait(hass):
|
||||
"command": "action.devices.commands.ThermostatTemperatureSetpoint",
|
||||
"params": {"thermostatTemperatureSetpoint": 10},
|
||||
},
|
||||
"source": "cloud",
|
||||
}
|
||||
|
||||
|
||||
@ -483,6 +505,7 @@ async def test_unavailable_state_does_sync(hass):
|
||||
BASIC_CONFIG,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -515,7 +538,7 @@ async def test_unavailable_state_does_sync(hass):
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0].event_type == EVENT_SYNC_RECEIVED
|
||||
assert events[0].data == {"request_id": REQ_ID}
|
||||
assert events[0].data == {"request_id": REQ_ID, "source": "cloud"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -545,6 +568,7 @@ async def test_device_class_switch(hass, device_class, google_type):
|
||||
BASIC_CONFIG,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -589,6 +613,7 @@ async def test_device_class_binary_sensor(hass, device_class, google_type):
|
||||
BASIC_CONFIG,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -629,6 +654,7 @@ async def test_device_class_cover(hass, device_class, google_type):
|
||||
BASIC_CONFIG,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -669,6 +695,7 @@ async def test_device_media_player(hass, device_class, google_type):
|
||||
BASIC_CONFIG,
|
||||
"test-agent",
|
||||
{"requestId": REQ_ID, "inputs": [{"intent": "action.devices.SYNC"}]},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -702,6 +729,7 @@ async def test_query_disconnect(hass):
|
||||
config,
|
||||
"test-agent",
|
||||
{"inputs": [{"intent": "action.devices.DISCONNECT"}], "requestId": REQ_ID},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
assert result is None
|
||||
assert len(mock_disconnect.mock_calls) == 1
|
||||
@ -751,6 +779,7 @@ async def test_trait_execute_adding_query_data(hass):
|
||||
}
|
||||
],
|
||||
},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -817,6 +846,7 @@ async def test_identify(hass):
|
||||
}
|
||||
],
|
||||
},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
@ -851,8 +881,11 @@ async def test_reachable_devices(hass):
|
||||
# Not passed in as google_id
|
||||
hass.states.async_set("light.not_mentioned", "on")
|
||||
|
||||
# Has 2FA
|
||||
hass.states.async_set("lock.has_2fa", "on")
|
||||
|
||||
config = MockConfig(
|
||||
should_expose=lambda state: state.entity_id != "light.not_expose"
|
||||
should_expose=lambda state: state.entity_id != "light.not_expose",
|
||||
)
|
||||
|
||||
user_agent_id = "mock-user-id"
|
||||
@ -898,9 +931,19 @@ async def test_reachable_devices(hass):
|
||||
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "lock.has_2fa",
|
||||
"customData": {
|
||||
"httpPort": 8123,
|
||||
"httpSSL": False,
|
||||
"proxyDeviceId": proxy_device_id,
|
||||
"webhookId": "dde3b9800a905e886cc4d38e226a6e7e3f2a6993d2b9b9f63d13e42ee7de3219",
|
||||
},
|
||||
},
|
||||
{"id": proxy_device_id, "customData": {}},
|
||||
],
|
||||
},
|
||||
const.SOURCE_CLOUD,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
|
@ -51,11 +51,15 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
REQ_ID = "ff36a3cc-ec34-11e6-b1a0-64510650abcf"
|
||||
|
||||
BASIC_DATA = helpers.RequestData(BASIC_CONFIG, "test-agent", REQ_ID, None)
|
||||
BASIC_DATA = helpers.RequestData(
|
||||
BASIC_CONFIG, "test-agent", const.SOURCE_CLOUD, REQ_ID, None
|
||||
)
|
||||
|
||||
PIN_CONFIG = MockConfig(secure_devices_pin="1234")
|
||||
|
||||
PIN_DATA = helpers.RequestData(PIN_CONFIG, "test-agent", REQ_ID, None)
|
||||
PIN_DATA = helpers.RequestData(
|
||||
PIN_CONFIG, "test-agent", const.SOURCE_CLOUD, REQ_ID, None
|
||||
)
|
||||
|
||||
|
||||
async def test_brightness_light(hass):
|
||||
|
Loading…
x
Reference in New Issue
Block a user