diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 7771922697..291a03523e 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -1,4 +1,20 @@ #!/usr/bin/env python3 +from __future__ import annotations + +from abc import ABC, abstractmethod +import os +from pathlib import Path +import re +from subprocess import call +import sys +from textwrap import dedent +from typing import Any + +# Generate with +# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto +import aioesphomeapi.api_options_pb2 as pb +import google.protobuf.descriptor_pb2 as descriptor + """Python 3 script to automatically generate C++ classes for ESPHome's native API. It's pretty crappy spaghetti code, but it works. @@ -17,25 +33,14 @@ then run this script with python3 and the files will be generated, they still need to be formatted """ -from abc import ABC, abstractmethod -import os -from pathlib import Path -import re -from subprocess import call -import sys -from textwrap import dedent - -# Generate with -# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto -import aioesphomeapi.api_options_pb2 as pb -import google.protobuf.descriptor_pb2 as descriptor FILE_HEADER = """// This file was automatically generated with a tool. // See scripts/api_protobuf/api_protobuf.py """ -def indent_list(text, padding=" "): +def indent_list(text: str, padding: str = " ") -> list[str]: + """Indent each line of the given text with the specified padding.""" lines = [] for line in text.splitlines(): if line == "": @@ -48,54 +53,62 @@ def indent_list(text, padding=" "): return lines -def indent(text, padding=" "): +def indent(text: str, padding: str = " ") -> str: return "\n".join(indent_list(text, padding)) -def camel_to_snake(name): +def camel_to_snake(name: str) -> str: # https://stackoverflow.com/a/1176023 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() class TypeInfo(ABC): - def __init__(self, field): + """Base class for all type information.""" + + def __init__(self, field: descriptor.FieldDescriptorProto) -> None: self._field = field @property - def default_value(self): + def default_value(self) -> str: + """Get the default value.""" return "" @property - def name(self): + def name(self) -> str: + """Get the name of the field.""" return self._field.name @property - def arg_name(self): + def arg_name(self) -> str: + """Get the argument name.""" return self.name @property - def field_name(self): + def field_name(self) -> str: + """Get the field name.""" return self.name @property - def number(self): + def number(self) -> int: + """Get the field number.""" return self._field.number @property - def repeated(self): + def repeated(self) -> bool: + """Check if the field is repeated.""" return self._field.label == 3 @property - def cpp_type(self): + def cpp_type(self) -> str: raise NotImplementedError @property - def reference_type(self): + def reference_type(self) -> str: return f"{self.cpp_type} " @property - def const_reference_type(self): + def const_reference_type(self) -> str: return f"{self.cpp_type} " @property @@ -171,28 +184,31 @@ class TypeInfo(ABC): decode_64bit = None @property - def encode_content(self): + def encode_content(self) -> str: return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});" encode_func = None @property - def dump_content(self): + def dump_content(self) -> str: o = f'out.append(" {self.name}: ");\n' o += self.dump(f"this->{self.field_name}") + "\n" o += 'out.append("\\n");\n' return o @abstractmethod - def dump(self, name: str): - pass + def dump(self, name: str) -> str: + """Dump the value to the output.""" -TYPE_INFO = {} +TYPE_INFO: dict[int, TypeInfo] = {} -def register_type(name): - def func(value): +def register_type(name: int): + """Decorator to register a type with a name and number.""" + + def func(value: TypeInfo) -> TypeInfo: + """Register the type with the given name and number.""" TYPE_INFO[name] = value return value @@ -206,7 +222,7 @@ class DoubleType(TypeInfo): decode_64bit = "value.as_double()" encode_func = "encode_double" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%g", {name});\n' o += "out.append(buffer);" return o @@ -219,7 +235,7 @@ class FloatType(TypeInfo): decode_32bit = "value.as_float()" encode_func = "encode_float" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%g", {name});\n' o += "out.append(buffer);" return o @@ -232,7 +248,7 @@ class Int64Type(TypeInfo): decode_varint = "value.as_int64()" encode_func = "encode_int64" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' o += "out.append(buffer);" return o @@ -245,7 +261,7 @@ class UInt64Type(TypeInfo): decode_varint = "value.as_uint64()" encode_func = "encode_uint64" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%llu", {name});\n' o += "out.append(buffer);" return o @@ -258,7 +274,7 @@ class Int32Type(TypeInfo): decode_varint = "value.as_int32()" encode_func = "encode_int32" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' o += "out.append(buffer);" return o @@ -271,7 +287,7 @@ class Fixed64Type(TypeInfo): decode_64bit = "value.as_fixed64()" encode_func = "encode_fixed64" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%llu", {name});\n' o += "out.append(buffer);" return o @@ -284,7 +300,7 @@ class Fixed32Type(TypeInfo): decode_32bit = "value.as_fixed32()" encode_func = "encode_fixed32" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRIu32, {name});\n' o += "out.append(buffer);" return o @@ -297,7 +313,7 @@ class BoolType(TypeInfo): decode_varint = "value.as_bool()" encode_func = "encode_bool" - def dump(self, name): + def dump(self, name: str) -> str: o = f"out.append(YESNO({name}));" return o @@ -319,28 +335,28 @@ class StringType(TypeInfo): @register_type(11) class MessageType(TypeInfo): @property - def cpp_type(self): + def cpp_type(self) -> str: return self._field.type_name[1:] default_value = "" @property - def reference_type(self): + def reference_type(self) -> str: return f"{self.cpp_type} &" @property - def const_reference_type(self): + def const_reference_type(self) -> str: return f"const {self.cpp_type} &" @property - def encode_func(self): + def encode_func(self) -> str: return f"encode_message<{self.cpp_type}>" @property - def decode_length(self): + def decode_length(self) -> str: return f"value.as_message<{self.cpp_type}>()" - def dump(self, name): + def dump(self, name: str) -> str: o = f"{name}.dump_to(out);" return o @@ -354,7 +370,7 @@ class BytesType(TypeInfo): decode_length = "value.as_string()" encode_func = "encode_string" - def dump(self, name): + def dump(self, name: str) -> str: o = f'out.append("\'").append({name}).append("\'");' return o @@ -366,7 +382,7 @@ class UInt32Type(TypeInfo): decode_varint = "value.as_uint32()" encode_func = "encode_uint32" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRIu32, {name});\n' o += "out.append(buffer);" return o @@ -375,20 +391,20 @@ class UInt32Type(TypeInfo): @register_type(14) class EnumType(TypeInfo): @property - def cpp_type(self): + def cpp_type(self) -> str: return f"enums::{self._field.type_name[1:]}" @property - def decode_varint(self): + def decode_varint(self) -> str: return f"value.as_enum<{self.cpp_type}>()" default_value = "" @property - def encode_func(self): + def encode_func(self) -> str: return f"encode_enum<{self.cpp_type}>" - def dump(self, name): + def dump(self, name: str) -> str: o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));" return o @@ -400,7 +416,7 @@ class SFixed32Type(TypeInfo): decode_32bit = "value.as_sfixed32()" encode_func = "encode_sfixed32" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' o += "out.append(buffer);" return o @@ -413,7 +429,7 @@ class SFixed64Type(TypeInfo): decode_64bit = "value.as_sfixed64()" encode_func = "encode_sfixed64" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' o += "out.append(buffer);" return o @@ -426,7 +442,7 @@ class SInt32Type(TypeInfo): decode_varint = "value.as_sint32()" encode_func = "encode_sint32" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' o += "out.append(buffer);" return o @@ -439,27 +455,27 @@ class SInt64Type(TypeInfo): decode_varint = "value.as_sint64()" encode_func = "encode_sint64" - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' o += "out.append(buffer);" return o class RepeatedTypeInfo(TypeInfo): - def __init__(self, field): + def __init__(self, field: descriptor.FieldDescriptorProto) -> None: super().__init__(field) - self._ti = TYPE_INFO[field.type](field) + self._ti: TypeInfo = TYPE_INFO[field.type](field) @property - def cpp_type(self): + def cpp_type(self) -> str: return f"std::vector<{self._ti.cpp_type}>" @property - def reference_type(self): + def reference_type(self) -> str: return f"{self.cpp_type} &" @property - def const_reference_type(self): + def const_reference_type(self) -> str: return f"const {self.cpp_type} &" @property @@ -515,19 +531,19 @@ class RepeatedTypeInfo(TypeInfo): ) @property - def _ti_is_bool(self): + def _ti_is_bool(self) -> bool: # std::vector is specialized for bool, reference does not work return isinstance(self._ti, BoolType) @property - def encode_content(self): + def encode_content(self) -> str: o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n" o += "}" return o @property - def dump_content(self): + def dump_content(self) -> str: o = f"for (const auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o += f' out.append(" {self.name}: ");\n' o += indent(self._ti.dump("it")) + "\n" @@ -539,7 +555,8 @@ class RepeatedTypeInfo(TypeInfo): pass -def build_enum_type(desc): +def build_enum_type(desc) -> tuple[str, str]: + """Builds the enum type.""" name = desc.name out = f"enum {name} : uint32_t {{\n" for v in desc.value: @@ -561,15 +578,15 @@ def build_enum_type(desc): return out, cpp -def build_message_type(desc): - public_content = [] - protected_content = [] - decode_varint = [] - decode_length = [] - decode_32bit = [] - decode_64bit = [] - encode = [] - dump = [] +def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: + public_content: list[str] = [] + protected_content: list[str] = [] + decode_varint: list[str] = [] + decode_length: list[str] = [] + decode_32bit: list[str] = [] + decode_64bit: list[str] = [] + encode: list[str] = [] + dump: list[str] = [] for field in desc.field: if field.label == 3: @@ -687,27 +704,35 @@ SOURCE_BOTH = 0 SOURCE_SERVER = 1 SOURCE_CLIENT = 2 -RECEIVE_CASES = {} +RECEIVE_CASES: dict[int, str] = {} -ifdefs = {} +ifdefs: dict[str, str] = {} -def get_opt(desc, opt, default=None): +def get_opt( + desc: descriptor.DescriptorProto, + opt: descriptor.MessageOptions, + default: Any = None, +) -> Any: + """Get the option from the descriptor.""" if not desc.options.HasExtension(opt): return default return desc.options.Extensions[opt] -def build_service_message_type(mt): +def build_service_message_type( + mt: descriptor.DescriptorProto, +) -> tuple[str, str] | None: + """Builds the service message type.""" snake = camel_to_snake(mt.name) - id_ = get_opt(mt, pb.id) + id_: int | None = get_opt(mt, pb.id) if id_ is None: return None - source = get_opt(mt, pb.source, 0) + source: int = get_opt(mt, pb.source, 0) - ifdef = get_opt(mt, pb.ifdef) - log = get_opt(mt, pb.log, True) + ifdef: str | None = get_opt(mt, pb.ifdef) + log: bool = get_opt(mt, pb.log, True) hout = "" cout = "" @@ -754,7 +779,8 @@ def build_service_message_type(mt): return hout, cout -def main(): +def main() -> None: + """Main function to generate the C++ classes.""" cwd = Path(__file__).resolve().parent root = cwd.parent.parent / "esphome" / "components" / "api" prot_file = root / "api.protoc" @@ -959,7 +985,7 @@ def main(): try: import clang_format - def exec_clang_format(path): + def exec_clang_format(path: Path) -> None: clang_format_path = os.path.join( os.path.dirname(clang_format.__file__), "data", "bin", "clang-format" )