Merge remote-tracking branch 'origin/enum_uint32' into integration

This commit is contained in:
J. Nick Koston 2025-07-11 16:34:05 -10:00
commit a035db1d11
No known key found for this signature in database
3 changed files with 251 additions and 198 deletions

File diff suppressed because it is too large Load Diff

View File

@ -59,7 +59,6 @@ class ProtoVarInt {
uint32_t as_uint32() const { return this->value_; } uint32_t as_uint32() const { return this->value_; }
uint64_t as_uint64() const { return this->value_; } uint64_t as_uint64() const { return this->value_; }
bool as_bool() const { return this->value_; } bool as_bool() const { return this->value_; }
template<typename T> T as_enum() const { return static_cast<T>(this->as_uint32()); }
int32_t as_int32() const { int32_t as_int32() const {
// Not ZigZag encoded // Not ZigZag encoded
return static_cast<int32_t>(this->as_int64()); return static_cast<int32_t>(this->as_int64());
@ -133,15 +132,16 @@ class ProtoVarInt {
uint64_t value_; uint64_t value_;
}; };
// Forward declaration for decode_to_message
class ProtoMessage;
class ProtoLengthDelimited { class ProtoLengthDelimited {
public: public:
explicit ProtoLengthDelimited(const uint8_t *value, size_t length) : value_(value), length_(length) {} explicit ProtoLengthDelimited(const uint8_t *value, size_t length) : value_(value), length_(length) {}
std::string as_string() const { return std::string(reinterpret_cast<const char *>(this->value_), this->length_); } std::string as_string() const { return std::string(reinterpret_cast<const char *>(this->value_), this->length_); }
template<class C> C as_message() const {
auto msg = C(); // Non-template method to decode into an existing message instance
msg.decode(this->value_, this->length_); void decode_to_message(ProtoMessage &msg) const;
return msg;
}
protected: protected:
const uint8_t *const value_; const uint8_t *const value_;
@ -184,6 +184,9 @@ class Proto64Bit {
const uint64_t value_; const uint64_t value_;
}; };
// Forward declaration needed for method declaration
class ProtoMessage;
class ProtoWriteBuffer { class ProtoWriteBuffer {
public: public:
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {} ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {}
@ -263,9 +266,6 @@ class ProtoWriteBuffer {
this->write((value >> 48) & 0xFF); this->write((value >> 48) & 0xFF);
this->write((value >> 56) & 0xFF); this->write((value >> 56) & 0xFF);
} }
template<typename T> void encode_enum(uint32_t field_id, T value, bool force = false) {
this->encode_uint32(field_id, static_cast<uint32_t>(value), force);
}
void encode_float(uint32_t field_id, float value, bool force = false) { void encode_float(uint32_t field_id, float value, bool force = false) {
if (value == 0.0f && !force) if (value == 0.0f && !force)
return; return;
@ -306,18 +306,7 @@ class ProtoWriteBuffer {
} }
this->encode_uint64(field_id, uvalue, force); this->encode_uint64(field_id, uvalue, force);
} }
template<class C> void encode_message(uint32_t field_id, const C &value, bool force = false) { void encode_message(uint32_t field_id, const ProtoMessage &value, bool force = false);
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
size_t begin = this->buffer_->size();
value.encode(*this);
const uint32_t nested_length = this->buffer_->size() - begin;
// add size varint
std::vector<uint8_t> var;
ProtoVarInt(nested_length).encode(var);
this->buffer_->insert(this->buffer_->begin() + begin, var.begin(), var.end());
}
std::vector<uint8_t> *get_buffer() const { return buffer_; } std::vector<uint8_t> *get_buffer() const { return buffer_; }
protected: protected:
@ -345,6 +334,25 @@ class ProtoMessage {
virtual bool decode_64bit(uint32_t field_id, Proto64Bit value) { return false; } virtual bool decode_64bit(uint32_t field_id, Proto64Bit value) { return false; }
}; };
// Implementation of encode_message - must be after ProtoMessage is defined
inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value, bool force) {
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
size_t begin = this->buffer_->size();
value.encode(*this);
const uint32_t nested_length = this->buffer_->size() - begin;
// add size varint
std::vector<uint8_t> var;
ProtoVarInt(nested_length).encode(var);
this->buffer_->insert(this->buffer_->begin() + begin, var.begin(), var.end());
}
// Implementation of decode_to_message - must be after ProtoMessage is defined
inline void ProtoLengthDelimited::decode_to_message(ProtoMessage &msg) const {
msg.decode(this->value_, this->length_);
}
template<typename T> const char *proto_enum_to_string(T value); template<typename T> const char *proto_enum_to_string(T value);
class ProtoService { class ProtoService {

View File

@ -536,11 +536,23 @@ class MessageType(TypeInfo):
@property @property
def encode_func(self) -> str: def encode_func(self) -> str:
return f"encode_message<{self.cpp_type}>" return "encode_message"
@property @property
def decode_length(self) -> str: def decode_length(self) -> str:
return f"value.as_message<{self.cpp_type}>()" # For non-template decoding, we need to handle this differently
return None
@property
def decode_length_content(self) -> str:
# Custom decode that doesn't use templates
return dedent(
f"""\
case {self.number}: {{
value.decode_to_message(this->{self.field_name});
return true;
}}"""
)
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f"{name}.dump_to(out);" o = f"{name}.dump_to(out);"
@ -608,14 +620,18 @@ class EnumType(TypeInfo):
@property @property
def decode_varint(self) -> str: def decode_varint(self) -> str:
return f"value.as_enum<{self.cpp_type}>()" return f"static_cast<{self.cpp_type}>(value.as_uint32())"
default_value = "" default_value = ""
wire_type = WireType.VARINT # Uses wire type 0 wire_type = WireType.VARINT # Uses wire type 0
@property @property
def encode_func(self) -> str: def encode_func(self) -> str:
return f"encode_enum<{self.cpp_type}>" return "encode_uint32"
@property
def encode_content(self) -> str:
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}));"
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));" o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
@ -757,6 +773,16 @@ class RepeatedTypeInfo(TypeInfo):
@property @property
def decode_length_content(self) -> str: def decode_length_content(self) -> str:
content = self._ti.decode_length content = self._ti.decode_length
if content is None and isinstance(self._ti, MessageType):
# Special handling for non-template message decoding
return dedent(
f"""\
case {self.number}: {{
this->{self.field_name}.emplace_back();
value.decode_to_message(this->{self.field_name}.back());
return true;
}}"""
)
if content is None: if content is None:
return None return None
return dedent( return dedent(
@ -801,7 +827,10 @@ class RepeatedTypeInfo(TypeInfo):
@property @property
def encode_content(self) -> str: def encode_content(self) -> str:
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n" if isinstance(self._ti, EnumType):
o += f" buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>(it), true);\n"
else:
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
o += "}" o += "}"
return o return o