diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index b7624221c9..e83d508c50 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1687,7 +1687,9 @@ void APIConnection::DeferredBatch::add_item(EntityBase *entity, MessageCreator c // O(n) but optimized for RAM and not performance. for (auto &item : items) { if (item.entity == entity && item.message_type == message_type) { - // Update the existing item with the new creator + // Clean up old creator before replacing + item.creator.cleanup(message_type); + // Move assign the new creator item.creator = std::move(creator); return; } @@ -1730,11 +1732,11 @@ void APIConnection::process_batch_() { return; } - size_t num_items = this->deferred_batch_.items.size(); + size_t num_items = this->deferred_batch_.size(); // Fast path for single message - allocate exact size needed if (num_items == 1) { - const auto &item = this->deferred_batch_.items[0]; + const auto &item = this->deferred_batch_[0]; // Let the creator calculate size and encode if it fits uint16_t payload_size = @@ -1764,7 +1766,8 @@ void APIConnection::process_batch_() { // Pre-calculate exact buffer size needed based on message types uint32_t total_estimated_size = 0; - for (const auto &item : this->deferred_batch_.items) { + for (size_t i = 0; i < this->deferred_batch_.size(); i++) { + const auto &item = this->deferred_batch_[i]; total_estimated_size += get_estimated_message_size(item.message_type); } @@ -1785,7 +1788,8 @@ void APIConnection::process_batch_() { uint32_t current_offset = 0; // Process items and encode directly to buffer - for (const auto &item : this->deferred_batch_.items) { + for (size_t i = 0; i < this->deferred_batch_.size(); i++) { + const auto &item = this->deferred_batch_[i]; // Try to encode message // The creator will calculate overhead to determine if the message fits uint16_t payload_size = item.creator(item.entity, this, remaining_size, false, item.message_type); @@ -1840,17 +1844,15 @@ void APIConnection::process_batch_() { // Log messages after send attempt for VV debugging // It's safe to use the buffer for logging at this point regardless of send result for (size_t i = 0; i < items_processed; i++) { - const auto &item = this->deferred_batch_.items[i]; + const auto &item = this->deferred_batch_[i]; this->log_batch_item_(item); } #endif // Handle remaining items more efficiently - if (items_processed < this->deferred_batch_.items.size()) { - // Remove processed items from the beginning - this->deferred_batch_.items.erase(this->deferred_batch_.items.begin(), - this->deferred_batch_.items.begin() + items_processed); - + if (items_processed < this->deferred_batch_.size()) { + // Remove processed items from the beginning with proper cleanup + this->deferred_batch_.remove_front(items_processed); // Reschedule for remaining items this->schedule_batch_(); } else { @@ -1861,23 +1863,16 @@ void APIConnection::process_batch_() { uint16_t APIConnection::MessageCreator::operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single, uint16_t message_type) const { - if (has_tagged_string_ptr_()) { - // Handle string-based messages - switch (message_type) { #ifdef USE_EVENT - case EventResponse::MESSAGE_TYPE: { - auto *e = static_cast(entity); - return APIConnection::try_send_event_response(e, *get_string_ptr_(), conn, remaining_size, is_single); - } -#endif - default: - // Should not happen, return 0 to indicate no message - return 0; - } - } else { - // Function pointer case - return data_.ptr(entity, conn, remaining_size, is_single); + // Special case: EventResponse uses string pointer + if (message_type == EventResponse::MESSAGE_TYPE) { + auto *e = static_cast(entity); + return APIConnection::try_send_event_response(e, *data_.string_ptr, conn, remaining_size, is_single); } +#endif + + // All other message types use function pointers + return data_.function_ptr(entity, conn, remaining_size, is_single); } uint16_t APIConnection::try_send_list_info_done(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 410a9ad3a5..151369aa70 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -451,96 +451,53 @@ class APIConnection : public APIServerConnection { // Function pointer type for message encoding using MessageCreatorPtr = uint16_t (*)(EntityBase *, APIConnection *, uint32_t remaining_size, bool is_single); - // Optimized MessageCreator class using tagged pointer class MessageCreator { - // Ensure pointer alignment allows LSB tagging - static_assert(alignof(std::string *) > 1, "String pointer alignment must be > 1 for LSB tagging"); - public: // Constructor for function pointer - MessageCreator(MessageCreatorPtr ptr) { - // Function pointers are always aligned, so LSB is 0 - data_.ptr = ptr; - } + MessageCreator(MessageCreatorPtr ptr) { data_.function_ptr = ptr; } // Constructor for string state capture - explicit MessageCreator(const std::string &str_value) { - // Allocate string and tag the pointer - auto *str = new std::string(str_value); - // Set LSB to 1 to indicate string pointer - data_.tagged = reinterpret_cast(str) | 1; - } + explicit MessageCreator(const std::string &str_value) { data_.string_ptr = new std::string(str_value); } - // Destructor - ~MessageCreator() { - if (has_tagged_string_ptr_()) { - delete get_string_ptr_(); - } - } + // No destructor - cleanup must be called explicitly with message_type - // Copy constructor - MessageCreator(const MessageCreator &other) { - if (other.has_tagged_string_ptr_()) { - auto *str = new std::string(*other.get_string_ptr_()); - data_.tagged = reinterpret_cast(str) | 1; - } else { - data_ = other.data_; - } - } + // Delete copy operations - MessageCreator should only be moved + MessageCreator(const MessageCreator &other) = delete; + MessageCreator &operator=(const MessageCreator &other) = delete; // Move constructor - MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.ptr = nullptr; } - - // Assignment operators (needed for batch deduplication) - MessageCreator &operator=(const MessageCreator &other) { - if (this != &other) { - // Clean up current string data if needed - if (has_tagged_string_ptr_()) { - delete get_string_ptr_(); - } - // Copy new data - if (other.has_tagged_string_ptr_()) { - auto *str = new std::string(*other.get_string_ptr_()); - data_.tagged = reinterpret_cast(str) | 1; - } else { - data_ = other.data_; - } - } - return *this; - } + MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.function_ptr = nullptr; } + // Move assignment MessageCreator &operator=(MessageCreator &&other) noexcept { if (this != &other) { - // Clean up current string data if needed - if (has_tagged_string_ptr_()) { - delete get_string_ptr_(); - } - // Move data + // IMPORTANT: Caller must ensure cleanup() was called if this contains a string! + // In our usage, this happens in add_item() deduplication and vector::erase() data_ = other.data_; - // Reset other to safe state - other.data_.ptr = nullptr; + other.data_.function_ptr = nullptr; } return *this; } - // Call operator - now accepts message_type as parameter + // Call operator - uses message_type to determine union type uint16_t operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single, uint16_t message_type) const; - private: - // Check if this contains a string pointer - bool has_tagged_string_ptr_() const { return (data_.tagged & 1) != 0; } - - // Get the actual string pointer (clears the tag bit) - std::string *get_string_ptr_() const { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - return reinterpret_cast(data_.tagged & ~uintptr_t(1)); + // Manual cleanup method - must be called before destruction for string types + void cleanup(uint16_t message_type) { +#ifdef USE_EVENT + if (message_type == EventResponse::MESSAGE_TYPE && data_.string_ptr != nullptr) { + delete data_.string_ptr; + data_.string_ptr = nullptr; + } +#endif } - union { - MessageCreatorPtr ptr; - uintptr_t tagged; - } data_; // 4 bytes on 32-bit + private: + union Data { + MessageCreatorPtr function_ptr; + std::string *string_ptr; + } data_; // 4 bytes on 32-bit, 8 bytes on 64-bit - same as before }; // Generic batching mechanism for both state updates and entity info @@ -558,20 +515,46 @@ class APIConnection : public APIServerConnection { std::vector items; uint32_t batch_start_time{0}; + private: + // Helper to cleanup items from the beginning + void cleanup_items_(size_t count) { + for (size_t i = 0; i < count; i++) { + items[i].creator.cleanup(items[i].message_type); + } + } + + public: DeferredBatch() { // Pre-allocate capacity for typical batch sizes to avoid reallocation items.reserve(8); } + ~DeferredBatch() { + // Ensure cleanup of any remaining items + clear(); + } + // Add item to the batch void add_item(EntityBase *entity, MessageCreator creator, uint16_t message_type); // Add item to the front of the batch (for high priority messages like ping) void add_item_front(EntityBase *entity, MessageCreator creator, uint16_t message_type); + + // Clear all items with proper cleanup void clear() { + cleanup_items_(items.size()); items.clear(); batch_start_time = 0; } + + // Remove processed items from the front with proper cleanup + void remove_front(size_t count) { + cleanup_items_(count); + items.erase(items.begin(), items.begin() + count); + } + bool empty() const { return items.empty(); } + size_t size() const { return items.size(); } + const BatchItem &operator[](size_t index) const { return items[index]; } }; // DeferredBatch here (16 bytes, 4-byte aligned) diff --git a/esphome/components/host/__init__.py b/esphome/components/host/__init__.py index da75873eaf..d3dbcba6ed 100644 --- a/esphome/components/host/__init__.py +++ b/esphome/components/host/__init__.py @@ -44,3 +44,4 @@ async def to_code(config): cg.add_build_flag("-std=gnu++20") cg.add_define("ESPHOME_BOARD", "host") cg.add_platformio_option("platform", "platformio/native") + cg.add_platformio_option("lib_ldf_mode", "off") diff --git a/esphome/components/modbus/modbus.cpp b/esphome/components/modbus/modbus.cpp index c2efa93fae..6350f43ef6 100644 --- a/esphome/components/modbus/modbus.cpp +++ b/esphome/components/modbus/modbus.cpp @@ -90,15 +90,24 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { } else { // data starts at 2 and length is 4 for read registers commands - if (this->role == ModbusRole::SERVER && (function_code == 0x1 || function_code == 0x3 || function_code == 0x4)) { - data_offset = 2; - data_len = 4; - } - - // the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands - if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) { - data_offset = 2; - data_len = 4; + if (this->role == ModbusRole::SERVER) { + if (function_code == 0x1 || function_code == 0x3 || function_code == 0x4 || function_code == 0x6) { + data_offset = 2; + data_len = 4; + } else if (function_code == 0x10) { + if (at < 6) { + return true; + } + data_offset = 2; + // starting address (2 bytes) + quantity of registers (2 bytes) + byte count itself (1 byte) + actual byte count + data_len = 2 + 2 + 1 + raw[6]; + } + } else { + // the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands + if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) { + data_offset = 2; + data_len = 4; + } } // Error ( msb indicates error ) @@ -132,6 +141,7 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { bool found = false; for (auto *device : this->devices_) { if (device->address_ == address) { + found = true; // Is it an error response? if ((function_code & 0x80) == 0x80) { ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]); @@ -141,13 +151,21 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { // Ignore modbus exception not related to a pending command ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response"); } - } else if (this->role == ModbusRole::SERVER && (function_code == 0x3 || function_code == 0x4)) { - device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8), - uint16_t(data[3]) | (uint16_t(data[2]) << 8)); - } else { - device->on_modbus_data(data); + continue; } - found = true; + if (this->role == ModbusRole::SERVER) { + if (function_code == 0x3 || function_code == 0x4) { + device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8), + uint16_t(data[3]) | (uint16_t(data[2]) << 8)); + continue; + } + if (function_code == 0x6 || function_code == 0x10) { + device->on_modbus_write_registers(function_code, data); + continue; + } + } + // fallthrough for other function codes + device->on_modbus_data(data); } } waiting_for_response = 0; diff --git a/esphome/components/modbus/modbus.h b/esphome/components/modbus/modbus.h index aebdbccc78..ec35612690 100644 --- a/esphome/components/modbus/modbus.h +++ b/esphome/components/modbus/modbus.h @@ -59,6 +59,7 @@ class ModbusDevice { virtual void on_modbus_data(const std::vector &data) = 0; virtual void on_modbus_error(uint8_t function_code, uint8_t exception_code) {} virtual void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers){}; + virtual void on_modbus_write_registers(uint8_t function_code, const std::vector &data){}; void send(uint8_t function, uint16_t start_address, uint16_t number_of_entities, uint8_t payload_len = 0, const uint8_t *payload = nullptr) { this->parent_->send(this->address_, function, start_address, number_of_entities, payload_len, payload); diff --git a/esphome/components/modbus_controller/__init__.py b/esphome/components/modbus_controller/__init__.py index 8079b824b0..5ab82f5e17 100644 --- a/esphome/components/modbus_controller/__init__.py +++ b/esphome/components/modbus_controller/__init__.py @@ -39,6 +39,7 @@ CODEOWNERS = ["@martgras"] AUTO_LOAD = ["modbus"] CONF_READ_LAMBDA = "read_lambda" +CONF_WRITE_LAMBDA = "write_lambda" CONF_SERVER_REGISTERS = "server_registers" MULTI_CONF = True @@ -148,6 +149,7 @@ ModbusServerRegisterSchema = cv.Schema( cv.Required(CONF_ADDRESS): cv.positive_int, cv.Optional(CONF_VALUE_TYPE, default="U_WORD"): cv.enum(SENSOR_VALUE_TYPE), cv.Required(CONF_READ_LAMBDA): cv.returning_lambda, + cv.Optional(CONF_WRITE_LAMBDA): cv.returning_lambda, } ) @@ -318,6 +320,17 @@ async def to_code(config): ), ) ) + if CONF_WRITE_LAMBDA in server_register: + cg.add( + server_register_var.set_write_lambda( + cg.TemplateArguments(cpp_type), + await cg.process_lambda( + server_register[CONF_WRITE_LAMBDA], + parameters=[(cg.uint16, "address"), (cpp_type, "x")], + return_type=cg.bool_, + ), + ) + ) cg.add(var.add_server_register(server_register_var)) await register_modbus_device(var, config) for conf in config.get(CONF_ON_COMMAND_SENT, []): diff --git a/esphome/components/modbus_controller/modbus_controller.cpp b/esphome/components/modbus_controller/modbus_controller.cpp index 81e9ccf0a6..0f3ddf920d 100644 --- a/esphome/components/modbus_controller/modbus_controller.cpp +++ b/esphome/components/modbus_controller/modbus_controller.cpp @@ -152,6 +152,86 @@ void ModbusController::on_modbus_read_registers(uint8_t function_code, uint16_t this->send(function_code, start_address, number_of_registers, response.size(), response.data()); } +void ModbusController::on_modbus_write_registers(uint8_t function_code, const std::vector &data) { + uint16_t number_of_registers; + uint16_t payload_offset; + + if (function_code == 0x10) { + number_of_registers = uint16_t(data[3]) | (uint16_t(data[2]) << 8); + if (number_of_registers == 0 || number_of_registers > 0x7B) { + ESP_LOGW(TAG, "Invalid number of registers %d. Sending exception response.", number_of_registers); + send_error(function_code, 3); + return; + } + uint16_t payload_size = data[4]; + if (payload_size != number_of_registers * 2) { + ESP_LOGW(TAG, "Payload size of %d bytes is not 2 times the number of registers (%d). Sending exception response.", + payload_size, number_of_registers); + send_error(function_code, 3); + return; + } + payload_offset = 5; + } else if (function_code == 0x06) { + number_of_registers = 1; + payload_offset = 2; + } else { + ESP_LOGW(TAG, "Invalid function code 0x%X. Sending exception response.", function_code); + send_error(function_code, 1); + return; + } + + uint16_t start_address = uint16_t(data[1]) | (uint16_t(data[0]) << 8); + ESP_LOGD(TAG, + "Received write holding registers for device 0x%X. FC: 0x%X. Start address: 0x%X. Number of registers: " + "0x%X.", + this->address_, function_code, start_address, number_of_registers); + + auto for_each_register = [this, start_address, number_of_registers, payload_offset]( + const std::function &callback) -> bool { + uint16_t offset = payload_offset; + for (uint16_t current_address = start_address; current_address < start_address + number_of_registers;) { + bool ok = false; + for (auto *server_register : this->server_registers_) { + if (server_register->address == current_address) { + ok = callback(server_register, offset); + current_address += server_register->register_count; + offset += server_register->register_count * sizeof(uint16_t); + break; + } + } + + if (!ok) { + return false; + } + } + return true; + }; + + // check all registers are writable before writing to any of them: + if (!for_each_register([](ServerRegister *server_register, uint16_t offset) -> bool { + return server_register->write_lambda != nullptr; + })) { + send_error(function_code, 1); + return; + } + + // Actually write to the registers: + if (!for_each_register([&data](ServerRegister *server_register, uint16_t offset) { + int64_t number = payload_to_number(data, server_register->value_type, offset, 0xFFFFFFFF); + return server_register->write_lambda(number); + })) { + send_error(function_code, 4); + return; + } + + std::vector response; + response.reserve(6); + response.push_back(this->address_); + response.push_back(function_code); + response.insert(response.end(), data.begin(), data.begin() + 4); + this->send_raw(response); +} + SensorSet ModbusController::find_sensors_(ModbusRegisterType register_type, uint16_t start_address) const { auto reg_it = std::find_if( std::begin(this->register_ranges_), std::end(this->register_ranges_), diff --git a/esphome/components/modbus_controller/modbus_controller.h b/esphome/components/modbus_controller/modbus_controller.h index 11d27c4025..a86ad1ccb5 100644 --- a/esphome/components/modbus_controller/modbus_controller.h +++ b/esphome/components/modbus_controller/modbus_controller.h @@ -258,6 +258,7 @@ class SensorItem { class ServerRegister { using ReadLambda = std::function; + using WriteLambda = std::function; public: ServerRegister(uint16_t address, SensorValueType value_type, uint8_t register_count) { @@ -277,6 +278,17 @@ class ServerRegister { }; } + template + void set_write_lambda(const std::function &&user_write_lambda) { + this->write_lambda = [this, user_write_lambda](int64_t number) { + if constexpr (std::is_same_v) { + float float_value = bit_cast(static_cast(number)); + return user_write_lambda(this->address, float_value); + } + return user_write_lambda(this->address, static_cast(number)); + }; + } + // Formats a raw value into a string representation based on the value type for debugging std::string format_value(int64_t value) const { switch (this->value_type) { @@ -304,6 +316,7 @@ class ServerRegister { SensorValueType value_type{SensorValueType::RAW}; uint8_t register_count{0}; ReadLambda read_lambda; + WriteLambda write_lambda; }; // ModbusController::create_register_ranges_ tries to optimize register range @@ -485,6 +498,8 @@ class ModbusController : public PollingComponent, public modbus::ModbusDevice { void on_modbus_error(uint8_t function_code, uint8_t exception_code) override; /// called when a modbus request (function code 0x03 or 0x04) was parsed without errors void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers) final; + /// called when a modbus request (function code 0x06 or 0x10) was parsed without errors + void on_modbus_write_registers(uint8_t function_code, const std::vector &data) final; /// default delegate called by process_modbus_data when a response has retrieved from the incoming queue void on_register_data(ModbusRegisterType register_type, uint16_t start_address, const std::vector &data); /// default delegate called by process_modbus_data when a response for a write response has retrieved from the diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index 08ae798282..6eb746ec63 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -74,7 +74,7 @@ BASE_SCHEMA = cv.All( { cv.Required(CONF_PATH): validate_yaml_filename, cv.Optional(CONF_VARS, default={}): cv.Schema( - {cv.string: cv.string} + {cv.string: object} ), } ), @@ -148,7 +148,6 @@ def _process_base_package(config: dict) -> dict: raise cv.Invalid( f"Current ESPHome Version is too old to use this package: {ESPHOME_VERSION} < {min_version}" ) - vars = {k: str(v) for k, v in vars.items()} new_yaml = yaml_util.substitute_vars(new_yaml, vars) packages[f"{filename}{idx}"] = new_yaml except EsphomeError as e: diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index 41e49f70db..5878af43b2 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -5,6 +5,13 @@ from esphome.config_helpers import Extend, Remove, merge_config import esphome.config_validation as cv from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS from esphome.yaml_util import ESPHomeDataBase, make_data_base +from .jinja import ( + Jinja, + JinjaStr, + has_jinja, + TemplateError, + TemplateRuntimeError, +) CODEOWNERS = ["@esphome/core"] _LOGGER = logging.getLogger(__name__) @@ -28,7 +35,7 @@ def validate_substitution_key(value): CONFIG_SCHEMA = cv.Schema( { - validate_substitution_key: cv.string_strict, + validate_substitution_key: object, } ) @@ -37,7 +44,42 @@ async def to_code(config): pass -def _expand_substitutions(substitutions, value, path, ignore_missing): +def _expand_jinja(value, orig_value, path, jinja, ignore_missing): + if has_jinja(value): + # If the original value passed in to this function is a JinjaStr, it means it contains an unresolved + # Jinja expression from a previous pass. + if isinstance(orig_value, JinjaStr): + # Rebuild the JinjaStr in case it was lost while replacing substitutions. + value = JinjaStr(value, orig_value.upvalues) + try: + # Invoke the jinja engine to evaluate the expression. + value, err = jinja.expand(value) + if err is not None: + if not ignore_missing and "password" not in path: + _LOGGER.warning( + "Found '%s' (see %s) which looks like an expression," + " but could not resolve all the variables: %s", + value, + "->".join(str(x) for x in path), + err.message, + ) + except ( + TemplateError, + TemplateRuntimeError, + RuntimeError, + ArithmeticError, + AttributeError, + TypeError, + ) as err: + raise cv.Invalid( + f"{type(err).__name__} Error evaluating jinja expression '{value}': {str(err)}." + f" See {'->'.join(str(x) for x in path)}", + path, + ) + return value + + +def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): if "$" not in value: return value @@ -47,7 +89,8 @@ def _expand_substitutions(substitutions, value, path, ignore_missing): while True: m = cv.VARIABLE_PROG.search(value, i) if not m: - # Nothing more to match. Done + # No more variable substitutions found. See if the remainder looks like a jinja template + value = _expand_jinja(value, orig_value, path, jinja, ignore_missing) break i, j = m.span(0) @@ -67,8 +110,15 @@ def _expand_substitutions(substitutions, value, path, ignore_missing): continue sub = substitutions[name] + + if i == 0 and j == len(value): + # The variable spans the whole expression, e.g., "${varName}". Return its resolved value directly + # to conserve its type. + value = sub + break + tail = value[j:] - value = value[:i] + sub + value = value[:i] + str(sub) i = len(value) value += tail @@ -77,36 +127,40 @@ def _expand_substitutions(substitutions, value, path, ignore_missing): if isinstance(orig_value, ESPHomeDataBase): # even though string can get larger or smaller, the range should point # to original document marks - return make_data_base(value, orig_value) + value = make_data_base(value, orig_value) return value -def _substitute_item(substitutions, item, path, ignore_missing): +def _substitute_item(substitutions, item, path, jinja, ignore_missing): if isinstance(item, list): for i, it in enumerate(item): - sub = _substitute_item(substitutions, it, path + [i], ignore_missing) + sub = _substitute_item(substitutions, it, path + [i], jinja, ignore_missing) if sub is not None: item[i] = sub elif isinstance(item, dict): replace_keys = [] for k, v in item.items(): if path or k != CONF_SUBSTITUTIONS: - sub = _substitute_item(substitutions, k, path + [k], ignore_missing) + sub = _substitute_item( + substitutions, k, path + [k], jinja, ignore_missing + ) if sub is not None: replace_keys.append((k, sub)) - sub = _substitute_item(substitutions, v, path + [k], ignore_missing) + sub = _substitute_item(substitutions, v, path + [k], jinja, ignore_missing) if sub is not None: item[k] = sub for old, new in replace_keys: item[new] = merge_config(item.get(old), item.get(new)) del item[old] elif isinstance(item, str): - sub = _expand_substitutions(substitutions, item, path, ignore_missing) - if sub != item: + sub = _expand_substitutions(substitutions, item, path, jinja, ignore_missing) + if isinstance(sub, JinjaStr) or sub != item: return sub elif isinstance(item, (core.Lambda, Extend, Remove)): - sub = _expand_substitutions(substitutions, item.value, path, ignore_missing) + sub = _expand_substitutions( + substitutions, item.value, path, jinja, ignore_missing + ) if sub != item: item.value = sub return None @@ -116,11 +170,11 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals if CONF_SUBSTITUTIONS not in config and not command_line_substitutions: return - substitutions = config.get(CONF_SUBSTITUTIONS) - if substitutions is None: - substitutions = command_line_substitutions - elif command_line_substitutions: - substitutions = {**substitutions, **command_line_substitutions} + # Merge substitutions in config, overriding with substitutions coming from command line: + substitutions = { + **config.get(CONF_SUBSTITUTIONS, {}), + **(command_line_substitutions or {}), + } with cv.prepend_path("substitutions"): if not isinstance(substitutions, dict): raise cv.Invalid( @@ -133,7 +187,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals sub = validate_substitution_key(key) if sub != key: replace_keys.append((key, sub)) - substitutions[key] = cv.string_strict(value) + substitutions[key] = value for old, new in replace_keys: substitutions[new] = substitutions[old] del substitutions[old] @@ -141,4 +195,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals config[CONF_SUBSTITUTIONS] = substitutions # Move substitutions to the first place to replace substitutions in them correctly config.move_to_end(CONF_SUBSTITUTIONS, False) - _substitute_item(substitutions, config, [], ignore_missing) + + # Create a Jinja environment that will consider substitutions in scope: + jinja = Jinja(substitutions) + _substitute_item(substitutions, config, [], jinja, ignore_missing) diff --git a/esphome/components/substitutions/jinja.py b/esphome/components/substitutions/jinja.py new file mode 100644 index 0000000000..9ecdbab844 --- /dev/null +++ b/esphome/components/substitutions/jinja.py @@ -0,0 +1,99 @@ +import logging +import math +import re +import jinja2 as jinja +from jinja2.nativetypes import NativeEnvironment + +TemplateError = jinja.TemplateError +TemplateSyntaxError = jinja.TemplateSyntaxError +TemplateRuntimeError = jinja.TemplateRuntimeError +UndefinedError = jinja.UndefinedError +Undefined = jinja.Undefined + +_LOGGER = logging.getLogger(__name__) + +DETECT_JINJA = r"(\$\{)" +detect_jinja_re = re.compile( + r"<%.+?%>" # Block form expression: <% ... %> + r"|\$\{[^}]+\}", # Braced form expression: ${ ... } + flags=re.MULTILINE, +) + + +def has_jinja(st): + return detect_jinja_re.search(st) is not None + + +class JinjaStr(str): + """ + Wraps a string containing an unresolved Jinja expression, + storing the variables visible to it when it failed to resolve. + For example, an expression inside a package, `${ A * B }` may fail + to resolve at package parsing time if `A` is a local package var + but `B` is a substitution defined in the root yaml. + Therefore, we store the value of `A` as an upvalue bound + to the original string so we may be able to resolve `${ A * B }` + later in the main substitutions pass. + """ + + def __new__(cls, value: str, upvalues=None): + obj = super().__new__(cls, value) + obj.upvalues = upvalues or {} + return obj + + def __init__(self, value: str, upvalues=None): + self.upvalues = upvalues or {} + + +class Jinja: + """ + Wraps a Jinja environment + """ + + def __init__(self, context_vars): + self.env = NativeEnvironment( + trim_blocks=True, + lstrip_blocks=True, + block_start_string="<%", + block_end_string="%>", + line_statement_prefix="#", + line_comment_prefix="##", + variable_start_string="${", + variable_end_string="}", + undefined=jinja.StrictUndefined, + ) + self.env.add_extension("jinja2.ext.do") + self.env.globals["math"] = math # Inject entire math module + self.context_vars = {**context_vars} + self.env.globals = {**self.env.globals, **self.context_vars} + + def expand(self, content_str): + """ + Renders a string that may contain Jinja expressions or statements + Returns the resulting processed string if all values could be resolved. + Otherwise, it returns a tagged (JinjaStr) string that captures variables + in scope (upvalues), like a closure for later evaluation. + """ + result = None + override_vars = {} + if isinstance(content_str, JinjaStr): + # If `value` is already a JinjaStr, it means we are trying to evaluate it again + # in a parent pass. + # Hopefully, all required variables are visible now. + override_vars = content_str.upvalues + try: + template = self.env.from_string(content_str) + result = template.render(override_vars) + if isinstance(result, Undefined): + # This happens when the expression is simply an undefined variable. Jinja does not + # raise an exception, instead we get "Undefined". + # Trigger an UndefinedError exception so we skip to below. + print("" + result) + except (TemplateSyntaxError, UndefinedError) as err: + # `content_str` contains a Jinja expression that refers to a variable that is undefined + # in this scope. Perhaps it refers to a root substitution that is not visible yet. + # Therefore, return the original `content_str` as a JinjaStr, which contains the variables + # that are actually visible to it at this point to postpone evaluation. + return JinjaStr(content_str, {**self.context_vars, **override_vars}), err + + return result, None diff --git a/esphome/config.py b/esphome/config.py index ca3686a0e6..73cc7657cc 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -789,7 +789,6 @@ def validate_config( result.add_output_path([CONF_SUBSTITUTIONS], CONF_SUBSTITUTIONS) try: substitutions.do_substitution_pass(config, command_line_substitutions) - substitutions.do_substitution_pass(config, command_line_substitutions) except vol.Invalid as err: result.add_error(err) return result diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index bd1806affc..e52fc9e788 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -292,8 +292,6 @@ class ESPHomeLoaderMixin: if file is None: raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark) vars = fields.get(CONF_VARS) - if vars: - vars = {k: str(v) for k, v in vars.items()} return file, vars if isinstance(node, yaml.nodes.MappingNode): diff --git a/requirements.txt b/requirements.txt index 12f3b84359..1010a311d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ esphome-glyphsets==0.2.0 pillow==10.4.0 cairosvg==2.8.2 freetype-py==2.5.1 +jinja2==3.1.6 # esp-idf requires this, but doesn't bundle it by default # https://github.com/espressif/esp-idf/blob/220590d599e134d7a5e7f1e683cc4550349ffbf8/requirements.txt#L24 diff --git a/tests/components/modbus_controller/common.yaml b/tests/components/modbus_controller/common.yaml index 7fa9f8dae3..7d342ee353 100644 --- a/tests/components/modbus_controller/common.yaml +++ b/tests/components/modbus_controller/common.yaml @@ -33,7 +33,18 @@ modbus_controller: read_lambda: |- return 42.3; max_cmd_retries: 0 - + - id: modbus_controller3 + address: 0x3 + modbus_id: mod_bus2 + server_registers: + - address: 0x0009 + value_type: S_DWORD + read_lambda: |- + return 31; + write_lambda: |- + printf("address=%d, value=%d", x); + return true; + max_cmd_retries: 0 binary_sensor: - platform: modbus_controller modbus_controller_id: modbus_controller1 diff --git a/tests/unit_tests/fixtures/substitutions/.gitignore b/tests/unit_tests/fixtures/substitutions/.gitignore new file mode 100644 index 0000000000..0b15cdb2b7 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/.gitignore @@ -0,0 +1 @@ +*.received.yaml \ No newline at end of file diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml new file mode 100644 index 0000000000..c031399c37 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml @@ -0,0 +1,19 @@ +substitutions: + var1: '1' + var2: '2' + var21: '79' +esphome: + name: test +test_list: + - '1' + - '1' + - '1' + - '1' + - 'Values: 1 2' + - 'Value: 79' + - 1 + 2 + - 1 * 2 + - 'Undefined var: ${undefined_var}' + - ${undefined_var} + - $undefined_var + - ${ undefined_var } diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml new file mode 100644 index 0000000000..88a4ffb991 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml @@ -0,0 +1,21 @@ +esphome: + name: test + +substitutions: + var1: "1" + var2: "2" + var21: "79" + +test_list: + - "$var1" + - "${var1}" + - $var1 + - ${var1} + - "Values: $var1 ${var2}" + - "Value: ${var2${var1}}" + - "$var1 + $var2" + - "${ var1 } * ${ var2 }" + - "Undefined var: ${undefined_var}" + - ${undefined_var} + - $undefined_var + - ${ undefined_var } diff --git a/tests/unit_tests/fixtures/substitutions/01-include.approved.yaml b/tests/unit_tests/fixtures/substitutions/01-include.approved.yaml new file mode 100644 index 0000000000..a812fedcfd --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/01-include.approved.yaml @@ -0,0 +1,15 @@ +substitutions: + var1: '1' + var2: '2' + a: alpha +test_list: + - values: + - var1: '1' + - a: A + - b: B-default + - c: The value of C is C + - values: + - var1: '1' + - a: alpha + - b: beta + - c: The value of C is $c diff --git a/tests/unit_tests/fixtures/substitutions/01-include.input.yaml b/tests/unit_tests/fixtures/substitutions/01-include.input.yaml new file mode 100644 index 0000000000..d3daa681a4 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/01-include.input.yaml @@ -0,0 +1,15 @@ +substitutions: + var1: "1" + var2: "2" + a: "alpha" + +test_list: + - !include + file: inc1.yaml + vars: + a: "A" + c: "C" + - !include + file: inc1.yaml + vars: + b: "beta" diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml new file mode 100644 index 0000000000..9e401ec5d6 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml @@ -0,0 +1,24 @@ +substitutions: + width: 7 + height: 8 + enabled: true + pin: &id001 + number: 18 + inverted: true + area: 25 + numberOne: 1 + var1: 79 +test_list: + - The area is 56 + - 56 + - 56 + 1 + - ENABLED + - list: + - 7 + - 8 + - width: 7 + height: 8 + - *id001 + - The pin number is 18 + - The square root is: 5.0 + - The number is 80 diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml new file mode 100644 index 0000000000..1777b46f67 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml @@ -0,0 +1,22 @@ +substitutions: + width: 7 + height: 8 + enabled: true + pin: + number: 18 + inverted: true + area: 25 + numberOne: 1 + var1: 79 + +test_list: + - "The area is ${width * height}" + - ${width * height} + - ${width * height} + 1 + - ${enabled and "ENABLED" or "DISABLED"} + - list: ${ [width, height] } + - "${ {'width': width, 'height': height} }" + - ${pin} + - The pin number is ${pin.number} + - The square root is: ${math.sqrt(area)} + - The number is ${var${numberOne} + 1} diff --git a/tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml b/tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml new file mode 100644 index 0000000000..c8f7d9976c --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml @@ -0,0 +1,17 @@ +substitutions: + B: 5 + var7: 79 +package_result: + - The value of A*B is 35, where A is a package var and B is a substitution in the + root file + - Double substitution also works; the value of var7 is 79, where A is a package + var +local_results: + - The value of B is 5 + - 'You will see, however, that + + ${A} is not substituted here, since + + it is out of scope. + + ' diff --git a/tests/unit_tests/fixtures/substitutions/03-closures.input.yaml b/tests/unit_tests/fixtures/substitutions/03-closures.input.yaml new file mode 100644 index 0000000000..e0b2c39e52 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/03-closures.input.yaml @@ -0,0 +1,16 @@ +substitutions: + B: 5 + var7: 79 + +packages: + closures_package: !include + file: closures_package.yaml + vars: + A: 7 + +local_results: + - The value of B is ${B} + - | + You will see, however, that + ${A} is not substituted here, since + it is out of scope. diff --git a/tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml b/tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml new file mode 100644 index 0000000000..f559181b45 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml @@ -0,0 +1,5 @@ +display: + - platform: ili9xxx + dimensions: + width: 960 + height: 544 diff --git a/tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml b/tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml new file mode 100644 index 0000000000..9d8f64a253 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml @@ -0,0 +1,7 @@ +# main.yaml +packages: + my_display: !include + file: display.yaml + vars: + high_dpi: true + native_height: 272 diff --git a/tests/unit_tests/fixtures/substitutions/closures_package.yaml b/tests/unit_tests/fixtures/substitutions/closures_package.yaml new file mode 100644 index 0000000000..e87908814d --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/closures_package.yaml @@ -0,0 +1,3 @@ +package_result: + - The value of A*B is ${A * B}, where A is a package var and B is a substitution in the root file + - Double substitution also works; the value of var7 is ${var$A}, where A is a package var diff --git a/tests/unit_tests/fixtures/substitutions/display.yaml b/tests/unit_tests/fixtures/substitutions/display.yaml new file mode 100644 index 0000000000..1e2249dddb --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/display.yaml @@ -0,0 +1,11 @@ +# display.yaml + +defaults: + native_width: 480 + native_height: 480 + +display: + - platform: ili9xxx + dimensions: + width: ${high_dpi and native_width * 2 or native_width} + height: ${high_dpi and native_height * 2 or native_height} diff --git a/tests/unit_tests/fixtures/substitutions/inc1.yaml b/tests/unit_tests/fixtures/substitutions/inc1.yaml new file mode 100644 index 0000000000..65b91a5e16 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/inc1.yaml @@ -0,0 +1,8 @@ +defaults: + b: "B-default" + +values: + - var1: $var1 + - a: $a + - b: ${b} + - c: The value of C is $c diff --git a/tests/unit_tests/test_substitutions.py b/tests/unit_tests/test_substitutions.py new file mode 100644 index 0000000000..b377499d29 --- /dev/null +++ b/tests/unit_tests/test_substitutions.py @@ -0,0 +1,125 @@ +import glob +import logging +import os + +from esphome import yaml_util +from esphome.components import substitutions +from esphome.const import CONF_PACKAGES + +_LOGGER = logging.getLogger(__name__) + +# Set to True for dev mode behavior +# This will generate the expected version of the test files. + +DEV_MODE = False + + +def sort_dicts(obj): + """Recursively sort dictionaries for order-insensitive comparison.""" + if isinstance(obj, dict): + return {k: sort_dicts(obj[k]) for k in sorted(obj)} + elif isinstance(obj, list): + # Lists are not sorted; we preserve order + return [sort_dicts(i) for i in obj] + else: + return obj + + +def dict_diff(a, b, path=""): + """Recursively find differences between two dict/list structures.""" + diffs = [] + if isinstance(a, dict) and isinstance(b, dict): + a_keys = set(a) + b_keys = set(b) + for key in a_keys - b_keys: + diffs.append(f"{path}/{key} only in actual") + for key in b_keys - a_keys: + diffs.append(f"{path}/{key} only in expected") + for key in a_keys & b_keys: + diffs.extend(dict_diff(a[key], b[key], f"{path}/{key}")) + elif isinstance(a, list) and isinstance(b, list): + min_len = min(len(a), len(b)) + for i in range(min_len): + diffs.extend(dict_diff(a[i], b[i], f"{path}[{i}]")) + if len(a) > len(b): + for i in range(min_len, len(a)): + diffs.append(f"{path}[{i}] only in actual: {a[i]!r}") + elif len(b) > len(a): + for i in range(min_len, len(b)): + diffs.append(f"{path}[{i}] only in expected: {b[i]!r}") + else: + if a != b: + diffs.append(f"\t{path}: actual={a!r} expected={b!r}") + return diffs + + +def write_yaml(path, data): + with open(path, "w", encoding="utf-8") as f: + f.write(yaml_util.dump(data)) + + +def test_substitutions_fixtures(fixture_path): + base_dir = fixture_path / "substitutions" + sources = sorted(glob.glob(str(base_dir / "*.input.yaml"))) + assert sources, f"No input YAML files found in {base_dir}" + + failures = [] + for source_path in sources: + try: + expected_path = source_path.replace(".input.yaml", ".approved.yaml") + test_case = os.path.splitext(os.path.basename(source_path))[0].replace( + ".input", "" + ) + + # Load using ESPHome's YAML loader + config = yaml_util.load_yaml(source_path) + + if CONF_PACKAGES in config: + from esphome.components.packages import do_packages_pass + + config = do_packages_pass(config) + + substitutions.do_substitution_pass(config, None) + + # Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE + if os.path.isfile(expected_path): + expected = yaml_util.load_yaml(expected_path) + elif DEV_MODE: + expected = {} + else: + assert os.path.isfile(expected_path), ( + f"Expected file missing: {expected_path}" + ) + + # Sort dicts only (not lists) for comparison + got_sorted = sort_dicts(config) + expected_sorted = sort_dicts(expected) + + if got_sorted != expected_sorted: + diff = "\n".join(dict_diff(got_sorted, expected_sorted)) + msg = ( + f"Substitution result mismatch for {os.path.basename(source_path)}\n" + f"Diff:\n{diff}\n\n" + f"Got: {got_sorted}\n" + f"Expected: {expected_sorted}" + ) + # Write out the received file when test fails + if DEV_MODE: + received_path = os.path.join( + os.path.dirname(source_path), f"{test_case}.received.yaml" + ) + write_yaml(received_path, config) + print(msg) + failures.append(msg) + else: + raise AssertionError(msg) + except Exception as err: + _LOGGER.error("Error in test file %s", source_path) + raise err + + if DEV_MODE and failures: + print(f"\n{len(failures)} substitution test case(s) failed.") + + if DEV_MODE: + _LOGGER.error("Tests passed, but Dev mode is enabled.") + assert not DEV_MODE # make sure DEV_MODE is disabled after you are finished.