Add custom JSONEncoder for automation traces (#47942)

* Add custom JSONEncoder for automation traces

* Add tests

* Update default case to include type

* Tweak

* Refactor

* Tweak

* Lint

* Update websocket_api.py
This commit is contained in:
Erik Montnemery 2021-03-16 14:21:05 +01:00 committed by GitHub
parent 5f2326fb57
commit 9647eeb2e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 2 deletions

View File

@ -2,11 +2,13 @@
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
import datetime as dt import datetime as dt
from datetime import timedelta
from itertools import count from itertools import count
import logging import logging
from typing import Any, Awaitable, Callable, Deque, Dict, Optional from typing import Any, Awaitable, Callable, Deque, Dict, Optional
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.json import JSONEncoder as HAJSONEncoder
from homeassistant.helpers.trace import TraceElement, trace_id_set from homeassistant.helpers.trace import TraceElement, trace_id_set
from homeassistant.helpers.typing import TemplateVarsType from homeassistant.helpers.typing import TemplateVarsType
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -203,3 +205,19 @@ def get_debug_traces(hass, summary=False):
traces.extend(get_debug_traces_for_automation(hass, automation_id, summary)) traces.extend(get_debug_traces_for_automation(hass, automation_id, summary))
return traces return traces
class TraceJSONEncoder(HAJSONEncoder):
"""JSONEncoder that supports Home Assistant objects and falls back to repr(o)."""
def default(self, o: Any) -> Any:
"""Convert certain objects.
Fall back to repr(o).
"""
if isinstance(o, timedelta):
return {"__type": str(type(o)), "total_seconds": o.total_seconds()}
try:
return super().default(o)
except TypeError:
return {"__type": str(type(o)), "repr": repr(o)}

View File

@ -1,4 +1,6 @@
"""Websocket API for automation.""" """Websocket API for automation."""
import json
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
@ -21,7 +23,12 @@ from homeassistant.helpers.script import (
debug_stop, debug_stop,
) )
from .trace import get_debug_trace, get_debug_traces, get_debug_traces_for_automation from .trace import (
TraceJSONEncoder,
get_debug_trace,
get_debug_traces,
get_debug_traces_for_automation,
)
# mypy: allow-untyped-calls, allow-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs
@ -55,8 +62,9 @@ def websocket_automation_trace_get(hass, connection, msg):
run_id = msg["run_id"] run_id = msg["run_id"]
trace = get_debug_trace(hass, automation_id, run_id) trace = get_debug_trace(hass, automation_id, run_id)
message = websocket_api.messages.result_message(msg["id"], trace)
connection.send_result(msg["id"], trace) connection.send_message(json.dumps(message, cls=TraceJSONEncoder, allow_nan=False))
@callback @callback

View File

@ -0,0 +1,42 @@
"""Test Automation trace helpers."""
from datetime import timedelta
from homeassistant import core
from homeassistant.components import automation
from homeassistant.util import dt as dt_util
def test_json_encoder(hass):
"""Test the Trace JSON Encoder."""
ha_json_enc = automation.trace.TraceJSONEncoder()
state = core.State("test.test", "hello")
# Test serializing a datetime
now = dt_util.utcnow()
assert ha_json_enc.default(now) == now.isoformat()
# Test serializing a timedelta
data = timedelta(
days=50,
seconds=27,
microseconds=10,
milliseconds=29000,
minutes=5,
hours=8,
weeks=2,
)
assert ha_json_enc.default(data) == {
"__type": str(type(data)),
"total_seconds": data.total_seconds(),
}
# Test serializing a set()
data = {"milk", "beer"}
assert sorted(ha_json_enc.default(data)) == sorted(list(data))
# Test serializong object which implements as_dict
assert ha_json_enc.default(state) == state.as_dict()
# Default method falls back to repr(o)
o = object()
assert ha_json_enc.default(o) == {"__type": str(type(o)), "repr": repr(o)}