Avoid bytes to string to bytes conversion in websocket api (#108139)

This commit is contained in:
J. Nick Koston 2024-01-16 10:37:34 -10:00 committed by GitHub
parent ad35113e86
commit 60ab360fe7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 137 additions and 93 deletions

View File

@ -214,7 +214,7 @@ class APIStatesView(HomeAssistantView):
if entity_perm(state.entity_id, "read") if entity_perm(state.entity_id, "read")
) )
response = web.Response( response = web.Response(
body=f'[{",".join(states)}]', body=b"[" + b",".join(states) + b"]",
content_type=CONTENT_TYPE_JSON, content_type=CONTENT_TYPE_JSON,
zlib_executor_size=32768, zlib_executor_size=32768,
) )

View File

@ -45,16 +45,16 @@ def websocket_list_devices(
msg_json_prefix = ( msg_json_prefix = (
f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",' f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",'
f'"success":true,"result": [' f'"success":true,"result": ['
) ).encode()
# Concatenate cached entity registry item JSON serializations # Concatenate cached entity registry item JSON serializations
msg_json = ( msg_json = (
msg_json_prefix msg_json_prefix
+ ",".join( + b",".join(
entry.json_repr entry.json_repr
for entry in registry.devices.values() for entry in registry.devices.values()
if entry.json_repr is not None if entry.json_repr is not None
) )
+ "]}" + b"]}"
) )
connection.send_message(msg_json) connection.send_message(msg_json)

View File

@ -43,20 +43,23 @@ def websocket_list_entities(
msg_json_prefix = ( msg_json_prefix = (
f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",' f'{{"id":{msg["id"]},"type": "{websocket_api.const.TYPE_RESULT}",'
'"success":true,"result": [' '"success":true,"result": ['
) ).encode()
# Concatenate cached entity registry item JSON serializations # Concatenate cached entity registry item JSON serializations
msg_json = ( msg_json = (
msg_json_prefix msg_json_prefix
+ ",".join( + b",".join(
entry.partial_json_repr entry.partial_json_repr
for entry in registry.entities.values() for entry in registry.entities.values()
if entry.partial_json_repr is not None if entry.partial_json_repr is not None
) )
+ "]}" + b"]}"
) )
connection.send_message(msg_json) connection.send_message(msg_json)
_ENTITY_CATEGORIES_JSON = json_dumps(er.ENTITY_CATEGORY_INDEX_TO_VALUE)
@websocket_api.websocket_command( @websocket_api.websocket_command(
{vol.Required("type"): "config/entity_registry/list_for_display"} {vol.Required("type"): "config/entity_registry/list_for_display"}
) )
@ -69,20 +72,19 @@ def websocket_list_entities_for_display(
"""Handle list registry entries command.""" """Handle list registry entries command."""
registry = er.async_get(hass) registry = er.async_get(hass)
# Build start of response message # Build start of response message
entity_categories = json_dumps(er.ENTITY_CATEGORY_INDEX_TO_VALUE)
msg_json_prefix = ( msg_json_prefix = (
f'{{"id":{msg["id"]},"type":"{websocket_api.const.TYPE_RESULT}","success":true,' f'{{"id":{msg["id"]},"type":"{websocket_api.const.TYPE_RESULT}","success":true,'
f'"result":{{"entity_categories":{entity_categories},"entities":[' f'"result":{{"entity_categories":{_ENTITY_CATEGORIES_JSON},"entities":['
) ).encode()
# Concatenate cached entity registry item JSON serializations # Concatenate cached entity registry item JSON serializations
msg_json = ( msg_json = (
msg_json_prefix msg_json_prefix
+ ",".join( + b",".join(
entry.display_json_repr entry.display_json_repr
for entry in registry.entities.values() for entry in registry.entities.values()
if entry.disabled_by is None and entry.display_json_repr is not None if entry.disabled_by is None and entry.display_json_repr is not None
) )
+ "]}}" + b"]}}"
) )
connection.send_message(msg_json) connection.send_message(msg_json)

View File

@ -34,7 +34,7 @@ from homeassistant.helpers.event import (
async_track_point_in_utc_time, async_track_point_in_utc_time,
async_track_state_change_event, async_track_state_change_event,
) )
from homeassistant.helpers.json import JSON_DUMP from homeassistant.helpers.json import json_bytes
from homeassistant.helpers.typing import EventType from homeassistant.helpers.typing import EventType
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -72,9 +72,9 @@ def _ws_get_significant_states(
significant_changes_only: bool, significant_changes_only: bool,
minimal_response: bool, minimal_response: bool,
no_attributes: bool, no_attributes: bool,
) -> str: ) -> bytes:
"""Fetch history significant_states and convert them to json in the executor.""" """Fetch history significant_states and convert them to json in the executor."""
return JSON_DUMP( return json_bytes(
messages.result_message( messages.result_message(
msg_id, msg_id,
history.get_significant_states( history.get_significant_states(
@ -201,9 +201,9 @@ def _generate_websocket_response(
start_time: dt, start_time: dt,
end_time: dt, end_time: dt,
states: MutableMapping[str, list[dict[str, Any]]], states: MutableMapping[str, list[dict[str, Any]]],
) -> str: ) -> bytes:
"""Generate a websocket response.""" """Generate a websocket response."""
return JSON_DUMP( return json_bytes(
messages.event_message( messages.event_message(
msg_id, _generate_stream_message(states, start_time, end_time) msg_id, _generate_stream_message(states, start_time, end_time)
) )
@ -221,7 +221,7 @@ def _generate_historical_response(
minimal_response: bool, minimal_response: bool,
no_attributes: bool, no_attributes: bool,
send_empty: bool, send_empty: bool,
) -> tuple[float, dt | None, str | None]: ) -> tuple[float, dt | None, bytes | None]:
"""Generate a historical response.""" """Generate a historical response."""
states = cast( states = cast(
MutableMapping[str, list[dict[str, Any]]], MutableMapping[str, list[dict[str, Any]]],
@ -346,7 +346,7 @@ async def _async_events_consumer(
if history_states := _events_to_compressed_states(events, no_attributes): if history_states := _events_to_compressed_states(events, no_attributes):
connection.send_message( connection.send_message(
JSON_DUMP( json_bytes(
messages.event_message( messages.event_message(
msg_id, msg_id,
{"states": history_states}, {"states": history_states},

View File

@ -16,7 +16,7 @@ from homeassistant.components.websocket_api import messages
from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.helpers.json import JSON_DUMP from homeassistant.helpers.json import json_bytes
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import DOMAIN from .const import DOMAIN
@ -70,7 +70,7 @@ def _async_send_empty_response(
stream_end_time = end_time or dt_util.utcnow() stream_end_time = end_time or dt_util.utcnow()
empty_stream_message = _generate_stream_message([], start_time, stream_end_time) empty_stream_message = _generate_stream_message([], start_time, stream_end_time)
empty_response = messages.event_message(msg_id, empty_stream_message) empty_response = messages.event_message(msg_id, empty_stream_message)
connection.send_message(JSON_DUMP(empty_response)) connection.send_message(json_bytes(empty_response))
async def _async_send_historical_events( async def _async_send_historical_events(
@ -165,7 +165,7 @@ async def _async_get_ws_stream_events(
formatter: Callable[[int, Any], dict[str, Any]], formatter: Callable[[int, Any], dict[str, Any]],
event_processor: EventProcessor, event_processor: EventProcessor,
partial: bool, partial: bool,
) -> tuple[str, dt | None]: ) -> tuple[bytes, dt | None]:
"""Async wrapper around _ws_formatted_get_events.""" """Async wrapper around _ws_formatted_get_events."""
return await get_instance(hass).async_add_executor_job( return await get_instance(hass).async_add_executor_job(
_ws_stream_get_events, _ws_stream_get_events,
@ -196,7 +196,7 @@ def _ws_stream_get_events(
formatter: Callable[[int, Any], dict[str, Any]], formatter: Callable[[int, Any], dict[str, Any]],
event_processor: EventProcessor, event_processor: EventProcessor,
partial: bool, partial: bool,
) -> tuple[str, dt | None]: ) -> tuple[bytes, dt | None]:
"""Fetch events and convert them to json in the executor.""" """Fetch events and convert them to json in the executor."""
events = event_processor.get_events(start_day, end_day) events = event_processor.get_events(start_day, end_day)
last_time = None last_time = None
@ -209,7 +209,7 @@ def _ws_stream_get_events(
# data in case the UI needs to show that historical # data in case the UI needs to show that historical
# data is still loading in the future # data is still loading in the future
message["partial"] = True message["partial"] = True
return JSON_DUMP(formatter(msg_id, message)), last_time return json_bytes(formatter(msg_id, message)), last_time
async def _async_events_consumer( async def _async_events_consumer(
@ -238,7 +238,7 @@ async def _async_events_consumer(
async_event_to_row(e) for e in events async_event_to_row(e) for e in events
): ):
connection.send_message( connection.send_message(
JSON_DUMP( json_bytes(
messages.event_message( messages.event_message(
msg_id, msg_id,
{"events": logbook_events}, {"events": logbook_events},
@ -435,9 +435,9 @@ def _ws_formatted_get_events(
start_time: dt, start_time: dt,
end_time: dt, end_time: dt,
event_processor: EventProcessor, event_processor: EventProcessor,
) -> str: ) -> bytes:
"""Fetch events and convert them to json in the executor.""" """Fetch events and convert them to json in the executor."""
return JSON_DUMP( return json_bytes(
messages.result_message( messages.result_message(
msg_id, event_processor.get_events(start_time, end_time) msg_id, event_processor.get_events(start_time, end_time)
) )

View File

@ -12,7 +12,7 @@ from homeassistant.components.websocket_api import messages
from homeassistant.core import HomeAssistant, callback, valid_entity_id from homeassistant.core import HomeAssistant, callback, valid_entity_id
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP from homeassistant.helpers.json import json_bytes
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.unit_conversion import ( from homeassistant.util.unit_conversion import (
DataRateConverter, DataRateConverter,
@ -97,9 +97,9 @@ def _ws_get_statistic_during_period(
statistic_id: str, statistic_id: str,
types: set[Literal["max", "mean", "min", "change"]] | None, types: set[Literal["max", "mean", "min", "change"]] | None,
units: dict[str, str], units: dict[str, str],
) -> str: ) -> bytes:
"""Fetch statistics and convert them to json in the executor.""" """Fetch statistics and convert them to json in the executor."""
return JSON_DUMP( return json_bytes(
messages.result_message( messages.result_message(
msg_id, msg_id,
statistic_during_period( statistic_during_period(
@ -155,7 +155,7 @@ def _ws_get_statistics_during_period(
period: Literal["5minute", "day", "hour", "week", "month"], period: Literal["5minute", "day", "hour", "week", "month"],
units: dict[str, str], units: dict[str, str],
types: set[Literal["change", "last_reset", "max", "mean", "min", "state", "sum"]], types: set[Literal["change", "last_reset", "max", "mean", "min", "state", "sum"]],
) -> str: ) -> bytes:
"""Fetch statistics and convert them to json in the executor.""" """Fetch statistics and convert them to json in the executor."""
result = statistics_during_period( result = statistics_during_period(
hass, hass,
@ -174,7 +174,7 @@ def _ws_get_statistics_during_period(
item["end"] = int(end * 1000) item["end"] = int(end * 1000)
if (last_reset := item.get("last_reset")) is not None: if (last_reset := item.get("last_reset")) is not None:
item["last_reset"] = int(last_reset * 1000) item["last_reset"] = int(last_reset * 1000)
return JSON_DUMP(messages.result_message(msg_id, result)) return json_bytes(messages.result_message(msg_id, result))
async def ws_handle_get_statistics_during_period( async def ws_handle_get_statistics_during_period(
@ -242,12 +242,12 @@ def _ws_get_list_statistic_ids(
hass: HomeAssistant, hass: HomeAssistant,
msg_id: int, msg_id: int,
statistic_type: Literal["mean"] | Literal["sum"] | None = None, statistic_type: Literal["mean"] | Literal["sum"] | None = None,
) -> str: ) -> bytes:
"""Fetch a list of available statistic_id and convert them to JSON. """Fetch a list of available statistic_id and convert them to JSON.
Runs in the executor. Runs in the executor.
""" """
return JSON_DUMP( return json_bytes(
messages.result_message(msg_id, list_statistic_ids(hass, None, statistic_type)) messages.result_message(msg_id, list_statistic_ids(hass, None, statistic_type))
) )

View File

@ -57,7 +57,7 @@ class AuthPhase:
self, self,
logger: WebSocketAdapter, logger: WebSocketAdapter,
hass: HomeAssistant, hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None], send_message: Callable[[bytes | str | dict[str, Any]], None],
cancel_ws: CALLBACK_TYPE, cancel_ws: CALLBACK_TYPE,
request: Request, request: Request,
) -> None: ) -> None:

View File

@ -104,7 +104,7 @@ def pong_message(iden: int) -> dict[str, Any]:
@callback @callback
def _forward_events_check_permissions( def _forward_events_check_permissions(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None],
user: User, user: User,
msg_id: int, msg_id: int,
event: Event, event: Event,
@ -124,7 +124,7 @@ def _forward_events_check_permissions(
@callback @callback
def _forward_events_unconditional( def _forward_events_unconditional(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], send_message: Callable[[bytes | str | dict[str, Any] | Callable[[], str]], None],
msg_id: int, msg_id: int,
event: Event, event: Event,
) -> None: ) -> None:
@ -352,17 +352,17 @@ def handle_get_states(
def _send_handle_get_states_response( def _send_handle_get_states_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str] connection: ActiveConnection, msg_id: int, serialized_states: list[bytes]
) -> None: ) -> None:
"""Send handle get states response.""" """Send handle get states response."""
connection.send_message( connection.send_message(
construct_result_message(msg_id, f'[{",".join(serialized_states)}]') construct_result_message(msg_id, b"[" + b",".join(serialized_states) + b"]")
) )
@callback @callback
def _forward_entity_changes( def _forward_entity_changes(
send_message: Callable[[str | dict[str, Any] | Callable[[], str]], None], send_message: Callable[[str | bytes | dict[str, Any] | Callable[[], str]], None],
entity_ids: set[str], entity_ids: set[str],
user: User, user: User,
msg_id: int, msg_id: int,
@ -444,11 +444,19 @@ def handle_subscribe_entities(
def _send_handle_entities_init_response( def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str] connection: ActiveConnection, msg_id: int, serialized_states: list[bytes]
) -> None: ) -> None:
"""Send handle entities init response.""" """Send handle entities init response."""
connection.send_message( connection.send_message(
f'{{"id":{msg_id},"type":"event","event":{{"a":{{{",".join(serialized_states)}}}}}}}' b"".join(
(
b'{"id":',
str(msg_id).encode(),
b',"type":"event","event":{"a":{',
b",".join(serialized_states),
b"}}}",
)
)
) )
@ -474,7 +482,7 @@ async def handle_get_services(
) -> None: ) -> None:
"""Handle get services command.""" """Handle get services command."""
payload = await _async_get_all_descriptions_json(hass) payload = await _async_get_all_descriptions_json(hass)
connection.send_message(construct_result_message(msg["id"], payload)) connection.send_message(construct_result_message(msg["id"], payload.encode()))
@callback @callback

View File

@ -51,7 +51,7 @@ class ActiveConnection:
self, self,
logger: WebSocketAdapter, logger: WebSocketAdapter,
hass: HomeAssistant, hass: HomeAssistant,
send_message: Callable[[str | dict[str, Any]], None], send_message: Callable[[bytes | str | dict[str, Any]], None],
user: User, user: User,
refresh_token: RefreshToken, refresh_token: RefreshToken,
) -> None: ) -> None:
@ -244,7 +244,7 @@ class ActiveConnection:
@callback @callback
def _connect_closed_error( def _connect_closed_error(
self, msg: str | dict[str, Any] | Callable[[], str] self, msg: bytes | str | dict[str, Any] | Callable[[], str]
) -> None: ) -> None:
"""Send a message when the connection is closed.""" """Send a message when the connection is closed."""
self.logger.debug("Tried to send message %s on closed connection", msg) self.logger.debug("Tried to send message %s on closed connection", msg)

View File

@ -5,6 +5,7 @@ import asyncio
from collections import deque from collections import deque
from collections.abc import Callable from collections.abc import Callable
import datetime as dt import datetime as dt
from functools import partial
import logging import logging
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
@ -28,7 +29,7 @@ from .const import (
URL, URL,
) )
from .error import Disconnect from .error import Disconnect
from .messages import message_to_json from .messages import message_to_json_bytes
from .util import describe_request from .util import describe_request
if TYPE_CHECKING: if TYPE_CHECKING:
@ -94,7 +95,7 @@ class WebSocketHandler:
# to where messages are queued. This allows the implementation # to where messages are queued. This allows the implementation
# to use a deque and an asyncio.Future to avoid the overhead of # to use a deque and an asyncio.Future to avoid the overhead of
# an asyncio.Queue. # an asyncio.Queue.
self._message_queue: deque[str | None] = deque() self._message_queue: deque[bytes | None] = deque()
self._ready_future: asyncio.Future[None] | None = None self._ready_future: asyncio.Future[None] | None = None
def __repr__(self) -> str: def __repr__(self) -> str:
@ -121,7 +122,10 @@ class WebSocketHandler:
message_queue = self._message_queue message_queue = self._message_queue
logger = self._logger logger = self._logger
wsock = self._wsock wsock = self._wsock
send_str = wsock.send_str writer = wsock._writer # pylint: disable=protected-access
if TYPE_CHECKING:
assert writer is not None
send_str = partial(writer.send, binary=False)
loop = self._hass.loop loop = self._hass.loop
debug = logger.debug debug = logger.debug
is_enabled_for = logger.isEnabledFor is_enabled_for = logger.isEnabledFor
@ -151,7 +155,7 @@ class WebSocketHandler:
await send_str(message) await send_str(message)
continue continue
messages: list[str] = [message] messages: list[bytes] = [message]
while messages_remaining: while messages_remaining:
# A None message is used to signal the end of the connection # A None message is used to signal the end of the connection
if (message := message_queue.popleft()) is None: if (message := message_queue.popleft()) is None:
@ -159,7 +163,7 @@ class WebSocketHandler:
messages.append(message) messages.append(message)
messages_remaining -= 1 messages_remaining -= 1
coalesced_messages = f'[{",".join(messages)}]' coalesced_messages = b"".join((b"[", b",".join(messages), b"]"))
if debug_enabled: if debug_enabled:
debug("%s: Sending %s", self.description, coalesced_messages) debug("%s: Sending %s", self.description, coalesced_messages)
await send_str(coalesced_messages) await send_str(coalesced_messages)
@ -181,7 +185,7 @@ class WebSocketHandler:
self._peak_checker_unsub = None self._peak_checker_unsub = None
@callback @callback
def _send_message(self, message: str | dict[str, Any]) -> None: def _send_message(self, message: str | bytes | dict[str, Any]) -> None:
"""Send a message to the client. """Send a message to the client.
Closes connection if the client is not reading the messages. Closes connection if the client is not reading the messages.
@ -194,7 +198,9 @@ class WebSocketHandler:
return return
if isinstance(message, dict): if isinstance(message, dict):
message = message_to_json(message) message = message_to_json_bytes(message)
elif isinstance(message, str):
message = message.encode("utf-8")
message_queue = self._message_queue message_queue = self._message_queue
queue_size_before_add = len(message_queue) queue_size_before_add = len(message_queue)

View File

@ -16,7 +16,11 @@ from homeassistant.const import (
) )
from homeassistant.core import Event, State from homeassistant.core import Event, State
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.json import JSON_DUMP, find_paths_unserializable_data from homeassistant.helpers.json import (
JSON_DUMP,
find_paths_unserializable_data,
json_bytes,
)
from homeassistant.util.json import format_unserializable_data from homeassistant.util.json import format_unserializable_data
from . import const from . import const
@ -44,7 +48,7 @@ BASE_ERROR_MESSAGE = {
"success": False, "success": False,
} }
INVALID_JSON_PARTIAL_MESSAGE = JSON_DUMP( INVALID_JSON_PARTIAL_MESSAGE = json_bytes(
{ {
**BASE_ERROR_MESSAGE, **BASE_ERROR_MESSAGE,
"error": { "error": {
@ -60,9 +64,17 @@ def result_message(iden: int, result: Any = None) -> dict[str, Any]:
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result} return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}
def construct_result_message(iden: int, payload: str) -> str: def construct_result_message(iden: int, payload: bytes) -> bytes:
"""Construct a success result message JSON.""" """Construct a success result message JSON."""
return f'{{"id":{iden},"type":"result","success":true,"result":{payload}}}' return b"".join(
(
b'{"id":',
str(iden).encode(),
b',"type":"result","success":true,"result":',
payload,
b"}",
)
)
def error_message( def error_message(
@ -96,7 +108,7 @@ def event_message(iden: int, event: Any) -> dict[str, Any]:
return {"id": iden, "type": "event", "event": event} return {"id": iden, "type": "event", "event": event}
def cached_event_message(iden: int, event: Event) -> str: def cached_event_message(iden: int, event: Event) -> bytes:
"""Return an event message. """Return an event message.
Serialize to json once per message. Serialize to json once per message.
@ -105,23 +117,30 @@ def cached_event_message(iden: int, event: Event) -> str:
all getting many of the same events (mostly state changed) all getting many of the same events (mostly state changed)
we can avoid serializing the same data for each connection. we can avoid serializing the same data for each connection.
""" """
return f'{_partial_cached_event_message(event)[:-1]},"id":{iden}}}' return b"".join(
(
_partial_cached_event_message(event)[:-1],
b',"id":',
str(iden).encode(),
b"}",
)
)
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def _partial_cached_event_message(event: Event) -> str: def _partial_cached_event_message(event: Event) -> bytes:
"""Cache and serialize the event to json. """Cache and serialize the event to json.
The message is constructed without the id which appended The message is constructed without the id which appended
in cached_event_message. in cached_event_message.
""" """
return ( return (
_message_to_json_or_none({"type": "event", "event": event.json_fragment}) _message_to_json_bytes_or_none({"type": "event", "event": event.json_fragment})
or INVALID_JSON_PARTIAL_MESSAGE or INVALID_JSON_PARTIAL_MESSAGE
) )
def cached_state_diff_message(iden: int, event: Event) -> str: def cached_state_diff_message(iden: int, event: Event) -> bytes:
"""Return an event message. """Return an event message.
Serialize to json once per message. Serialize to json once per message.
@ -130,18 +149,27 @@ def cached_state_diff_message(iden: int, event: Event) -> str:
all getting many of the same events (mostly state changed) all getting many of the same events (mostly state changed)
we can avoid serializing the same data for each connection. we can avoid serializing the same data for each connection.
""" """
return f'{_partial_cached_state_diff_message(event)[:-1]},"id":{iden}}}' return b"".join(
(
_partial_cached_state_diff_message(event)[:-1],
b',"id":',
str(iden).encode(),
b"}",
)
)
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def _partial_cached_state_diff_message(event: Event) -> str: def _partial_cached_state_diff_message(event: Event) -> bytes:
"""Cache and serialize the event to json. """Cache and serialize the event to json.
The message is constructed without the id which The message is constructed without the id which
will be appended in cached_state_diff_message will be appended in cached_state_diff_message
""" """
return ( return (
_message_to_json_or_none({"type": "event", "event": _state_diff_event(event)}) _message_to_json_bytes_or_none(
{"type": "event", "event": _state_diff_event(event)}
)
or INVALID_JSON_PARTIAL_MESSAGE or INVALID_JSON_PARTIAL_MESSAGE
) )
@ -212,10 +240,10 @@ def _state_diff(
return {ENTITY_EVENT_CHANGE: {new_state.entity_id: diff}} return {ENTITY_EVENT_CHANGE: {new_state.entity_id: diff}}
def _message_to_json_or_none(message: dict[str, Any]) -> str | None: def _message_to_json_bytes_or_none(message: dict[str, Any]) -> bytes | None:
"""Serialize a websocket message to json or return None.""" """Serialize a websocket message to json or return None."""
try: try:
return JSON_DUMP(message) return json_bytes(message)
except (ValueError, TypeError): except (ValueError, TypeError):
_LOGGER.error( _LOGGER.error(
"Unable to serialize to JSON. Bad data found at %s", "Unable to serialize to JSON. Bad data found at %s",
@ -226,9 +254,9 @@ def _message_to_json_or_none(message: dict[str, Any]) -> str | None:
return None return None
def message_to_json(message: dict[str, Any]) -> str: def message_to_json_bytes(message: dict[str, Any]) -> bytes:
"""Serialize a websocket message to json or return an error.""" """Serialize a websocket message to json or return an error."""
return _message_to_json_or_none(message) or JSON_DUMP( return _message_to_json_bytes_or_none(message) or json_bytes(
error_message( error_message(
message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response" message["id"], const.ERR_UNKNOWN_ERROR, "Invalid JSON in response"
) )

View File

@ -86,7 +86,7 @@ from .helpers.deprecation import (
check_if_deprecated_constant, check_if_deprecated_constant,
dir_with_deprecated_constants, dir_with_deprecated_constants,
) )
from .helpers.json import json_dumps, json_fragment from .helpers.json import json_bytes, json_fragment
from .util import dt as dt_util, location from .util import dt as dt_util, location
from .util.async_ import ( from .util.async_ import (
cancelling, cancelling,
@ -1039,7 +1039,7 @@ class Context:
@cached_property @cached_property
def json_fragment(self) -> json_fragment: def json_fragment(self) -> json_fragment:
"""Return a JSON fragment of the context.""" """Return a JSON fragment of the context."""
return json_fragment(json_dumps(self._as_dict)) return json_fragment(json_bytes(self._as_dict))
class EventOrigin(enum.Enum): class EventOrigin(enum.Enum):
@ -1126,7 +1126,7 @@ class Event:
@cached_property @cached_property
def json_fragment(self) -> json_fragment: def json_fragment(self) -> json_fragment:
"""Return an event as a JSON fragment.""" """Return an event as a JSON fragment."""
return json_fragment(json_dumps(self._as_dict)) return json_fragment(json_bytes(self._as_dict))
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return the representation.""" """Return the representation."""
@ -1512,9 +1512,9 @@ class State:
return ReadOnlyDict(as_dict) return ReadOnlyDict(as_dict)
@cached_property @cached_property
def as_dict_json(self) -> str: def as_dict_json(self) -> bytes:
"""Return a JSON string of the State.""" """Return a JSON string of the State."""
return json_dumps(self._as_dict) return json_bytes(self._as_dict)
@cached_property @cached_property
def json_fragment(self) -> json_fragment: def json_fragment(self) -> json_fragment:
@ -1550,14 +1550,14 @@ class State:
return compressed_state return compressed_state
@cached_property @cached_property
def as_compressed_state_json(self) -> str: def as_compressed_state_json(self) -> bytes:
"""Build a compressed JSON key value pair of a state for adds. """Build a compressed JSON key value pair of a state for adds.
The JSON string is a key value pair of the entity_id and the compressed state. The JSON string is a key value pair of the entity_id and the compressed state.
It is used for sending multiple states in a single message. It is used for sending multiple states in a single message.
""" """
return json_dumps({self.entity_id: self.as_compressed_state})[1:-1] return json_bytes({self.entity_id: self.as_compressed_state})[1:-1]
@classmethod @classmethod
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None: def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:

View File

@ -29,7 +29,7 @@ from .deprecation import (
dir_with_deprecated_constants, dir_with_deprecated_constants,
) )
from .frame import report from .frame import report
from .json import JSON_DUMP, find_paths_unserializable_data from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -277,11 +277,11 @@ class DeviceEntry:
} }
@cached_property @cached_property
def json_repr(self) -> str | None: def json_repr(self) -> bytes | None:
"""Return a cached JSON representation of the entry.""" """Return a cached JSON representation of the entry."""
try: try:
dict_repr = self.dict_repr dict_repr = self.dict_repr
return JSON_DUMP(dict_repr) return json_bytes(dict_repr)
except (ValueError, TypeError): except (ValueError, TypeError):
_LOGGER.error( _LOGGER.error(
"Unable to serialize entry %s to JSON. Bad data found at %s", "Unable to serialize entry %s to JSON. Bad data found at %s",

View File

@ -51,7 +51,7 @@ from homeassistant.util.read_only_dict import ReadOnlyDict
from . import device_registry as dr, storage from . import device_registry as dr, storage
from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED from .device_registry import EVENT_DEVICE_REGISTRY_UPDATED
from .json import JSON_DUMP, find_paths_unserializable_data from .json import JSON_DUMP, find_paths_unserializable_data, json_bytes
from .typing import UNDEFINED, UndefinedType from .typing import UNDEFINED, UndefinedType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -227,14 +227,14 @@ class RegistryEntry:
return display_dict return display_dict
@cached_property @cached_property
def display_json_repr(self) -> str | None: def display_json_repr(self) -> bytes | None:
"""Return a cached partial JSON representation of the entry. """Return a cached partial JSON representation of the entry.
This version only includes what's needed for display. This version only includes what's needed for display.
""" """
try: try:
dict_repr = self._as_display_dict dict_repr = self._as_display_dict
json_repr: str | None = JSON_DUMP(dict_repr) if dict_repr else None json_repr: bytes | None = json_bytes(dict_repr) if dict_repr else None
return json_repr return json_repr
except (ValueError, TypeError): except (ValueError, TypeError):
_LOGGER.error( _LOGGER.error(
@ -282,11 +282,11 @@ class RegistryEntry:
} }
@cached_property @cached_property
def partial_json_repr(self) -> str | None: def partial_json_repr(self) -> bytes | None:
"""Return a cached partial JSON representation of the entry.""" """Return a cached partial JSON representation of the entry."""
try: try:
dict_repr = self.as_partial_dict dict_repr = self.as_partial_dict
return JSON_DUMP(dict_repr) return json_bytes(dict_repr)
except (ValueError, TypeError): except (ValueError, TypeError):
_LOGGER.error( _LOGGER.error(
"Unable to serialize entry %s to JSON. Bad data found at %s", "Unable to serialize entry %s to JSON. Bad data found at %s",

View File

@ -5,7 +5,7 @@ from homeassistant.components.websocket_api.messages import (
_partial_cached_event_message as lru_event_cache, _partial_cached_event_message as lru_event_cache,
_state_diff_event, _state_diff_event,
cached_event_message, cached_event_message,
message_to_json, message_to_json_bytes,
) )
from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.const import EVENT_STATE_CHANGED
from homeassistant.core import Context, Event, HomeAssistant, State, callback from homeassistant.core import Context, Event, HomeAssistant, State, callback
@ -282,18 +282,18 @@ async def test_state_diff_event(hass: HomeAssistant) -> None:
} }
async def test_message_to_json(caplog: pytest.LogCaptureFixture) -> None: async def test_message_to_json_bytes(caplog: pytest.LogCaptureFixture) -> None:
"""Test we can serialize websocket messages.""" """Test we can serialize websocket messages."""
json_str = message_to_json({"id": 1, "message": "xyz"}) json_str = message_to_json_bytes({"id": 1, "message": "xyz"})
assert json_str == '{"id":1,"message":"xyz"}' assert json_str == b'{"id":1,"message":"xyz"}'
json_str2 = message_to_json({"id": 1, "message": _Unserializeable()}) json_str2 = message_to_json_bytes({"id": 1, "message": _Unserializeable()})
assert ( assert (
json_str2 json_str2
== '{"id":1,"type":"result","success":false,"error":{"code":"unknown_error","message":"Invalid JSON in response"}}' == b'{"id":1,"type":"result","success":false,"error":{"code":"unknown_error","message":"Invalid JSON in response"}}'
) )
assert "Unable to serialize to JSON" in caplog.text assert "Unable to serialize to JSON" in caplog.text

View File

@ -742,9 +742,9 @@ def test_state_as_dict_json() -> None:
context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"), context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"),
) )
expected = ( expected = (
'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},' b'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},'
'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",' b'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",'
'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}' b'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}'
) )
as_dict_json_1 = state.as_dict_json as_dict_json_1 = state.as_dict_json
assert as_dict_json_1 == expected assert as_dict_json_1 == expected
@ -852,7 +852,7 @@ def test_state_as_compressed_state_json() -> None:
last_changed=last_time, last_changed=last_time,
context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"), context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"),
) )
expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}' expected = b'"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}'
as_compressed_state = state.as_compressed_state_json as_compressed_state = state.as_compressed_state_json
# We are not too concerned about these being ReadOnlyDict # We are not too concerned about these being ReadOnlyDict
# since we don't expect them to be called by external callers # since we don't expect them to be called by external callers