Improve API protobuf decode method readability and reduce code size (#9455)

This commit is contained in:
J. Nick Koston 2025-07-15 15:15:11 -10:00 committed by GitHub
parent 5c2dea79ef
commit b5be45273f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 606 additions and 795 deletions

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,6 @@ from pathlib import Path
import re import re
from subprocess import call from subprocess import call
import sys import sys
from textwrap import dedent
from typing import Any from typing import Any
import aioesphomeapi.api_options_pb2 as pb import aioesphomeapi.api_options_pb2 as pb
@ -157,13 +156,7 @@ class TypeInfo(ABC):
content = self.decode_varint content = self.decode_varint
if content is None: if content is None:
return None return None
return dedent( return f"case {self.number}: this->{self.field_name} = {content}; break;"
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
decode_varint = None decode_varint = None
@ -172,13 +165,7 @@ class TypeInfo(ABC):
content = self.decode_length content = self.decode_length
if content is None: if content is None:
return None return None
return dedent( return f"case {self.number}: this->{self.field_name} = {content}; break;"
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
decode_length = None decode_length = None
@ -187,13 +174,7 @@ class TypeInfo(ABC):
content = self.decode_32bit content = self.decode_32bit
if content is None: if content is None:
return None return None
return dedent( return f"case {self.number}: this->{self.field_name} = {content}; break;"
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
decode_32bit = None decode_32bit = None
@ -202,13 +183,7 @@ class TypeInfo(ABC):
content = self.decode_64bit content = self.decode_64bit
if content is None: if content is None:
return None return None
return dedent( return f"case {self.number}: this->{self.field_name} = {content}; break;"
f"""\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}"""
)
decode_64bit = None decode_64bit = None
@ -580,13 +555,7 @@ class MessageType(TypeInfo):
@property @property
def decode_length_content(self) -> str: def decode_length_content(self) -> str:
# Custom decode that doesn't use templates # Custom decode that doesn't use templates
return dedent( return f"case {self.number}: value.decode_to_message(this->{self.field_name}); break;"
f"""\
case {self.number}: {{
value.decode_to_message(this->{self.field_name});
return true;
}}"""
)
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f"{name}.dump_to(out);" o = f"{name}.dump_to(out);"
@ -797,12 +766,8 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_varint content = self._ti.decode_varint
if content is None: if content is None:
return None return None
return dedent( return (
f"""\ f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
) )
@property @property
@ -810,22 +775,11 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_length content = self._ti.decode_length
if content is None and isinstance(self._ti, MessageType): if content is None and isinstance(self._ti, MessageType):
# Special handling for non-template message decoding # Special handling for non-template message decoding
return dedent( return f"case {self.number}: this->{self.field_name}.emplace_back(); value.decode_to_message(this->{self.field_name}.back()); break;"
f"""\
case {self.number}: {{
this->{self.field_name}.emplace_back();
value.decode_to_message(this->{self.field_name}.back());
return true;
}}"""
)
if content is None: if content is None:
return None return None
return dedent( return (
f"""\ f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
) )
@property @property
@ -833,12 +787,8 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_32bit content = self._ti.decode_32bit
if content is None: if content is None:
return None return None
return dedent( return (
f"""\ f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
) )
@property @property
@ -846,12 +796,8 @@ class RepeatedTypeInfo(TypeInfo):
content = self._ti.decode_64bit content = self._ti.decode_64bit
if content is None: if content is None:
return None return None
return dedent( return (
f"""\ f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}"""
) )
@property @property
@ -1155,41 +1101,45 @@ def build_message_type(
cpp = "" cpp = ""
if decode_varint: if decode_varint:
decode_varint.append("default:\n return false;")
o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n" o = f"bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n"
o += " switch (field_id) {\n" o += " switch (field_id) {\n"
o += indent("\n".join(decode_varint), " ") + "\n" o += indent("\n".join(decode_varint), " ") + "\n"
o += " default: return false;\n"
o += " }\n" o += " }\n"
o += " return true;\n"
o += "}\n" o += "}\n"
cpp += o cpp += o
prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;" prot = "bool decode_varint(uint32_t field_id, ProtoVarInt value) override;"
protected_content.insert(0, prot) protected_content.insert(0, prot)
if decode_length: if decode_length:
decode_length.append("default:\n return false;")
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n" o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " switch (field_id) {\n" o += " switch (field_id) {\n"
o += indent("\n".join(decode_length), " ") + "\n" o += indent("\n".join(decode_length), " ") + "\n"
o += " default: return false;\n"
o += " }\n" o += " }\n"
o += " return true;\n"
o += "}\n" o += "}\n"
cpp += o cpp += o
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;" prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot) protected_content.insert(0, prot)
if decode_32bit: if decode_32bit:
decode_32bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n" o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " switch (field_id) {\n" o += " switch (field_id) {\n"
o += indent("\n".join(decode_32bit), " ") + "\n" o += indent("\n".join(decode_32bit), " ") + "\n"
o += " default: return false;\n"
o += " }\n" o += " }\n"
o += " return true;\n"
o += "}\n" o += "}\n"
cpp += o cpp += o
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;" prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot) protected_content.insert(0, prot)
if decode_64bit: if decode_64bit:
decode_64bit.append("default:\n return false;")
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n" o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " switch (field_id) {\n" o += " switch (field_id) {\n"
o += indent("\n".join(decode_64bit), " ") + "\n" o += indent("\n".join(decode_64bit), " ") + "\n"
o += " default: return false;\n"
o += " }\n" o += " }\n"
o += " return true;\n"
o += "}\n" o += "}\n"
cpp += o cpp += o
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;" prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"