mirror of
https://github.com/esphome/esphome.git
synced 2025-04-19 13:17:19 +00:00
Add typing to protobuf code generator (#8541)
This commit is contained in:
parent
7e133171e0
commit
3677ef71d1
@ -1,4 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from subprocess import call
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
# Generate with
|
||||
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
||||
import aioesphomeapi.api_options_pb2 as pb
|
||||
import google.protobuf.descriptor_pb2 as descriptor
|
||||
|
||||
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
|
||||
|
||||
It's pretty crappy spaghetti code, but it works.
|
||||
@ -17,25 +33,14 @@ then run this script with python3 and the files
|
||||
will be generated, they still need to be formatted
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from subprocess import call
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
|
||||
# Generate with
|
||||
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
||||
import aioesphomeapi.api_options_pb2 as pb
|
||||
import google.protobuf.descriptor_pb2 as descriptor
|
||||
|
||||
FILE_HEADER = """// This file was automatically generated with a tool.
|
||||
// See scripts/api_protobuf/api_protobuf.py
|
||||
"""
|
||||
|
||||
|
||||
def indent_list(text, padding=" "):
|
||||
def indent_list(text: str, padding: str = " ") -> list[str]:
|
||||
"""Indent each line of the given text with the specified padding."""
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
if line == "":
|
||||
@ -48,54 +53,62 @@ def indent_list(text, padding=" "):
|
||||
return lines
|
||||
|
||||
|
||||
def indent(text, padding=" "):
|
||||
def indent(text: str, padding: str = " ") -> str:
|
||||
return "\n".join(indent_list(text, padding))
|
||||
|
||||
|
||||
def camel_to_snake(name):
|
||||
def camel_to_snake(name: str) -> str:
|
||||
# https://stackoverflow.com/a/1176023
|
||||
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
|
||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
||||
|
||||
|
||||
class TypeInfo(ABC):
|
||||
def __init__(self, field):
|
||||
"""Base class for all type information."""
|
||||
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
self._field = field
|
||||
|
||||
@property
|
||||
def default_value(self):
|
||||
def default_value(self) -> str:
|
||||
"""Get the default value."""
|
||||
return ""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
"""Get the name of the field."""
|
||||
return self._field.name
|
||||
|
||||
@property
|
||||
def arg_name(self):
|
||||
def arg_name(self) -> str:
|
||||
"""Get the argument name."""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def field_name(self):
|
||||
def field_name(self) -> str:
|
||||
"""Get the field name."""
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def number(self):
|
||||
def number(self) -> int:
|
||||
"""Get the field number."""
|
||||
return self._field.number
|
||||
|
||||
@property
|
||||
def repeated(self):
|
||||
def repeated(self) -> bool:
|
||||
"""Check if the field is repeated."""
|
||||
return self._field.label == 3
|
||||
|
||||
@property
|
||||
def cpp_type(self):
|
||||
def cpp_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def reference_type(self):
|
||||
def reference_type(self) -> str:
|
||||
return f"{self.cpp_type} "
|
||||
|
||||
@property
|
||||
def const_reference_type(self):
|
||||
def const_reference_type(self) -> str:
|
||||
return f"{self.cpp_type} "
|
||||
|
||||
@property
|
||||
@ -171,28 +184,31 @@ class TypeInfo(ABC):
|
||||
decode_64bit = None
|
||||
|
||||
@property
|
||||
def encode_content(self):
|
||||
def encode_content(self) -> str:
|
||||
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
|
||||
|
||||
encode_func = None
|
||||
|
||||
@property
|
||||
def dump_content(self):
|
||||
def dump_content(self) -> str:
|
||||
o = f'out.append(" {self.name}: ");\n'
|
||||
o += self.dump(f"this->{self.field_name}") + "\n"
|
||||
o += 'out.append("\\n");\n'
|
||||
return o
|
||||
|
||||
@abstractmethod
|
||||
def dump(self, name: str):
|
||||
pass
|
||||
def dump(self, name: str) -> str:
|
||||
"""Dump the value to the output."""
|
||||
|
||||
|
||||
TYPE_INFO = {}
|
||||
TYPE_INFO: dict[int, TypeInfo] = {}
|
||||
|
||||
|
||||
def register_type(name):
|
||||
def func(value):
|
||||
def register_type(name: int):
|
||||
"""Decorator to register a type with a name and number."""
|
||||
|
||||
def func(value: TypeInfo) -> TypeInfo:
|
||||
"""Register the type with the given name and number."""
|
||||
TYPE_INFO[name] = value
|
||||
return value
|
||||
|
||||
@ -206,7 +222,7 @@ class DoubleType(TypeInfo):
|
||||
decode_64bit = "value.as_double()"
|
||||
encode_func = "encode_double"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -219,7 +235,7 @@ class FloatType(TypeInfo):
|
||||
decode_32bit = "value.as_float()"
|
||||
encode_func = "encode_float"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%g", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -232,7 +248,7 @@ class Int64Type(TypeInfo):
|
||||
decode_varint = "value.as_int64()"
|
||||
encode_func = "encode_int64"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -245,7 +261,7 @@ class UInt64Type(TypeInfo):
|
||||
decode_varint = "value.as_uint64()"
|
||||
encode_func = "encode_uint64"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -258,7 +274,7 @@ class Int32Type(TypeInfo):
|
||||
decode_varint = "value.as_int32()"
|
||||
encode_func = "encode_int32"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -271,7 +287,7 @@ class Fixed64Type(TypeInfo):
|
||||
decode_64bit = "value.as_fixed64()"
|
||||
encode_func = "encode_fixed64"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -284,7 +300,7 @@ class Fixed32Type(TypeInfo):
|
||||
decode_32bit = "value.as_fixed32()"
|
||||
encode_func = "encode_fixed32"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -297,7 +313,7 @@ class BoolType(TypeInfo):
|
||||
decode_varint = "value.as_bool()"
|
||||
encode_func = "encode_bool"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"out.append(YESNO({name}));"
|
||||
return o
|
||||
|
||||
@ -319,28 +335,28 @@ class StringType(TypeInfo):
|
||||
@register_type(11)
|
||||
class MessageType(TypeInfo):
|
||||
@property
|
||||
def cpp_type(self):
|
||||
def cpp_type(self) -> str:
|
||||
return self._field.type_name[1:]
|
||||
|
||||
default_value = ""
|
||||
|
||||
@property
|
||||
def reference_type(self):
|
||||
def reference_type(self) -> str:
|
||||
return f"{self.cpp_type} &"
|
||||
|
||||
@property
|
||||
def const_reference_type(self):
|
||||
def const_reference_type(self) -> str:
|
||||
return f"const {self.cpp_type} &"
|
||||
|
||||
@property
|
||||
def encode_func(self):
|
||||
def encode_func(self) -> str:
|
||||
return f"encode_message<{self.cpp_type}>"
|
||||
|
||||
@property
|
||||
def decode_length(self):
|
||||
def decode_length(self) -> str:
|
||||
return f"value.as_message<{self.cpp_type}>()"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"{name}.dump_to(out);"
|
||||
return o
|
||||
|
||||
@ -354,7 +370,7 @@ class BytesType(TypeInfo):
|
||||
decode_length = "value.as_string()"
|
||||
encode_func = "encode_string"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'out.append("\'").append({name}).append("\'");'
|
||||
return o
|
||||
|
||||
@ -366,7 +382,7 @@ class UInt32Type(TypeInfo):
|
||||
decode_varint = "value.as_uint32()"
|
||||
encode_func = "encode_uint32"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -375,20 +391,20 @@ class UInt32Type(TypeInfo):
|
||||
@register_type(14)
|
||||
class EnumType(TypeInfo):
|
||||
@property
|
||||
def cpp_type(self):
|
||||
def cpp_type(self) -> str:
|
||||
return f"enums::{self._field.type_name[1:]}"
|
||||
|
||||
@property
|
||||
def decode_varint(self):
|
||||
def decode_varint(self) -> str:
|
||||
return f"value.as_enum<{self.cpp_type}>()"
|
||||
|
||||
default_value = ""
|
||||
|
||||
@property
|
||||
def encode_func(self):
|
||||
def encode_func(self) -> str:
|
||||
return f"encode_enum<{self.cpp_type}>"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
|
||||
return o
|
||||
|
||||
@ -400,7 +416,7 @@ class SFixed32Type(TypeInfo):
|
||||
decode_32bit = "value.as_sfixed32()"
|
||||
encode_func = "encode_sfixed32"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -413,7 +429,7 @@ class SFixed64Type(TypeInfo):
|
||||
decode_64bit = "value.as_sfixed64()"
|
||||
encode_func = "encode_sfixed64"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -426,7 +442,7 @@ class SInt32Type(TypeInfo):
|
||||
decode_varint = "value.as_sint32()"
|
||||
encode_func = "encode_sint32"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
@ -439,27 +455,27 @@ class SInt64Type(TypeInfo):
|
||||
decode_varint = "value.as_sint64()"
|
||||
encode_func = "encode_sint64"
|
||||
|
||||
def dump(self, name):
|
||||
def dump(self, name: str) -> str:
|
||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||
o += "out.append(buffer);"
|
||||
return o
|
||||
|
||||
|
||||
class RepeatedTypeInfo(TypeInfo):
|
||||
def __init__(self, field):
|
||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||
super().__init__(field)
|
||||
self._ti = TYPE_INFO[field.type](field)
|
||||
self._ti: TypeInfo = TYPE_INFO[field.type](field)
|
||||
|
||||
@property
|
||||
def cpp_type(self):
|
||||
def cpp_type(self) -> str:
|
||||
return f"std::vector<{self._ti.cpp_type}>"
|
||||
|
||||
@property
|
||||
def reference_type(self):
|
||||
def reference_type(self) -> str:
|
||||
return f"{self.cpp_type} &"
|
||||
|
||||
@property
|
||||
def const_reference_type(self):
|
||||
def const_reference_type(self) -> str:
|
||||
return f"const {self.cpp_type} &"
|
||||
|
||||
@property
|
||||
@ -515,19 +531,19 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
)
|
||||
|
||||
@property
|
||||
def _ti_is_bool(self):
|
||||
def _ti_is_bool(self) -> bool:
|
||||
# std::vector is specialized for bool, reference does not work
|
||||
return isinstance(self._ti, BoolType)
|
||||
|
||||
@property
|
||||
def encode_content(self):
|
||||
def encode_content(self) -> str:
|
||||
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
@property
|
||||
def dump_content(self):
|
||||
def dump_content(self) -> str:
|
||||
o = f"for (const auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
|
||||
o += f' out.append(" {self.name}: ");\n'
|
||||
o += indent(self._ti.dump("it")) + "\n"
|
||||
@ -539,7 +555,8 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
pass
|
||||
|
||||
|
||||
def build_enum_type(desc):
|
||||
def build_enum_type(desc) -> tuple[str, str]:
|
||||
"""Builds the enum type."""
|
||||
name = desc.name
|
||||
out = f"enum {name} : uint32_t {{\n"
|
||||
for v in desc.value:
|
||||
@ -561,15 +578,15 @@ def build_enum_type(desc):
|
||||
return out, cpp
|
||||
|
||||
|
||||
def build_message_type(desc):
|
||||
public_content = []
|
||||
protected_content = []
|
||||
decode_varint = []
|
||||
decode_length = []
|
||||
decode_32bit = []
|
||||
decode_64bit = []
|
||||
encode = []
|
||||
dump = []
|
||||
def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
||||
public_content: list[str] = []
|
||||
protected_content: list[str] = []
|
||||
decode_varint: list[str] = []
|
||||
decode_length: list[str] = []
|
||||
decode_32bit: list[str] = []
|
||||
decode_64bit: list[str] = []
|
||||
encode: list[str] = []
|
||||
dump: list[str] = []
|
||||
|
||||
for field in desc.field:
|
||||
if field.label == 3:
|
||||
@ -687,27 +704,35 @@ SOURCE_BOTH = 0
|
||||
SOURCE_SERVER = 1
|
||||
SOURCE_CLIENT = 2
|
||||
|
||||
RECEIVE_CASES = {}
|
||||
RECEIVE_CASES: dict[int, str] = {}
|
||||
|
||||
ifdefs = {}
|
||||
ifdefs: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_opt(desc, opt, default=None):
|
||||
def get_opt(
|
||||
desc: descriptor.DescriptorProto,
|
||||
opt: descriptor.MessageOptions,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""Get the option from the descriptor."""
|
||||
if not desc.options.HasExtension(opt):
|
||||
return default
|
||||
return desc.options.Extensions[opt]
|
||||
|
||||
|
||||
def build_service_message_type(mt):
|
||||
def build_service_message_type(
|
||||
mt: descriptor.DescriptorProto,
|
||||
) -> tuple[str, str] | None:
|
||||
"""Builds the service message type."""
|
||||
snake = camel_to_snake(mt.name)
|
||||
id_ = get_opt(mt, pb.id)
|
||||
id_: int | None = get_opt(mt, pb.id)
|
||||
if id_ is None:
|
||||
return None
|
||||
|
||||
source = get_opt(mt, pb.source, 0)
|
||||
source: int = get_opt(mt, pb.source, 0)
|
||||
|
||||
ifdef = get_opt(mt, pb.ifdef)
|
||||
log = get_opt(mt, pb.log, True)
|
||||
ifdef: str | None = get_opt(mt, pb.ifdef)
|
||||
log: bool = get_opt(mt, pb.log, True)
|
||||
hout = ""
|
||||
cout = ""
|
||||
|
||||
@ -754,7 +779,8 @@ def build_service_message_type(mt):
|
||||
return hout, cout
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""Main function to generate the C++ classes."""
|
||||
cwd = Path(__file__).resolve().parent
|
||||
root = cwd.parent.parent / "esphome" / "components" / "api"
|
||||
prot_file = root / "api.protoc"
|
||||
@ -959,7 +985,7 @@ def main():
|
||||
try:
|
||||
import clang_format
|
||||
|
||||
def exec_clang_format(path):
|
||||
def exec_clang_format(path: Path) -> None:
|
||||
clang_format_path = os.path.join(
|
||||
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user