From 9a0ba1657e0140b90a03aa9b1dcc1af4a4c61ff2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Jun 2025 21:38:19 -0500 Subject: [PATCH 1/8] Fix entity hash collisions by enforcing unique names across devices per platform (#9276) --- esphome/core/entity_helpers.py | 22 +-- ...ies_not_allowed_on_different_devices.yaml} | 80 ++++++---- tests/integration/test_duplicate_entities.py | 137 +++++++++++------- tests/unit_tests/core/test_entity_helpers.py | 40 +++-- 4 files changed, 175 insertions(+), 104 deletions(-) rename tests/integration/fixtures/{duplicate_entities_on_different_devices.yaml => duplicate_entities_not_allowed_on_different_devices.yaml} (61%) diff --git a/esphome/core/entity_helpers.py b/esphome/core/entity_helpers.py index c95acebbf9..2442fbca4b 100644 --- a/esphome/core/entity_helpers.py +++ b/esphome/core/entity_helpers.py @@ -184,25 +184,27 @@ def entity_duplicate_validator(platform: str) -> Callable[[ConfigType], ConfigTy # No name to validate return config - # Get the entity name and device info + # Get the entity name entity_name = config[CONF_NAME] - device_id = "" # Empty string for main device + # Get device name if entity is on a sub-device + device_name = None if CONF_DEVICE_ID in config: device_id_obj = config[CONF_DEVICE_ID] - # Use the device ID string directly for uniqueness - device_id = device_id_obj.id + device_name = device_id_obj.id - # For duplicate detection, just use the sanitized name - name_key = sanitize(snake_case(entity_name)) + # Calculate what object_id will actually be used + # This handles empty names correctly by using device/friendly names + name_key = get_base_entity_object_id( + entity_name, CORE.friendly_name, device_name + ) # Check for duplicates - unique_key = (device_id, platform, name_key) + unique_key = (platform, name_key) if unique_key in CORE.unique_ids: - device_prefix = f" on device '{device_id}'" if device_id else "" raise cv.Invalid( - f"Duplicate {platform} entity with name '{entity_name}' found{device_prefix}. " - f"Each entity on a device must have a unique name within its platform." + f"Duplicate {platform} entity with name '{entity_name}' found. " + f"Each entity must have a unique name within its platform across all devices." ) # Add to tracking set diff --git a/tests/integration/fixtures/duplicate_entities_on_different_devices.yaml b/tests/integration/fixtures/duplicate_entities_not_allowed_on_different_devices.yaml similarity index 61% rename from tests/integration/fixtures/duplicate_entities_on_different_devices.yaml rename to tests/integration/fixtures/duplicate_entities_not_allowed_on_different_devices.yaml index ecc502ad28..f7d017a0ae 100644 --- a/tests/integration/fixtures/duplicate_entities_on_different_devices.yaml +++ b/tests/integration/fixtures/duplicate_entities_not_allowed_on_different_devices.yaml @@ -1,6 +1,6 @@ esphome: name: duplicate-entities-test - # Define devices to test multi-device duplicate handling + # Define devices to test multi-device unique name validation devices: - id: controller_1 name: Controller 1 @@ -13,31 +13,31 @@ host: api: # Port will be automatically injected logger: -# Test that duplicate entity names are allowed on different devices +# Test that duplicate entity names are NOT allowed on different devices -# Scenario 1: Same sensor name on different devices (allowed) +# Scenario 1: Different sensor names on different devices (allowed) sensor: - platform: template - name: Temperature + name: Temperature Controller 1 device_id: controller_1 lambda: return 21.0; update_interval: 0.1s - platform: template - name: Temperature + name: Temperature Controller 2 device_id: controller_2 lambda: return 22.0; update_interval: 0.1s - platform: template - name: Temperature + name: Temperature Controller 3 device_id: controller_3 lambda: return 23.0; update_interval: 0.1s # Main device sensor (no device_id) - platform: template - name: Temperature + name: Temperature Main lambda: return 20.0; update_interval: 0.1s @@ -47,20 +47,20 @@ sensor: lambda: return 60.0; update_interval: 0.1s -# Scenario 2: Same binary sensor name on different devices (allowed) +# Scenario 2: Different binary sensor names on different devices binary_sensor: - platform: template - name: Status + name: Status Controller 1 device_id: controller_1 lambda: return true; - platform: template - name: Status + name: Status Controller 2 device_id: controller_2 lambda: return false; - platform: template - name: Status + name: Status Main lambda: return true; # Main device # Different platform can have same name as sensor @@ -68,43 +68,43 @@ binary_sensor: name: Temperature lambda: return true; -# Scenario 3: Same text sensor name on different devices +# Scenario 3: Different text sensor names on different devices text_sensor: - platform: template - name: Device Info + name: Device Info Controller 1 device_id: controller_1 lambda: return {"Controller 1 Active"}; update_interval: 0.1s - platform: template - name: Device Info + name: Device Info Controller 2 device_id: controller_2 lambda: return {"Controller 2 Active"}; update_interval: 0.1s - platform: template - name: Device Info + name: Device Info Main lambda: return {"Main Device Active"}; update_interval: 0.1s -# Scenario 4: Same switch name on different devices +# Scenario 4: Different switch names on different devices switch: - platform: template - name: Power + name: Power Controller 1 device_id: controller_1 lambda: return false; turn_on_action: [] turn_off_action: [] - platform: template - name: Power + name: Power Controller 2 device_id: controller_2 lambda: return true; turn_on_action: [] turn_off_action: [] - platform: template - name: Power + name: Power Controller 3 device_id: controller_3 lambda: return false; turn_on_action: [] @@ -117,26 +117,54 @@ switch: turn_on_action: [] turn_off_action: [] -# Scenario 5: Empty names on different devices (should use device name) +# Scenario 5: Buttons with unique names button: - platform: template - name: "" + name: "Reset Controller 1" device_id: controller_1 on_press: [] - platform: template - name: "" + name: "Reset Controller 2" device_id: controller_2 on_press: [] - platform: template - name: "" + name: "Reset Main" on_press: [] # Main device -# Scenario 6: Special characters in names +# Scenario 6: Empty names (should use device names) +select: + - platform: template + name: "" + device_id: controller_1 + options: + - "Option 1" + - "Option 2" + lambda: return {"Option 1"}; + set_action: [] + + - platform: template + name: "" + device_id: controller_2 + options: + - "Option 1" + - "Option 2" + lambda: return {"Option 1"}; + set_action: [] + + - platform: template + name: "" # Main device + options: + - "Option 1" + - "Option 2" + lambda: return {"Option 1"}; + set_action: [] + +# Scenario 7: Special characters in names - now with unique names number: - platform: template - name: "Temperature Setpoint!" + name: "Temperature Setpoint! Controller 1" device_id: controller_1 min_value: 10.0 max_value: 30.0 @@ -145,7 +173,7 @@ number: set_action: [] - platform: template - name: "Temperature Setpoint!" + name: "Temperature Setpoint! Controller 2" device_id: controller_2 min_value: 10.0 max_value: 30.0 diff --git a/tests/integration/test_duplicate_entities.py b/tests/integration/test_duplicate_entities.py index 99968204d4..b7ee8dd478 100644 --- a/tests/integration/test_duplicate_entities.py +++ b/tests/integration/test_duplicate_entities.py @@ -11,12 +11,12 @@ from .types import APIClientConnectedFactory, RunCompiledFunction @pytest.mark.asyncio -async def test_duplicate_entities_on_different_devices( +async def test_duplicate_entities_not_allowed_on_different_devices( yaml_config: str, run_compiled: RunCompiledFunction, api_client_connected: APIClientConnectedFactory, ) -> None: - """Test that duplicate entity names are allowed on different devices.""" + """Test that duplicate entity names are NOT allowed on different devices.""" async with run_compiled(yaml_config), api_client_connected() as client: # Get device info device_info = await client.device_info() @@ -52,42 +52,46 @@ async def test_duplicate_entities_on_different_devices( switches = [e for e in all_entities if e.__class__.__name__ == "SwitchInfo"] buttons = [e for e in all_entities if e.__class__.__name__ == "ButtonInfo"] numbers = [e for e in all_entities if e.__class__.__name__ == "NumberInfo"] + selects = [e for e in all_entities if e.__class__.__name__ == "SelectInfo"] - # Scenario 1: Check sensors with same "Temperature" name on different devices - temp_sensors = [s for s in sensors if s.name == "Temperature"] + # Scenario 1: Check that temperature sensors have unique names per device + temp_sensors = [s for s in sensors if "Temperature" in s.name] assert len(temp_sensors) == 4, ( f"Expected exactly 4 temperature sensors, got {len(temp_sensors)}" ) - # Verify each sensor is on a different device - temp_device_ids = set() + # Verify each sensor has a unique name + temp_names = set() temp_object_ids = set() for sensor in temp_sensors: - temp_device_ids.add(sensor.device_id) + temp_names.add(sensor.name) temp_object_ids.add(sensor.object_id) - # All should have object_id "temperature" (no suffix) - assert sensor.object_id == "temperature", ( - f"Expected object_id 'temperature', got '{sensor.object_id}'" - ) - - # Should have 4 different device IDs (including None for main device) - assert len(temp_device_ids) == 4, ( - f"Temperature sensors should be on different devices, got {temp_device_ids}" + # Should have 4 unique names + assert len(temp_names) == 4, ( + f"Temperature sensors should have unique names, got {temp_names}" ) - # Scenario 2: Check binary sensors "Status" on different devices - status_binary = [b for b in binary_sensors if b.name == "Status"] + # Object IDs should also be unique + assert len(temp_object_ids) == 4, ( + f"Temperature sensors should have unique object_ids, got {temp_object_ids}" + ) + + # Scenario 2: Check binary sensors have unique names + status_binary = [b for b in binary_sensors if "Status" in b.name] assert len(status_binary) == 3, ( f"Expected exactly 3 status binary sensors, got {len(status_binary)}" ) - # All should have object_id "status" + # All should have unique object_ids + status_names = set() for binary in status_binary: - assert binary.object_id == "status", ( - f"Expected object_id 'status', got '{binary.object_id}'" - ) + status_names.add(binary.name) + + assert len(status_names) == 3, ( + f"Status binary sensors should have unique names, got {status_names}" + ) # Scenario 3: Check that sensor and binary_sensor can have same name temp_binary = [b for b in binary_sensors if b.name == "Temperature"] @@ -96,62 +100,86 @@ async def test_duplicate_entities_on_different_devices( ) assert temp_binary[0].object_id == "temperature" - # Scenario 4: Check text sensors "Device Info" on different devices - info_text = [t for t in text_sensors if t.name == "Device Info"] + # Scenario 4: Check text sensors have unique names + info_text = [t for t in text_sensors if "Device Info" in t.name] assert len(info_text) == 3, ( f"Expected exactly 3 device info text sensors, got {len(info_text)}" ) - # All should have object_id "device_info" + # All should have unique names and object_ids + info_names = set() for text in info_text: - assert text.object_id == "device_info", ( - f"Expected object_id 'device_info', got '{text.object_id}'" - ) + info_names.add(text.name) - # Scenario 5: Check switches "Power" on different devices - power_switches = [s for s in switches if s.name == "Power"] - assert len(power_switches) == 3, ( - f"Expected exactly 3 power switches, got {len(power_switches)}" + assert len(info_names) == 3, ( + f"Device info text sensors should have unique names, got {info_names}" ) - # All should have object_id "power" - for switch in power_switches: - assert switch.object_id == "power", ( - f"Expected object_id 'power', got '{switch.object_id}'" - ) + # Scenario 5: Check switches have unique names + power_switches = [s for s in switches if "Power" in s.name] + assert len(power_switches) == 4, ( + f"Expected exactly 4 power switches, got {len(power_switches)}" + ) - # Scenario 6: Check empty name buttons (should use device name) - empty_buttons = [b for b in buttons if b.name == ""] - assert len(empty_buttons) == 3, ( - f"Expected exactly 3 empty name buttons, got {len(empty_buttons)}" + # All should have unique names + power_names = set() + for switch in power_switches: + power_names.add(switch.name) + + assert len(power_names) == 4, ( + f"Power switches should have unique names, got {power_names}" + ) + + # Scenario 6: Check reset buttons have unique names + reset_buttons = [b for b in buttons if "Reset" in b.name] + assert len(reset_buttons) == 3, ( + f"Expected exactly 3 reset buttons, got {len(reset_buttons)}" + ) + + # All should have unique names + reset_names = set() + for button in reset_buttons: + reset_names.add(button.name) + + assert len(reset_names) == 3, ( + f"Reset buttons should have unique names, got {reset_names}" + ) + + # Scenario 7: Check empty name selects (should use device names) + empty_selects = [s for s in selects if s.name == ""] + assert len(empty_selects) == 3, ( + f"Expected exactly 3 empty name selects, got {len(empty_selects)}" ) # Group by device - c1_buttons = [b for b in empty_buttons if b.device_id == controller_1.device_id] - c2_buttons = [b for b in empty_buttons if b.device_id == controller_2.device_id] + c1_selects = [s for s in empty_selects if s.device_id == controller_1.device_id] + c2_selects = [s for s in empty_selects if s.device_id == controller_2.device_id] # For main device, device_id is 0 - main_buttons = [b for b in empty_buttons if b.device_id == 0] + main_selects = [s for s in empty_selects if s.device_id == 0] - # Check object IDs for empty name entities - assert len(c1_buttons) == 1 and c1_buttons[0].object_id == "controller_1" - assert len(c2_buttons) == 1 and c2_buttons[0].object_id == "controller_2" + # Check object IDs for empty name entities - they should use device names + assert len(c1_selects) == 1 and c1_selects[0].object_id == "controller_1" + assert len(c2_selects) == 1 and c2_selects[0].object_id == "controller_2" assert ( - len(main_buttons) == 1 - and main_buttons[0].object_id == "duplicate-entities-test" + len(main_selects) == 1 + and main_selects[0].object_id == "duplicate-entities-test" ) - # Scenario 7: Check special characters in number names - temp_numbers = [n for n in numbers if n.name == "Temperature Setpoint!"] + # Scenario 8: Check special characters in number names - now unique + temp_numbers = [n for n in numbers if "Temperature Setpoint!" in n.name] assert len(temp_numbers) == 2, ( f"Expected exactly 2 temperature setpoint numbers, got {len(temp_numbers)}" ) - # Special characters should be sanitized to _ in object_id + # Should have unique names + setpoint_names = set() for number in temp_numbers: - assert number.object_id == "temperature_setpoint_", ( - f"Expected object_id 'temperature_setpoint_', got '{number.object_id}'" - ) + setpoint_names.add(number.name) + + assert len(setpoint_names) == 2, ( + f"Temperature setpoint numbers should have unique names, got {setpoint_names}" + ) # Verify we can get states for all entities (ensures they're functional) loop = asyncio.get_running_loop() @@ -164,6 +192,7 @@ async def test_duplicate_entities_on_different_devices( + len(switches) + len(buttons) + len(numbers) + + len(selects) ) def on_state(state) -> None: diff --git a/tests/unit_tests/core/test_entity_helpers.py b/tests/unit_tests/core/test_entity_helpers.py index e166eeedee..0dcdd84507 100644 --- a/tests/unit_tests/core/test_entity_helpers.py +++ b/tests/unit_tests/core/test_entity_helpers.py @@ -505,13 +505,13 @@ def test_entity_duplicate_validator() -> None: config1 = {CONF_NAME: "Temperature"} validated1 = validator(config1) assert validated1 == config1 - assert ("", "sensor", "temperature") in CORE.unique_ids + assert ("sensor", "temperature") in CORE.unique_ids # Second entity with different name should pass config2 = {CONF_NAME: "Humidity"} validated2 = validator(config2) assert validated2 == config2 - assert ("", "sensor", "humidity") in CORE.unique_ids + assert ("sensor", "humidity") in CORE.unique_ids # Duplicate entity should fail config3 = {CONF_NAME: "Temperature"} @@ -535,24 +535,36 @@ def test_entity_duplicate_validator_with_devices() -> None: device1 = ID("device1", type="Device") device2 = ID("device2", type="Device") - # Same name on different devices should pass + # First entity on device1 should pass config1 = {CONF_NAME: "Temperature", CONF_DEVICE_ID: device1} validated1 = validator(config1) assert validated1 == config1 - assert ("device1", "sensor", "temperature") in CORE.unique_ids + assert ("sensor", "temperature") in CORE.unique_ids + # Same name on different device should now fail config2 = {CONF_NAME: "Temperature", CONF_DEVICE_ID: device2} - validated2 = validator(config2) - assert validated2 == config2 - assert ("device2", "sensor", "temperature") in CORE.unique_ids - - # Duplicate on same device should fail - config3 = {CONF_NAME: "Temperature", CONF_DEVICE_ID: device1} with pytest.raises( Invalid, - match=r"Duplicate sensor entity with name 'Temperature' found on device 'device1'", + match=r"Duplicate sensor entity with name 'Temperature' found. Each entity must have a unique name within its platform across all devices.", ): - validator(config3) + validator(config2) + + # Different name on device2 should pass + config3 = {CONF_NAME: "Humidity", CONF_DEVICE_ID: device2} + validated3 = validator(config3) + assert validated3 == config3 + assert ("sensor", "humidity") in CORE.unique_ids + + # Empty names should use device names and be allowed + config4 = {CONF_NAME: "", CONF_DEVICE_ID: device1} + validated4 = validator(config4) + assert validated4 == config4 + assert ("sensor", "device1") in CORE.unique_ids + + config5 = {CONF_NAME: "", CONF_DEVICE_ID: device2} + validated5 = validator(config5) + assert validated5 == config5 + assert ("sensor", "device2") in CORE.unique_ids def test_duplicate_entity_yaml_validation( @@ -576,10 +588,10 @@ def test_duplicate_entity_with_devices_yaml_validation( ) assert result is None - # Check for the duplicate entity error message with device + # Check for the duplicate entity error message captured = capsys.readouterr() assert ( - "Duplicate sensor entity with name 'Temperature' found on device 'device1'" + "Duplicate sensor entity with name 'Temperature' found. Each entity must have a unique name within its platform across all devices." in captured.out ) From 27c745d5a10d7b7768619ceb9143bf61ac06ae5f Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 1 Jul 2025 14:38:39 +1200 Subject: [PATCH 2/8] [host] Disable platformio ldf (#9277) --- esphome/components/host/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/esphome/components/host/__init__.py b/esphome/components/host/__init__.py index da75873eaf..d3dbcba6ed 100644 --- a/esphome/components/host/__init__.py +++ b/esphome/components/host/__init__.py @@ -44,3 +44,4 @@ async def to_code(config): cg.add_build_flag("-std=gnu++20") cg.add_define("ESPHOME_BOARD", "host") cg.add_platformio_option("platform", "platformio/native") + cg.add_platformio_option("lib_ldf_mode", "off") From 8c34b72b62f69d850a54b0ec511f8b7638352776 Mon Sep 17 00:00:00 2001 From: Javier Peletier Date: Tue, 1 Jul 2025 04:57:00 +0200 Subject: [PATCH 3/8] Jinja expressions in configs (Take #3) (#8955) --- esphome/components/packages/__init__.py | 3 +- esphome/components/substitutions/__init__.py | 95 ++++++++++--- esphome/components/substitutions/jinja.py | 99 ++++++++++++++ esphome/config.py | 1 - esphome/yaml_util.py | 2 - requirements.txt | 1 + .../fixtures/substitutions/.gitignore | 1 + .../substitutions/00-simple_var.approved.yaml | 19 +++ .../substitutions/00-simple_var.input.yaml | 21 +++ .../substitutions/01-include.approved.yaml | 15 +++ .../substitutions/01-include.input.yaml | 15 +++ .../02-expressions.approved.yaml | 24 ++++ .../substitutions/02-expressions.input.yaml | 22 +++ .../substitutions/03-closures.approved.yaml | 17 +++ .../substitutions/03-closures.input.yaml | 16 +++ .../04-display_example.approved.yaml | 5 + .../04-display_example.input.yaml | 7 + .../substitutions/closures_package.yaml | 3 + .../fixtures/substitutions/display.yaml | 11 ++ .../fixtures/substitutions/inc1.yaml | 8 ++ tests/unit_tests/test_substitutions.py | 125 ++++++++++++++++++ 21 files changed, 486 insertions(+), 24 deletions(-) create mode 100644 esphome/components/substitutions/jinja.py create mode 100644 tests/unit_tests/fixtures/substitutions/.gitignore create mode 100644 tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/01-include.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/01-include.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/03-closures.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/closures_package.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/display.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/inc1.yaml create mode 100644 tests/unit_tests/test_substitutions.py diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index 08ae798282..6eb746ec63 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -74,7 +74,7 @@ BASE_SCHEMA = cv.All( { cv.Required(CONF_PATH): validate_yaml_filename, cv.Optional(CONF_VARS, default={}): cv.Schema( - {cv.string: cv.string} + {cv.string: object} ), } ), @@ -148,7 +148,6 @@ def _process_base_package(config: dict) -> dict: raise cv.Invalid( f"Current ESPHome Version is too old to use this package: {ESPHOME_VERSION} < {min_version}" ) - vars = {k: str(v) for k, v in vars.items()} new_yaml = yaml_util.substitute_vars(new_yaml, vars) packages[f"{filename}{idx}"] = new_yaml except EsphomeError as e: diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index 41e49f70db..5878af43b2 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -5,6 +5,13 @@ from esphome.config_helpers import Extend, Remove, merge_config import esphome.config_validation as cv from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS from esphome.yaml_util import ESPHomeDataBase, make_data_base +from .jinja import ( + Jinja, + JinjaStr, + has_jinja, + TemplateError, + TemplateRuntimeError, +) CODEOWNERS = ["@esphome/core"] _LOGGER = logging.getLogger(__name__) @@ -28,7 +35,7 @@ def validate_substitution_key(value): CONFIG_SCHEMA = cv.Schema( { - validate_substitution_key: cv.string_strict, + validate_substitution_key: object, } ) @@ -37,7 +44,42 @@ async def to_code(config): pass -def _expand_substitutions(substitutions, value, path, ignore_missing): +def _expand_jinja(value, orig_value, path, jinja, ignore_missing): + if has_jinja(value): + # If the original value passed in to this function is a JinjaStr, it means it contains an unresolved + # Jinja expression from a previous pass. + if isinstance(orig_value, JinjaStr): + # Rebuild the JinjaStr in case it was lost while replacing substitutions. + value = JinjaStr(value, orig_value.upvalues) + try: + # Invoke the jinja engine to evaluate the expression. + value, err = jinja.expand(value) + if err is not None: + if not ignore_missing and "password" not in path: + _LOGGER.warning( + "Found '%s' (see %s) which looks like an expression," + " but could not resolve all the variables: %s", + value, + "->".join(str(x) for x in path), + err.message, + ) + except ( + TemplateError, + TemplateRuntimeError, + RuntimeError, + ArithmeticError, + AttributeError, + TypeError, + ) as err: + raise cv.Invalid( + f"{type(err).__name__} Error evaluating jinja expression '{value}': {str(err)}." + f" See {'->'.join(str(x) for x in path)}", + path, + ) + return value + + +def _expand_substitutions(substitutions, value, path, jinja, ignore_missing): if "$" not in value: return value @@ -47,7 +89,8 @@ def _expand_substitutions(substitutions, value, path, ignore_missing): while True: m = cv.VARIABLE_PROG.search(value, i) if not m: - # Nothing more to match. Done + # No more variable substitutions found. See if the remainder looks like a jinja template + value = _expand_jinja(value, orig_value, path, jinja, ignore_missing) break i, j = m.span(0) @@ -67,8 +110,15 @@ def _expand_substitutions(substitutions, value, path, ignore_missing): continue sub = substitutions[name] + + if i == 0 and j == len(value): + # The variable spans the whole expression, e.g., "${varName}". Return its resolved value directly + # to conserve its type. + value = sub + break + tail = value[j:] - value = value[:i] + sub + value = value[:i] + str(sub) i = len(value) value += tail @@ -77,36 +127,40 @@ def _expand_substitutions(substitutions, value, path, ignore_missing): if isinstance(orig_value, ESPHomeDataBase): # even though string can get larger or smaller, the range should point # to original document marks - return make_data_base(value, orig_value) + value = make_data_base(value, orig_value) return value -def _substitute_item(substitutions, item, path, ignore_missing): +def _substitute_item(substitutions, item, path, jinja, ignore_missing): if isinstance(item, list): for i, it in enumerate(item): - sub = _substitute_item(substitutions, it, path + [i], ignore_missing) + sub = _substitute_item(substitutions, it, path + [i], jinja, ignore_missing) if sub is not None: item[i] = sub elif isinstance(item, dict): replace_keys = [] for k, v in item.items(): if path or k != CONF_SUBSTITUTIONS: - sub = _substitute_item(substitutions, k, path + [k], ignore_missing) + sub = _substitute_item( + substitutions, k, path + [k], jinja, ignore_missing + ) if sub is not None: replace_keys.append((k, sub)) - sub = _substitute_item(substitutions, v, path + [k], ignore_missing) + sub = _substitute_item(substitutions, v, path + [k], jinja, ignore_missing) if sub is not None: item[k] = sub for old, new in replace_keys: item[new] = merge_config(item.get(old), item.get(new)) del item[old] elif isinstance(item, str): - sub = _expand_substitutions(substitutions, item, path, ignore_missing) - if sub != item: + sub = _expand_substitutions(substitutions, item, path, jinja, ignore_missing) + if isinstance(sub, JinjaStr) or sub != item: return sub elif isinstance(item, (core.Lambda, Extend, Remove)): - sub = _expand_substitutions(substitutions, item.value, path, ignore_missing) + sub = _expand_substitutions( + substitutions, item.value, path, jinja, ignore_missing + ) if sub != item: item.value = sub return None @@ -116,11 +170,11 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals if CONF_SUBSTITUTIONS not in config and not command_line_substitutions: return - substitutions = config.get(CONF_SUBSTITUTIONS) - if substitutions is None: - substitutions = command_line_substitutions - elif command_line_substitutions: - substitutions = {**substitutions, **command_line_substitutions} + # Merge substitutions in config, overriding with substitutions coming from command line: + substitutions = { + **config.get(CONF_SUBSTITUTIONS, {}), + **(command_line_substitutions or {}), + } with cv.prepend_path("substitutions"): if not isinstance(substitutions, dict): raise cv.Invalid( @@ -133,7 +187,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals sub = validate_substitution_key(key) if sub != key: replace_keys.append((key, sub)) - substitutions[key] = cv.string_strict(value) + substitutions[key] = value for old, new in replace_keys: substitutions[new] = substitutions[old] del substitutions[old] @@ -141,4 +195,7 @@ def do_substitution_pass(config, command_line_substitutions, ignore_missing=Fals config[CONF_SUBSTITUTIONS] = substitutions # Move substitutions to the first place to replace substitutions in them correctly config.move_to_end(CONF_SUBSTITUTIONS, False) - _substitute_item(substitutions, config, [], ignore_missing) + + # Create a Jinja environment that will consider substitutions in scope: + jinja = Jinja(substitutions) + _substitute_item(substitutions, config, [], jinja, ignore_missing) diff --git a/esphome/components/substitutions/jinja.py b/esphome/components/substitutions/jinja.py new file mode 100644 index 0000000000..9ecdbab844 --- /dev/null +++ b/esphome/components/substitutions/jinja.py @@ -0,0 +1,99 @@ +import logging +import math +import re +import jinja2 as jinja +from jinja2.nativetypes import NativeEnvironment + +TemplateError = jinja.TemplateError +TemplateSyntaxError = jinja.TemplateSyntaxError +TemplateRuntimeError = jinja.TemplateRuntimeError +UndefinedError = jinja.UndefinedError +Undefined = jinja.Undefined + +_LOGGER = logging.getLogger(__name__) + +DETECT_JINJA = r"(\$\{)" +detect_jinja_re = re.compile( + r"<%.+?%>" # Block form expression: <% ... %> + r"|\$\{[^}]+\}", # Braced form expression: ${ ... } + flags=re.MULTILINE, +) + + +def has_jinja(st): + return detect_jinja_re.search(st) is not None + + +class JinjaStr(str): + """ + Wraps a string containing an unresolved Jinja expression, + storing the variables visible to it when it failed to resolve. + For example, an expression inside a package, `${ A * B }` may fail + to resolve at package parsing time if `A` is a local package var + but `B` is a substitution defined in the root yaml. + Therefore, we store the value of `A` as an upvalue bound + to the original string so we may be able to resolve `${ A * B }` + later in the main substitutions pass. + """ + + def __new__(cls, value: str, upvalues=None): + obj = super().__new__(cls, value) + obj.upvalues = upvalues or {} + return obj + + def __init__(self, value: str, upvalues=None): + self.upvalues = upvalues or {} + + +class Jinja: + """ + Wraps a Jinja environment + """ + + def __init__(self, context_vars): + self.env = NativeEnvironment( + trim_blocks=True, + lstrip_blocks=True, + block_start_string="<%", + block_end_string="%>", + line_statement_prefix="#", + line_comment_prefix="##", + variable_start_string="${", + variable_end_string="}", + undefined=jinja.StrictUndefined, + ) + self.env.add_extension("jinja2.ext.do") + self.env.globals["math"] = math # Inject entire math module + self.context_vars = {**context_vars} + self.env.globals = {**self.env.globals, **self.context_vars} + + def expand(self, content_str): + """ + Renders a string that may contain Jinja expressions or statements + Returns the resulting processed string if all values could be resolved. + Otherwise, it returns a tagged (JinjaStr) string that captures variables + in scope (upvalues), like a closure for later evaluation. + """ + result = None + override_vars = {} + if isinstance(content_str, JinjaStr): + # If `value` is already a JinjaStr, it means we are trying to evaluate it again + # in a parent pass. + # Hopefully, all required variables are visible now. + override_vars = content_str.upvalues + try: + template = self.env.from_string(content_str) + result = template.render(override_vars) + if isinstance(result, Undefined): + # This happens when the expression is simply an undefined variable. Jinja does not + # raise an exception, instead we get "Undefined". + # Trigger an UndefinedError exception so we skip to below. + print("" + result) + except (TemplateSyntaxError, UndefinedError) as err: + # `content_str` contains a Jinja expression that refers to a variable that is undefined + # in this scope. Perhaps it refers to a root substitution that is not visible yet. + # Therefore, return the original `content_str` as a JinjaStr, which contains the variables + # that are actually visible to it at this point to postpone evaluation. + return JinjaStr(content_str, {**self.context_vars, **override_vars}), err + + return result, None diff --git a/esphome/config.py b/esphome/config.py index ca3686a0e6..73cc7657cc 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -789,7 +789,6 @@ def validate_config( result.add_output_path([CONF_SUBSTITUTIONS], CONF_SUBSTITUTIONS) try: substitutions.do_substitution_pass(config, command_line_substitutions) - substitutions.do_substitution_pass(config, command_line_substitutions) except vol.Invalid as err: result.add_error(err) return result diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index bd1806affc..e52fc9e788 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -292,8 +292,6 @@ class ESPHomeLoaderMixin: if file is None: raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark) vars = fields.get(CONF_VARS) - if vars: - vars = {k: str(v) for k, v in vars.items()} return file, vars if isinstance(node, yaml.nodes.MappingNode): diff --git a/requirements.txt b/requirements.txt index 12f3b84359..1010a311d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,7 @@ esphome-glyphsets==0.2.0 pillow==10.4.0 cairosvg==2.8.2 freetype-py==2.5.1 +jinja2==3.1.6 # esp-idf requires this, but doesn't bundle it by default # https://github.com/espressif/esp-idf/blob/220590d599e134d7a5e7f1e683cc4550349ffbf8/requirements.txt#L24 diff --git a/tests/unit_tests/fixtures/substitutions/.gitignore b/tests/unit_tests/fixtures/substitutions/.gitignore new file mode 100644 index 0000000000..0b15cdb2b7 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/.gitignore @@ -0,0 +1 @@ +*.received.yaml \ No newline at end of file diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml new file mode 100644 index 0000000000..c031399c37 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml @@ -0,0 +1,19 @@ +substitutions: + var1: '1' + var2: '2' + var21: '79' +esphome: + name: test +test_list: + - '1' + - '1' + - '1' + - '1' + - 'Values: 1 2' + - 'Value: 79' + - 1 + 2 + - 1 * 2 + - 'Undefined var: ${undefined_var}' + - ${undefined_var} + - $undefined_var + - ${ undefined_var } diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml new file mode 100644 index 0000000000..88a4ffb991 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml @@ -0,0 +1,21 @@ +esphome: + name: test + +substitutions: + var1: "1" + var2: "2" + var21: "79" + +test_list: + - "$var1" + - "${var1}" + - $var1 + - ${var1} + - "Values: $var1 ${var2}" + - "Value: ${var2${var1}}" + - "$var1 + $var2" + - "${ var1 } * ${ var2 }" + - "Undefined var: ${undefined_var}" + - ${undefined_var} + - $undefined_var + - ${ undefined_var } diff --git a/tests/unit_tests/fixtures/substitutions/01-include.approved.yaml b/tests/unit_tests/fixtures/substitutions/01-include.approved.yaml new file mode 100644 index 0000000000..a812fedcfd --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/01-include.approved.yaml @@ -0,0 +1,15 @@ +substitutions: + var1: '1' + var2: '2' + a: alpha +test_list: + - values: + - var1: '1' + - a: A + - b: B-default + - c: The value of C is C + - values: + - var1: '1' + - a: alpha + - b: beta + - c: The value of C is $c diff --git a/tests/unit_tests/fixtures/substitutions/01-include.input.yaml b/tests/unit_tests/fixtures/substitutions/01-include.input.yaml new file mode 100644 index 0000000000..d3daa681a4 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/01-include.input.yaml @@ -0,0 +1,15 @@ +substitutions: + var1: "1" + var2: "2" + a: "alpha" + +test_list: + - !include + file: inc1.yaml + vars: + a: "A" + c: "C" + - !include + file: inc1.yaml + vars: + b: "beta" diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml new file mode 100644 index 0000000000..9e401ec5d6 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml @@ -0,0 +1,24 @@ +substitutions: + width: 7 + height: 8 + enabled: true + pin: &id001 + number: 18 + inverted: true + area: 25 + numberOne: 1 + var1: 79 +test_list: + - The area is 56 + - 56 + - 56 + 1 + - ENABLED + - list: + - 7 + - 8 + - width: 7 + height: 8 + - *id001 + - The pin number is 18 + - The square root is: 5.0 + - The number is 80 diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml new file mode 100644 index 0000000000..1777b46f67 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml @@ -0,0 +1,22 @@ +substitutions: + width: 7 + height: 8 + enabled: true + pin: + number: 18 + inverted: true + area: 25 + numberOne: 1 + var1: 79 + +test_list: + - "The area is ${width * height}" + - ${width * height} + - ${width * height} + 1 + - ${enabled and "ENABLED" or "DISABLED"} + - list: ${ [width, height] } + - "${ {'width': width, 'height': height} }" + - ${pin} + - The pin number is ${pin.number} + - The square root is: ${math.sqrt(area)} + - The number is ${var${numberOne} + 1} diff --git a/tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml b/tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml new file mode 100644 index 0000000000..c8f7d9976c --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/03-closures.approved.yaml @@ -0,0 +1,17 @@ +substitutions: + B: 5 + var7: 79 +package_result: + - The value of A*B is 35, where A is a package var and B is a substitution in the + root file + - Double substitution also works; the value of var7 is 79, where A is a package + var +local_results: + - The value of B is 5 + - 'You will see, however, that + + ${A} is not substituted here, since + + it is out of scope. + + ' diff --git a/tests/unit_tests/fixtures/substitutions/03-closures.input.yaml b/tests/unit_tests/fixtures/substitutions/03-closures.input.yaml new file mode 100644 index 0000000000..e0b2c39e52 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/03-closures.input.yaml @@ -0,0 +1,16 @@ +substitutions: + B: 5 + var7: 79 + +packages: + closures_package: !include + file: closures_package.yaml + vars: + A: 7 + +local_results: + - The value of B is ${B} + - | + You will see, however, that + ${A} is not substituted here, since + it is out of scope. diff --git a/tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml b/tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml new file mode 100644 index 0000000000..f559181b45 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/04-display_example.approved.yaml @@ -0,0 +1,5 @@ +display: + - platform: ili9xxx + dimensions: + width: 960 + height: 544 diff --git a/tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml b/tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml new file mode 100644 index 0000000000..9d8f64a253 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/04-display_example.input.yaml @@ -0,0 +1,7 @@ +# main.yaml +packages: + my_display: !include + file: display.yaml + vars: + high_dpi: true + native_height: 272 diff --git a/tests/unit_tests/fixtures/substitutions/closures_package.yaml b/tests/unit_tests/fixtures/substitutions/closures_package.yaml new file mode 100644 index 0000000000..e87908814d --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/closures_package.yaml @@ -0,0 +1,3 @@ +package_result: + - The value of A*B is ${A * B}, where A is a package var and B is a substitution in the root file + - Double substitution also works; the value of var7 is ${var$A}, where A is a package var diff --git a/tests/unit_tests/fixtures/substitutions/display.yaml b/tests/unit_tests/fixtures/substitutions/display.yaml new file mode 100644 index 0000000000..1e2249dddb --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/display.yaml @@ -0,0 +1,11 @@ +# display.yaml + +defaults: + native_width: 480 + native_height: 480 + +display: + - platform: ili9xxx + dimensions: + width: ${high_dpi and native_width * 2 or native_width} + height: ${high_dpi and native_height * 2 or native_height} diff --git a/tests/unit_tests/fixtures/substitutions/inc1.yaml b/tests/unit_tests/fixtures/substitutions/inc1.yaml new file mode 100644 index 0000000000..65b91a5e16 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/inc1.yaml @@ -0,0 +1,8 @@ +defaults: + b: "B-default" + +values: + - var1: $var1 + - a: $a + - b: ${b} + - c: The value of C is $c diff --git a/tests/unit_tests/test_substitutions.py b/tests/unit_tests/test_substitutions.py new file mode 100644 index 0000000000..b377499d29 --- /dev/null +++ b/tests/unit_tests/test_substitutions.py @@ -0,0 +1,125 @@ +import glob +import logging +import os + +from esphome import yaml_util +from esphome.components import substitutions +from esphome.const import CONF_PACKAGES + +_LOGGER = logging.getLogger(__name__) + +# Set to True for dev mode behavior +# This will generate the expected version of the test files. + +DEV_MODE = False + + +def sort_dicts(obj): + """Recursively sort dictionaries for order-insensitive comparison.""" + if isinstance(obj, dict): + return {k: sort_dicts(obj[k]) for k in sorted(obj)} + elif isinstance(obj, list): + # Lists are not sorted; we preserve order + return [sort_dicts(i) for i in obj] + else: + return obj + + +def dict_diff(a, b, path=""): + """Recursively find differences between two dict/list structures.""" + diffs = [] + if isinstance(a, dict) and isinstance(b, dict): + a_keys = set(a) + b_keys = set(b) + for key in a_keys - b_keys: + diffs.append(f"{path}/{key} only in actual") + for key in b_keys - a_keys: + diffs.append(f"{path}/{key} only in expected") + for key in a_keys & b_keys: + diffs.extend(dict_diff(a[key], b[key], f"{path}/{key}")) + elif isinstance(a, list) and isinstance(b, list): + min_len = min(len(a), len(b)) + for i in range(min_len): + diffs.extend(dict_diff(a[i], b[i], f"{path}[{i}]")) + if len(a) > len(b): + for i in range(min_len, len(a)): + diffs.append(f"{path}[{i}] only in actual: {a[i]!r}") + elif len(b) > len(a): + for i in range(min_len, len(b)): + diffs.append(f"{path}[{i}] only in expected: {b[i]!r}") + else: + if a != b: + diffs.append(f"\t{path}: actual={a!r} expected={b!r}") + return diffs + + +def write_yaml(path, data): + with open(path, "w", encoding="utf-8") as f: + f.write(yaml_util.dump(data)) + + +def test_substitutions_fixtures(fixture_path): + base_dir = fixture_path / "substitutions" + sources = sorted(glob.glob(str(base_dir / "*.input.yaml"))) + assert sources, f"No input YAML files found in {base_dir}" + + failures = [] + for source_path in sources: + try: + expected_path = source_path.replace(".input.yaml", ".approved.yaml") + test_case = os.path.splitext(os.path.basename(source_path))[0].replace( + ".input", "" + ) + + # Load using ESPHome's YAML loader + config = yaml_util.load_yaml(source_path) + + if CONF_PACKAGES in config: + from esphome.components.packages import do_packages_pass + + config = do_packages_pass(config) + + substitutions.do_substitution_pass(config, None) + + # Also load expected using ESPHome's loader, or use {} if missing and DEV_MODE + if os.path.isfile(expected_path): + expected = yaml_util.load_yaml(expected_path) + elif DEV_MODE: + expected = {} + else: + assert os.path.isfile(expected_path), ( + f"Expected file missing: {expected_path}" + ) + + # Sort dicts only (not lists) for comparison + got_sorted = sort_dicts(config) + expected_sorted = sort_dicts(expected) + + if got_sorted != expected_sorted: + diff = "\n".join(dict_diff(got_sorted, expected_sorted)) + msg = ( + f"Substitution result mismatch for {os.path.basename(source_path)}\n" + f"Diff:\n{diff}\n\n" + f"Got: {got_sorted}\n" + f"Expected: {expected_sorted}" + ) + # Write out the received file when test fails + if DEV_MODE: + received_path = os.path.join( + os.path.dirname(source_path), f"{test_case}.received.yaml" + ) + write_yaml(received_path, config) + print(msg) + failures.append(msg) + else: + raise AssertionError(msg) + except Exception as err: + _LOGGER.error("Error in test file %s", source_path) + raise err + + if DEV_MODE and failures: + print(f"\n{len(failures)} substitution test case(s) failed.") + + if DEV_MODE: + _LOGGER.error("Tests passed, but Dev mode is enabled.") + assert not DEV_MODE # make sure DEV_MODE is disabled after you are finished. From e3ccb9b46c9b0c5535d9d3c5bbdefe03b0c1d7b8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Jun 2025 22:04:50 -0500 Subject: [PATCH 4/8] Use interrupt based approach for esp32_touch (#9059) Co-authored-by: Keith Burzinski --- .../components/esp32_touch/esp32_touch.cpp | 355 ---------------- esphome/components/esp32_touch/esp32_touch.h | 209 +++++++-- .../esp32_touch/esp32_touch_common.cpp | 159 +++++++ .../components/esp32_touch/esp32_touch_v1.cpp | 240 +++++++++++ .../components/esp32_touch/esp32_touch_v2.cpp | 398 ++++++++++++++++++ 5 files changed, 975 insertions(+), 386 deletions(-) delete mode 100644 esphome/components/esp32_touch/esp32_touch.cpp create mode 100644 esphome/components/esp32_touch/esp32_touch_common.cpp create mode 100644 esphome/components/esp32_touch/esp32_touch_v1.cpp create mode 100644 esphome/components/esp32_touch/esp32_touch_v2.cpp diff --git a/esphome/components/esp32_touch/esp32_touch.cpp b/esphome/components/esp32_touch/esp32_touch.cpp deleted file mode 100644 index 366aa10697..0000000000 --- a/esphome/components/esp32_touch/esp32_touch.cpp +++ /dev/null @@ -1,355 +0,0 @@ -#ifdef USE_ESP32 - -#include "esp32_touch.h" -#include "esphome/core/application.h" -#include "esphome/core/log.h" -#include "esphome/core/hal.h" - -#include - -namespace esphome { -namespace esp32_touch { - -static const char *const TAG = "esp32_touch"; - -void ESP32TouchComponent::setup() { - ESP_LOGCONFIG(TAG, "Running setup"); - touch_pad_init(); -// set up and enable/start filtering based on ESP32 variant -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - if (this->filter_configured_()) { - touch_filter_config_t filter_info = { - .mode = this->filter_mode_, - .debounce_cnt = this->debounce_count_, - .noise_thr = this->noise_threshold_, - .jitter_step = this->jitter_step_, - .smh_lvl = this->smooth_level_, - }; - touch_pad_filter_set_config(&filter_info); - touch_pad_filter_enable(); - } - - if (this->denoise_configured_()) { - touch_pad_denoise_t denoise = { - .grade = this->grade_, - .cap_level = this->cap_level_, - }; - touch_pad_denoise_set_config(&denoise); - touch_pad_denoise_enable(); - } - - if (this->waterproof_configured_()) { - touch_pad_waterproof_t waterproof = { - .guard_ring_pad = this->waterproof_guard_ring_pad_, - .shield_driver = this->waterproof_shield_driver_, - }; - touch_pad_waterproof_set_config(&waterproof); - touch_pad_waterproof_enable(); - } -#else - if (this->iir_filter_enabled_()) { - touch_pad_filter_start(this->iir_filter_); - } -#endif - -#if ESP_IDF_VERSION_MAJOR >= 5 && defined(USE_ESP32_VARIANT_ESP32) - touch_pad_set_measurement_clock_cycles(this->meas_cycle_); - touch_pad_set_measurement_interval(this->sleep_cycle_); -#else - touch_pad_set_meas_time(this->sleep_cycle_, this->meas_cycle_); -#endif - touch_pad_set_voltage(this->high_voltage_reference_, this->low_voltage_reference_, this->voltage_attenuation_); - - for (auto *child : this->children_) { -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - touch_pad_config(child->get_touch_pad()); -#else - // Disable interrupt threshold - touch_pad_config(child->get_touch_pad(), 0); -#endif - } -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - touch_pad_set_fsm_mode(TOUCH_FSM_MODE_TIMER); - touch_pad_fsm_start(); -#endif -} - -void ESP32TouchComponent::dump_config() { - ESP_LOGCONFIG(TAG, - "Config for ESP32 Touch Hub:\n" - " Meas cycle: %.2fms\n" - " Sleep cycle: %.2fms", - this->meas_cycle_ / (8000000.0f / 1000.0f), this->sleep_cycle_ / (150000.0f / 1000.0f)); - - const char *lv_s; - switch (this->low_voltage_reference_) { - case TOUCH_LVOLT_0V5: - lv_s = "0.5V"; - break; - case TOUCH_LVOLT_0V6: - lv_s = "0.6V"; - break; - case TOUCH_LVOLT_0V7: - lv_s = "0.7V"; - break; - case TOUCH_LVOLT_0V8: - lv_s = "0.8V"; - break; - default: - lv_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, " Low Voltage Reference: %s", lv_s); - - const char *hv_s; - switch (this->high_voltage_reference_) { - case TOUCH_HVOLT_2V4: - hv_s = "2.4V"; - break; - case TOUCH_HVOLT_2V5: - hv_s = "2.5V"; - break; - case TOUCH_HVOLT_2V6: - hv_s = "2.6V"; - break; - case TOUCH_HVOLT_2V7: - hv_s = "2.7V"; - break; - default: - hv_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, " High Voltage Reference: %s", hv_s); - - const char *atten_s; - switch (this->voltage_attenuation_) { - case TOUCH_HVOLT_ATTEN_1V5: - atten_s = "1.5V"; - break; - case TOUCH_HVOLT_ATTEN_1V: - atten_s = "1V"; - break; - case TOUCH_HVOLT_ATTEN_0V5: - atten_s = "0.5V"; - break; - case TOUCH_HVOLT_ATTEN_0V: - atten_s = "0V"; - break; - default: - atten_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, " Voltage Attenuation: %s", atten_s); - -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - if (this->filter_configured_()) { - const char *filter_mode_s; - switch (this->filter_mode_) { - case TOUCH_PAD_FILTER_IIR_4: - filter_mode_s = "IIR_4"; - break; - case TOUCH_PAD_FILTER_IIR_8: - filter_mode_s = "IIR_8"; - break; - case TOUCH_PAD_FILTER_IIR_16: - filter_mode_s = "IIR_16"; - break; - case TOUCH_PAD_FILTER_IIR_32: - filter_mode_s = "IIR_32"; - break; - case TOUCH_PAD_FILTER_IIR_64: - filter_mode_s = "IIR_64"; - break; - case TOUCH_PAD_FILTER_IIR_128: - filter_mode_s = "IIR_128"; - break; - case TOUCH_PAD_FILTER_IIR_256: - filter_mode_s = "IIR_256"; - break; - case TOUCH_PAD_FILTER_JITTER: - filter_mode_s = "JITTER"; - break; - default: - filter_mode_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, - " Filter mode: %s\n" - " Debounce count: %" PRIu32 "\n" - " Noise threshold coefficient: %" PRIu32 "\n" - " Jitter filter step size: %" PRIu32, - filter_mode_s, this->debounce_count_, this->noise_threshold_, this->jitter_step_); - const char *smooth_level_s; - switch (this->smooth_level_) { - case TOUCH_PAD_SMOOTH_OFF: - smooth_level_s = "OFF"; - break; - case TOUCH_PAD_SMOOTH_IIR_2: - smooth_level_s = "IIR_2"; - break; - case TOUCH_PAD_SMOOTH_IIR_4: - smooth_level_s = "IIR_4"; - break; - case TOUCH_PAD_SMOOTH_IIR_8: - smooth_level_s = "IIR_8"; - break; - default: - smooth_level_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, " Smooth level: %s", smooth_level_s); - } - - if (this->denoise_configured_()) { - const char *grade_s; - switch (this->grade_) { - case TOUCH_PAD_DENOISE_BIT12: - grade_s = "BIT12"; - break; - case TOUCH_PAD_DENOISE_BIT10: - grade_s = "BIT10"; - break; - case TOUCH_PAD_DENOISE_BIT8: - grade_s = "BIT8"; - break; - case TOUCH_PAD_DENOISE_BIT4: - grade_s = "BIT4"; - break; - default: - grade_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, " Denoise grade: %s", grade_s); - - const char *cap_level_s; - switch (this->cap_level_) { - case TOUCH_PAD_DENOISE_CAP_L0: - cap_level_s = "L0"; - break; - case TOUCH_PAD_DENOISE_CAP_L1: - cap_level_s = "L1"; - break; - case TOUCH_PAD_DENOISE_CAP_L2: - cap_level_s = "L2"; - break; - case TOUCH_PAD_DENOISE_CAP_L3: - cap_level_s = "L3"; - break; - case TOUCH_PAD_DENOISE_CAP_L4: - cap_level_s = "L4"; - break; - case TOUCH_PAD_DENOISE_CAP_L5: - cap_level_s = "L5"; - break; - case TOUCH_PAD_DENOISE_CAP_L6: - cap_level_s = "L6"; - break; - case TOUCH_PAD_DENOISE_CAP_L7: - cap_level_s = "L7"; - break; - default: - cap_level_s = "UNKNOWN"; - break; - } - ESP_LOGCONFIG(TAG, " Denoise capacitance level: %s", cap_level_s); - } -#else - if (this->iir_filter_enabled_()) { - ESP_LOGCONFIG(TAG, " IIR Filter: %" PRIu32 "ms", this->iir_filter_); - } else { - ESP_LOGCONFIG(TAG, " IIR Filter DISABLED"); - } -#endif - - if (this->setup_mode_) { - ESP_LOGCONFIG(TAG, " Setup Mode ENABLED"); - } - - for (auto *child : this->children_) { - LOG_BINARY_SENSOR(" ", "Touch Pad", child); - ESP_LOGCONFIG(TAG, " Pad: T%" PRIu32, (uint32_t) child->get_touch_pad()); - ESP_LOGCONFIG(TAG, " Threshold: %" PRIu32, child->get_threshold()); - } -} - -uint32_t ESP32TouchComponent::component_touch_pad_read(touch_pad_t tp) { -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - uint32_t value = 0; - if (this->filter_configured_()) { - touch_pad_filter_read_smooth(tp, &value); - } else { - touch_pad_read_raw_data(tp, &value); - } -#else - uint16_t value = 0; - if (this->iir_filter_enabled_()) { - touch_pad_read_filtered(tp, &value); - } else { - touch_pad_read(tp, &value); - } -#endif - return value; -} - -void ESP32TouchComponent::loop() { - const uint32_t now = App.get_loop_component_start_time(); - bool should_print = this->setup_mode_ && now - this->setup_mode_last_log_print_ > 250; - for (auto *child : this->children_) { - child->value_ = this->component_touch_pad_read(child->get_touch_pad()); -#if !(defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3)) - child->publish_state(child->value_ < child->get_threshold()); -#else - child->publish_state(child->value_ > child->get_threshold()); -#endif - - if (should_print) { - ESP_LOGD(TAG, "Touch Pad '%s' (T%" PRIu32 "): %" PRIu32, child->get_name().c_str(), - (uint32_t) child->get_touch_pad(), child->value_); - } - - App.feed_wdt(); - } - - if (should_print) { - // Avoid spamming logs - this->setup_mode_last_log_print_ = now; - } -} - -void ESP32TouchComponent::on_shutdown() { - bool is_wakeup_source = false; - -#if !(defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3)) - if (this->iir_filter_enabled_()) { - touch_pad_filter_stop(); - touch_pad_filter_delete(); - } -#endif - - for (auto *child : this->children_) { - if (child->get_wakeup_threshold() != 0) { - if (!is_wakeup_source) { - is_wakeup_source = true; - // Touch sensor FSM mode must be 'TOUCH_FSM_MODE_TIMER' to use it to wake-up. - touch_pad_set_fsm_mode(TOUCH_FSM_MODE_TIMER); - } - -#if !(defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3)) - // No filter available when using as wake-up source. - touch_pad_config(child->get_touch_pad(), child->get_wakeup_threshold()); -#endif - } - } - - if (!is_wakeup_source) { - touch_pad_deinit(); - } -} - -ESP32TouchBinarySensor::ESP32TouchBinarySensor(touch_pad_t touch_pad, uint32_t threshold, uint32_t wakeup_threshold) - : touch_pad_(touch_pad), threshold_(threshold), wakeup_threshold_(wakeup_threshold) {} - -} // namespace esp32_touch -} // namespace esphome - -#endif diff --git a/esphome/components/esp32_touch/esp32_touch.h b/esphome/components/esp32_touch/esp32_touch.h index 3fce8a7e18..576c1a5649 100644 --- a/esphome/components/esp32_touch/esp32_touch.h +++ b/esphome/components/esp32_touch/esp32_touch.h @@ -9,10 +9,26 @@ #include #include +#include +#include namespace esphome { namespace esp32_touch { +// IMPORTANT: Touch detection logic differs between ESP32 variants: +// - ESP32 v1 (original): Touch detected when value < threshold (capacitance increase causes value decrease) +// - ESP32-S2/S3 v2: Touch detected when value > threshold (capacitance increase causes value increase) +// This inversion is due to different hardware implementations between chip generations. +// +// INTERRUPT BEHAVIOR: +// - ESP32 v1: Interrupts fire when ANY pad is touched and continue while touched. +// Releases are detected by timeout since hardware doesn't generate release interrupts. +// - ESP32-S2/S3 v2: Hardware supports both touch and release interrupts, but release +// interrupts are unreliable and sometimes don't fire. We now only use touch interrupts +// and detect releases via timeout, similar to v1. + +static const uint32_t SETUP_MODE_LOG_INTERVAL_MS = 250; + class ESP32TouchBinarySensor; class ESP32TouchComponent : public Component { @@ -31,6 +47,14 @@ class ESP32TouchComponent : public Component { void set_voltage_attenuation(touch_volt_atten_t voltage_attenuation) { this->voltage_attenuation_ = voltage_attenuation; } + + void setup() override; + void dump_config() override; + void loop() override; + float get_setup_priority() const override { return setup_priority::DATA; } + + void on_shutdown() override; + #if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) void set_filter_mode(touch_filter_mode_t filter_mode) { this->filter_mode_ = filter_mode; } void set_debounce_count(uint32_t debounce_count) { this->debounce_count_ = debounce_count; } @@ -47,16 +71,101 @@ class ESP32TouchComponent : public Component { void set_iir_filter(uint32_t iir_filter) { this->iir_filter_ = iir_filter; } #endif - uint32_t component_touch_pad_read(touch_pad_t tp); + protected: + // Common helper methods + void dump_config_base_(); + void dump_config_sensors_(); + bool create_touch_queue_(); + void cleanup_touch_queue_(); + void configure_wakeup_pads_(); - void setup() override; - void dump_config() override; - void loop() override; + // Helper methods for loop() logic + void process_setup_mode_logging_(uint32_t now); + bool should_check_for_releases_(uint32_t now); + void publish_initial_state_if_needed_(ESP32TouchBinarySensor *child, uint32_t now); + void check_and_disable_loop_if_all_released_(size_t pads_off); + void calculate_release_timeout_(); - void on_shutdown() override; + // Common members + std::vector children_; + bool setup_mode_{false}; + uint32_t setup_mode_last_log_print_{0}; + uint32_t last_release_check_{0}; + uint32_t release_timeout_ms_{1500}; + uint32_t release_check_interval_ms_{50}; + bool initial_state_published_[TOUCH_PAD_MAX] = {false}; + + // Common configuration parameters + uint16_t sleep_cycle_{4095}; + uint16_t meas_cycle_{65535}; + touch_low_volt_t low_voltage_reference_{TOUCH_LVOLT_0V5}; + touch_high_volt_t high_voltage_reference_{TOUCH_HVOLT_2V7}; + touch_volt_atten_t voltage_attenuation_{TOUCH_HVOLT_ATTEN_0V}; + + // Common constants + static constexpr uint32_t MINIMUM_RELEASE_TIME_MS = 100; + + // ==================== PLATFORM SPECIFIC ==================== + +#ifdef USE_ESP32_VARIANT_ESP32 + // ESP32 v1 specific + + static void touch_isr_handler(void *arg); + QueueHandle_t touch_queue_{nullptr}; + + private: + // Touch event structure for ESP32 v1 + // Contains touch pad info, value, and touch state for queue communication + struct TouchPadEventV1 { + touch_pad_t pad; + uint32_t value; + bool is_touched; + }; protected: -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) + // Design note: last_touch_time_ does not require synchronization primitives because: + // 1. ESP32 guarantees atomic 32-bit aligned reads/writes + // 2. ISR only writes timestamps, main loop only reads + // 3. Timing tolerance allows for occasional stale reads (50ms check interval) + // 4. Queue operations provide implicit memory barriers + // Using atomic/critical sections would add overhead without meaningful benefit + uint32_t last_touch_time_[TOUCH_PAD_MAX] = {0}; + uint32_t iir_filter_{0}; + + bool iir_filter_enabled_() const { return this->iir_filter_ > 0; } + +#elif defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) + // ESP32-S2/S3 v2 specific + static void touch_isr_handler(void *arg); + QueueHandle_t touch_queue_{nullptr}; + + private: + // Touch event structure for ESP32 v2 (S2/S3) + // Contains touch pad and interrupt mask for queue communication + struct TouchPadEventV2 { + touch_pad_t pad; + uint32_t intr_mask; + }; + + // Track last touch time for timeout-based release detection + uint32_t last_touch_time_[TOUCH_PAD_MAX] = {0}; + + protected: + // Filter configuration + touch_filter_mode_t filter_mode_{TOUCH_PAD_FILTER_MAX}; + uint32_t debounce_count_{0}; + uint32_t noise_threshold_{0}; + uint32_t jitter_step_{0}; + touch_smooth_mode_t smooth_level_{TOUCH_PAD_SMOOTH_MAX}; + + // Denoise configuration + touch_pad_denoise_grade_t grade_{TOUCH_PAD_DENOISE_MAX}; + touch_pad_denoise_cap_t cap_level_{TOUCH_PAD_DENOISE_CAP_MAX}; + + // Waterproof configuration + touch_pad_t waterproof_guard_ring_pad_{TOUCH_PAD_MAX}; + touch_pad_shield_driver_t waterproof_shield_driver_{TOUCH_PAD_SHIELD_DRV_MAX}; + bool filter_configured_() const { return (this->filter_mode_ != TOUCH_PAD_FILTER_MAX) && (this->smooth_level_ != TOUCH_PAD_SMOOTH_MAX); } @@ -67,43 +176,78 @@ class ESP32TouchComponent : public Component { return (this->waterproof_guard_ring_pad_ != TOUCH_PAD_MAX) && (this->waterproof_shield_driver_ != TOUCH_PAD_SHIELD_DRV_MAX); } -#else - bool iir_filter_enabled_() const { return this->iir_filter_ > 0; } + + // Helper method to read touch values - non-blocking operation + // Returns the current touch pad value using either filtered or raw reading + // based on the filter configuration + uint32_t read_touch_value(touch_pad_t pad) const; + + // Helper to update touch state with a known state + void update_touch_state_(ESP32TouchBinarySensor *child, bool is_touched); + + // Helper to read touch value and update state for a given child + bool check_and_update_touch_state_(ESP32TouchBinarySensor *child); #endif - std::vector children_; - bool setup_mode_{false}; - uint32_t setup_mode_last_log_print_{0}; - // common parameters - uint16_t sleep_cycle_{4095}; - uint16_t meas_cycle_{65535}; - touch_low_volt_t low_voltage_reference_{TOUCH_LVOLT_0V5}; - touch_high_volt_t high_voltage_reference_{TOUCH_HVOLT_2V7}; - touch_volt_atten_t voltage_attenuation_{TOUCH_HVOLT_ATTEN_0V}; -#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - touch_filter_mode_t filter_mode_{TOUCH_PAD_FILTER_MAX}; - uint32_t debounce_count_{0}; - uint32_t noise_threshold_{0}; - uint32_t jitter_step_{0}; - touch_smooth_mode_t smooth_level_{TOUCH_PAD_SMOOTH_MAX}; - touch_pad_denoise_grade_t grade_{TOUCH_PAD_DENOISE_MAX}; - touch_pad_denoise_cap_t cap_level_{TOUCH_PAD_DENOISE_CAP_MAX}; - touch_pad_t waterproof_guard_ring_pad_{TOUCH_PAD_MAX}; - touch_pad_shield_driver_t waterproof_shield_driver_{TOUCH_PAD_SHIELD_DRV_MAX}; -#else - uint32_t iir_filter_{0}; -#endif + // Helper functions for dump_config - common to both implementations + static const char *get_low_voltage_reference_str(touch_low_volt_t ref) { + switch (ref) { + case TOUCH_LVOLT_0V5: + return "0.5V"; + case TOUCH_LVOLT_0V6: + return "0.6V"; + case TOUCH_LVOLT_0V7: + return "0.7V"; + case TOUCH_LVOLT_0V8: + return "0.8V"; + default: + return "UNKNOWN"; + } + } + + static const char *get_high_voltage_reference_str(touch_high_volt_t ref) { + switch (ref) { + case TOUCH_HVOLT_2V4: + return "2.4V"; + case TOUCH_HVOLT_2V5: + return "2.5V"; + case TOUCH_HVOLT_2V6: + return "2.6V"; + case TOUCH_HVOLT_2V7: + return "2.7V"; + default: + return "UNKNOWN"; + } + } + + static const char *get_voltage_attenuation_str(touch_volt_atten_t atten) { + switch (atten) { + case TOUCH_HVOLT_ATTEN_1V5: + return "1.5V"; + case TOUCH_HVOLT_ATTEN_1V: + return "1V"; + case TOUCH_HVOLT_ATTEN_0V5: + return "0.5V"; + case TOUCH_HVOLT_ATTEN_0V: + return "0V"; + default: + return "UNKNOWN"; + } + } }; /// Simple helper class to expose a touch pad value as a binary sensor. class ESP32TouchBinarySensor : public binary_sensor::BinarySensor { public: - ESP32TouchBinarySensor(touch_pad_t touch_pad, uint32_t threshold, uint32_t wakeup_threshold); + ESP32TouchBinarySensor(touch_pad_t touch_pad, uint32_t threshold, uint32_t wakeup_threshold) + : touch_pad_(touch_pad), threshold_(threshold), wakeup_threshold_(wakeup_threshold) {} touch_pad_t get_touch_pad() const { return this->touch_pad_; } uint32_t get_threshold() const { return this->threshold_; } void set_threshold(uint32_t threshold) { this->threshold_ = threshold; } +#ifdef USE_ESP32_VARIANT_ESP32 uint32_t get_value() const { return this->value_; } +#endif uint32_t get_wakeup_threshold() const { return this->wakeup_threshold_; } protected: @@ -111,7 +255,10 @@ class ESP32TouchBinarySensor : public binary_sensor::BinarySensor { touch_pad_t touch_pad_{TOUCH_PAD_MAX}; uint32_t threshold_{0}; +#ifdef USE_ESP32_VARIANT_ESP32 uint32_t value_{0}; +#endif + bool last_state_{false}; const uint32_t wakeup_threshold_{0}; }; diff --git a/esphome/components/esp32_touch/esp32_touch_common.cpp b/esphome/components/esp32_touch/esp32_touch_common.cpp new file mode 100644 index 0000000000..fd2cdfcbad --- /dev/null +++ b/esphome/components/esp32_touch/esp32_touch_common.cpp @@ -0,0 +1,159 @@ +#ifdef USE_ESP32 + +#include "esp32_touch.h" +#include "esphome/core/log.h" +#include + +#include "soc/rtc.h" + +namespace esphome { +namespace esp32_touch { + +static const char *const TAG = "esp32_touch"; + +void ESP32TouchComponent::dump_config_base_() { + const char *lv_s = get_low_voltage_reference_str(this->low_voltage_reference_); + const char *hv_s = get_high_voltage_reference_str(this->high_voltage_reference_); + const char *atten_s = get_voltage_attenuation_str(this->voltage_attenuation_); + + ESP_LOGCONFIG(TAG, + "Config for ESP32 Touch Hub:\n" + " Meas cycle: %.2fms\n" + " Sleep cycle: %.2fms\n" + " Low Voltage Reference: %s\n" + " High Voltage Reference: %s\n" + " Voltage Attenuation: %s", + this->meas_cycle_ / (8000000.0f / 1000.0f), this->sleep_cycle_ / (150000.0f / 1000.0f), lv_s, hv_s, + atten_s); +} + +void ESP32TouchComponent::dump_config_sensors_() { + for (auto *child : this->children_) { + LOG_BINARY_SENSOR(" ", "Touch Pad", child); + ESP_LOGCONFIG(TAG, " Pad: T%" PRIu32, (uint32_t) child->get_touch_pad()); + ESP_LOGCONFIG(TAG, " Threshold: %" PRIu32, child->get_threshold()); + } +} + +bool ESP32TouchComponent::create_touch_queue_() { + // Queue size calculation: children * 4 allows for burst scenarios where ISR + // fires multiple times before main loop processes. + size_t queue_size = this->children_.size() * 4; + if (queue_size < 8) + queue_size = 8; + +#ifdef USE_ESP32_VARIANT_ESP32 + this->touch_queue_ = xQueueCreate(queue_size, sizeof(TouchPadEventV1)); +#else + this->touch_queue_ = xQueueCreate(queue_size, sizeof(TouchPadEventV2)); +#endif + + if (this->touch_queue_ == nullptr) { + ESP_LOGE(TAG, "Failed to create touch event queue of size %" PRIu32, (uint32_t) queue_size); + this->mark_failed(); + return false; + } + return true; +} + +void ESP32TouchComponent::cleanup_touch_queue_() { + if (this->touch_queue_) { + vQueueDelete(this->touch_queue_); + this->touch_queue_ = nullptr; + } +} + +void ESP32TouchComponent::configure_wakeup_pads_() { + bool is_wakeup_source = false; + + // Check if any pad is configured for wakeup + for (auto *child : this->children_) { + if (child->get_wakeup_threshold() != 0) { + is_wakeup_source = true; + +#ifdef USE_ESP32_VARIANT_ESP32 + // ESP32 v1: No filter available when using as wake-up source. + touch_pad_config(child->get_touch_pad(), child->get_wakeup_threshold()); +#else + // ESP32-S2/S3 v2: Set threshold for wakeup + touch_pad_set_thresh(child->get_touch_pad(), child->get_wakeup_threshold()); +#endif + } + } + + if (!is_wakeup_source) { + // If no pad is configured for wakeup, deinitialize touch pad + touch_pad_deinit(); + } +} + +void ESP32TouchComponent::process_setup_mode_logging_(uint32_t now) { + if (this->setup_mode_ && now - this->setup_mode_last_log_print_ > SETUP_MODE_LOG_INTERVAL_MS) { + for (auto *child : this->children_) { +#ifdef USE_ESP32_VARIANT_ESP32 + ESP_LOGD(TAG, "Touch Pad '%s' (T%" PRIu32 "): %" PRIu32, child->get_name().c_str(), + (uint32_t) child->get_touch_pad(), child->value_); +#else + // Read the value being used for touch detection + uint32_t value = this->read_touch_value(child->get_touch_pad()); + ESP_LOGD(TAG, "Touch Pad '%s' (T%d): %d", child->get_name().c_str(), child->get_touch_pad(), value); +#endif + } + this->setup_mode_last_log_print_ = now; + } +} + +bool ESP32TouchComponent::should_check_for_releases_(uint32_t now) { + if (now - this->last_release_check_ < this->release_check_interval_ms_) { + return false; + } + this->last_release_check_ = now; + return true; +} + +void ESP32TouchComponent::publish_initial_state_if_needed_(ESP32TouchBinarySensor *child, uint32_t now) { + touch_pad_t pad = child->get_touch_pad(); + if (!this->initial_state_published_[pad]) { + // Check if enough time has passed since startup + if (now > this->release_timeout_ms_) { + child->publish_initial_state(false); + this->initial_state_published_[pad] = true; + ESP_LOGV(TAG, "Touch Pad '%s' state: OFF (initial)", child->get_name().c_str()); + } + } +} + +void ESP32TouchComponent::check_and_disable_loop_if_all_released_(size_t pads_off) { + // Disable the loop to save CPU cycles when all pads are off and not in setup mode. + if (pads_off == this->children_.size() && !this->setup_mode_) { + this->disable_loop(); + } +} + +void ESP32TouchComponent::calculate_release_timeout_() { + // Calculate release timeout based on sleep cycle + // Design note: Hardware limitation - interrupts only fire reliably on touch (not release) + // We must use timeout-based detection for release events + // Formula: 3 sleep cycles converted to ms, with MINIMUM_RELEASE_TIME_MS minimum + // Per ESP-IDF docs: t_sleep = sleep_cycle / SOC_CLK_RC_SLOW_FREQ_APPROX + + uint32_t rtc_freq = rtc_clk_slow_freq_get_hz(); + + // Calculate timeout as 3 sleep cycles + this->release_timeout_ms_ = (this->sleep_cycle_ * 1000 * 3) / rtc_freq; + + if (this->release_timeout_ms_ < MINIMUM_RELEASE_TIME_MS) { + this->release_timeout_ms_ = MINIMUM_RELEASE_TIME_MS; + } + + // Check for releases at 1/4 the timeout interval + // Since hardware doesn't generate reliable release interrupts, we must poll + // for releases in the main loop. Checking at 1/4 the timeout interval provides + // a good balance between responsiveness and efficiency. + this->release_check_interval_ms_ = this->release_timeout_ms_ / 4; +} + +} // namespace esp32_touch +} // namespace esphome + +#endif // USE_ESP32 diff --git a/esphome/components/esp32_touch/esp32_touch_v1.cpp b/esphome/components/esp32_touch/esp32_touch_v1.cpp new file mode 100644 index 0000000000..a6d499e9fa --- /dev/null +++ b/esphome/components/esp32_touch/esp32_touch_v1.cpp @@ -0,0 +1,240 @@ +#ifdef USE_ESP32_VARIANT_ESP32 + +#include "esp32_touch.h" +#include "esphome/core/application.h" +#include "esphome/core/log.h" +#include "esphome/core/hal.h" + +#include +#include + +// Include HAL for ISR-safe touch reading +#include "hal/touch_sensor_ll.h" + +namespace esphome { +namespace esp32_touch { + +static const char *const TAG = "esp32_touch"; + +void ESP32TouchComponent::setup() { + // Create queue for touch events + // Queue size calculation: children * 4 allows for burst scenarios where ISR + // fires multiple times before main loop processes. This is important because + // ESP32 v1 scans all pads on each interrupt, potentially sending multiple events. + if (!this->create_touch_queue_()) { + return; + } + + touch_pad_init(); + touch_pad_set_fsm_mode(TOUCH_FSM_MODE_TIMER); + + // Set up IIR filter if enabled + if (this->iir_filter_enabled_()) { + touch_pad_filter_start(this->iir_filter_); + } + + // Configure measurement parameters +#if ESP_IDF_VERSION_MAJOR >= 5 + touch_pad_set_measurement_clock_cycles(this->meas_cycle_); + touch_pad_set_measurement_interval(this->sleep_cycle_); +#else + touch_pad_set_meas_time(this->sleep_cycle_, this->meas_cycle_); +#endif + touch_pad_set_voltage(this->high_voltage_reference_, this->low_voltage_reference_, this->voltage_attenuation_); + + // Configure each touch pad + for (auto *child : this->children_) { + touch_pad_config(child->get_touch_pad(), child->get_threshold()); + } + + // Register ISR handler + esp_err_t err = touch_pad_isr_register(touch_isr_handler, this); + if (err != ESP_OK) { + ESP_LOGE(TAG, "Failed to register touch ISR: %s", esp_err_to_name(err)); + this->cleanup_touch_queue_(); + this->mark_failed(); + return; + } + + // Calculate release timeout based on sleep cycle + this->calculate_release_timeout_(); + + // Enable touch pad interrupt + touch_pad_intr_enable(); +} + +void ESP32TouchComponent::dump_config() { + this->dump_config_base_(); + + if (this->iir_filter_enabled_()) { + ESP_LOGCONFIG(TAG, " IIR Filter: %" PRIu32 "ms", this->iir_filter_); + } else { + ESP_LOGCONFIG(TAG, " IIR Filter DISABLED"); + } + + if (this->setup_mode_) { + ESP_LOGCONFIG(TAG, " Setup Mode ENABLED"); + } + + this->dump_config_sensors_(); +} + +void ESP32TouchComponent::loop() { + const uint32_t now = App.get_loop_component_start_time(); + + // Print debug info for all pads in setup mode + this->process_setup_mode_logging_(now); + + // Process any queued touch events from interrupts + // Note: Events are only sent by ISR for pads that were measured in that cycle (value != 0) + // This is more efficient than sending all pad states every interrupt + TouchPadEventV1 event; + while (xQueueReceive(this->touch_queue_, &event, 0) == pdTRUE) { + // Find the corresponding sensor - O(n) search is acceptable since events are infrequent + for (auto *child : this->children_) { + if (child->get_touch_pad() != event.pad) { + continue; + } + + // Found matching pad - process it + child->value_ = event.value; + + // The interrupt gives us the touch state directly + bool new_state = event.is_touched; + + // Track when we last saw this pad as touched + if (new_state) { + this->last_touch_time_[event.pad] = now; + } + + // Only publish if state changed - this filters out repeated events + if (new_state != child->last_state_) { + child->last_state_ = new_state; + child->publish_state(new_state); + // Original ESP32: ISR only fires when touched, release is detected by timeout + // Note: ESP32 v1 uses inverted logic - touched when value < threshold + ESP_LOGV(TAG, "Touch Pad '%s' state: ON (value: %" PRIu32 " < threshold: %" PRIu32 ")", + child->get_name().c_str(), event.value, child->get_threshold()); + } + break; // Exit inner loop after processing matching pad + } + } + + // Check for released pads periodically + if (!this->should_check_for_releases_(now)) { + return; + } + + size_t pads_off = 0; + for (auto *child : this->children_) { + touch_pad_t pad = child->get_touch_pad(); + + // Handle initial state publication after startup + this->publish_initial_state_if_needed_(child, now); + + if (child->last_state_) { + // Pad is currently in touched state - check for release timeout + // Using subtraction handles 32-bit rollover correctly + uint32_t time_diff = now - this->last_touch_time_[pad]; + + // Check if we haven't seen this pad recently + if (time_diff > this->release_timeout_ms_) { + // Haven't seen this pad recently, assume it's released + child->last_state_ = false; + child->publish_state(false); + ESP_LOGV(TAG, "Touch Pad '%s' state: OFF (timeout)", child->get_name().c_str()); + pads_off++; + } + } else { + // Pad is already off + pads_off++; + } + } + + // Disable the loop to save CPU cycles when all pads are off and not in setup mode. + // The loop will be re-enabled by the ISR when any touch pad is touched. + // v1 hardware limitations require us to check all pads are off because: + // - v1 only generates interrupts on touch events (not releases) + // - We must poll for release timeouts in the main loop + // - We can only safely disable when no pads need timeout monitoring + this->check_and_disable_loop_if_all_released_(pads_off); +} + +void ESP32TouchComponent::on_shutdown() { + touch_pad_intr_disable(); + touch_pad_isr_deregister(touch_isr_handler, this); + this->cleanup_touch_queue_(); + + if (this->iir_filter_enabled_()) { + touch_pad_filter_stop(); + touch_pad_filter_delete(); + } + + // Configure wakeup pads if any are set + this->configure_wakeup_pads_(); +} + +void IRAM_ATTR ESP32TouchComponent::touch_isr_handler(void *arg) { + ESP32TouchComponent *component = static_cast(arg); + + touch_pad_clear_status(); + + // INTERRUPT BEHAVIOR: On ESP32 v1 hardware, the interrupt fires when ANY configured + // touch pad detects a touch (value goes below threshold). The hardware does NOT + // generate interrupts on release - only on touch events. + // The interrupt will continue to fire periodically (based on sleep_cycle) as long + // as any pad remains touched. This allows us to detect both new touches and + // continued touches, but releases must be detected by timeout in the main loop. + + // Process all configured pads to check their current state + // Note: ESP32 v1 doesn't tell us which specific pad triggered the interrupt, + // so we must scan all configured pads to find which ones were touched + for (auto *child : component->children_) { + touch_pad_t pad = child->get_touch_pad(); + + // Read current value using ISR-safe API + uint32_t value; + if (component->iir_filter_enabled_()) { + uint16_t temp_value = 0; + touch_pad_read_filtered(pad, &temp_value); + value = temp_value; + } else { + // Use low-level HAL function when filter is not enabled + value = touch_ll_read_raw_data(pad); + } + + // Skip pads with 0 value - they haven't been measured in this cycle + // This is important: not all pads are measured every interrupt cycle, + // only those that the hardware has updated + if (value == 0) { + continue; + } + + // IMPORTANT: ESP32 v1 touch detection logic - INVERTED compared to v2! + // ESP32 v1: Touch is detected when capacitance INCREASES, causing the measured value to DECREASE + // Therefore: touched = (value < threshold) + // This is opposite to ESP32-S2/S3 v2 where touched = (value > threshold) + bool is_touched = value < child->get_threshold(); + + // Always send the current state - the main loop will filter for changes + // We send both touched and untouched states because the ISR doesn't + // track previous state (to keep ISR fast and simple) + TouchPadEventV1 event; + event.pad = pad; + event.value = value; + event.is_touched = is_touched; + + // Send to queue from ISR - non-blocking, drops if queue full + BaseType_t x_higher_priority_task_woken = pdFALSE; + xQueueSendFromISR(component->touch_queue_, &event, &x_higher_priority_task_woken); + component->enable_loop_soon_any_context(); + if (x_higher_priority_task_woken) { + portYIELD_FROM_ISR(); + } + } +} + +} // namespace esp32_touch +} // namespace esphome + +#endif // USE_ESP32_VARIANT_ESP32 diff --git a/esphome/components/esp32_touch/esp32_touch_v2.cpp b/esphome/components/esp32_touch/esp32_touch_v2.cpp new file mode 100644 index 0000000000..ad77881724 --- /dev/null +++ b/esphome/components/esp32_touch/esp32_touch_v2.cpp @@ -0,0 +1,398 @@ +#if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) + +#include "esp32_touch.h" +#include "esphome/core/application.h" +#include "esphome/core/log.h" +#include "esphome/core/hal.h" + +namespace esphome { +namespace esp32_touch { + +static const char *const TAG = "esp32_touch"; + +// Helper to update touch state with a known state +void ESP32TouchComponent::update_touch_state_(ESP32TouchBinarySensor *child, bool is_touched) { + // Always update timer when touched + if (is_touched) { + this->last_touch_time_[child->get_touch_pad()] = App.get_loop_component_start_time(); + } + + if (child->last_state_ != is_touched) { + // Read value for logging + uint32_t value = this->read_touch_value(child->get_touch_pad()); + + child->last_state_ = is_touched; + child->publish_state(is_touched); + if (is_touched) { + // ESP32-S2/S3 v2: touched when value > threshold + ESP_LOGV(TAG, "Touch Pad '%s' state: ON (value: %" PRIu32 " > threshold: %" PRIu32 ")", child->get_name().c_str(), + value, child->get_threshold()); + } else { + ESP_LOGV(TAG, "Touch Pad '%s' state: OFF", child->get_name().c_str()); + } + } +} + +// Helper to read touch value and update state for a given child (used for timeout events) +bool ESP32TouchComponent::check_and_update_touch_state_(ESP32TouchBinarySensor *child) { + // Read current touch value + uint32_t value = this->read_touch_value(child->get_touch_pad()); + + // ESP32-S2/S3 v2: Touch is detected when value > threshold + bool is_touched = value > child->get_threshold(); + + this->update_touch_state_(child, is_touched); + return is_touched; +} + +void ESP32TouchComponent::setup() { + // Create queue for touch events first + if (!this->create_touch_queue_()) { + return; + } + + // Initialize touch pad peripheral + esp_err_t init_err = touch_pad_init(); + if (init_err != ESP_OK) { + ESP_LOGE(TAG, "Failed to initialize touch pad: %s", esp_err_to_name(init_err)); + this->mark_failed(); + return; + } + + // Configure each touch pad first + for (auto *child : this->children_) { + esp_err_t config_err = touch_pad_config(child->get_touch_pad()); + if (config_err != ESP_OK) { + ESP_LOGE(TAG, "Failed to configure touch pad %d: %s", child->get_touch_pad(), esp_err_to_name(config_err)); + } + } + + // Set up filtering if configured + if (this->filter_configured_()) { + touch_filter_config_t filter_info = { + .mode = this->filter_mode_, + .debounce_cnt = this->debounce_count_, + .noise_thr = this->noise_threshold_, + .jitter_step = this->jitter_step_, + .smh_lvl = this->smooth_level_, + }; + touch_pad_filter_set_config(&filter_info); + touch_pad_filter_enable(); + } + + if (this->denoise_configured_()) { + touch_pad_denoise_t denoise = { + .grade = this->grade_, + .cap_level = this->cap_level_, + }; + touch_pad_denoise_set_config(&denoise); + touch_pad_denoise_enable(); + } + + if (this->waterproof_configured_()) { + touch_pad_waterproof_t waterproof = { + .guard_ring_pad = this->waterproof_guard_ring_pad_, + .shield_driver = this->waterproof_shield_driver_, + }; + touch_pad_waterproof_set_config(&waterproof); + touch_pad_waterproof_enable(); + } + + // Configure measurement parameters + touch_pad_set_voltage(this->high_voltage_reference_, this->low_voltage_reference_, this->voltage_attenuation_); + // ESP32-S2/S3 always use the older API + touch_pad_set_meas_time(this->sleep_cycle_, this->meas_cycle_); + + // Configure timeout if needed + touch_pad_timeout_set(true, TOUCH_PAD_THRESHOLD_MAX); + + // Register ISR handler with interrupt mask + esp_err_t err = + touch_pad_isr_register(touch_isr_handler, this, static_cast(TOUCH_PAD_INTR_MASK_ALL)); + if (err != ESP_OK) { + ESP_LOGE(TAG, "Failed to register touch ISR: %s", esp_err_to_name(err)); + this->cleanup_touch_queue_(); + this->mark_failed(); + return; + } + + // Set thresholds for each pad BEFORE starting FSM + for (auto *child : this->children_) { + if (child->get_threshold() != 0) { + touch_pad_set_thresh(child->get_touch_pad(), child->get_threshold()); + } + } + + // Enable interrupts - only ACTIVE and TIMEOUT + // NOTE: We intentionally don't enable INACTIVE interrupts because they are unreliable + // on ESP32-S2/S3 hardware and sometimes don't fire. Instead, we use timeout-based + // release detection with the ability to verify the actual state. + touch_pad_intr_enable(static_cast(TOUCH_PAD_INTR_MASK_ACTIVE | TOUCH_PAD_INTR_MASK_TIMEOUT)); + + // Set FSM mode before starting + touch_pad_set_fsm_mode(TOUCH_FSM_MODE_TIMER); + + // Start FSM + touch_pad_fsm_start(); + + // Calculate release timeout based on sleep cycle + this->calculate_release_timeout_(); +} + +void ESP32TouchComponent::dump_config() { + this->dump_config_base_(); + + if (this->filter_configured_()) { + const char *filter_mode_s; + switch (this->filter_mode_) { + case TOUCH_PAD_FILTER_IIR_4: + filter_mode_s = "IIR_4"; + break; + case TOUCH_PAD_FILTER_IIR_8: + filter_mode_s = "IIR_8"; + break; + case TOUCH_PAD_FILTER_IIR_16: + filter_mode_s = "IIR_16"; + break; + case TOUCH_PAD_FILTER_IIR_32: + filter_mode_s = "IIR_32"; + break; + case TOUCH_PAD_FILTER_IIR_64: + filter_mode_s = "IIR_64"; + break; + case TOUCH_PAD_FILTER_IIR_128: + filter_mode_s = "IIR_128"; + break; + case TOUCH_PAD_FILTER_IIR_256: + filter_mode_s = "IIR_256"; + break; + case TOUCH_PAD_FILTER_JITTER: + filter_mode_s = "JITTER"; + break; + default: + filter_mode_s = "UNKNOWN"; + break; + } + ESP_LOGCONFIG(TAG, + " Filter mode: %s\n" + " Debounce count: %" PRIu32 "\n" + " Noise threshold coefficient: %" PRIu32 "\n" + " Jitter filter step size: %" PRIu32, + filter_mode_s, this->debounce_count_, this->noise_threshold_, this->jitter_step_); + const char *smooth_level_s; + switch (this->smooth_level_) { + case TOUCH_PAD_SMOOTH_OFF: + smooth_level_s = "OFF"; + break; + case TOUCH_PAD_SMOOTH_IIR_2: + smooth_level_s = "IIR_2"; + break; + case TOUCH_PAD_SMOOTH_IIR_4: + smooth_level_s = "IIR_4"; + break; + case TOUCH_PAD_SMOOTH_IIR_8: + smooth_level_s = "IIR_8"; + break; + default: + smooth_level_s = "UNKNOWN"; + break; + } + ESP_LOGCONFIG(TAG, " Smooth level: %s", smooth_level_s); + } + + if (this->denoise_configured_()) { + const char *grade_s; + switch (this->grade_) { + case TOUCH_PAD_DENOISE_BIT12: + grade_s = "BIT12"; + break; + case TOUCH_PAD_DENOISE_BIT10: + grade_s = "BIT10"; + break; + case TOUCH_PAD_DENOISE_BIT8: + grade_s = "BIT8"; + break; + case TOUCH_PAD_DENOISE_BIT4: + grade_s = "BIT4"; + break; + default: + grade_s = "UNKNOWN"; + break; + } + ESP_LOGCONFIG(TAG, " Denoise grade: %s", grade_s); + + const char *cap_level_s; + switch (this->cap_level_) { + case TOUCH_PAD_DENOISE_CAP_L0: + cap_level_s = "L0"; + break; + case TOUCH_PAD_DENOISE_CAP_L1: + cap_level_s = "L1"; + break; + case TOUCH_PAD_DENOISE_CAP_L2: + cap_level_s = "L2"; + break; + case TOUCH_PAD_DENOISE_CAP_L3: + cap_level_s = "L3"; + break; + case TOUCH_PAD_DENOISE_CAP_L4: + cap_level_s = "L4"; + break; + case TOUCH_PAD_DENOISE_CAP_L5: + cap_level_s = "L5"; + break; + case TOUCH_PAD_DENOISE_CAP_L6: + cap_level_s = "L6"; + break; + case TOUCH_PAD_DENOISE_CAP_L7: + cap_level_s = "L7"; + break; + default: + cap_level_s = "UNKNOWN"; + break; + } + ESP_LOGCONFIG(TAG, " Denoise capacitance level: %s", cap_level_s); + } + + if (this->setup_mode_) { + ESP_LOGCONFIG(TAG, " Setup Mode ENABLED"); + } + + this->dump_config_sensors_(); +} + +void ESP32TouchComponent::loop() { + const uint32_t now = App.get_loop_component_start_time(); + + // V2 TOUCH HANDLING: + // Due to unreliable INACTIVE interrupts on ESP32-S2/S3, we use a hybrid approach: + // 1. Process ACTIVE interrupts when pads are touched + // 2. Use timeout-based release detection (like v1) + // 3. But smarter than v1: verify actual state before releasing on timeout + // This prevents false releases if we missed interrupts + + // In setup mode, periodically log all pad values + this->process_setup_mode_logging_(now); + + // Process any queued touch events from interrupts + TouchPadEventV2 event; + while (xQueueReceive(this->touch_queue_, &event, 0) == pdTRUE) { + // Handle timeout events + if (event.intr_mask & TOUCH_PAD_INTR_MASK_TIMEOUT) { + // Resume measurement after timeout + touch_pad_timeout_resume(); + // For timeout events, always check the current state + } else if (!(event.intr_mask & TOUCH_PAD_INTR_MASK_ACTIVE)) { + // Skip if not an active/timeout event + continue; + } + + // Find the child for the pad that triggered the interrupt + for (auto *child : this->children_) { + if (child->get_touch_pad() != event.pad) { + continue; + } + + if (event.intr_mask & TOUCH_PAD_INTR_MASK_TIMEOUT) { + // For timeout events, we need to read the value to determine state + this->check_and_update_touch_state_(child); + } else if (event.intr_mask & TOUCH_PAD_INTR_MASK_ACTIVE) { + // We only get ACTIVE interrupts now, releases are detected by timeout + this->update_touch_state_(child, true); // Always touched for ACTIVE interrupts + } + break; + } + } + + // Check for released pads periodically (like v1) + if (!this->should_check_for_releases_(now)) { + return; + } + + size_t pads_off = 0; + for (auto *child : this->children_) { + touch_pad_t pad = child->get_touch_pad(); + + // Handle initial state publication after startup + this->publish_initial_state_if_needed_(child, now); + + if (child->last_state_) { + // Pad is currently in touched state - check for release timeout + // Using subtraction handles 32-bit rollover correctly + uint32_t time_diff = now - this->last_touch_time_[pad]; + + // Check if we haven't seen this pad recently + if (time_diff > this->release_timeout_ms_) { + // Haven't seen this pad recently - verify actual state + // Unlike v1, v2 hardware allows us to read the current state anytime + // This makes v2 smarter: we can verify if it's actually released before + // declaring a timeout, preventing false releases if interrupts were missed + bool still_touched = this->check_and_update_touch_state_(child); + + if (still_touched) { + // Still touched! Timer was reset in update_touch_state_ + ESP_LOGVV(TAG, "Touch Pad '%s' still touched after %" PRIu32 "ms timeout, resetting timer", + child->get_name().c_str(), this->release_timeout_ms_); + } else { + // Actually released - already handled by check_and_update_touch_state_ + pads_off++; + } + } + } else { + // Pad is already off + pads_off++; + } + } + + // Disable the loop when all pads are off and not in setup mode (like v1) + // We need to keep checking for timeouts, so only disable when all pads are confirmed off + this->check_and_disable_loop_if_all_released_(pads_off); +} + +void ESP32TouchComponent::on_shutdown() { + // Disable interrupts + touch_pad_intr_disable(static_cast(TOUCH_PAD_INTR_MASK_ACTIVE | TOUCH_PAD_INTR_MASK_TIMEOUT)); + touch_pad_isr_deregister(touch_isr_handler, this); + this->cleanup_touch_queue_(); + + // Configure wakeup pads if any are set + this->configure_wakeup_pads_(); +} + +void IRAM_ATTR ESP32TouchComponent::touch_isr_handler(void *arg) { + ESP32TouchComponent *component = static_cast(arg); + BaseType_t x_higher_priority_task_woken = pdFALSE; + + // Read interrupt status + TouchPadEventV2 event; + event.intr_mask = touch_pad_read_intr_status_mask(); + event.pad = touch_pad_get_current_meas_channel(); + + // Send event to queue for processing in main loop + xQueueSendFromISR(component->touch_queue_, &event, &x_higher_priority_task_woken); + component->enable_loop_soon_any_context(); + + if (x_higher_priority_task_woken) { + portYIELD_FROM_ISR(); + } +} + +uint32_t ESP32TouchComponent::read_touch_value(touch_pad_t pad) const { + // Unlike ESP32 v1, touch reads on ESP32-S2/S3 v2 are non-blocking operations. + // The hardware continuously samples in the background and we can read the + // latest value at any time without waiting. + uint32_t value = 0; + if (this->filter_configured_()) { + // Read filtered/smoothed value when filter is enabled + touch_pad_filter_read_smooth(pad, &value); + } else { + // Read raw value when filter is not configured + touch_pad_read_raw_data(pad, &value); + } + return value; +} + +} // namespace esp32_touch +} // namespace esphome + +#endif // USE_ESP32_VARIANT_ESP32S2 || USE_ESP32_VARIANT_ESP32S3 From 16ef5a93774476bcbbb53bc3bcb7fc2ffef61454 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Jun 2025 22:21:11 -0500 Subject: [PATCH 5/8] Add OTA support to ESP-IDF webserver (#9264) --- .../captive_portal/captive_portal.cpp | 2 + esphome/components/web_server/__init__.py | 22 +- esphome/components/web_server/web_server.cpp | 6 + .../web_server_base/web_server_base.cpp | 210 +++++++++++++-- .../web_server_base/web_server_base.h | 16 ++ esphome/components/web_server_idf/__init__.py | 8 +- .../components/web_server_idf/multipart.cpp | 254 ++++++++++++++++++ esphome/components/web_server_idf/multipart.h | 86 ++++++ esphome/components/web_server_idf/utils.cpp | 32 +++ esphome/components/web_server_idf/utils.h | 10 + .../web_server_idf/web_server_idf.cpp | 128 ++++++++- .../web_server_idf/web_server_idf.h | 3 + esphome/core/defines.h | 1 + esphome/idf_component.yml | 2 + .../web_server/test_no_ota.esp32-idf.yaml | 9 + .../web_server/test_ota.esp32-idf.yaml | 32 +++ .../test_ota_disabled.esp32-idf.yaml | 11 + 17 files changed, 788 insertions(+), 44 deletions(-) create mode 100644 esphome/components/web_server_idf/multipart.cpp create mode 100644 esphome/components/web_server_idf/multipart.h create mode 100644 tests/components/web_server/test_no_ota.esp32-idf.yaml create mode 100644 tests/components/web_server/test_ota.esp32-idf.yaml create mode 100644 tests/components/web_server/test_ota_disabled.esp32-idf.yaml diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 51e5cfc8ff..ba392bb0f2 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -47,7 +47,9 @@ void CaptivePortal::start() { this->base_->init(); if (!this->initialized_) { this->base_->add_handler(this); +#ifdef USE_WEBSERVER_OTA this->base_->add_ota_handler(); +#endif } #ifdef USE_ARDUINO diff --git a/esphome/components/web_server/__init__.py b/esphome/components/web_server/__init__.py index f2c1824028..ca145c732b 100644 --- a/esphome/components/web_server/__init__.py +++ b/esphome/components/web_server/__init__.py @@ -40,6 +40,7 @@ CONF_SORTING_GROUP_ID = "sorting_group_id" CONF_SORTING_GROUPS = "sorting_groups" CONF_SORTING_WEIGHT = "sorting_weight" + web_server_ns = cg.esphome_ns.namespace("web_server") WebServer = web_server_ns.class_("WebServer", cg.Component, cg.Controller) @@ -72,12 +73,6 @@ def validate_local(config): return config -def validate_ota(config): - if CORE.using_esp_idf and config[CONF_OTA]: - raise cv.Invalid("Enabling 'ota' is not supported for IDF framework yet") - return config - - def validate_sorting_groups(config): if CONF_SORTING_GROUPS in config and config[CONF_VERSION] != 3: raise cv.Invalid( @@ -175,15 +170,7 @@ CONFIG_SCHEMA = cv.All( web_server_base.WebServerBase ), cv.Optional(CONF_INCLUDE_INTERNAL, default=False): cv.boolean, - cv.SplitDefault( - CONF_OTA, - esp8266=True, - esp32_arduino=True, - esp32_idf=False, - bk72xx=True, - ln882x=True, - rtl87xx=True, - ): cv.boolean, + cv.Optional(CONF_OTA, default=True): cv.boolean, cv.Optional(CONF_LOG, default=True): cv.boolean, cv.Optional(CONF_LOCAL): cv.boolean, cv.Optional(CONF_SORTING_GROUPS): cv.ensure_list(sorting_group), @@ -200,7 +187,6 @@ CONFIG_SCHEMA = cv.All( ), default_url, validate_local, - validate_ota, validate_sorting_groups, ) @@ -286,6 +272,10 @@ async def to_code(config): cg.add(var.set_css_url(config[CONF_CSS_URL])) cg.add(var.set_js_url(config[CONF_JS_URL])) cg.add(var.set_allow_ota(config[CONF_OTA])) + if config[CONF_OTA]: + # Define USE_WEBSERVER_OTA based only on web_server OTA config + # This allows web server OTA to work without loading the OTA component + cg.add_define("USE_WEBSERVER_OTA") cg.add(var.set_expose_log(config[CONF_LOG])) if config[CONF_ENABLE_PRIVATE_NETWORK_ACCESS]: cg.add_define("USE_WEBSERVER_PRIVATE_NETWORK_ACCESS") diff --git a/esphome/components/web_server/web_server.cpp b/esphome/components/web_server/web_server.cpp index 669bfbf279..e0027d0b27 100644 --- a/esphome/components/web_server/web_server.cpp +++ b/esphome/components/web_server/web_server.cpp @@ -299,8 +299,10 @@ void WebServer::setup() { #endif this->base_->add_handler(this); +#ifdef USE_WEBSERVER_OTA if (this->allow_ota_) this->base_->add_ota_handler(); +#endif // doesn't need defer functionality - if the queue is full, the client JS knows it's alive because it's clearly // getting a lot of events @@ -2030,6 +2032,10 @@ void WebServer::handleRequest(AsyncWebServerRequest *request) { return; } #endif + + // No matching handler found - send 404 + ESP_LOGV(TAG, "Request for unknown URL: %s", request->url().c_str()); + request->send(404, "text/plain", "Not Found"); } bool WebServer::isRequestHandlerTrivial() const { return false; } diff --git a/esphome/components/web_server_base/web_server_base.cpp b/esphome/components/web_server_base/web_server_base.cpp index 2835585387..9ad88e09f4 100644 --- a/esphome/components/web_server_base/web_server_base.cpp +++ b/esphome/components/web_server_base/web_server_base.cpp @@ -14,11 +14,114 @@ #endif #endif +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#include +#include +#endif + namespace esphome { namespace web_server_base { static const char *const TAG = "web_server_base"; +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +// Minimal OTA backend implementation for web server +// This allows OTA updates via web server without requiring the OTA component +// TODO: In the future, this should be refactored into a common ota_base component +// that both web_server and ota components can depend on, avoiding code duplication +// while keeping the components independent. This would allow both ESP-IDF and Arduino +// implementations to share the base OTA functionality without requiring the full OTA component. +// The IDFWebServerOTABackend class is intentionally designed with the same interface +// as OTABackend to make it easy to swap to using OTABackend when the ota component +// is split into ota and ota_base in the future. +class IDFWebServerOTABackend { + public: + bool begin() { + this->partition_ = esp_ota_get_next_update_partition(nullptr); + if (this->partition_ == nullptr) { + ESP_LOGE(TAG, "No OTA partition available"); + return false; + } + +#if CONFIG_ESP_TASK_WDT_TIMEOUT_S < 15 + // The following function takes longer than the default timeout of WDT due to flash erase +#if ESP_IDF_VERSION_MAJOR >= 5 + esp_task_wdt_config_t wdtc; + wdtc.idle_core_mask = 0; +#if CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0 + wdtc.idle_core_mask |= (1 << 0); +#endif +#if CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1 + wdtc.idle_core_mask |= (1 << 1); +#endif + wdtc.timeout_ms = 15000; + wdtc.trigger_panic = false; + esp_task_wdt_reconfigure(&wdtc); +#else + esp_task_wdt_init(15, false); +#endif +#endif + + esp_err_t err = esp_ota_begin(this->partition_, 0, &this->update_handle_); + +#if CONFIG_ESP_TASK_WDT_TIMEOUT_S < 15 + // Set the WDT back to the configured timeout +#if ESP_IDF_VERSION_MAJOR >= 5 + wdtc.timeout_ms = CONFIG_ESP_TASK_WDT_TIMEOUT_S * 1000; + esp_task_wdt_reconfigure(&wdtc); +#else + esp_task_wdt_init(CONFIG_ESP_TASK_WDT_TIMEOUT_S, false); +#endif +#endif + + if (err != ESP_OK) { + esp_ota_abort(this->update_handle_); + this->update_handle_ = 0; + ESP_LOGE(TAG, "esp_ota_begin failed: %s", esp_err_to_name(err)); + return false; + } + return true; + } + + bool write(uint8_t *data, size_t len) { + esp_err_t err = esp_ota_write(this->update_handle_, data, len); + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_ota_write failed: %s", esp_err_to_name(err)); + return false; + } + return true; + } + + bool end() { + esp_err_t err = esp_ota_end(this->update_handle_); + this->update_handle_ = 0; + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_ota_end failed: %s", esp_err_to_name(err)); + return false; + } + + err = esp_ota_set_boot_partition(this->partition_); + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_ota_set_boot_partition failed: %s", esp_err_to_name(err)); + return false; + } + + return true; + } + + void abort() { + if (this->update_handle_ != 0) { + esp_ota_abort(this->update_handle_); + this->update_handle_ = 0; + } + } + + private: + esp_ota_handle_t update_handle_{0}; + const esp_partition_t *partition_{nullptr}; +}; +#endif + void WebServerBase::add_handler(AsyncWebHandler *handler) { // remove all handlers @@ -31,6 +134,33 @@ void WebServerBase::add_handler(AsyncWebHandler *handler) { } } +#ifdef USE_WEBSERVER_OTA +void OTARequestHandler::report_ota_progress_(AsyncWebServerRequest *request) { + const uint32_t now = millis(); + if (now - this->last_ota_progress_ > 1000) { + if (request->contentLength() != 0) { + float percentage = (this->ota_read_length_ * 100.0f) / request->contentLength(); + ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); + } else { + ESP_LOGD(TAG, "OTA in progress: %u bytes read", this->ota_read_length_); + } + this->last_ota_progress_ = now; + } +} + +void OTARequestHandler::schedule_ota_reboot_() { + ESP_LOGI(TAG, "OTA update successful!"); + this->parent_->set_timeout(100, []() { + ESP_LOGI(TAG, "Performing OTA reboot now"); + App.safe_reboot(); + }); +} + +void OTARequestHandler::ota_init_(const char *filename) { + ESP_LOGI(TAG, "OTA Update Start: %s", filename); + this->ota_read_length_ = 0; +} + void report_ota_error() { #ifdef USE_ARDUINO StreamString ss; @@ -44,8 +174,7 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin #ifdef USE_ARDUINO bool success; if (index == 0) { - ESP_LOGI(TAG, "OTA Update Start: %s", filename.c_str()); - this->ota_read_length_ = 0; + this->ota_init_(filename.c_str()); #ifdef USE_ESP8266 Update.runAsync(true); // NOLINTNEXTLINE(readability-static-accessed-through-instance) @@ -72,31 +201,68 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin return; } this->ota_read_length_ += len; - - const uint32_t now = millis(); - if (now - this->last_ota_progress_ > 1000) { - if (request->contentLength() != 0) { - float percentage = (this->ota_read_length_ * 100.0f) / request->contentLength(); - ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); - } else { - ESP_LOGD(TAG, "OTA in progress: %u bytes read", this->ota_read_length_); - } - this->last_ota_progress_ = now; - } + this->report_ota_progress_(request); if (final) { if (Update.end(true)) { - ESP_LOGI(TAG, "OTA update successful!"); - this->parent_->set_timeout(100, []() { App.safe_reboot(); }); + this->schedule_ota_reboot_(); } else { report_ota_error(); } } -#endif +#endif // USE_ARDUINO + +#ifdef USE_ESP_IDF + // ESP-IDF implementation + if (index == 0 && !this->ota_backend_) { + // Initialize OTA on first call + this->ota_init_(filename.c_str()); + this->ota_success_ = false; + + auto *backend = new IDFWebServerOTABackend(); + if (!backend->begin()) { + ESP_LOGE(TAG, "OTA begin failed"); + delete backend; + return; + } + this->ota_backend_ = backend; + } + + auto *backend = static_cast(this->ota_backend_); + if (!backend) { + return; + } + + // Process data + if (len > 0) { + if (!backend->write(data, len)) { + ESP_LOGE(TAG, "OTA write failed"); + backend->abort(); + delete backend; + this->ota_backend_ = nullptr; + return; + } + this->ota_read_length_ += len; + this->report_ota_progress_(request); + } + + // Finalize + if (final) { + this->ota_success_ = backend->end(); + if (this->ota_success_) { + this->schedule_ota_reboot_(); + } else { + ESP_LOGE(TAG, "OTA end failed"); + } + delete backend; + this->ota_backend_ = nullptr; + } +#endif // USE_ESP_IDF } + void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) { -#ifdef USE_ARDUINO AsyncWebServerResponse *response; +#ifdef USE_ARDUINO if (!Update.hasError()) { response = request->beginResponse(200, "text/plain", "Update Successful!"); } else { @@ -105,16 +271,20 @@ void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) { Update.printError(ss); response = request->beginResponse(200, "text/plain", ss); } +#endif // USE_ARDUINO +#ifdef USE_ESP_IDF + // Send response based on the OTA result + response = request->beginResponse(200, "text/plain", this->ota_success_ ? "Update Successful!" : "Update Failed!"); +#endif // USE_ESP_IDF response->addHeader("Connection", "close"); request->send(response); -#endif } void WebServerBase::add_ota_handler() { -#ifdef USE_ARDUINO this->add_handler(new OTARequestHandler(this)); // NOLINT -#endif } +#endif + float WebServerBase::get_setup_priority() const { // Before WiFi (captive portal) return setup_priority::WIFI + 2.0f; diff --git a/esphome/components/web_server_base/web_server_base.h b/esphome/components/web_server_base/web_server_base.h index 641006cb99..09a41956c9 100644 --- a/esphome/components/web_server_base/web_server_base.h +++ b/esphome/components/web_server_base/web_server_base.h @@ -110,13 +110,17 @@ class WebServerBase : public Component { void add_handler(AsyncWebHandler *handler); +#ifdef USE_WEBSERVER_OTA void add_ota_handler(); +#endif void set_port(uint16_t port) { port_ = port; } uint16_t get_port() const { return port_; } protected: +#ifdef USE_WEBSERVER_OTA friend class OTARequestHandler; +#endif int initialized_{0}; uint16_t port_{80}; @@ -125,6 +129,7 @@ class WebServerBase : public Component { internal::Credentials credentials_; }; +#ifdef USE_WEBSERVER_OTA class OTARequestHandler : public AsyncWebHandler { public: OTARequestHandler(WebServerBase *parent) : parent_(parent) {} @@ -139,10 +144,21 @@ class OTARequestHandler : public AsyncWebHandler { bool isRequestHandlerTrivial() const override { return false; } protected: + void report_ota_progress_(AsyncWebServerRequest *request); + void schedule_ota_reboot_(); + void ota_init_(const char *filename); + uint32_t last_ota_progress_{0}; uint32_t ota_read_length_{0}; WebServerBase *parent_; + + private: +#ifdef USE_ESP_IDF + void *ota_backend_{nullptr}; + bool ota_success_{false}; +#endif }; +#endif // USE_WEBSERVER_OTA } // namespace web_server_base } // namespace esphome diff --git a/esphome/components/web_server_idf/__init__.py b/esphome/components/web_server_idf/__init__.py index 506e1c5c13..fe1c6f2640 100644 --- a/esphome/components/web_server_idf/__init__.py +++ b/esphome/components/web_server_idf/__init__.py @@ -1,5 +1,7 @@ -from esphome.components.esp32 import add_idf_sdkconfig_option +from esphome.components.esp32 import add_idf_component, add_idf_sdkconfig_option import esphome.config_validation as cv +from esphome.const import CONF_OTA, CONF_WEB_SERVER +from esphome.core import CORE CODEOWNERS = ["@dentra"] @@ -12,3 +14,7 @@ CONFIG_SCHEMA = cv.All( async def to_code(config): # Increase the maximum supported size of headers section in HTTP request packet to be processed by the server add_idf_sdkconfig_option("CONFIG_HTTPD_MAX_REQ_HDR_LEN", 1024) + # Check if web_server component has OTA enabled + if CORE.config.get(CONF_WEB_SERVER, {}).get(CONF_OTA, True): + # Add multipart parser component for ESP-IDF OTA support + add_idf_component(name="zorxx/multipart-parser", ref="1.0.1") diff --git a/esphome/components/web_server_idf/multipart.cpp b/esphome/components/web_server_idf/multipart.cpp new file mode 100644 index 0000000000..8655226ab9 --- /dev/null +++ b/esphome/components/web_server_idf/multipart.cpp @@ -0,0 +1,254 @@ +#include "esphome/core/defines.h" +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#include "multipart.h" +#include "utils.h" +#include "esphome/core/log.h" +#include +#include "multipart_parser.h" + +namespace esphome { +namespace web_server_idf { + +static const char *const TAG = "multipart"; + +// ========== MultipartReader Implementation ========== + +MultipartReader::MultipartReader(const std::string &boundary) { + // Initialize settings with callbacks + memset(&settings_, 0, sizeof(settings_)); + settings_.on_header_field = on_header_field; + settings_.on_header_value = on_header_value; + settings_.on_part_data = on_part_data; + settings_.on_part_data_end = on_part_data_end; + + ESP_LOGV(TAG, "Initializing multipart parser with boundary: '%s' (len: %zu)", boundary.c_str(), boundary.length()); + + // Create parser with boundary + parser_ = multipart_parser_init(boundary.c_str(), &settings_); + if (parser_) { + multipart_parser_set_data(parser_, this); + } else { + ESP_LOGE(TAG, "Failed to initialize multipart parser"); + } +} + +MultipartReader::~MultipartReader() { + if (parser_) { + multipart_parser_free(parser_); + } +} + +size_t MultipartReader::parse(const char *data, size_t len) { + if (!parser_) { + ESP_LOGE(TAG, "Parser not initialized"); + return 0; + } + + size_t parsed = multipart_parser_execute(parser_, data, len); + + if (parsed != len) { + ESP_LOGW(TAG, "Parser consumed %zu of %zu bytes - possible error", parsed, len); + } + + return parsed; +} + +void MultipartReader::process_header_(const char *value, size_t length) { + // Process the completed header (field + value pair) + std::string value_str(value, length); + + if (str_startswith_case_insensitive(current_header_field_, "content-disposition")) { + // Parse name and filename from Content-Disposition + current_part_.name = extract_header_param(value_str, "name"); + current_part_.filename = extract_header_param(value_str, "filename"); + } else if (str_startswith_case_insensitive(current_header_field_, "content-type")) { + current_part_.content_type = str_trim(value_str); + } + + // Clear field for next header + current_header_field_.clear(); +} + +int MultipartReader::on_header_field(multipart_parser *parser, const char *at, size_t length) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + reader->current_header_field_.assign(at, length); + return 0; +} + +int MultipartReader::on_header_value(multipart_parser *parser, const char *at, size_t length) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + reader->process_header_(at, length); + return 0; +} + +int MultipartReader::on_part_data(multipart_parser *parser, const char *at, size_t length) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + // Only process file uploads + if (reader->has_file() && reader->data_callback_) { + // IMPORTANT: The 'at' pointer points to data within the parser's input buffer. + // This data is only valid during this callback. The callback handler MUST + // process or copy the data immediately - it cannot store the pointer for + // later use as the buffer will be overwritten. + reader->data_callback_(reinterpret_cast(at), length); + } + return 0; +} + +int MultipartReader::on_part_data_end(multipart_parser *parser) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + ESP_LOGV(TAG, "Part data end"); + if (reader->part_complete_callback_) { + reader->part_complete_callback_(); + } + // Clear part info for next part + reader->current_part_ = Part{}; + return 0; +} + +// ========== Utility Functions ========== + +// Case-insensitive string prefix check +bool str_startswith_case_insensitive(const std::string &str, const std::string &prefix) { + if (str.length() < prefix.length()) { + return false; + } + return str_ncmp_ci(str.c_str(), prefix.c_str(), prefix.length()); +} + +// Extract a parameter value from a header line +// Handles both quoted and unquoted values +std::string extract_header_param(const std::string &header, const std::string ¶m) { + size_t search_pos = 0; + + while (search_pos < header.length()) { + // Look for param name + const char *found = stristr(header.c_str() + search_pos, param.c_str()); + if (!found) { + return ""; + } + size_t pos = found - header.c_str(); + + // Check if this is a word boundary (not part of another parameter) + if (pos > 0 && header[pos - 1] != ' ' && header[pos - 1] != ';' && header[pos - 1] != '\t') { + search_pos = pos + 1; + continue; + } + + // Move past param name + pos += param.length(); + + // Skip whitespace and find '=' + while (pos < header.length() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + if (pos >= header.length() || header[pos] != '=') { + search_pos = pos; + continue; + } + + pos++; // Skip '=' + + // Skip whitespace after '=' + while (pos < header.length() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + if (pos >= header.length()) { + return ""; + } + + // Check if value is quoted + if (header[pos] == '"') { + pos++; + size_t end = header.find('"', pos); + if (end != std::string::npos) { + return header.substr(pos, end - pos); + } + // Malformed - no closing quote + return ""; + } + + // Unquoted value - find the end (semicolon, comma, or end of string) + size_t end = pos; + while (end < header.length() && header[end] != ';' && header[end] != ',' && header[end] != ' ' && + header[end] != '\t') { + end++; + } + + return header.substr(pos, end - pos); + } + + return ""; +} + +// Parse boundary from Content-Type header +// Returns true if boundary found, false otherwise +// boundary_start and boundary_len will point to the boundary value +bool parse_multipart_boundary(const char *content_type, const char **boundary_start, size_t *boundary_len) { + if (!content_type) { + return false; + } + + // Check for multipart/form-data (case-insensitive) + if (!stristr(content_type, "multipart/form-data")) { + return false; + } + + // Look for boundary parameter + const char *b = stristr(content_type, "boundary="); + if (!b) { + return false; + } + + const char *start = b + 9; // Skip "boundary=" + + // Skip whitespace + while (*start == ' ' || *start == '\t') { + start++; + } + + if (!*start) { + return false; + } + + // Find end of boundary + const char *end = start; + if (*end == '"') { + // Quoted boundary + start++; + end++; + while (*end && *end != '"') { + end++; + } + *boundary_len = end - start; + } else { + // Unquoted boundary + while (*end && *end != ' ' && *end != ';' && *end != '\r' && *end != '\n' && *end != '\t') { + end++; + } + *boundary_len = end - start; + } + + if (*boundary_len == 0) { + return false; + } + + *boundary_start = start; + + return true; +} + +// Trim whitespace from both ends of a string +std::string str_trim(const std::string &str) { + size_t start = str.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) { + return ""; + } + size_t end = str.find_last_not_of(" \t\r\n"); + return str.substr(start, end - start + 1); +} + +} // namespace web_server_idf +} // namespace esphome +#endif // defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) diff --git a/esphome/components/web_server_idf/multipart.h b/esphome/components/web_server_idf/multipart.h new file mode 100644 index 0000000000..967c72ffa5 --- /dev/null +++ b/esphome/components/web_server_idf/multipart.h @@ -0,0 +1,86 @@ +#pragma once +#include "esphome/core/defines.h" +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) + +#include +#include +#include +#include +#include +#include +#include + +namespace esphome { +namespace web_server_idf { + +// Wrapper around zorxx/multipart-parser for ESP-IDF OTA uploads +class MultipartReader { + public: + struct Part { + std::string name; + std::string filename; + std::string content_type; + }; + + // IMPORTANT: The data pointer in DataCallback is only valid during the callback! + // The multipart parser passes pointers to its internal buffer which will be + // overwritten after the callback returns. Callbacks MUST process or copy the + // data immediately - storing the pointer for deferred processing will result + // in use-after-free bugs. + using DataCallback = std::function; + using PartCompleteCallback = std::function; + + explicit MultipartReader(const std::string &boundary); + ~MultipartReader(); + + // Set callbacks for handling data + void set_data_callback(DataCallback callback) { data_callback_ = std::move(callback); } + void set_part_complete_callback(PartCompleteCallback callback) { part_complete_callback_ = std::move(callback); } + + // Parse incoming data + size_t parse(const char *data, size_t len); + + // Get current part info + const Part &get_current_part() const { return current_part_; } + + // Check if we found a file upload + bool has_file() const { return !current_part_.filename.empty(); } + + private: + static int on_header_field(multipart_parser *parser, const char *at, size_t length); + static int on_header_value(multipart_parser *parser, const char *at, size_t length); + static int on_part_data(multipart_parser *parser, const char *at, size_t length); + static int on_part_data_end(multipart_parser *parser); + + multipart_parser *parser_{nullptr}; + multipart_parser_settings settings_{}; + + Part current_part_; + std::string current_header_field_; + + DataCallback data_callback_; + PartCompleteCallback part_complete_callback_; + + void process_header_(const char *value, size_t length); +}; + +// ========== Utility Functions ========== + +// Case-insensitive string prefix check +bool str_startswith_case_insensitive(const std::string &str, const std::string &prefix); + +// Extract a parameter value from a header line +// Handles both quoted and unquoted values +std::string extract_header_param(const std::string &header, const std::string ¶m); + +// Parse boundary from Content-Type header +// Returns true if boundary found, false otherwise +// boundary_start and boundary_len will point to the boundary value +bool parse_multipart_boundary(const char *content_type, const char **boundary_start, size_t *boundary_len); + +// Trim whitespace from both ends of a string +std::string str_trim(const std::string &str); + +} // namespace web_server_idf +} // namespace esphome +#endif // defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) diff --git a/esphome/components/web_server_idf/utils.cpp b/esphome/components/web_server_idf/utils.cpp index 349acce50d..ac5df90bb8 100644 --- a/esphome/components/web_server_idf/utils.cpp +++ b/esphome/components/web_server_idf/utils.cpp @@ -1,5 +1,7 @@ #ifdef USE_ESP_IDF #include +#include +#include #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "http_parser.h" @@ -88,6 +90,36 @@ optional query_key_value(const std::string &query_url, const std::s return {val.get()}; } +// Helper function for case-insensitive string region comparison +bool str_ncmp_ci(const char *s1, const char *s2, size_t n) { + for (size_t i = 0; i < n; i++) { + if (!char_equals_ci(s1[i], s2[i])) { + return false; + } + } + return true; +} + +// Case-insensitive string search (like strstr but case-insensitive) +const char *stristr(const char *haystack, const char *needle) { + if (!haystack) { + return nullptr; + } + + size_t needle_len = strlen(needle); + if (needle_len == 0) { + return haystack; + } + + for (const char *p = haystack; *p; p++) { + if (str_ncmp_ci(p, needle, needle_len)) { + return p; + } + } + + return nullptr; +} + } // namespace web_server_idf } // namespace esphome #endif // USE_ESP_IDF diff --git a/esphome/components/web_server_idf/utils.h b/esphome/components/web_server_idf/utils.h index 9ed17c1d50..988b962d72 100644 --- a/esphome/components/web_server_idf/utils.h +++ b/esphome/components/web_server_idf/utils.h @@ -2,6 +2,7 @@ #ifdef USE_ESP_IDF #include +#include #include "esphome/core/helpers.h" namespace esphome { @@ -12,6 +13,15 @@ optional request_get_header(httpd_req_t *req, const char *name); optional request_get_url_query(httpd_req_t *req); optional query_key_value(const std::string &query_url, const std::string &key); +// Helper function for case-insensitive character comparison +inline bool char_equals_ci(char a, char b) { return ::tolower(a) == ::tolower(b); } + +// Helper function for case-insensitive string region comparison +bool str_ncmp_ci(const char *s1, const char *s2, size_t n); + +// Case-insensitive string search (like strstr but case-insensitive) +const char *stristr(const char *haystack, const char *needle); + } // namespace web_server_idf } // namespace esphome #endif // USE_ESP_IDF diff --git a/esphome/components/web_server_idf/web_server_idf.cpp b/esphome/components/web_server_idf/web_server_idf.cpp index 409230806c..9478e4748c 100644 --- a/esphome/components/web_server_idf/web_server_idf.cpp +++ b/esphome/components/web_server_idf/web_server_idf.cpp @@ -1,16 +1,25 @@ #ifdef USE_ESP_IDF #include +#include +#include +#include #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "esp_tls_crypto.h" +#include +#include #include "utils.h" - #include "web_server_idf.h" +#ifdef USE_WEBSERVER_OTA +#include +#include "multipart.h" // For parse_multipart_boundary and other utils +#endif + #ifdef USE_WEBSERVER #include "esphome/components/web_server/web_server.h" #include "esphome/components/web_server/list_entities.h" @@ -72,18 +81,32 @@ void AsyncWebServer::begin() { esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) { ESP_LOGVV(TAG, "Enter AsyncWebServer::request_post_handler. uri=%s", r->uri); auto content_type = request_get_header(r, "Content-Type"); - if (content_type.has_value() && *content_type != "application/x-www-form-urlencoded") { - ESP_LOGW(TAG, "Only application/x-www-form-urlencoded supported for POST request"); - // fallback to get handler to support backward compatibility - return AsyncWebServer::request_handler(r); - } if (!request_has_header(r, "Content-Length")) { - ESP_LOGW(TAG, "Content length is requred for post: %s", r->uri); + ESP_LOGW(TAG, "Content length is required for post: %s", r->uri); httpd_resp_send_err(r, HTTPD_411_LENGTH_REQUIRED, nullptr); return ESP_OK; } + if (content_type.has_value()) { + const char *content_type_char = content_type.value().c_str(); + + // Check most common case first + if (stristr(content_type_char, "application/x-www-form-urlencoded") != nullptr) { + // Normal form data - proceed with regular handling +#ifdef USE_WEBSERVER_OTA + } else if (stristr(content_type_char, "multipart/form-data") != nullptr) { + auto *server = static_cast(r->user_ctx); + return server->handle_multipart_upload_(r, content_type_char); +#endif + } else { + ESP_LOGW(TAG, "Unsupported content type for POST: %s", content_type_char); + // fallback to get handler to support backward compatibility + return AsyncWebServer::request_handler(r); + } + } + + // Handle regular form data if (r->content_len > HTTPD_MAX_REQ_HDR_LEN) { ESP_LOGW(TAG, "Request size is to big: %zu", r->content_len); httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); @@ -539,6 +562,97 @@ void AsyncEventSourceResponse::deferrable_send_state(void *source, const char *e } #endif +#ifdef USE_WEBSERVER_OTA +esp_err_t AsyncWebServer::handle_multipart_upload_(httpd_req_t *r, const char *content_type) { + static constexpr size_t MULTIPART_CHUNK_SIZE = 1460; // Match Arduino AsyncWebServer buffer size + static constexpr size_t YIELD_INTERVAL_BYTES = 16 * 1024; // Yield every 16KB to prevent watchdog + + // Parse boundary and create reader + const char *boundary_start; + size_t boundary_len; + if (!parse_multipart_boundary(content_type, &boundary_start, &boundary_len)) { + ESP_LOGE(TAG, "Failed to parse multipart boundary"); + httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); + return ESP_FAIL; + } + + AsyncWebServerRequest req(r); + AsyncWebHandler *handler = nullptr; + for (auto *h : this->handlers_) { + if (h->canHandle(&req)) { + handler = h; + break; + } + } + + if (!handler) { + ESP_LOGW(TAG, "No handler found for OTA request"); + httpd_resp_send_err(r, HTTPD_404_NOT_FOUND, nullptr); + return ESP_OK; + } + + // Upload state + std::string filename; + size_t index = 0; + // Create reader on heap to reduce stack usage + auto reader = std::make_unique("--" + std::string(boundary_start, boundary_len)); + + // Configure callbacks + reader->set_data_callback([&](const uint8_t *data, size_t len) { + if (!reader->has_file() || !len) + return; + + if (filename.empty()) { + filename = reader->get_current_part().filename; + ESP_LOGV(TAG, "Processing file: '%s'", filename.c_str()); + handler->handleUpload(&req, filename, 0, nullptr, 0, false); // Start + } + + handler->handleUpload(&req, filename, index, const_cast(data), len, false); + index += len; + }); + + reader->set_part_complete_callback([&]() { + if (index > 0) { + handler->handleUpload(&req, filename, index, nullptr, 0, true); // End + filename.clear(); + index = 0; + } + }); + + // Process data + std::unique_ptr buffer(new char[MULTIPART_CHUNK_SIZE]); + size_t bytes_since_yield = 0; + + for (size_t remaining = r->content_len; remaining > 0;) { + int recv_len = httpd_req_recv(r, buffer.get(), std::min(remaining, MULTIPART_CHUNK_SIZE)); + + if (recv_len <= 0) { + httpd_resp_send_err(r, recv_len == HTTPD_SOCK_ERR_TIMEOUT ? HTTPD_408_REQ_TIMEOUT : HTTPD_400_BAD_REQUEST, + nullptr); + return recv_len == HTTPD_SOCK_ERR_TIMEOUT ? ESP_ERR_TIMEOUT : ESP_FAIL; + } + + if (reader->parse(buffer.get(), recv_len) != static_cast(recv_len)) { + ESP_LOGW(TAG, "Multipart parser error"); + httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); + return ESP_FAIL; + } + + remaining -= recv_len; + bytes_since_yield += recv_len; + + if (bytes_since_yield > YIELD_INTERVAL_BYTES) { + vTaskDelay(1); + bytes_since_yield = 0; + } + } + + handler->handleRequest(&req); + return ESP_OK; +} +#endif // USE_WEBSERVER_OTA + } // namespace web_server_idf } // namespace esphome diff --git a/esphome/components/web_server_idf/web_server_idf.h b/esphome/components/web_server_idf/web_server_idf.h index 7547117224..8de25c8e96 100644 --- a/esphome/components/web_server_idf/web_server_idf.h +++ b/esphome/components/web_server_idf/web_server_idf.h @@ -204,6 +204,9 @@ class AsyncWebServer { static esp_err_t request_handler(httpd_req_t *r); static esp_err_t request_post_handler(httpd_req_t *r); esp_err_t request_handler_(AsyncWebServerRequest *request) const; +#ifdef USE_WEBSERVER_OTA + esp_err_t handle_multipart_upload_(httpd_req_t *r, const char *content_type); +#endif std::vector handlers_; std::function on_not_found_{}; }; diff --git a/esphome/core/defines.h b/esphome/core/defines.h index ea3c8bdc17..cfaed6fdb7 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -153,6 +153,7 @@ #define USE_SPI #define USE_VOICE_ASSISTANT #define USE_WEBSERVER +#define USE_WEBSERVER_OTA #define USE_WEBSERVER_PORT 80 // NOLINT #define USE_WEBSERVER_SORTING #define USE_WIFI_11KV_SUPPORT diff --git a/esphome/idf_component.yml b/esphome/idf_component.yml index 6299909033..c43b622684 100644 --- a/esphome/idf_component.yml +++ b/esphome/idf_component.yml @@ -17,3 +17,5 @@ dependencies: version: 2.0.11 rules: - if: "target in [esp32h2, esp32p4]" + zorxx/multipart-parser: + version: 1.0.1 diff --git a/tests/components/web_server/test_no_ota.esp32-idf.yaml b/tests/components/web_server/test_no_ota.esp32-idf.yaml new file mode 100644 index 0000000000..1f677fb948 --- /dev/null +++ b/tests/components/web_server/test_no_ota.esp32-idf.yaml @@ -0,0 +1,9 @@ +packages: + device_base: !include common.yaml + +# No OTA component defined for this test + +web_server: + port: 8080 + version: 2 + ota: false diff --git a/tests/components/web_server/test_ota.esp32-idf.yaml b/tests/components/web_server/test_ota.esp32-idf.yaml new file mode 100644 index 0000000000..294e7f862e --- /dev/null +++ b/tests/components/web_server/test_ota.esp32-idf.yaml @@ -0,0 +1,32 @@ +# Test configuration for ESP-IDF web server with OTA enabled +esphome: + name: test-web-server-ota-idf + +# Force ESP-IDF framework +esp32: + board: esp32dev + framework: + type: esp-idf + +packages: + device_base: !include common.yaml + +# Enable OTA for multipart upload testing +ota: + - platform: esphome + password: "test_ota_password" + +# Web server with OTA enabled +web_server: + port: 8080 + version: 2 + ota: true + include_internal: true + +# Enable debug logging for OTA +logger: + level: DEBUG + logs: + web_server: VERBOSE + web_server_idf: VERBOSE + diff --git a/tests/components/web_server/test_ota_disabled.esp32-idf.yaml b/tests/components/web_server/test_ota_disabled.esp32-idf.yaml new file mode 100644 index 0000000000..c7c7574e3b --- /dev/null +++ b/tests/components/web_server/test_ota_disabled.esp32-idf.yaml @@ -0,0 +1,11 @@ +packages: + device_base: !include common.yaml + +# OTA is configured but web_server OTA is disabled +ota: + - platform: esphome + +web_server: + port: 8080 + version: 2 + ota: false From 35de36d690188623cbd11062aeb59fdc35a4ee1e Mon Sep 17 00:00:00 2001 From: Javier Peletier Date: Tue, 1 Jul 2025 05:39:06 +0200 Subject: [PATCH 6/8] [modbus] Modbus server role: write holding registers (#9156) --- esphome/components/modbus/modbus.cpp | 48 +++++++---- esphome/components/modbus/modbus.h | 1 + .../components/modbus_controller/__init__.py | 13 +++ .../modbus_controller/modbus_controller.cpp | 80 +++++++++++++++++++ .../modbus_controller/modbus_controller.h | 15 ++++ .../components/modbus_controller/common.yaml | 13 ++- 6 files changed, 154 insertions(+), 16 deletions(-) diff --git a/esphome/components/modbus/modbus.cpp b/esphome/components/modbus/modbus.cpp index c2efa93fae..6350f43ef6 100644 --- a/esphome/components/modbus/modbus.cpp +++ b/esphome/components/modbus/modbus.cpp @@ -90,15 +90,24 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { } else { // data starts at 2 and length is 4 for read registers commands - if (this->role == ModbusRole::SERVER && (function_code == 0x1 || function_code == 0x3 || function_code == 0x4)) { - data_offset = 2; - data_len = 4; - } - - // the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands - if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) { - data_offset = 2; - data_len = 4; + if (this->role == ModbusRole::SERVER) { + if (function_code == 0x1 || function_code == 0x3 || function_code == 0x4 || function_code == 0x6) { + data_offset = 2; + data_len = 4; + } else if (function_code == 0x10) { + if (at < 6) { + return true; + } + data_offset = 2; + // starting address (2 bytes) + quantity of registers (2 bytes) + byte count itself (1 byte) + actual byte count + data_len = 2 + 2 + 1 + raw[6]; + } + } else { + // the response for write command mirrors the requests and data starts at offset 2 instead of 3 for read commands + if (function_code == 0x5 || function_code == 0x06 || function_code == 0xF || function_code == 0x10) { + data_offset = 2; + data_len = 4; + } } // Error ( msb indicates error ) @@ -132,6 +141,7 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { bool found = false; for (auto *device : this->devices_) { if (device->address_ == address) { + found = true; // Is it an error response? if ((function_code & 0x80) == 0x80) { ESP_LOGD(TAG, "Modbus error function code: 0x%X exception: %d", function_code, raw[2]); @@ -141,13 +151,21 @@ bool Modbus::parse_modbus_byte_(uint8_t byte) { // Ignore modbus exception not related to a pending command ESP_LOGD(TAG, "Ignoring Modbus error - not expecting a response"); } - } else if (this->role == ModbusRole::SERVER && (function_code == 0x3 || function_code == 0x4)) { - device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8), - uint16_t(data[3]) | (uint16_t(data[2]) << 8)); - } else { - device->on_modbus_data(data); + continue; } - found = true; + if (this->role == ModbusRole::SERVER) { + if (function_code == 0x3 || function_code == 0x4) { + device->on_modbus_read_registers(function_code, uint16_t(data[1]) | (uint16_t(data[0]) << 8), + uint16_t(data[3]) | (uint16_t(data[2]) << 8)); + continue; + } + if (function_code == 0x6 || function_code == 0x10) { + device->on_modbus_write_registers(function_code, data); + continue; + } + } + // fallthrough for other function codes + device->on_modbus_data(data); } } waiting_for_response = 0; diff --git a/esphome/components/modbus/modbus.h b/esphome/components/modbus/modbus.h index aebdbccc78..ec35612690 100644 --- a/esphome/components/modbus/modbus.h +++ b/esphome/components/modbus/modbus.h @@ -59,6 +59,7 @@ class ModbusDevice { virtual void on_modbus_data(const std::vector &data) = 0; virtual void on_modbus_error(uint8_t function_code, uint8_t exception_code) {} virtual void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers){}; + virtual void on_modbus_write_registers(uint8_t function_code, const std::vector &data){}; void send(uint8_t function, uint16_t start_address, uint16_t number_of_entities, uint8_t payload_len = 0, const uint8_t *payload = nullptr) { this->parent_->send(this->address_, function, start_address, number_of_entities, payload_len, payload); diff --git a/esphome/components/modbus_controller/__init__.py b/esphome/components/modbus_controller/__init__.py index 8079b824b0..5ab82f5e17 100644 --- a/esphome/components/modbus_controller/__init__.py +++ b/esphome/components/modbus_controller/__init__.py @@ -39,6 +39,7 @@ CODEOWNERS = ["@martgras"] AUTO_LOAD = ["modbus"] CONF_READ_LAMBDA = "read_lambda" +CONF_WRITE_LAMBDA = "write_lambda" CONF_SERVER_REGISTERS = "server_registers" MULTI_CONF = True @@ -148,6 +149,7 @@ ModbusServerRegisterSchema = cv.Schema( cv.Required(CONF_ADDRESS): cv.positive_int, cv.Optional(CONF_VALUE_TYPE, default="U_WORD"): cv.enum(SENSOR_VALUE_TYPE), cv.Required(CONF_READ_LAMBDA): cv.returning_lambda, + cv.Optional(CONF_WRITE_LAMBDA): cv.returning_lambda, } ) @@ -318,6 +320,17 @@ async def to_code(config): ), ) ) + if CONF_WRITE_LAMBDA in server_register: + cg.add( + server_register_var.set_write_lambda( + cg.TemplateArguments(cpp_type), + await cg.process_lambda( + server_register[CONF_WRITE_LAMBDA], + parameters=[(cg.uint16, "address"), (cpp_type, "x")], + return_type=cg.bool_, + ), + ) + ) cg.add(var.add_server_register(server_register_var)) await register_modbus_device(var, config) for conf in config.get(CONF_ON_COMMAND_SENT, []): diff --git a/esphome/components/modbus_controller/modbus_controller.cpp b/esphome/components/modbus_controller/modbus_controller.cpp index 81e9ccf0a6..0f3ddf920d 100644 --- a/esphome/components/modbus_controller/modbus_controller.cpp +++ b/esphome/components/modbus_controller/modbus_controller.cpp @@ -152,6 +152,86 @@ void ModbusController::on_modbus_read_registers(uint8_t function_code, uint16_t this->send(function_code, start_address, number_of_registers, response.size(), response.data()); } +void ModbusController::on_modbus_write_registers(uint8_t function_code, const std::vector &data) { + uint16_t number_of_registers; + uint16_t payload_offset; + + if (function_code == 0x10) { + number_of_registers = uint16_t(data[3]) | (uint16_t(data[2]) << 8); + if (number_of_registers == 0 || number_of_registers > 0x7B) { + ESP_LOGW(TAG, "Invalid number of registers %d. Sending exception response.", number_of_registers); + send_error(function_code, 3); + return; + } + uint16_t payload_size = data[4]; + if (payload_size != number_of_registers * 2) { + ESP_LOGW(TAG, "Payload size of %d bytes is not 2 times the number of registers (%d). Sending exception response.", + payload_size, number_of_registers); + send_error(function_code, 3); + return; + } + payload_offset = 5; + } else if (function_code == 0x06) { + number_of_registers = 1; + payload_offset = 2; + } else { + ESP_LOGW(TAG, "Invalid function code 0x%X. Sending exception response.", function_code); + send_error(function_code, 1); + return; + } + + uint16_t start_address = uint16_t(data[1]) | (uint16_t(data[0]) << 8); + ESP_LOGD(TAG, + "Received write holding registers for device 0x%X. FC: 0x%X. Start address: 0x%X. Number of registers: " + "0x%X.", + this->address_, function_code, start_address, number_of_registers); + + auto for_each_register = [this, start_address, number_of_registers, payload_offset]( + const std::function &callback) -> bool { + uint16_t offset = payload_offset; + for (uint16_t current_address = start_address; current_address < start_address + number_of_registers;) { + bool ok = false; + for (auto *server_register : this->server_registers_) { + if (server_register->address == current_address) { + ok = callback(server_register, offset); + current_address += server_register->register_count; + offset += server_register->register_count * sizeof(uint16_t); + break; + } + } + + if (!ok) { + return false; + } + } + return true; + }; + + // check all registers are writable before writing to any of them: + if (!for_each_register([](ServerRegister *server_register, uint16_t offset) -> bool { + return server_register->write_lambda != nullptr; + })) { + send_error(function_code, 1); + return; + } + + // Actually write to the registers: + if (!for_each_register([&data](ServerRegister *server_register, uint16_t offset) { + int64_t number = payload_to_number(data, server_register->value_type, offset, 0xFFFFFFFF); + return server_register->write_lambda(number); + })) { + send_error(function_code, 4); + return; + } + + std::vector response; + response.reserve(6); + response.push_back(this->address_); + response.push_back(function_code); + response.insert(response.end(), data.begin(), data.begin() + 4); + this->send_raw(response); +} + SensorSet ModbusController::find_sensors_(ModbusRegisterType register_type, uint16_t start_address) const { auto reg_it = std::find_if( std::begin(this->register_ranges_), std::end(this->register_ranges_), diff --git a/esphome/components/modbus_controller/modbus_controller.h b/esphome/components/modbus_controller/modbus_controller.h index 11d27c4025..a86ad1ccb5 100644 --- a/esphome/components/modbus_controller/modbus_controller.h +++ b/esphome/components/modbus_controller/modbus_controller.h @@ -258,6 +258,7 @@ class SensorItem { class ServerRegister { using ReadLambda = std::function; + using WriteLambda = std::function; public: ServerRegister(uint16_t address, SensorValueType value_type, uint8_t register_count) { @@ -277,6 +278,17 @@ class ServerRegister { }; } + template + void set_write_lambda(const std::function &&user_write_lambda) { + this->write_lambda = [this, user_write_lambda](int64_t number) { + if constexpr (std::is_same_v) { + float float_value = bit_cast(static_cast(number)); + return user_write_lambda(this->address, float_value); + } + return user_write_lambda(this->address, static_cast(number)); + }; + } + // Formats a raw value into a string representation based on the value type for debugging std::string format_value(int64_t value) const { switch (this->value_type) { @@ -304,6 +316,7 @@ class ServerRegister { SensorValueType value_type{SensorValueType::RAW}; uint8_t register_count{0}; ReadLambda read_lambda; + WriteLambda write_lambda; }; // ModbusController::create_register_ranges_ tries to optimize register range @@ -485,6 +498,8 @@ class ModbusController : public PollingComponent, public modbus::ModbusDevice { void on_modbus_error(uint8_t function_code, uint8_t exception_code) override; /// called when a modbus request (function code 0x03 or 0x04) was parsed without errors void on_modbus_read_registers(uint8_t function_code, uint16_t start_address, uint16_t number_of_registers) final; + /// called when a modbus request (function code 0x06 or 0x10) was parsed without errors + void on_modbus_write_registers(uint8_t function_code, const std::vector &data) final; /// default delegate called by process_modbus_data when a response has retrieved from the incoming queue void on_register_data(ModbusRegisterType register_type, uint16_t start_address, const std::vector &data); /// default delegate called by process_modbus_data when a response for a write response has retrieved from the diff --git a/tests/components/modbus_controller/common.yaml b/tests/components/modbus_controller/common.yaml index 7fa9f8dae3..7d342ee353 100644 --- a/tests/components/modbus_controller/common.yaml +++ b/tests/components/modbus_controller/common.yaml @@ -33,7 +33,18 @@ modbus_controller: read_lambda: |- return 42.3; max_cmd_retries: 0 - + - id: modbus_controller3 + address: 0x3 + modbus_id: mod_bus2 + server_registers: + - address: 0x0009 + value_type: S_DWORD + read_lambda: |- + return 31; + write_lambda: |- + printf("address=%d, value=%d", x); + return true; + max_cmd_retries: 0 binary_sensor: - platform: modbus_controller modbus_controller_id: modbus_controller1 From b000b1b70cc0c2981a8c64020720857f0df6efaa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 1 Jul 2025 09:43:50 -0500 Subject: [PATCH 7/8] Fix regression: BK7231N devices not returning entities via API --- esphome/components/api/api_connection.cpp | 47 ++++----- esphome/components/api/api_connection.h | 119 ++++++++++------------ 2 files changed, 72 insertions(+), 94 deletions(-) diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index b7624221c9..e83d508c50 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -1687,7 +1687,9 @@ void APIConnection::DeferredBatch::add_item(EntityBase *entity, MessageCreator c // O(n) but optimized for RAM and not performance. for (auto &item : items) { if (item.entity == entity && item.message_type == message_type) { - // Update the existing item with the new creator + // Clean up old creator before replacing + item.creator.cleanup(message_type); + // Move assign the new creator item.creator = std::move(creator); return; } @@ -1730,11 +1732,11 @@ void APIConnection::process_batch_() { return; } - size_t num_items = this->deferred_batch_.items.size(); + size_t num_items = this->deferred_batch_.size(); // Fast path for single message - allocate exact size needed if (num_items == 1) { - const auto &item = this->deferred_batch_.items[0]; + const auto &item = this->deferred_batch_[0]; // Let the creator calculate size and encode if it fits uint16_t payload_size = @@ -1764,7 +1766,8 @@ void APIConnection::process_batch_() { // Pre-calculate exact buffer size needed based on message types uint32_t total_estimated_size = 0; - for (const auto &item : this->deferred_batch_.items) { + for (size_t i = 0; i < this->deferred_batch_.size(); i++) { + const auto &item = this->deferred_batch_[i]; total_estimated_size += get_estimated_message_size(item.message_type); } @@ -1785,7 +1788,8 @@ void APIConnection::process_batch_() { uint32_t current_offset = 0; // Process items and encode directly to buffer - for (const auto &item : this->deferred_batch_.items) { + for (size_t i = 0; i < this->deferred_batch_.size(); i++) { + const auto &item = this->deferred_batch_[i]; // Try to encode message // The creator will calculate overhead to determine if the message fits uint16_t payload_size = item.creator(item.entity, this, remaining_size, false, item.message_type); @@ -1840,17 +1844,15 @@ void APIConnection::process_batch_() { // Log messages after send attempt for VV debugging // It's safe to use the buffer for logging at this point regardless of send result for (size_t i = 0; i < items_processed; i++) { - const auto &item = this->deferred_batch_.items[i]; + const auto &item = this->deferred_batch_[i]; this->log_batch_item_(item); } #endif // Handle remaining items more efficiently - if (items_processed < this->deferred_batch_.items.size()) { - // Remove processed items from the beginning - this->deferred_batch_.items.erase(this->deferred_batch_.items.begin(), - this->deferred_batch_.items.begin() + items_processed); - + if (items_processed < this->deferred_batch_.size()) { + // Remove processed items from the beginning with proper cleanup + this->deferred_batch_.remove_front(items_processed); // Reschedule for remaining items this->schedule_batch_(); } else { @@ -1861,23 +1863,16 @@ void APIConnection::process_batch_() { uint16_t APIConnection::MessageCreator::operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single, uint16_t message_type) const { - if (has_tagged_string_ptr_()) { - // Handle string-based messages - switch (message_type) { #ifdef USE_EVENT - case EventResponse::MESSAGE_TYPE: { - auto *e = static_cast(entity); - return APIConnection::try_send_event_response(e, *get_string_ptr_(), conn, remaining_size, is_single); - } -#endif - default: - // Should not happen, return 0 to indicate no message - return 0; - } - } else { - // Function pointer case - return data_.ptr(entity, conn, remaining_size, is_single); + // Special case: EventResponse uses string pointer + if (message_type == EventResponse::MESSAGE_TYPE) { + auto *e = static_cast(entity); + return APIConnection::try_send_event_response(e, *data_.string_ptr, conn, remaining_size, is_single); } +#endif + + // All other message types use function pointers + return data_.function_ptr(entity, conn, remaining_size, is_single); } uint16_t APIConnection::try_send_list_info_done(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 410a9ad3a5..642c11bc9f 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -451,96 +451,53 @@ class APIConnection : public APIServerConnection { // Function pointer type for message encoding using MessageCreatorPtr = uint16_t (*)(EntityBase *, APIConnection *, uint32_t remaining_size, bool is_single); - // Optimized MessageCreator class using tagged pointer class MessageCreator { - // Ensure pointer alignment allows LSB tagging - static_assert(alignof(std::string *) > 1, "String pointer alignment must be > 1 for LSB tagging"); - public: // Constructor for function pointer - MessageCreator(MessageCreatorPtr ptr) { - // Function pointers are always aligned, so LSB is 0 - data_.ptr = ptr; - } + MessageCreator(MessageCreatorPtr ptr) { data_.function_ptr = ptr; } // Constructor for string state capture - explicit MessageCreator(const std::string &str_value) { - // Allocate string and tag the pointer - auto *str = new std::string(str_value); - // Set LSB to 1 to indicate string pointer - data_.tagged = reinterpret_cast(str) | 1; - } + explicit MessageCreator(const std::string &str_value) { data_.string_ptr = new std::string(str_value); } - // Destructor - ~MessageCreator() { - if (has_tagged_string_ptr_()) { - delete get_string_ptr_(); - } - } + // No destructor - cleanup must be called explicitly with message_type - // Copy constructor - MessageCreator(const MessageCreator &other) { - if (other.has_tagged_string_ptr_()) { - auto *str = new std::string(*other.get_string_ptr_()); - data_.tagged = reinterpret_cast(str) | 1; - } else { - data_ = other.data_; - } - } + // Delete copy operations - MessageCreator should only be moved + MessageCreator(const MessageCreator &other) = delete; + MessageCreator &operator=(const MessageCreator &other) = delete; // Move constructor - MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.ptr = nullptr; } - - // Assignment operators (needed for batch deduplication) - MessageCreator &operator=(const MessageCreator &other) { - if (this != &other) { - // Clean up current string data if needed - if (has_tagged_string_ptr_()) { - delete get_string_ptr_(); - } - // Copy new data - if (other.has_tagged_string_ptr_()) { - auto *str = new std::string(*other.get_string_ptr_()); - data_.tagged = reinterpret_cast(str) | 1; - } else { - data_ = other.data_; - } - } - return *this; - } + MessageCreator(MessageCreator &&other) noexcept : data_(other.data_) { other.data_.function_ptr = nullptr; } + // Move assignment MessageCreator &operator=(MessageCreator &&other) noexcept { if (this != &other) { - // Clean up current string data if needed - if (has_tagged_string_ptr_()) { - delete get_string_ptr_(); - } - // Move data + // IMPORTANT: Caller must ensure cleanup() was called if this contains a string! + // In our usage, this happens in add_item() deduplication and vector::erase() data_ = other.data_; - // Reset other to safe state - other.data_.ptr = nullptr; + other.data_.function_ptr = nullptr; } return *this; } - // Call operator - now accepts message_type as parameter + // Call operator - uses message_type to determine union type uint16_t operator()(EntityBase *entity, APIConnection *conn, uint32_t remaining_size, bool is_single, uint16_t message_type) const; - private: - // Check if this contains a string pointer - bool has_tagged_string_ptr_() const { return (data_.tagged & 1) != 0; } - - // Get the actual string pointer (clears the tag bit) - std::string *get_string_ptr_() const { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - return reinterpret_cast(data_.tagged & ~uintptr_t(1)); + // Manual cleanup method - must be called before destruction for string types + void cleanup(uint16_t message_type) { +#ifdef USE_EVENT + if (message_type == EventResponse::MESSAGE_TYPE && data_.string_ptr != nullptr) { + delete data_.string_ptr; + data_.string_ptr = nullptr; + } +#endif } - union { - MessageCreatorPtr ptr; - uintptr_t tagged; - } data_; // 4 bytes on 32-bit + private: + union Data { + MessageCreatorPtr function_ptr; + std::string *string_ptr; + } data_; // 4 bytes on 32-bit, 8 bytes on 64-bit - same as before }; // Generic batching mechanism for both state updates and entity info @@ -558,20 +515,46 @@ class APIConnection : public APIServerConnection { std::vector items; uint32_t batch_start_time{0}; + private: + // Helper to cleanup items from the beginning + void cleanup_items(size_t count) { + for (size_t i = 0; i < count; i++) { + items[i].creator.cleanup(items[i].message_type); + } + } + + public: DeferredBatch() { // Pre-allocate capacity for typical batch sizes to avoid reallocation items.reserve(8); } + ~DeferredBatch() { + // Ensure cleanup of any remaining items + clear(); + } + // Add item to the batch void add_item(EntityBase *entity, MessageCreator creator, uint16_t message_type); // Add item to the front of the batch (for high priority messages like ping) void add_item_front(EntityBase *entity, MessageCreator creator, uint16_t message_type); + + // Clear all items with proper cleanup void clear() { + cleanup_items(items.size()); items.clear(); batch_start_time = 0; } + + // Remove processed items from the front with proper cleanup + void remove_front(size_t count) { + cleanup_items(count); + items.erase(items.begin(), items.begin() + count); + } + bool empty() const { return items.empty(); } + size_t size() const { return items.size(); } + const BatchItem &operator[](size_t index) const { return items[index]; } }; // DeferredBatch here (16 bytes, 4-byte aligned) From c33c14a46f62294a5f5306662b9a09967d00efcd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 1 Jul 2025 11:57:02 -0500 Subject: [PATCH 8/8] tidy --- esphome/components/api/api_connection.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index 642c11bc9f..151369aa70 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -517,7 +517,7 @@ class APIConnection : public APIServerConnection { private: // Helper to cleanup items from the beginning - void cleanup_items(size_t count) { + void cleanup_items_(size_t count) { for (size_t i = 0; i < count; i++) { items[i].creator.cleanup(items[i].message_type); } @@ -541,14 +541,14 @@ class APIConnection : public APIServerConnection { // Clear all items with proper cleanup void clear() { - cleanup_items(items.size()); + cleanup_items_(items.size()); items.clear(); batch_start_time = 0; } // Remove processed items from the front with proper cleanup void remove_front(size_t count) { - cleanup_items(count); + cleanup_items_(count); items.erase(items.begin(), items.begin() + count); }