mirror of
https://github.com/esphome/esphome.git
synced 2025-08-07 19:07:45 +00:00
Merge branch 'bk7200_tagged_pointer_fix' into integration
This commit is contained in:
commit
d463dd0f57
@ -1687,7 +1687,9 @@ void APIConnection::DeferredBatch::add_item(EntityBase *entity, MessageCreator c
|
|||||||
// O(n) but optimized for RAM and not performance.
|
// O(n) but optimized for RAM and not performance.
|
||||||
for (auto &item : items) {
|
for (auto &item : items) {
|
||||||
if (item.entity == entity && item.message_type == message_type) {
|
if (item.entity == entity && item.message_type == message_type) {
|
||||||
// Update the existing item with the new creator
|
// Clean up old creator before replacing
|
||||||
|
item.creator.cleanup(message_type);
|
||||||
|
// Move assign the new creator
|
||||||
item.creator = std::move(creator);
|
item.creator = std::move(creator);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1730,11 +1732,11 @@ void APIConnection::process_batch_() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t num_items = this->deferred_batch_.items.size();
|
size_t num_items = this->deferred_batch_.size();
|
||||||
|
|
||||||
// Fast path for single message - allocate exact size needed
|
// Fast path for single message - allocate exact size needed
|
||||||
if (num_items == 1) {
|
if (num_items == 1) {
|
||||||
const auto &item = this->deferred_batch_.items[0];
|
const auto &item = this->deferred_batch_[0];
|
||||||
|
|
||||||
// Let the creator calculate size and encode if it fits
|
// Let the creator calculate size and encode if it fits
|
||||||
uint16_t payload_size =
|
uint16_t payload_size =
|
||||||
@ -1764,7 +1766,8 @@ void APIConnection::process_batch_() {
|
|||||||
|
|
||||||
// Pre-calculate exact buffer size needed based on message types
|
// Pre-calculate exact buffer size needed based on message types
|
||||||
uint32_t total_estimated_size = 0;
|
uint32_t total_estimated_size = 0;
|
||||||
for (const auto &item : this->deferred_batch_.items) {
|
for (size_t i = 0; i < this->deferred_batch_.size(); i++) {
|
||||||
|
const auto &item = this->deferred_batch_[i];
|
||||||
total_estimated_size += get_estimated_message_size(item.message_type);
|
total_estimated_size += get_estimated_message_size(item.message_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1785,7 +1788,8 @@ void APIConnection::process_batch_() {
|
|||||||
uint32_t current_offset = 0;
|
uint32_t current_offset = 0;
|
||||||
|
|
||||||
// Process items and encode directly to buffer
|
// Process items and encode directly to buffer
|
||||||
for (const auto &item : this->deferred_batch_.items) {
|
for (size_t i = 0; i < this->deferred_batch_.size(); i++) {
|
||||||
|
const auto &item = this->deferred_batch_[i];
|
||||||
// Try to encode message
|
// Try to encode message
|
||||||
// The creator will calculate overhead to determine if the message fits
|
// The creator will calculate overhead to determine if the message fits
|
||||||
uint16_t payload_size = item.creator(item.entity, this, remaining_size, false, item.message_type);
|
uint16_t payload_size = item.creator(item.entity, this, remaining_size, false, item.message_type);
|
||||||
@ -1840,17 +1844,15 @@ void APIConnection::process_batch_() {
|
|||||||
// Log messages after send attempt for VV debugging
|
// Log messages after send attempt for VV debugging
|
||||||
// It's safe to use the buffer for logging at this point regardless of send result
|
// It's safe to use the buffer for logging at this point regardless of send result
|
||||||
for (size_t i = 0; i < items_processed; i++) {
|
for (size_t i = 0; i < items_processed; i++) {
|
||||||
const auto &item = this->deferred_batch_.items[i];
|
const auto &item = this->deferred_batch_[i];
|
||||||
this->log_batch_item_(item);
|
this->log_batch_item_(item);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Handle remaining items more efficiently
|
// Handle remaining items more efficiently
|
||||||
if (items_processed < this->deferred_batch_.items.size()) {
|
if (items_processed < this->deferred_batch_.size()) {
|
||||||
// Remove processed items from the beginning
|
// Remove processed items from the beginning with proper cleanup
|
||||||
this->deferred_batch_.items.erase(this->deferred_batch_.items.begin(),
|
this->deferred_batch_.remove_front(items_processed);
|
||||||
this->deferred_batch_.items.begin() + items_processed);
|
|
||||||
|
|
||||||
// Reschedule for remaining items
|
// Reschedule for remaining items
|
||||||
this->schedule_batch_();
|
this->schedule_batch_();
|
||||||
} else {
|
} else {
|
||||||
@ -1861,23 +1863,16 @@ void APIConnection::process_batch_() {
|
|||||||
|
|
||||||
uint16_t APIConnection::MessageCreator::operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size,
|
uint16_t APIConnection::MessageCreator::operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size,
|
||||||
bool is_single, uint16_t message_type) const {
|
bool is_single, uint16_t message_type) const {
|
||||||
if (has_tagged_string_ptr_()) {
|
|
||||||
// Handle string-based messages
|
|
||||||
switch (message_type) {
|
|
||||||
#ifdef USE_EVENT
|
#ifdef USE_EVENT
|
||||||
case EventResponse::MESSAGE_TYPE: {
|
// Special case: EventResponse uses string pointer
|
||||||
|
if (message_type == EventResponse::MESSAGE_TYPE) {
|
||||||
auto *e = static_cast<event::Event *>(entity);
|
auto *e = static_cast<event::Event *>(entity);
|
||||||
return APIConnection::try_send_event_response(e, *get_string_ptr_(), conn, remaining_size, is_single);
|
return APIConnection::try_send_event_response(e, *data_.string_ptr, conn, remaining_size, is_single);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
default:
|
|
||||||
// Should not happen, return 0 to indicate no message
|
// All other message types use function pointers
|
||||||
return 0;
|
return data_.function_ptr(entity, conn, remaining_size, is_single);
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Function pointer case
|
|
||||||
return data_.ptr(entity, conn, remaining_size, is_single);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint16_t APIConnection::try_send_list_info_done(EntityBase *entity, APIConnection *conn, uint32_t remaining_size,
|
uint16_t APIConnection::try_send_list_info_done(EntityBase *entity, APIConnection *conn, uint32_t remaining_size,
|
||||||
|
@ -451,96 +451,53 @@ class APIConnection : public APIServerConnection {
|
|||||||
// Function pointer type for message encoding
|
// Function pointer type for message encoding
|
||||||
using MessageCreatorPtr = uint16_t (*)(EntityBase *, APIConnection *, uint32_t remaining_size, bool is_single);
|
using MessageCreatorPtr = uint16_t (*)(EntityBase *, APIConnection *, uint32_t remaining_size, bool is_single);
|
||||||
|
|
||||||
// Optimized MessageCreator class using tagged pointer
|
|
||||||
class MessageCreator {
|
class MessageCreator {
|
||||||
// Ensure pointer alignment allows LSB tagging
|
|
||||||
static_assert(alignof(std::string *) > 1, "String pointer alignment must be > 1 for LSB tagging");
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Constructor for function pointer
|
// Constructor for function pointer
|
||||||
MessageCreator(MessageCreatorPtr ptr) {
|
MessageCreator(MessageCreatorPtr ptr) { data_.function_ptr = ptr; }
|
||||||
// Function pointers are always aligned, so LSB is 0
|
|
||||||
data_.ptr = ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Constructor for string state capture
|
// Constructor for string state capture
|
||||||
explicit MessageCreator(const std::string &str_value) {
|
explicit MessageCreator(const std::string &str_value) { data_.string_ptr = new std::string(str_value); }
|
||||||
// Allocate string and tag the pointer
|
|
||||||
auto *str = new std::string(str_value);
|
|
||||||
// Set LSB to 1 to indicate string pointer
|
|
||||||
data_.tagged = reinterpret_cast<uintptr_t>(str) | 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destructor
|
// No destructor - cleanup must be called explicitly with message_type
|
||||||
~MessageCreator() {
|
|
||||||
if (has_tagged_string_ptr_()) {
|
|
||||||
delete get_string_ptr_();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy constructor
|
// Delete copy operations - MessageCreator should only be moved
|
||||||
MessageCreator(const MessageCreator &other) {
|
MessageCreator(const MessageCreator &other) = delete;
|
||||||
if (other.has_tagged_string_ptr_()) {
|
MessageCreator &operator=(const MessageCreator &other) = delete;
|
||||||
auto *str = new std::string(*other.get_string_ptr_());
|
|
||||||
data_.tagged = reinterpret_cast<uintptr_t>(str) | 1;
|
|
||||||
} else {
|
|
||||||
data_ = other.data_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move constructor
|
// Move constructor
|
||||||
MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.ptr = nullptr; }
|
MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.function_ptr = nullptr; }
|
||||||
|
|
||||||
// Assignment operators (needed for batch deduplication)
|
|
||||||
MessageCreator &operator=(const MessageCreator &other) {
|
|
||||||
if (this != &other) {
|
|
||||||
// Clean up current string data if needed
|
|
||||||
if (has_tagged_string_ptr_()) {
|
|
||||||
delete get_string_ptr_();
|
|
||||||
}
|
|
||||||
// Copy new data
|
|
||||||
if (other.has_tagged_string_ptr_()) {
|
|
||||||
auto *str = new std::string(*other.get_string_ptr_());
|
|
||||||
data_.tagged = reinterpret_cast<uintptr_t>(str) | 1;
|
|
||||||
} else {
|
|
||||||
data_ = other.data_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Move assignment
|
||||||
MessageCreator &operator=(MessageCreator &&other) noexcept {
|
MessageCreator &operator=(MessageCreator &&other) noexcept {
|
||||||
if (this != &other) {
|
if (this != &other) {
|
||||||
// Clean up current string data if needed
|
// IMPORTANT: Caller must ensure cleanup() was called if this contains a string!
|
||||||
if (has_tagged_string_ptr_()) {
|
// In our usage, this happens in add_item() deduplication and vector::erase()
|
||||||
delete get_string_ptr_();
|
|
||||||
}
|
|
||||||
// Move data
|
|
||||||
data_ = other.data_;
|
data_ = other.data_;
|
||||||
// Reset other to safe state
|
other.data_.function_ptr = nullptr;
|
||||||
other.data_.ptr = nullptr;
|
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call operator - now accepts message_type as parameter
|
// Call operator - uses message_type to determine union type
|
||||||
uint16_t operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single,
|
uint16_t operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single,
|
||||||
uint16_t message_type) const;
|
uint16_t message_type) const;
|
||||||
|
|
||||||
private:
|
// Manual cleanup method - must be called before destruction for string types
|
||||||
// Check if this contains a string pointer
|
void cleanup(uint16_t message_type) {
|
||||||
bool has_tagged_string_ptr_() const { return (data_.tagged & 1) != 0; }
|
#ifdef USE_EVENT
|
||||||
|
if (message_type == EventResponse::MESSAGE_TYPE && data_.string_ptr != nullptr) {
|
||||||
// Get the actual string pointer (clears the tag bit)
|
delete data_.string_ptr;
|
||||||
std::string *get_string_ptr_() const {
|
data_.string_ptr = nullptr;
|
||||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
}
|
||||||
return reinterpret_cast<std::string *>(data_.tagged & ~uintptr_t(1));
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
union {
|
private:
|
||||||
MessageCreatorPtr ptr;
|
union Data {
|
||||||
uintptr_t tagged;
|
MessageCreatorPtr function_ptr;
|
||||||
} data_; // 4 bytes on 32-bit
|
std::string *string_ptr;
|
||||||
|
} data_; // 4 bytes on 32-bit, 8 bytes on 64-bit - same as before
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generic batching mechanism for both state updates and entity info
|
// Generic batching mechanism for both state updates and entity info
|
||||||
@ -558,20 +515,46 @@ class APIConnection : public APIServerConnection {
|
|||||||
std::vector<BatchItem> items;
|
std::vector<BatchItem> items;
|
||||||
uint32_t batch_start_time{0};
|
uint32_t batch_start_time{0};
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Helper to cleanup items from the beginning
|
||||||
|
void cleanup_items_(size_t count) {
|
||||||
|
for (size_t i = 0; i < count; i++) {
|
||||||
|
items[i].creator.cleanup(items[i].message_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
DeferredBatch() {
|
DeferredBatch() {
|
||||||
// Pre-allocate capacity for typical batch sizes to avoid reallocation
|
// Pre-allocate capacity for typical batch sizes to avoid reallocation
|
||||||
items.reserve(8);
|
items.reserve(8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
~DeferredBatch() {
|
||||||
|
// Ensure cleanup of any remaining items
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
// Add item to the batch
|
// Add item to the batch
|
||||||
void add_item(EntityBase *entity, MessageCreator creator, uint16_t message_type);
|
void add_item(EntityBase *entity, MessageCreator creator, uint16_t message_type);
|
||||||
// Add item to the front of the batch (for high priority messages like ping)
|
// Add item to the front of the batch (for high priority messages like ping)
|
||||||
void add_item_front(EntityBase *entity, MessageCreator creator, uint16_t message_type);
|
void add_item_front(EntityBase *entity, MessageCreator creator, uint16_t message_type);
|
||||||
|
|
||||||
|
// Clear all items with proper cleanup
|
||||||
void clear() {
|
void clear() {
|
||||||
|
cleanup_items_(items.size());
|
||||||
items.clear();
|
items.clear();
|
||||||
batch_start_time = 0;
|
batch_start_time = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove processed items from the front with proper cleanup
|
||||||
|
void remove_front(size_t count) {
|
||||||
|
cleanup_items_(count);
|
||||||
|
items.erase(items.begin(), items.begin() + count);
|
||||||
|
}
|
||||||
|
|
||||||
bool empty() const { return items.empty(); }
|
bool empty() const { return items.empty(); }
|
||||||
|
size_t size() const { return items.size(); }
|
||||||
|
const BatchItem &operator[](size_t index) const { return items[index]; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// DeferredBatch here (16 bytes, 4-byte aligned)
|
// DeferredBatch here (16 bytes, 4-byte aligned)
|
||||||
|
@ -44,3 +44,4 @@ async def to_code(config):
|
|||||||
cg.add_build_flag("-std=gnu++20")
|
cg.add_build_flag("-std=gnu++20")
|
||||||
cg.add_define("ESPHOME_BOARD", "host")
|
cg.add_define("ESPHOME_BOARD", "host")
|
||||||
cg.add_platformio_option("platform", "platformio/native")
|
cg.add_platformio_option("platform", "platformio/native")
|
||||||
|
cg.add_platformio_option("lib_ldf_mode", "off")
|
||||||
|
@ -90,16 +90,25 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) {
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
// data starts at 2 and length is 4 for read registers commands
|
// data starts at 2 and length is 4 for read registers commands
|
||||||
if (this->role == ModbusRole::SERVER && (function_code == 0x1 || function_code == 0x3 || function_code == 0x4)) {
|
if (this->role == ModbusRole::SERVER) {
|
||||||
|
if (function_code == 0x1 || function_code == 0x3 || function_code == 0x4 || function_code == 0x6) {
|
||||||
data_offset = 2;
|
data_offset = 2;
|
||||||
data_len = 4;
|
data_len = 4;
|
||||||
|
} else if (function_code == 0x10) {
|
||||||
|
if (at < 6) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
data_offset = 2;
|
||||||
|
// starting address (2 bytes) + quantity of registers (2 bytes) + byte count itself (1 byte) + actual byte count
|
||||||
|
data_len = 2 + 2 + 1 + raw[6];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
// the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands
|
// the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands
|
||||||
if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) {
|
if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) {
|
||||||
data_offset = 2;
|
data_offset = 2;
|
||||||
data_len = 4;
|
data_len = 4;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Error ( msb indicates error )
|
// Error ( msb indicates error )
|
||||||
// response format: Byte[0] = device address, Byte[1] function code | 0x80 , Byte[2] exception code, Byte[3-4] crc
|
// response format: Byte[0] = device address, Byte[1] function code | 0x80 , Byte[2] exception code, Byte[3-4] crc
|
||||||
@ -132,6 +141,7 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) {
|
|||||||
bool found = false;
|
bool found = false;
|
||||||
for (auto *device : this->devices_) {
|
for (auto *device : this->devices_) {
|
||||||
if (device->address_ == address) {
|
if (device->address_ == address) {
|
||||||
|
found = true;
|
||||||
// Is it an error response?
|
// Is it an error response?
|
||||||
if ((function_code & 0x80) == 0x80) {
|
if ((function_code & 0x80) == 0x80) {
|
||||||
ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]);
|
ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]);
|
||||||
@ -141,13 +151,21 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) {
|
|||||||
// Ignore modbus exception not related to a pending command
|
// Ignore modbus exception not related to a pending command
|
||||||
ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response");
|
ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response");
|
||||||
}
|
}
|
||||||
} else if (this->role == ModbusRole::SERVER && (function_code == 0x3 || function_code == 0x4)) {
|
continue;
|
||||||
|
}
|
||||||
|
if (this->role == ModbusRole::SERVER) {
|
||||||
|
if (function_code == 0x3 || function_code == 0x4) {
|
||||||
device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8),
|
device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8),
|
||||||
uint16_t(data[3]) | (uint16_t(data[2]) << 8));
|
uint16_t(data[3]) | (uint16_t(data[2]) << 8));
|
||||||
} else {
|
continue;
|
||||||
device->on_modbus_data(data);
|
|
||||||
}
|
}
|
||||||
found = true;
|
if (function_code == 0x6 || function_code == 0x10) {
|
||||||
|
device->on_modbus_write_registers(function_code, data);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// fallthrough for other function codes
|
||||||
|
device->on_modbus_data(data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
waiting_for_response = 0;
|
waiting_for_response = 0;
|
||||||
|
@ -59,6 +59,7 @@ class ModbusDevice {
|
|||||||
virtual void on_modbus_data(const std::vector<uint8_t> &data) = 0;
|
virtual void on_modbus_data(const std::vector<uint8_t> &data) = 0;
|
||||||
virtual void on_modbus_error(uint8_t function_code, uint8_t exception_code) {}
|
virtual void on_modbus_error(uint8_t function_code, uint8_t exception_code) {}
|
||||||
virtual void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers){};
|
virtual void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers){};
|
||||||
|
virtual void on_modbus_write_registers(uint8_t function_code, const std::vector<uint8_t> &data){};
|
||||||
void send(uint8_t function, uint16_t start_address, uint16_t number_of_entities, uint8_t payload_len = 0,
|
void send(uint8_t function, uint16_t start_address, uint16_t number_of_entities, uint8_t payload_len = 0,
|
||||||
const uint8_t *payload = nullptr) {
|
const uint8_t *payload = nullptr) {
|
||||||
this->parent_->send(this->address_, function, start_address, number_of_entities, payload_len, payload);
|
this->parent_->send(this->address_, function, start_address, number_of_entities, payload_len, payload);
|
||||||
|
@ -39,6 +39,7 @@ CODEOWNERS = ["@martgras"]
|
|||||||
AUTO_LOAD = ["modbus"]
|
AUTO_LOAD = ["modbus"]
|
||||||
|
|
||||||
CONF_READ_LAMBDA = "read_lambda"
|
CONF_READ_LAMBDA = "read_lambda"
|
||||||
|
CONF_WRITE_LAMBDA = "write_lambda"
|
||||||
CONF_SERVER_REGISTERS = "server_registers"
|
CONF_SERVER_REGISTERS = "server_registers"
|
||||||
MULTI_CONF = True
|
MULTI_CONF = True
|
||||||
|
|
||||||
@ -148,6 +149,7 @@ ModbusServerRegisterSchema = cv.Schema(
|
|||||||
cv.Required(CONF_ADDRESS): cv.positive_int,
|
cv.Required(CONF_ADDRESS): cv.positive_int,
|
||||||
cv.Optional(CONF_VALUE_TYPE, default="U_WORD"): cv.enum(SENSOR_VALUE_TYPE),
|
cv.Optional(CONF_VALUE_TYPE, default="U_WORD"): cv.enum(SENSOR_VALUE_TYPE),
|
||||||
cv.Required(CONF_READ_LAMBDA): cv.returning_lambda,
|
cv.Required(CONF_READ_LAMBDA): cv.returning_lambda,
|
||||||
|
cv.Optional(CONF_WRITE_LAMBDA): cv.returning_lambda,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -318,6 +320,17 @@ async def to_code(config):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if CONF_WRITE_LAMBDA in server_register:
|
||||||
|
cg.add(
|
||||||
|
server_register_var.set_write_lambda(
|
||||||
|
cg.TemplateArguments(cpp_type),
|
||||||
|
await cg.process_lambda(
|
||||||
|
server_register[CONF_WRITE_LAMBDA],
|
||||||
|
parameters=[(cg.uint16, "address"), (cpp_type, "x")],
|
||||||
|
return_type=cg.bool_,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
cg.add(var.add_server_register(server_register_var))
|
cg.add(var.add_server_register(server_register_var))
|
||||||
await register_modbus_device(var, config)
|
await register_modbus_device(var, config)
|
||||||
for conf in config.get(CONF_ON_COMMAND_SENT, []):
|
for conf in config.get(CONF_ON_COMMAND_SENT, []):
|
||||||
|
@ -152,6 +152,86 @@ void ModbusController::on_modbus_read_registers(uint8_t function_code, uint16_t
|
|||||||
this->send(function_code, start_address, number_of_registers, response.size(), response.data());
|
this->send(function_code, start_address, number_of_registers, response.size(), response.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ModbusController::on_modbus_write_registers(uint8_t function_code, const std::vector<uint8_t> &data) {
|
||||||
|
uint16_t number_of_registers;
|
||||||
|
uint16_t payload_offset;
|
||||||
|
|
||||||
|
if (function_code == 0x10) {
|
||||||
|
number_of_registers = uint16_t(data[3]) | (uint16_t(data[2]) << 8);
|
||||||
|
if (number_of_registers == 0 || number_of_registers > 0x7B) {
|
||||||
|
ESP_LOGW(TAG, "Invalid number of registers %d. Sending exception response.", number_of_registers);
|
||||||
|
send_error(function_code, 3);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
uint16_t payload_size = data[4];
|
||||||
|
if (payload_size != number_of_registers * 2) {
|
||||||
|
ESP_LOGW(TAG, "Payload size of %d bytes is not 2 times the number of registers (%d). Sending exception response.",
|
||||||
|
payload_size, number_of_registers);
|
||||||
|
send_error(function_code, 3);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
payload_offset = 5;
|
||||||
|
} else if (function_code == 0x06) {
|
||||||
|
number_of_registers = 1;
|
||||||
|
payload_offset = 2;
|
||||||
|
} else {
|
||||||
|
ESP_LOGW(TAG, "Invalid function code 0x%X. Sending exception response.", function_code);
|
||||||
|
send_error(function_code, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t start_address = uint16_t(data[1]) | (uint16_t(data[0]) << 8);
|
||||||
|
ESP_LOGD(TAG,
|
||||||
|
"Received write holding registers for device 0x%X. FC: 0x%X. Start address: 0x%X. Number of registers: "
|
||||||
|
"0x%X.",
|
||||||
|
this->address_, function_code, start_address, number_of_registers);
|
||||||
|
|
||||||
|
auto for_each_register = [this, start_address, number_of_registers, payload_offset](
|
||||||
|
const std::function<bool(ServerRegister *, uint16_t offset)> &callback) -> bool {
|
||||||
|
uint16_t offset = payload_offset;
|
||||||
|
for (uint16_t current_address = start_address; current_address < start_address + number_of_registers;) {
|
||||||
|
bool ok = false;
|
||||||
|
for (auto *server_register : this->server_registers_) {
|
||||||
|
if (server_register->address == current_address) {
|
||||||
|
ok = callback(server_register, offset);
|
||||||
|
current_address += server_register->register_count;
|
||||||
|
offset += server_register->register_count * sizeof(uint16_t);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ok) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
// check all registers are writable before writing to any of them:
|
||||||
|
if (!for_each_register([](ServerRegister *server_register, uint16_t offset) -> bool {
|
||||||
|
return server_register->write_lambda != nullptr;
|
||||||
|
})) {
|
||||||
|
send_error(function_code, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actually write to the registers:
|
||||||
|
if (!for_each_register([&data](ServerRegister *server_register, uint16_t offset) {
|
||||||
|
int64_t number = payload_to_number(data, server_register->value_type, offset, 0xFFFFFFFF);
|
||||||
|
return server_register->write_lambda(number);
|
||||||
|
})) {
|
||||||
|
send_error(function_code, 4);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint8_t> response;
|
||||||
|
response.reserve(6);
|
||||||
|
response.push_back(this->address_);
|
||||||
|
response.push_back(function_code);
|
||||||
|
response.insert(response.end(), data.begin(), data.begin() + 4);
|
||||||
|
this->send_raw(response);
|
||||||
|
}
|
||||||
|
|
||||||
SensorSet ModbusController::find_sensors_(ModbusRegisterType register_type, uint16_t start_address) const {
|
SensorSet ModbusController::find_sensors_(ModbusRegisterType register_type, uint16_t start_address) const {
|
||||||
auto reg_it = std::find_if(
|
auto reg_it = std::find_if(
|
||||||
std::begin(this->register_ranges_), std::end(this->register_ranges_),
|
std::begin(this->register_ranges_), std::end(this->register_ranges_),
|
||||||
|
@ -258,6 +258,7 @@ class SensorItem {
|
|||||||
|
|
||||||
class ServerRegister {
|
class ServerRegister {
|
||||||
using ReadLambda = std::function<int64_t()>;
|
using ReadLambda = std::function<int64_t()>;
|
||||||
|
using WriteLambda = std::function<bool(int64_t value)>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ServerRegister(uint16_t address, SensorValueType value_type, uint8_t register_count) {
|
ServerRegister(uint16_t address, SensorValueType value_type, uint8_t register_count) {
|
||||||
@ -277,6 +278,17 @@ class ServerRegister {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void set_write_lambda(const std::function<bool(uint16_t address, const T v)> &&user_write_lambda) {
|
||||||
|
this->write_lambda = [this, user_write_lambda](int64_t number) {
|
||||||
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
|
float float_value = bit_cast<float>(static_cast<uint32_t>(number));
|
||||||
|
return user_write_lambda(this->address, float_value);
|
||||||
|
}
|
||||||
|
return user_write_lambda(this->address, static_cast<T>(number));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Formats a raw value into a string representation based on the value type for debugging
|
// Formats a raw value into a string representation based on the value type for debugging
|
||||||
std::string format_value(int64_t value) const {
|
std::string format_value(int64_t value) const {
|
||||||
switch (this->value_type) {
|
switch (this->value_type) {
|
||||||
@ -304,6 +316,7 @@ class ServerRegister {
|
|||||||
SensorValueType value_type{SensorValueType::RAW};
|
SensorValueType value_type{SensorValueType::RAW};
|
||||||
uint8_t register_count{0};
|
uint8_t register_count{0};
|
||||||
ReadLambda read_lambda;
|
ReadLambda read_lambda;
|
||||||
|
WriteLambda write_lambda;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ModbusController::create_register_ranges_ tries to optimize register range
|
// ModbusController::create_register_ranges_ tries to optimize register range
|
||||||
@ -485,6 +498,8 @@ class ModbusController : public PollingComponent, public modbus::ModbusDevice {
|
|||||||
void on_modbus_error(uint8_t function_code, uint8_t exception_code) override;
|
void on_modbus_error(uint8_t function_code, uint8_t exception_code) override;
|
||||||
/// called when a modbus request (function code 0x03 or 0x04) was parsed without errors
|
/// called when a modbus request (function code 0x03 or 0x04) was parsed without errors
|
||||||
void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers) final;
|
void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers) final;
|
||||||
|
/// called when a modbus request (function code 0x06 or 0x10) was parsed without errors
|
||||||
|
void on_modbus_write_registers(uint8_t function_code, const std::vector<uint8_t> &data) final;
|
||||||
/// default delegate called by process_modbus_data when a response has retrieved from the incoming queue
|
/// default delegate called by process_modbus_data when a response has retrieved from the incoming queue
|
||||||
void on_register_data(ModbusRegisterType register_type, uint16_t start_address, const std::vector<uint8_t> &data);
|
void on_register_data(ModbusRegisterType register_type, uint16_t start_address, const std::vector<uint8_t> &data);
|
||||||
/// default delegate called by process_modbus_data when a response for a write response has retrieved from the
|
/// default delegate called by process_modbus_data when a response for a write response has retrieved from the
|
||||||
|
@ -74,7 +74,7 @@ BASE_SCHEMA = cv.All(
|
|||||||
{
|
{
|
||||||
cv.Required(CONF_PATH): validate_yaml_filename,
|
cv.Required(CONF_PATH): validate_yaml_filename,
|
||||||
cv.Optional(CONF_VARS, default={}): cv.Schema(
|
cv.Optional(CONF_VARS, default={}): cv.Schema(
|
||||||
{cv.string: cv.string}
|
{cv.string: object}
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -148,7 +148,6 @@ def _process_base_package(config: dict) -> dict:
|
|||||||
raise cv.Invalid(
|
raise cv.Invalid(
|
||||||
f"Current ESPHome Version is too old to use this package: {ESPHOME_VERSION} < {min_version}"
|
f"Current ESPHome Version is too old to use this package: {ESPHOME_VERSION} < {min_version}"
|
||||||
)
|
)
|
||||||
vars = {k: str(v) for k, v in vars.items()}
|
|
||||||
new_yaml = yaml_util.substitute_vars(new_yaml, vars)
|
new_yaml = yaml_util.substitute_vars(new_yaml, vars)
|
||||||
packages[f"{filename}{idx}"] = new_yaml
|
packages[f"{filename}{idx}"] = new_yaml
|
||||||
except EsphomeError as e:
|
except EsphomeError as e:
|
||||||
|
@ -5,6 +5,13 @@ from esphome.config_helpers import Extend, Remove, merge_config
|
|||||||
import esphome.config_validation as cv
|
import esphome.config_validation as cv
|
||||||
from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS
|
from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS
|
||||||
from esphome.yaml_util import ESPHomeDataBase, make_data_base
|
from esphome.yaml_util import ESPHomeDataBase, make_data_base
|
||||||
|
from .jinja import (
|
||||||
|
Jinja,
|
||||||
|
JinjaStr,
|
||||||
|
has_jinja,
|
||||||
|
TemplateError,
|
||||||
|
TemplateRuntimeError,
|
||||||
|
)
|
||||||
|
|
||||||
CODEOWNERS = ["@esphome/core"]
|
CODEOWNERS = ["@esphome/core"]
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -28,7 +35,7 @@ def validate_substitution_key(value):
|
|||||||
|
|
||||||
CONFIG_SCHEMA = cv.Schema(
|
CONFIG_SCHEMA = cv.Schema(
|
||||||
{
|
{
|
||||||
validate_substitution_key: cv.string_strict,
|
validate_substitution_key: object,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,7 +44,42 @@ async def to_code(config):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _expand_substitutions(substitutions, value, path, ignore_missing):
|
def _expand_jinja(value, orig_value, path, jinja, ignore_missing):
|
||||||
|
if has_jinja(value):
|
||||||
|
# If the original value passed in to this function is a JinjaStr, it means it contains an unresolved
|
||||||
|
# Jinja expression from a previous pass.
|
||||||
|
if isinstance(orig_value, JinjaStr):
|
||||||
|
# Rebuild the JinjaStr in case it was lost while replacing substitutions.
|
||||||
|
value = JinjaStr(value, orig_value.upvalues)
|
||||||
|
try:
|
||||||
|
# Invoke the jinja engine to evaluate the expression.
|
||||||
|
value, err = jinja.expand(value)
|
||||||
|
if err is not None:
|
||||||
|
if not ignore_missing and "password" not in path:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Found '%s' (see %s) which looks like an expression,"
|
||||||
|
" but could not resolve all the variables: %s",
|
||||||
|
value,
|
||||||
|
"->".join(str(x) for x in path),
|
||||||
|
err.message,
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
TemplateError,
|
||||||
|
TemplateRuntimeError,
|
||||||
|
RuntimeError,
|
||||||
|
ArithmeticError,
|
||||||
|
AttributeError,
|
||||||
|
TypeError,
|
||||||
|
) as err:
|
||||||
|
raise cv.Invalid(
|
||||||
|
f"{type(err).__name__} Error evaluating jinja expression '{value}': {str(err)}."
|
||||||
|
f" See {'->'.join(str(x) for x in path)}",
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_substitutions(substitutions, value, path, jinja, ignore_missing):
|
||||||
if "$" not in value:
|
if "$" not in value:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -47,7 +89,8 @@ def _expand_substitutions(substitutions, value, path, ignore_missing):
|
|||||||
while True:
|
while True:
|
||||||
m = cv.VARIABLE_PROG.search(value, i)
|
m = cv.VARIABLE_PROG.search(value, i)
|
||||||
if not m:
|
if not m:
|
||||||
# Nothing more to match. Done
|
# No more variable substitutions found. See if the remainder looks like a jinja template
|
||||||
|
value = _expand_jinja(value, orig_value, path, jinja, ignore_missing)
|
||||||
break
|
break
|
||||||
|
|
||||||
i, j = m.span(0)
|
i, j = m.span(0)
|
||||||
@ -67,8 +110,15 @@ def _expand_substitutions(substitutions, value, path, ignore_missing):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
sub = substitutions[name]
|
sub = substitutions[name]
|
||||||
|
|
||||||
|
if i == 0 and j == len(value):
|
||||||
|
# The variable spans the whole expression, e.g., "${varName}". Return its resolved value directly
|
||||||
|
# to conserve its type.
|
||||||
|
value = sub
|
||||||
|
break
|
||||||
|
|
||||||
tail = value[j:]
|
tail = value[j:]
|
||||||
value = value[:i] + sub
|
value = value[:i] + str(sub)
|
||||||
i = len(value)
|
i = len(value)
|
||||||
value += tail
|
value += tail
|
||||||
|
|
||||||
@ -77,36 +127,40 @@ def _expand_substitutions(substitutions, value, path, ignore_missing):
|
|||||||
if isinstance(orig_value, ESPHomeDataBase):
|
if isinstance(orig_value, ESPHomeDataBase):
|
||||||
# even though string can get larger or smaller, the range should point
|
# even though string can get larger or smaller, the range should point
|
||||||
# to original document marks
|
# to original document marks
|
||||||
return make_data_base(value, orig_value)
|
value = make_data_base(value, orig_value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _substitute_item(substitutions, item, path, ignore_missing):
|
def _substitute_item(substitutions, item, path, jinja, ignore_missing):
|
||||||
if isinstance(item, list):
|
if isinstance(item, list):
|
||||||
for i, it in enumerate(item):
|
for i, it in enumerate(item):
|
||||||
sub = _substitute_item(substitutions, it, path + [i], ignore_missing)
|
sub = _substitute_item(substitutions, it, path + [i], jinja, ignore_missing)
|
||||||
if sub is not None:
|
if sub is not None:
|
||||||
item[i] = sub
|
item[i] = sub
|
||||||
elif isinstance(item, dict):
|
elif isinstance(item, dict):
|
||||||
replace_keys = []
|
replace_keys = []
|
||||||
for k, v in item.items():
|
for k, v in item.items():
|
||||||
if path or k != CONF_SUBSTITUTIONS:
|
if path or k != CONF_SUBSTITUTIONS:
|
||||||
sub = _substitute_item(substitutions, k, path + [k], ignore_missing)
|
sub = _substitute_item(
|
||||||
|
substitutions, k, path + [k], jinja, ignore_missing
|
||||||
|
)
|
||||||
if sub is not None:
|
if sub is not None:
|
||||||
replace_keys.append((k, sub))
|
replace_keys.append((k, sub))
|
||||||
sub = _substitute_item(substitutions, v, path + [k], ignore_missing)
|
sub = _substitute_item(substitutions, v, path + [k], jinja, ignore_missing)
|
||||||
if sub is not None:
|
if sub is not None:
|
||||||
item[k] = sub
|
item[k] = sub
|
||||||
for old, new in replace_keys:
|
for old, new in replace_keys:
|
||||||
item[new] = merge_config(item.get(old), item.get(new))
|
item[new] = merge_config(item.get(old), item.get(new))
|
||||||
del item[old]
|
del item[old]
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
sub = _expand_substitutions(substitutions, item, path, ignore_missing)
|
sub = _expand_substitutions(substitutions, item, path, jinja, ignore_missing)
|
||||||
if sub != item:
|
if isinstance(sub, JinjaStr) or sub != item:
|
||||||
return sub
|
return sub
|
||||||
elif isinstance(item, (core.Lambda, Extend, Remove)):
|
elif isinstance(item, (core.Lambda, Extend, Remove)):
|
||||||
sub = _expand_substitutions(substitutions, item.value, path, ignore_missing)
|
sub = _expand_substitutions(
|
||||||
|
substitutions, item.value, path, jinja, ignore_missing
|
||||||
|
)
|
||||||
if sub != item:
|
if sub != item:
|
||||||
item.value = sub
|
item.value = sub
|
||||||
return None
|
return None
|
||||||
@ -116,11 +170,11 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals
|
|||||||
if CONF_SUBSTITUTIONS not in config and not command_line_substitutions:
|
if CONF_SUBSTITUTIONS not in config and not command_line_substitutions:
|
||||||
return
|
return
|
||||||
|
|
||||||
substitutions = config.get(CONF_SUBSTITUTIONS)
|
# Merge substitutions in config, overriding with substitutions coming from command line:
|
||||||
if substitutions is None:
|
substitutions = {
|
||||||
substitutions = command_line_substitutions
|
**config.get(CONF_SUBSTITUTIONS, {}),
|
||||||
elif command_line_substitutions:
|
**(command_line_substitutions or {}),
|
||||||
substitutions = {**substitutions, **command_line_substitutions}
|
}
|
||||||
with cv.prepend_path("substitutions"):
|
with cv.prepend_path("substitutions"):
|
||||||
if not isinstance(substitutions, dict):
|
if not isinstance(substitutions, dict):
|
||||||
raise cv.Invalid(
|
raise cv.Invalid(
|
||||||
@ -133,7 +187,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals
|
|||||||
sub = validate_substitution_key(key)
|
sub = validate_substitution_key(key)
|
||||||
if sub != key:
|
if sub != key:
|
||||||
replace_keys.append((key, sub))
|
replace_keys.append((key, sub))
|
||||||
substitutions[key] = cv.string_strict(value)
|
substitutions[key] = value
|
||||||
for old, new in replace_keys:
|
for old, new in replace_keys:
|
||||||
substitutions[new] = substitutions[old]
|
substitutions[new] = substitutions[old]
|
||||||
del substitutions[old]
|
del substitutions[old]
|
||||||
@ -141,4 +195,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals
|
|||||||
config[CONF_SUBSTITUTIONS] = substitutions
|
config[CONF_SUBSTITUTIONS] = substitutions
|
||||||
# Move substitutions to the first place to replace substitutions in them correctly
|
# Move substitutions to the first place to replace substitutions in them correctly
|
||||||
config.move_to_end(CONF_SUBSTITUTIONS, False)
|
config.move_to_end(CONF_SUBSTITUTIONS, False)
|
||||||
_substitute_item(substitutions, config, [], ignore_missing)
|
|
||||||
|
# Create a Jinja environment that will consider substitutions in scope:
|
||||||
|
jinja = Jinja(substitutions)
|
||||||
|
_substitute_item(substitutions, config, [], jinja, ignore_missing)
|
||||||
|
99
esphome/components/substitutions/jinja.py
Normal file
99
esphome/components/substitutions/jinja.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import jinja2 as jinja
|
||||||
|
from jinja2.nativetypes import NativeEnvironment
|
||||||
|
|
||||||
|
TemplateError = jinja.TemplateError
|
||||||
|
TemplateSyntaxError = jinja.TemplateSyntaxError
|
||||||
|
TemplateRuntimeError = jinja.TemplateRuntimeError
|
||||||
|
UndefinedError = jinja.UndefinedError
|
||||||
|
Undefined = jinja.Undefined
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DETECT_JINJA = r"(\$\{)"
|
||||||
|
detect_jinja_re = re.compile(
|
||||||
|
r"<%.+?%>" # Block form expression: <% ... %>
|
||||||
|
r"|\$\{[^}]+\}", # Braced form expression: ${ ... }
|
||||||
|
flags=re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def has_jinja(st):
|
||||||
|
return detect_jinja_re.search(st) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class JinjaStr(str):
|
||||||
|
"""
|
||||||
|
Wraps a string containing an unresolved Jinja expression,
|
||||||
|
storing the variables visible to it when it failed to resolve.
|
||||||
|
For example, an expression inside a package, `${ A * B }` may fail
|
||||||
|
to resolve at package parsing time if `A` is a local package var
|
||||||
|
but `B` is a substitution defined in the root yaml.
|
||||||
|
Therefore, we store the value of `A` as an upvalue bound
|
||||||
|
to the original string so we may be able to resolve `${ A * B }`
|
||||||
|
later in the main substitutions pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, value: str, upvalues=None):
|
||||||
|
obj = super().__new__(cls, value)
|
||||||
|
obj.upvalues = upvalues or {}
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __init__(self, value: str, upvalues=None):
|
||||||
|
self.upvalues = upvalues or {}
|
||||||
|
|
||||||
|
|
||||||
|
class Jinja:
|
||||||
|
"""
|
||||||
|
Wraps a Jinja environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, context_vars):
|
||||||
|
self.env = NativeEnvironment(
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
block_start_string="<%",
|
||||||
|
block_end_string="%>",
|
||||||
|
line_statement_prefix="#",
|
||||||
|
line_comment_prefix="##",
|
||||||
|
variable_start_string="${",
|
||||||
|
variable_end_string="}",
|
||||||
|
undefined=jinja.StrictUndefined,
|
||||||
|
)
|
||||||
|
self.env.add_extension("jinja2.ext.do")
|
||||||
|
self.env.globals["math"] = math # Inject entire math module
|
||||||
|
self.context_vars = {**context_vars}
|
||||||
|
self.env.globals = {**self.env.globals, **self.context_vars}
|
||||||
|
|
||||||
|
def expand(self, content_str):
|
||||||
|
"""
|
||||||
|
Renders a string that may contain Jinja expressions or statements
|
||||||
|
Returns the resulting processed string if all values could be resolved.
|
||||||
|
Otherwise, it returns a tagged (JinjaStr) string that captures variables
|
||||||
|
in scope (upvalues), like a closure for later evaluation.
|
||||||
|
"""
|
||||||
|
result = None
|
||||||
|
override_vars = {}
|
||||||
|
if isinstance(content_str, JinjaStr):
|
||||||
|
# If `value` is already a JinjaStr, it means we are trying to evaluate it again
|
||||||
|
# in a parent pass.
|
||||||
|
# Hopefully, all required variables are visible now.
|
||||||
|
override_vars = content_str.upvalues
|
||||||
|
try:
|
||||||
|
template = self.env.from_string(content_str)
|
||||||
|
result = template.render(override_vars)
|
||||||
|
if isinstance(result, Undefined):
|
||||||
|
# This happens when the expression is simply an undefined variable. Jinja does not
|
||||||
|
# raise an exception, instead we get "Undefined".
|
||||||
|
# Trigger an UndefinedError exception so we skip to below.
|
||||||
|
print("" + result)
|
||||||
|
except (TemplateSyntaxError, UndefinedError) as err:
|
||||||
|
# `content_str` contains a Jinja expression that refers to a variable that is undefined
|
||||||
|
# in this scope. Perhaps it refers to a root substitution that is not visible yet.
|
||||||
|
# Therefore, return the original `content_str` as a JinjaStr, which contains the variables
|
||||||
|
# that are actually visible to it at this point to postpone evaluation.
|
||||||
|
return JinjaStr(content_str, {**self.context_vars, **override_vars}), err
|
||||||
|
|
||||||
|
return result, None
|
@ -789,7 +789,6 @@ def validate_config(
|
|||||||
result.add_output_path([CONF_SUBSTITUTIONS], CONF_SUBSTITUTIONS)
|
result.add_output_path([CONF_SUBSTITUTIONS], CONF_SUBSTITUTIONS)
|
||||||
try:
|
try:
|
||||||
substitutions.do_substitution_pass(config, command_line_substitutions)
|
substitutions.do_substitution_pass(config, command_line_substitutions)
|
||||||
substitutions.do_substitution_pass(config, command_line_substitutions)
|
|
||||||
except vol.Invalid as err:
|
except vol.Invalid as err:
|
||||||
result.add_error(err)
|
result.add_error(err)
|
||||||
return result
|
return result
|
||||||
|
@ -292,8 +292,6 @@ class ESPHomeLoaderMixin:
|
|||||||
if file is None:
|
if file is None:
|
||||||
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
|
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
|
||||||
vars = fields.get(CONF_VARS)
|
vars = fields.get(CONF_VARS)
|
||||||
if vars:
|
|
||||||
vars = {k: str(v) for k, v in vars.items()}
|
|
||||||
return file, vars
|
return file, vars
|
||||||
|
|
||||||
if isinstance(node, yaml.nodes.MappingNode):
|
if isinstance(node, yaml.nodes.MappingNode):
|
||||||
|
@ -21,6 +21,7 @@ esphome-glyphsets==0.2.0
|
|||||||
pillow==10.4.0
|
pillow==10.4.0
|
||||||
cairosvg==2.8.2
|
cairosvg==2.8.2
|
||||||
freetype-py==2.5.1
|
freetype-py==2.5.1
|
||||||
|
jinja2==3.1.6
|
||||||
|
|
||||||
# esp-idf requires this, but doesn't bundle it by default
|
# esp-idf requires this, but doesn't bundle it by default
|
||||||
# https://github.com/espressif/esp-idf/blob/220590d599e134d7a5e7f1e683cc4550349ffbf8/requirements.txt#L24
|
# https://github.com/espressif/esp-idf/blob/220590d599e134d7a5e7f1e683cc4550349ffbf8/requirements.txt#L24
|
||||||
|
@ -33,7 +33,18 @@ modbus_controller:
|
|||||||
read_lambda: |-
|
read_lambda: |-
|
||||||
return 42.3;
|
return 42.3;
|
||||||
max_cmd_retries: 0
|
max_cmd_retries: 0
|
||||||
|
- id: modbus_controller3
|
||||||
|
address: 0x3
|
||||||
|
modbus_id: mod_bus2
|
||||||
|
server_registers:
|
||||||
|
- address: 0x0009
|
||||||
|
value_type: S_DWORD
|
||||||
|
read_lambda: |-
|
||||||
|
return 31;
|
||||||
|
write_lambda: |-
|
||||||
|
printf("address=%d, value=%d", x);
|
||||||
|
return true;
|
||||||
|
max_cmd_retries: 0
|
||||||
binary_sensor:
|
binary_sensor:
|
||||||
- platform: modbus_controller
|
- platform: modbus_controller
|
||||||
modbus_controller_id: modbus_controller1
|
modbus_controller_id: modbus_controller1
|
||||||
|
1
tests/unit_tests/fixtures/substitutions/.gitignore
vendored
Normal file
1
tests/unit_tests/fixtures/substitutions/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*.received.yaml
|
@ -0,0 +1,19 @@
|
|||||||
|
substitutions:
|
||||||
|
var1: '1'
|
||||||
|
var2: '2'
|
||||||
|
var21: '79'
|
||||||
|
esphome:
|
||||||
|
name: test
|
||||||
|
test_list:
|
||||||
|
- '1'
|
||||||
|
- '1'
|
||||||
|
- '1'
|
||||||
|
- '1'
|
||||||
|
- 'Values: 1 2'
|
||||||
|
- 'Value: 79'
|
||||||
|
- 1 + 2
|
||||||
|
- 1 * 2
|
||||||
|
- 'Undefined var: ${undefined_var}'
|
||||||
|
- ${undefined_var}
|
||||||
|
- $undefined_var
|
||||||
|
- ${ undefined_var }
|
@ -0,0 +1,21 @@
|
|||||||
|
esphome:
|
||||||
|
name: test
|
||||||
|
|
||||||
|
substitutions:
|
||||||
|
var1: "1"
|
||||||
|
var2: "2"
|
||||||
|
var21: "79"
|
||||||
|
|
||||||
|
test_list:
|
||||||
|
- "$var1"
|
||||||
|
- "${var1}"
|
||||||
|
- $var1
|
||||||
|
- ${var1}
|
||||||
|
- "Values: $var1 ${var2}"
|
||||||
|
- "Value: ${var2${var1}}"
|
||||||
|
- "$var1 + $var2"
|
||||||
|
- "${ var1 } * ${ var2 }"
|
||||||
|
- "Undefined var: ${undefined_var}"
|
||||||
|
- ${undefined_var}
|
||||||
|
- $undefined_var
|
||||||
|
- ${ undefined_var }
|
@ -0,0 +1,15 @@
|
|||||||
|
substitutions:
|
||||||
|
var1: '1'
|
||||||
|
var2: '2'
|
||||||
|
a: alpha
|
||||||
|
test_list:
|
||||||
|
- values:
|
||||||
|
- var1: '1'
|
||||||
|
- a: A
|
||||||
|
- b: B-default
|
||||||
|
- c: The value of C is C
|
||||||
|
- values:
|
||||||
|
- var1: '1'
|
||||||
|
- a: alpha
|
||||||
|
- b: beta
|
||||||
|
- c: The value of C is $c
|
@ -0,0 +1,15 @@
|
|||||||
|
substitutions:
|
||||||
|
var1: "1"
|
||||||
|
var2: "2"
|
||||||
|
a: "alpha"
|
||||||
|
|
||||||
|
test_list:
|
||||||
|
- !include
|
||||||
|
file: inc1.yaml
|
||||||
|
vars:
|
||||||
|
a: "A"
|
||||||
|
c: "C"
|
||||||
|
- !include
|
||||||
|
file: inc1.yaml
|
||||||
|
vars:
|
||||||
|
b: "beta"
|
@ -0,0 +1,24 @@
|
|||||||
|
substitutions:
|
||||||
|
width: 7
|
||||||
|
height: 8
|
||||||
|
enabled: true
|
||||||
|
pin: &id001
|
||||||
|
number: 18
|
||||||
|
inverted: true
|
||||||
|
area: 25
|
||||||
|
numberOne: 1
|
||||||
|
var1: 79
|
||||||
|
test_list:
|
||||||
|
- The area is 56
|
||||||
|
- 56
|
||||||
|
- 56 + 1
|
||||||
|
- ENABLED
|
||||||
|
- list:
|
||||||
|
- 7
|
||||||
|
- 8
|
||||||
|
- width: 7
|
||||||
|
height: 8
|
||||||
|
- *id001
|
||||||
|
- The pin number is 18
|
||||||
|
- The square root is: 5.0
|
||||||
|
- The number is 80
|
@ -0,0 +1,22 @@
|
|||||||
|
substitutions:
|
||||||
|
width: 7
|
||||||
|
height: 8
|
||||||
|
enabled: true
|
||||||
|
pin:
|
||||||
|
number: 18
|
||||||
|
inverted: true
|
||||||
|
area: 25
|
||||||
|
numberOne: 1
|
||||||
|
var1: 79
|
||||||
|
|
||||||
|
test_list:
|
||||||
|
- "The area is ${width * height}"
|
||||||
|
- ${width * height}
|
||||||
|
- ${width * height} + 1
|
||||||
|
- ${enabled and "ENABLED" or "DISABLED"}
|
||||||
|
- list: ${ [width, height] }
|
||||||
|
- "${ {'width': width, 'height': height} }"
|
||||||
|
- ${pin}
|
||||||
|
- The pin number is ${pin.number}
|
||||||
|
- The square root is: ${math.sqrt(area)}
|
||||||
|
- The number is ${var${numberOne} + 1}
|
@ -0,0 +1,17 @@
|
|||||||
|
substitutions:
|
||||||
|
B: 5
|
||||||
|
var7: 79
|
||||||
|
package_result:
|
||||||
|
- The value of A*B is 35, where A is a package var and B is a substitution in the
|
||||||
|
root file
|
||||||
|
- Double substitution also works; the value of var7 is 79, where A is a package
|
||||||
|
var
|
||||||
|
local_results:
|
||||||
|
- The value of B is 5
|
||||||
|
- 'You will see, however, that
|
||||||
|
|
||||||
|
${A} is not substituted here, since
|
||||||
|
|
||||||
|
it is out of scope.
|
||||||
|
|
||||||
|
'
|
@ -0,0 +1,16 @@
|
|||||||
|
substitutions:
|
||||||
|
B: 5
|
||||||
|
var7: 79
|
||||||
|
|
||||||
|
packages:
|
||||||
|
closures_package: !include
|
||||||
|
file: closures_package.yaml
|
||||||
|
vars:
|
||||||
|
A: 7
|
||||||
|
|
||||||
|
local_results:
|
||||||
|
- The value of B is ${B}
|
||||||
|
- |
|
||||||
|
You will see, however, that
|
||||||
|
${A} is not substituted here, since
|
||||||
|
it is out of scope.
|
@ -0,0 +1,5 @@
|
|||||||
|
display:
|
||||||
|
- platform: ili9xxx
|
||||||
|
dimensions:
|
||||||
|
width: 960
|
||||||
|
height: 544
|
@ -0,0 +1,7 @@
|
|||||||
|
# main.yaml
|
||||||
|
packages:
|
||||||
|
my_display: !include
|
||||||
|
file: display.yaml
|
||||||
|
vars:
|
||||||
|
high_dpi: true
|
||||||
|
native_height: 272
|
@ -0,0 +1,3 @@
|
|||||||
|
package_result:
|
||||||
|
- The value of A*B is ${A * B}, where A is a package var and B is a substitution in the root file
|
||||||
|
- Double substitution also works; the value of var7 is ${var$A}, where A is a package var
|
11
tests/unit_tests/fixtures/substitutions/display.yaml
Normal file
11
tests/unit_tests/fixtures/substitutions/display.yaml
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# display.yaml
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
native_width: 480
|
||||||
|
native_height: 480
|
||||||
|
|
||||||
|
display:
|
||||||
|
- platform: ili9xxx
|
||||||
|
dimensions:
|
||||||
|
width: ${high_dpi and native_width * 2 or native_width}
|
||||||
|
height: ${high_dpi and native_height * 2 or native_height}
|
8
tests/unit_tests/fixtures/substitutions/inc1.yaml
Normal file
8
tests/unit_tests/fixtures/substitutions/inc1.yaml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
defaults:
|
||||||
|
b: "B-default"
|
||||||
|
|
||||||
|
values:
|
||||||
|
- var1: $var1
|
||||||
|
- a: $a
|
||||||
|
- b: ${b}
|
||||||
|
- c: The value of C is $c
|
125
tests/unit_tests/test_substitutions.py
Normal file
125
tests/unit_tests/test_substitutions.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from esphome import yaml_util
|
||||||
|
from esphome.components import substitutions
|
||||||
|
from esphome.const import CONF_PACKAGES
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Set to True for dev mode behavior
|
||||||
|
# This will generate the expected version of the test files.
|
||||||
|
|
||||||
|
DEV_MODE = False
|
||||||
|
|
||||||
|
|
||||||
|
def sort_dicts(obj):
|
||||||
|
"""Recursively sort dictionaries for order-insensitive comparison."""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: sort_dicts(obj[k]) for k in sorted(obj)}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
# Lists are not sorted; we preserve order
|
||||||
|
return [sort_dicts(i) for i in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def dict_diff(a, b, path=""):
|
||||||
|
"""Recursively find differences between two dict/list structures."""
|
||||||
|
diffs = []
|
||||||
|
if isinstance(a, dict) and isinstance(b, dict):
|
||||||
|
a_keys = set(a)
|
||||||
|
b_keys = set(b)
|
||||||
|
for key in a_keys - b_keys:
|
||||||
|
diffs.append(f"{path}/{key} only in actual")
|
||||||
|
for key in b_keys - a_keys:
|
||||||
|
diffs.append(f"{path}/{key} only in expected")
|
||||||
|
for key in a_keys & b_keys:
|
||||||
|
diffs.extend(dict_diff(a[key], b[key], f"{path}/{key}"))
|
||||||
|
elif isinstance(a, list) and isinstance(b, list):
|
||||||
|
min_len = min(len(a), len(b))
|
||||||
|
for i in range(min_len):
|
||||||
|
diffs.extend(dict_diff(a[i], b[i], f"{path}[{i}]"))
|
||||||
|
if len(a) > len(b):
|
||||||
|
for i in range(min_len, len(a)):
|
||||||
|
diffs.append(f"{path}[{i}] only in actual: {a[i]!r}")
|
||||||
|
elif len(b) > len(a):
|
||||||
|
for i in range(min_len, len(b)):
|
||||||
|
diffs.append(f"{path}[{i}] only in expected: {b[i]!r}")
|
||||||
|
else:
|
||||||
|
if a != b:
|
||||||
|
diffs.append(f"\t{path}: actual={a!r} expected={b!r}")
|
||||||
|
return diffs
|
||||||
|
|
||||||
|
|
||||||
|
def write_yaml(path, data):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(yaml_util.dump(data))
|
||||||
|
|
||||||
|
|
||||||
|
def test_substitutions_fixtures(fixture_path):
|
||||||
|
base_dir = fixture_path / "substitutions"
|
||||||
|
sources = sorted(glob.glob(str(base_dir / "*.input.yaml")))
|
||||||
|
assert sources, f"No input YAML files found in {base_dir}"
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
for source_path in sources:
|
||||||
|
try:
|
||||||
|
expected_path = source_path.replace(".input.yaml", ".approved.yaml")
|
||||||
|
test_case = os.path.splitext(os.path.basename(source_path))[0].replace(
|
||||||
|
".input", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load using ESPHome's YAML loader
|
||||||
|
config = yaml_util.load_yaml(source_path)
|
||||||
|
|
||||||
|
if CONF_PACKAGES in config:
|
||||||
|
from esphome.components.packages import do_packages_pass
|
||||||
|
|
||||||
|
config = do_packages_pass(config)
|
||||||
|
|
||||||
|
substitutions.do_substitution_pass(config, None)
|
||||||
|
|
||||||
|
# Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE
|
||||||
|
if os.path.isfile(expected_path):
|
||||||
|
expected = yaml_util.load_yaml(expected_path)
|
||||||
|
elif DEV_MODE:
|
||||||
|
expected = {}
|
||||||
|
else:
|
||||||
|
assert os.path.isfile(expected_path), (
|
||||||
|
f"Expected file missing: {expected_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort dicts only (not lists) for comparison
|
||||||
|
got_sorted = sort_dicts(config)
|
||||||
|
expected_sorted = sort_dicts(expected)
|
||||||
|
|
||||||
|
if got_sorted != expected_sorted:
|
||||||
|
diff = "\n".join(dict_diff(got_sorted, expected_sorted))
|
||||||
|
msg = (
|
||||||
|
f"Substitution result mismatch for {os.path.basename(source_path)}\n"
|
||||||
|
f"Diff:\n{diff}\n\n"
|
||||||
|
f"Got: {got_sorted}\n"
|
||||||
|
f"Expected: {expected_sorted}"
|
||||||
|
)
|
||||||
|
# Write out the received file when test fails
|
||||||
|
if DEV_MODE:
|
||||||
|
received_path = os.path.join(
|
||||||
|
os.path.dirname(source_path), f"{test_case}.received.yaml"
|
||||||
|
)
|
||||||
|
write_yaml(received_path, config)
|
||||||
|
print(msg)
|
||||||
|
failures.append(msg)
|
||||||
|
else:
|
||||||
|
raise AssertionError(msg)
|
||||||
|
except Exception as err:
|
||||||
|
_LOGGER.error("Error in test file %s", source_path)
|
||||||
|
raise err
|
||||||
|
|
||||||
|
if DEV_MODE and failures:
|
||||||
|
print(f"\n{len(failures)} substitution test case(s) failed.")
|
||||||
|
|
||||||
|
if DEV_MODE:
|
||||||
|
_LOGGER.error("Tests passed, but Dev mode is enabled.")
|
||||||
|
assert not DEV_MODE # make sure DEV_MODE is disabled after you are finished.
|
Loading…
x
Reference in New Issue
Block a user