From 9647eeb2e08822da190e47208e57b4133facd5cf Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 16 Mar 2021 14:21:05 +0100 Subject: [PATCH] Add custom JSONEncoder for automation traces (#47942) * Add custom JSONEncoder for automation traces * Add tests * Update default case to include type * Tweak * Refactor * Tweak * Lint * Update websocket_api.py --- homeassistant/components/automation/trace.py | 18 ++++++++ .../components/automation/websocket_api.py | 12 +++++- tests/components/automation/test_trace.py | 42 +++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 tests/components/automation/test_trace.py diff --git a/homeassistant/components/automation/trace.py b/homeassistant/components/automation/trace.py index 351ca1ed979..4aac3d327b8 100644 --- a/homeassistant/components/automation/trace.py +++ b/homeassistant/components/automation/trace.py @@ -2,11 +2,13 @@ from collections import OrderedDict from contextlib import contextmanager import datetime as dt +from datetime import timedelta from itertools import count import logging from typing import Any, Awaitable, Callable, Deque, Dict, Optional from homeassistant.core import Context, HomeAssistant, callback +from homeassistant.helpers.json import JSONEncoder as HAJSONEncoder from homeassistant.helpers.trace import TraceElement, trace_id_set from homeassistant.helpers.typing import TemplateVarsType from homeassistant.util import dt as dt_util @@ -203,3 +205,19 @@ def get_debug_traces(hass, summary=False): traces.extend(get_debug_traces_for_automation(hass, automation_id, summary)) return traces + + +class TraceJSONEncoder(HAJSONEncoder): + """JSONEncoder that supports Home Assistant objects and falls back to repr(o).""" + + def default(self, o: Any) -> Any: + """Convert certain objects. + + Fall back to repr(o). + """ + if isinstance(o, timedelta): + return {"__type": str(type(o)), "total_seconds": o.total_seconds()} + try: + return super().default(o) + except TypeError: + return {"__type": str(type(o)), "repr": repr(o)} diff --git a/homeassistant/components/automation/websocket_api.py b/homeassistant/components/automation/websocket_api.py index bb47dd58ff9..eba56f94e7d 100644 --- a/homeassistant/components/automation/websocket_api.py +++ b/homeassistant/components/automation/websocket_api.py @@ -1,4 +1,6 @@ """Websocket API for automation.""" +import json + import voluptuous as vol from homeassistant.components import websocket_api @@ -21,7 +23,12 @@ from homeassistant.helpers.script import ( debug_stop, ) -from .trace import get_debug_trace, get_debug_traces, get_debug_traces_for_automation +from .trace import ( + TraceJSONEncoder, + get_debug_trace, + get_debug_traces, + get_debug_traces_for_automation, +) # mypy: allow-untyped-calls, allow-untyped-defs @@ -55,8 +62,9 @@ def websocket_automation_trace_get(hass, connection, msg): run_id = msg["run_id"] trace = get_debug_trace(hass, automation_id, run_id) + message = websocket_api.messages.result_message(msg["id"], trace) - connection.send_result(msg["id"], trace) + connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False)) @callback diff --git a/tests/components/automation/test_trace.py b/tests/components/automation/test_trace.py new file mode 100644 index 00000000000..818f1ee1768 --- /dev/null +++ b/tests/components/automation/test_trace.py @@ -0,0 +1,42 @@ +"""Test Automation trace helpers.""" +from datetime import timedelta + +from homeassistant import core +from homeassistant.components import automation +from homeassistant.util import dt as dt_util + + +def test_json_encoder(hass): + """Test the Trace JSON Encoder.""" + ha_json_enc = automation.trace.TraceJSONEncoder() + state = core.State("test.test", "hello") + + # Test serializing a datetime + now = dt_util.utcnow() + assert ha_json_enc.default(now) == now.isoformat() + + # Test serializing a timedelta + data = timedelta( + days=50, + seconds=27, + microseconds=10, + milliseconds=29000, + minutes=5, + hours=8, + weeks=2, + ) + assert ha_json_enc.default(data) == { + "__type": str(type(data)), + "total_seconds": data.total_seconds(), + } + + # Test serializing a set() + data = {"milk", "beer"} + assert sorted(ha_json_enc.default(data)) == sorted(list(data)) + + # Test serializong object which implements as_dict + assert ha_json_enc.default(state) == state.as_dict() + + # Default method falls back to repr(o) + o = object() + assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}