Add typing to protobuf code generator (#8541)

This commit is contained in:
J. Nick Koston 2025-04-15 10:19:22 -10:00 committed by GitHub
parent 7e133171e0
commit 3677ef71d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
from abc import ABC, abstractmethod
import os
from pathlib import Path
import re
from subprocess import call
import sys
from textwrap import dedent
from typing import Any
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
"""Python 3 script to automatically generate C++ classes for ESPHome's native API. """Python 3 script to automatically generate C++ classes for ESPHome's native API.
It's pretty crappy spaghetti code, but it works. It's pretty crappy spaghetti code, but it works.
@ -17,25 +33,14 @@ then run this script with python3 and the files
will be generated, they still need to be formatted will be generated, they still need to be formatted
""" """
from abc import ABC, abstractmethod
import os
from pathlib import Path
import re
from subprocess import call
import sys
from textwrap import dedent
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
FILE_HEADER = """// This file was automatically generated with a tool. FILE_HEADER = """// This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py // See scripts/api_protobuf/api_protobuf.py
""" """
def indent_list(text, padding=" "): def indent_list(text: str, padding: str = " ") -> list[str]:
"""Indent each line of the given text with the specified padding."""
lines = [] lines = []
for line in text.splitlines(): for line in text.splitlines():
if line == "": if line == "":
@ -48,54 +53,62 @@ def indent_list(text, padding=" "):
return lines return lines
def indent(text, padding=" "): def indent(text: str, padding: str = " ") -> str:
return "\n".join(indent_list(text, padding)) return "\n".join(indent_list(text, padding))
def camel_to_snake(name): def camel_to_snake(name: str) -> str:
# https://stackoverflow.com/a/1176023 # https://stackoverflow.com/a/1176023
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
class TypeInfo(ABC): class TypeInfo(ABC):
def __init__(self, field): """Base class for all type information."""
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
self._field = field self._field = field
@property @property
def default_value(self): def default_value(self) -> str:
"""Get the default value."""
return "" return ""
@property @property
def name(self): def name(self) -> str:
"""Get the name of the field."""
return self._field.name return self._field.name
@property @property
def arg_name(self): def arg_name(self) -> str:
"""Get the argument name."""
return self.name return self.name
@property @property
def field_name(self): def field_name(self) -> str:
"""Get the field name."""
return self.name return self.name
@property @property
def number(self): def number(self) -> int:
"""Get the field number."""
return self._field.number return self._field.number
@property @property
def repeated(self): def repeated(self) -> bool:
"""Check if the field is repeated."""
return self._field.label == 3 return self._field.label == 3
@property @property
def cpp_type(self): def cpp_type(self) -> str:
raise NotImplementedError raise NotImplementedError
@property @property
def reference_type(self): def reference_type(self) -> str:
return f"{self.cpp_type} " return f"{self.cpp_type} "
@property @property
def const_reference_type(self): def const_reference_type(self) -> str:
return f"{self.cpp_type} " return f"{self.cpp_type} "
@property @property
@ -171,28 +184,31 @@ class TypeInfo(ABC):
decode_64bit = None decode_64bit = None
@property @property
def encode_content(self): def encode_content(self) -> str:
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});" return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
encode_func = None encode_func = None
@property @property
def dump_content(self): def dump_content(self) -> str:
o = f'out.append(" {self.name}: ");\n' o = f'out.append(" {self.name}: ");\n'
o += self.dump(f"this->{self.field_name}") + "\n" o += self.dump(f"this->{self.field_name}") + "\n"
o += 'out.append("\\n");\n' o += 'out.append("\\n");\n'
return o return o
@abstractmethod @abstractmethod
def dump(self, name: str): def dump(self, name: str) -> str:
pass """Dump the value to the output."""
TYPE_INFO = {} TYPE_INFO: dict[int, TypeInfo] = {}
def register_type(name): def register_type(name: int):
def func(value): """Decorator to register a type with a name and number."""
def func(value: TypeInfo) -> TypeInfo:
"""Register the type with the given name and number."""
TYPE_INFO[name] = value TYPE_INFO[name] = value
return value return value
@ -206,7 +222,7 @@ class DoubleType(TypeInfo):
decode_64bit = "value.as_double()" decode_64bit = "value.as_double()"
encode_func = "encode_double" encode_func = "encode_double"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n' o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -219,7 +235,7 @@ class FloatType(TypeInfo):
decode_32bit = "value.as_float()" decode_32bit = "value.as_float()"
encode_func = "encode_float" encode_func = "encode_float"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n' o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -232,7 +248,7 @@ class Int64Type(TypeInfo):
decode_varint = "value.as_int64()" decode_varint = "value.as_int64()"
encode_func = "encode_int64" encode_func = "encode_int64"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -245,7 +261,7 @@ class UInt64Type(TypeInfo):
decode_varint = "value.as_uint64()" decode_varint = "value.as_uint64()"
encode_func = "encode_uint64" encode_func = "encode_uint64"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n' o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -258,7 +274,7 @@ class Int32Type(TypeInfo):
decode_varint = "value.as_int32()" decode_varint = "value.as_int32()"
encode_func = "encode_int32" encode_func = "encode_int32"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -271,7 +287,7 @@ class Fixed64Type(TypeInfo):
decode_64bit = "value.as_fixed64()" decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64" encode_func = "encode_fixed64"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n' o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -284,7 +300,7 @@ class Fixed32Type(TypeInfo):
decode_32bit = "value.as_fixed32()" decode_32bit = "value.as_fixed32()"
encode_func = "encode_fixed32" encode_func = "encode_fixed32"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n' o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -297,7 +313,7 @@ class BoolType(TypeInfo):
decode_varint = "value.as_bool()" decode_varint = "value.as_bool()"
encode_func = "encode_bool" encode_func = "encode_bool"
def dump(self, name): def dump(self, name: str) -> str:
o = f"out.append(YESNO({name}));" o = f"out.append(YESNO({name}));"
return o return o
@ -319,28 +335,28 @@ class StringType(TypeInfo):
@register_type(11) @register_type(11)
class MessageType(TypeInfo): class MessageType(TypeInfo):
@property @property
def cpp_type(self): def cpp_type(self) -> str:
return self._field.type_name[1:] return self._field.type_name[1:]
default_value = "" default_value = ""
@property @property
def reference_type(self): def reference_type(self) -> str:
return f"{self.cpp_type} &" return f"{self.cpp_type} &"
@property @property
def const_reference_type(self): def const_reference_type(self) -> str:
return f"const {self.cpp_type} &" return f"const {self.cpp_type} &"
@property @property
def encode_func(self): def encode_func(self) -> str:
return f"encode_message<{self.cpp_type}>" return f"encode_message<{self.cpp_type}>"
@property @property
def decode_length(self): def decode_length(self) -> str:
return f"value.as_message<{self.cpp_type}>()" return f"value.as_message<{self.cpp_type}>()"
def dump(self, name): def dump(self, name: str) -> str:
o = f"{name}.dump_to(out);" o = f"{name}.dump_to(out);"
return o return o
@ -354,7 +370,7 @@ class BytesType(TypeInfo):
decode_length = "value.as_string()" decode_length = "value.as_string()"
encode_func = "encode_string" encode_func = "encode_string"
def dump(self, name): def dump(self, name: str) -> str:
o = f'out.append("\'").append({name}).append("\'");' o = f'out.append("\'").append({name}).append("\'");'
return o return o
@ -366,7 +382,7 @@ class UInt32Type(TypeInfo):
decode_varint = "value.as_uint32()" decode_varint = "value.as_uint32()"
encode_func = "encode_uint32" encode_func = "encode_uint32"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n' o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -375,20 +391,20 @@ class UInt32Type(TypeInfo):
@register_type(14) @register_type(14)
class EnumType(TypeInfo): class EnumType(TypeInfo):
@property @property
def cpp_type(self): def cpp_type(self) -> str:
return f"enums::{self._field.type_name[1:]}" return f"enums::{self._field.type_name[1:]}"
@property @property
def decode_varint(self): def decode_varint(self) -> str:
return f"value.as_enum<{self.cpp_type}>()" return f"value.as_enum<{self.cpp_type}>()"
default_value = "" default_value = ""
@property @property
def encode_func(self): def encode_func(self) -> str:
return f"encode_enum<{self.cpp_type}>" return f"encode_enum<{self.cpp_type}>"
def dump(self, name): def dump(self, name: str) -> str:
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));" o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
return o return o
@ -400,7 +416,7 @@ class SFixed32Type(TypeInfo):
decode_32bit = "value.as_sfixed32()" decode_32bit = "value.as_sfixed32()"
encode_func = "encode_sfixed32" encode_func = "encode_sfixed32"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -413,7 +429,7 @@ class SFixed64Type(TypeInfo):
decode_64bit = "value.as_sfixed64()" decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64" encode_func = "encode_sfixed64"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -426,7 +442,7 @@ class SInt32Type(TypeInfo):
decode_varint = "value.as_sint32()" decode_varint = "value.as_sint32()"
encode_func = "encode_sint32" encode_func = "encode_sint32"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
@ -439,27 +455,27 @@ class SInt64Type(TypeInfo):
decode_varint = "value.as_sint64()" decode_varint = "value.as_sint64()"
encode_func = "encode_sint64" encode_func = "encode_sint64"
def dump(self, name): def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
class RepeatedTypeInfo(TypeInfo): class RepeatedTypeInfo(TypeInfo):
def __init__(self, field): def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
super().__init__(field) super().__init__(field)
self._ti = TYPE_INFO[field.type](field) self._ti: TypeInfo = TYPE_INFO[field.type](field)
@property @property
def cpp_type(self): def cpp_type(self) -> str:
return f"std::vector<{self._ti.cpp_type}>" return f"std::vector<{self._ti.cpp_type}>"
@property @property
def reference_type(self): def reference_type(self) -> str:
return f"{self.cpp_type} &" return f"{self.cpp_type} &"
@property @property
def const_reference_type(self): def const_reference_type(self) -> str:
return f"const {self.cpp_type} &" return f"const {self.cpp_type} &"
@property @property
@ -515,19 +531,19 @@ class RepeatedTypeInfo(TypeInfo):
) )
@property @property
def _ti_is_bool(self): def _ti_is_bool(self) -> bool:
# std::vector is specialized for bool, reference does not work # std::vector is specialized for bool, reference does not work
return isinstance(self._ti, BoolType) return isinstance(self._ti, BoolType)
@property @property
def encode_content(self): def encode_content(self) -> str:
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n" o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
o += "}" o += "}"
return o return o
@property @property
def dump_content(self): def dump_content(self) -> str:
o = f"for (const auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o = f"for (const auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f' out.append(" {self.name}: ");\n' o += f' out.append(" {self.name}: ");\n'
o += indent(self._ti.dump("it")) + "\n" o += indent(self._ti.dump("it")) + "\n"
@ -539,7 +555,8 @@ class RepeatedTypeInfo(TypeInfo):
pass pass
def build_enum_type(desc): def build_enum_type(desc) -> tuple[str, str]:
"""Builds the enum type."""
name = desc.name name = desc.name
out = f"enum {name} : uint32_t {{\n" out = f"enum {name} : uint32_t {{\n"
for v in desc.value: for v in desc.value:
@ -561,15 +578,15 @@ def build_enum_type(desc):
return out, cpp return out, cpp
def build_message_type(desc): def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
public_content = [] public_content: list[str] = []
protected_content = [] protected_content: list[str] = []
decode_varint = [] decode_varint: list[str] = []
decode_length = [] decode_length: list[str] = []
decode_32bit = [] decode_32bit: list[str] = []
decode_64bit = [] decode_64bit: list[str] = []
encode = [] encode: list[str] = []
dump = [] dump: list[str] = []
for field in desc.field: for field in desc.field:
if field.label == 3: if field.label == 3:
@ -687,27 +704,35 @@ SOURCE_BOTH = 0
SOURCE_SERVER = 1 SOURCE_SERVER = 1
SOURCE_CLIENT = 2 SOURCE_CLIENT = 2
RECEIVE_CASES = {} RECEIVE_CASES: dict[int, str] = {}
ifdefs = {} ifdefs: dict[str, str] = {}
def get_opt(desc, opt, default=None): def get_opt(
desc: descriptor.DescriptorProto,
opt: descriptor.MessageOptions,
default: Any = None,
) -> Any:
"""Get the option from the descriptor."""
if not desc.options.HasExtension(opt): if not desc.options.HasExtension(opt):
return default return default
return desc.options.Extensions[opt] return desc.options.Extensions[opt]
def build_service_message_type(mt): def build_service_message_type(
mt: descriptor.DescriptorProto,
) -> tuple[str, str] | None:
"""Builds the service message type."""
snake = camel_to_snake(mt.name) snake = camel_to_snake(mt.name)
id_ = get_opt(mt, pb.id) id_: int | None = get_opt(mt, pb.id)
if id_ is None: if id_ is None:
return None return None
source = get_opt(mt, pb.source, 0) source: int = get_opt(mt, pb.source, 0)
ifdef = get_opt(mt, pb.ifdef) ifdef: str | None = get_opt(mt, pb.ifdef)
log = get_opt(mt, pb.log, True) log: bool = get_opt(mt, pb.log, True)
hout = "" hout = ""
cout = "" cout = ""
@ -754,7 +779,8 @@ def build_service_message_type(mt):
return hout, cout return hout, cout
def main(): def main() -> None:
"""Main function to generate the C++ classes."""
cwd = Path(__file__).resolve().parent cwd = Path(__file__).resolve().parent
root = cwd.parent.parent / "esphome" / "components" / "api" root = cwd.parent.parent / "esphome" / "components" / "api"
prot_file = root / "api.protoc" prot_file = root / "api.protoc"
@ -959,7 +985,7 @@ def main():
try: try:
import clang_format import clang_format
def exec_clang_format(path): def exec_clang_format(path: Path) -> None:
clang_format_path = os.path.join( clang_format_path = os.path.join(
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format" os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
) )