Add typing to protobuf code generator (#8541)

This commit is contained in:
J. Nick Koston 2025-04-15 10:19:22 -10:00 committed by GitHub
parent 7e133171e0
commit 3677ef71d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,20 @@
#!/usr/bin/env python3
from __future__ import annotations
from abc import ABC, abstractmethod
import os
from pathlib import Path
import re
from subprocess import call
import sys
from textwrap import dedent
from typing import Any
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
It's pretty crappy spaghetti code, but it works.
@ -17,25 +33,14 @@ then run this script with python3 and the files
will be generated, they still need to be formatted
"""
from abc import ABC, abstractmethod
import os
from pathlib import Path
import re
from subprocess import call
import sys
from textwrap import dedent
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
FILE_HEADER = """// This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py
"""
def indent_list(text, padding=" "):
def indent_list(text: str, padding: str = " ") -> list[str]:
"""Indent each line of the given text with the specified padding."""
lines = []
for line in text.splitlines():
if line == "":
@ -48,54 +53,62 @@ def indent_list(text, padding=" "):
return lines
def indent(text, padding=" "):
def indent(text: str, padding: str = " ") -> str:
return "\n".join(indent_list(text, padding))
def camel_to_snake(name):
def camel_to_snake(name: str) -> str:
# https://stackoverflow.com/a/1176023
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
class TypeInfo(ABC):
def __init__(self, field):
"""Base class for all type information."""
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
self._field = field
@property
def default_value(self):
def default_value(self) -> str:
"""Get the default value."""
return ""
@property
def name(self):
def name(self) -> str:
"""Get the name of the field."""
return self._field.name
@property
def arg_name(self):
def arg_name(self) -> str:
"""Get the argument name."""
return self.name
@property
def field_name(self):
def field_name(self) -> str:
"""Get the field name."""
return self.name
@property
def number(self):
def number(self) -> int:
"""Get the field number."""
return self._field.number
@property
def repeated(self):
def repeated(self) -> bool:
"""Check if the field is repeated."""
return self._field.label == 3
@property
def cpp_type(self):
def cpp_type(self) -> str:
raise NotImplementedError
@property
def reference_type(self):
def reference_type(self) -> str:
return f"{self.cpp_type} "
@property
def const_reference_type(self):
def const_reference_type(self) -> str:
return f"{self.cpp_type} "
@property
@ -171,28 +184,31 @@ class TypeInfo(ABC):
decode_64bit = None
@property
def encode_content(self):
def encode_content(self) -> str:
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
encode_func = None
@property
def dump_content(self):
def dump_content(self) -> str:
o = f'out.append(" {self.name}: ");\n'
o += self.dump(f"this->{self.field_name}") + "\n"
o += 'out.append("\\n");\n'
return o
@abstractmethod
def dump(self, name: str):
pass
def dump(self, name: str) -> str:
"""Dump the value to the output."""
TYPE_INFO = {}
TYPE_INFO: dict[int, TypeInfo] = {}
def register_type(name):
def func(value):
def register_type(name: int):
"""Decorator to register a type with a name and number."""
def func(value: TypeInfo) -> TypeInfo:
"""Register the type with the given name and number."""
TYPE_INFO[name] = value
return value
@ -206,7 +222,7 @@ class DoubleType(TypeInfo):
decode_64bit = "value.as_double()"
encode_func = "encode_double"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);"
return o
@ -219,7 +235,7 @@ class FloatType(TypeInfo):
decode_32bit = "value.as_float()"
encode_func = "encode_float"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);"
return o
@ -232,7 +248,7 @@ class Int64Type(TypeInfo):
decode_varint = "value.as_int64()"
encode_func = "encode_int64"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);"
return o
@ -245,7 +261,7 @@ class UInt64Type(TypeInfo):
decode_varint = "value.as_uint64()"
encode_func = "encode_uint64"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);"
return o
@ -258,7 +274,7 @@ class Int32Type(TypeInfo):
decode_varint = "value.as_int32()"
encode_func = "encode_int32"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
@ -271,7 +287,7 @@ class Fixed64Type(TypeInfo):
decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);"
return o
@ -284,7 +300,7 @@ class Fixed32Type(TypeInfo):
decode_32bit = "value.as_fixed32()"
encode_func = "encode_fixed32"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);"
return o
@ -297,7 +313,7 @@ class BoolType(TypeInfo):
decode_varint = "value.as_bool()"
encode_func = "encode_bool"
def dump(self, name):
def dump(self, name: str) -> str:
o = f"out.append(YESNO({name}));"
return o
@ -319,28 +335,28 @@ class StringType(TypeInfo):
@register_type(11)
class MessageType(TypeInfo):
@property
def cpp_type(self):
def cpp_type(self) -> str:
return self._field.type_name[1:]
default_value = ""
@property
def reference_type(self):
def reference_type(self) -> str:
return f"{self.cpp_type} &"
@property
def const_reference_type(self):
def const_reference_type(self) -> str:
return f"const {self.cpp_type} &"
@property
def encode_func(self):
def encode_func(self) -> str:
return f"encode_message<{self.cpp_type}>"
@property
def decode_length(self):
def decode_length(self) -> str:
return f"value.as_message<{self.cpp_type}>()"
def dump(self, name):
def dump(self, name: str) -> str:
o = f"{name}.dump_to(out);"
return o
@ -354,7 +370,7 @@ class BytesType(TypeInfo):
decode_length = "value.as_string()"
encode_func = "encode_string"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'out.append("\'").append({name}).append("\'");'
return o
@ -366,7 +382,7 @@ class UInt32Type(TypeInfo):
decode_varint = "value.as_uint32()"
encode_func = "encode_uint32"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);"
return o
@ -375,20 +391,20 @@ class UInt32Type(TypeInfo):
@register_type(14)
class EnumType(TypeInfo):
@property
def cpp_type(self):
def cpp_type(self) -> str:
return f"enums::{self._field.type_name[1:]}"
@property
def decode_varint(self):
def decode_varint(self) -> str:
return f"value.as_enum<{self.cpp_type}>()"
default_value = ""
@property
def encode_func(self):
def encode_func(self) -> str:
return f"encode_enum<{self.cpp_type}>"
def dump(self, name):
def dump(self, name: str) -> str:
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
return o
@ -400,7 +416,7 @@ class SFixed32Type(TypeInfo):
decode_32bit = "value.as_sfixed32()"
encode_func = "encode_sfixed32"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
@ -413,7 +429,7 @@ class SFixed64Type(TypeInfo):
decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);"
return o
@ -426,7 +442,7 @@ class SInt32Type(TypeInfo):
decode_varint = "value.as_sint32()"
encode_func = "encode_sint32"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
@ -439,27 +455,27 @@ class SInt64Type(TypeInfo):
decode_varint = "value.as_sint64()"
encode_func = "encode_sint64"
def dump(self, name):
def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);"
return o
class RepeatedTypeInfo(TypeInfo):
def __init__(self, field):
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
super().__init__(field)
self._ti = TYPE_INFO[field.type](field)
self._ti: TypeInfo = TYPE_INFO[field.type](field)
@property
def cpp_type(self):
def cpp_type(self) -> str:
return f"std::vector<{self._ti.cpp_type}>"
@property
def reference_type(self):
def reference_type(self) -> str:
return f"{self.cpp_type} &"
@property
def const_reference_type(self):
def const_reference_type(self) -> str:
return f"const {self.cpp_type} &"
@property
@ -515,19 +531,19 @@ class RepeatedTypeInfo(TypeInfo):
)
@property
def _ti_is_bool(self):
def _ti_is_bool(self) -> bool:
# std::vector is specialized for bool, reference does not work
return isinstance(self._ti, BoolType)
@property
def encode_content(self):
def encode_content(self) -> str:
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
o += "}"
return o
@property
def dump_content(self):
def dump_content(self) -> str:
o = f"for (const auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f' out.append(" {self.name}: ");\n'
o += indent(self._ti.dump("it")) + "\n"
@ -539,7 +555,8 @@ class RepeatedTypeInfo(TypeInfo):
pass
def build_enum_type(desc):
def build_enum_type(desc) -> tuple[str, str]:
"""Builds the enum type."""
name = desc.name
out = f"enum {name} : uint32_t {{\n"
for v in desc.value:
@ -561,15 +578,15 @@ def build_enum_type(desc):
return out, cpp
def build_message_type(desc):
public_content = []
protected_content = []
decode_varint = []
decode_length = []
decode_32bit = []
decode_64bit = []
encode = []
dump = []
def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
public_content: list[str] = []
protected_content: list[str] = []
decode_varint: list[str] = []
decode_length: list[str] = []
decode_32bit: list[str] = []
decode_64bit: list[str] = []
encode: list[str] = []
dump: list[str] = []
for field in desc.field:
if field.label == 3:
@ -687,27 +704,35 @@ SOURCE_BOTH = 0
SOURCE_SERVER = 1
SOURCE_CLIENT = 2
RECEIVE_CASES = {}
RECEIVE_CASES: dict[int, str] = {}
ifdefs = {}
ifdefs: dict[str, str] = {}
def get_opt(desc, opt, default=None):
def get_opt(
desc: descriptor.DescriptorProto,
opt: descriptor.MessageOptions,
default: Any = None,
) -> Any:
"""Get the option from the descriptor."""
if not desc.options.HasExtension(opt):
return default
return desc.options.Extensions[opt]
def build_service_message_type(mt):
def build_service_message_type(
mt: descriptor.DescriptorProto,
) -> tuple[str, str] | None:
"""Builds the service message type."""
snake = camel_to_snake(mt.name)
id_ = get_opt(mt, pb.id)
id_: int | None = get_opt(mt, pb.id)
if id_ is None:
return None
source = get_opt(mt, pb.source, 0)
source: int = get_opt(mt, pb.source, 0)
ifdef = get_opt(mt, pb.ifdef)
log = get_opt(mt, pb.log, True)
ifdef: str | None = get_opt(mt, pb.ifdef)
log: bool = get_opt(mt, pb.log, True)
hout = ""
cout = ""
@ -754,7 +779,8 @@ def build_service_message_type(mt):
return hout, cout
def main():
def main() -> None:
"""Main function to generate the C++ classes."""
cwd = Path(__file__).resolve().parent
root = cwd.parent.parent / "esphome" / "components" / "api"
prot_file = root / "api.protoc"
@ -959,7 +985,7 @@ def main():
try:
import clang_format
def exec_clang_format(path):
def exec_clang_format(path: Path) -> None:
clang_format_path = os.path.join(
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
)