diff --git a/esphome/components/captive_portal/captive_portal.cpp b/esphome/components/captive_portal/captive_portal.cpp index 51e5cfc8ff..ba392bb0f2 100644 --- a/esphome/components/captive_portal/captive_portal.cpp +++ b/esphome/components/captive_portal/captive_portal.cpp @@ -47,7 +47,9 @@ void CaptivePortal::start() { this->base_->init(); if (!this->initialized_) { this->base_->add_handler(this); +#ifdef USE_WEBSERVER_OTA this->base_->add_ota_handler(); +#endif } #ifdef USE_ARDUINO diff --git a/esphome/components/web_server/__init__.py b/esphome/components/web_server/__init__.py index f2c1824028..ca145c732b 100644 --- a/esphome/components/web_server/__init__.py +++ b/esphome/components/web_server/__init__.py @@ -40,6 +40,7 @@ CONF_SORTING_GROUP_ID = "sorting_group_id" CONF_SORTING_GROUPS = "sorting_groups" CONF_SORTING_WEIGHT = "sorting_weight" + web_server_ns = cg.esphome_ns.namespace("web_server") WebServer = web_server_ns.class_("WebServer", cg.Component, cg.Controller) @@ -72,12 +73,6 @@ def validate_local(config): return config -def validate_ota(config): - if CORE.using_esp_idf and config[CONF_OTA]: - raise cv.Invalid("Enabling 'ota' is not supported for IDF framework yet") - return config - - def validate_sorting_groups(config): if CONF_SORTING_GROUPS in config and config[CONF_VERSION] != 3: raise cv.Invalid( @@ -175,15 +170,7 @@ CONFIG_SCHEMA = cv.All( web_server_base.WebServerBase ), cv.Optional(CONF_INCLUDE_INTERNAL, default=False): cv.boolean, - cv.SplitDefault( - CONF_OTA, - esp8266=True, - esp32_arduino=True, - esp32_idf=False, - bk72xx=True, - ln882x=True, - rtl87xx=True, - ): cv.boolean, + cv.Optional(CONF_OTA, default=True): cv.boolean, cv.Optional(CONF_LOG, default=True): cv.boolean, cv.Optional(CONF_LOCAL): cv.boolean, cv.Optional(CONF_SORTING_GROUPS): cv.ensure_list(sorting_group), @@ -200,7 +187,6 @@ CONFIG_SCHEMA = cv.All( ), default_url, validate_local, - validate_ota, validate_sorting_groups, ) @@ -286,6 +272,10 @@ async def to_code(config): cg.add(var.set_css_url(config[CONF_CSS_URL])) cg.add(var.set_js_url(config[CONF_JS_URL])) cg.add(var.set_allow_ota(config[CONF_OTA])) + if config[CONF_OTA]: + # Define USE_WEBSERVER_OTA based only on web_server OTA config + # This allows web server OTA to work without loading the OTA component + cg.add_define("USE_WEBSERVER_OTA") cg.add(var.set_expose_log(config[CONF_LOG])) if config[CONF_ENABLE_PRIVATE_NETWORK_ACCESS]: cg.add_define("USE_WEBSERVER_PRIVATE_NETWORK_ACCESS") diff --git a/esphome/components/web_server/web_server.cpp b/esphome/components/web_server/web_server.cpp index 669bfbf279..e0027d0b27 100644 --- a/esphome/components/web_server/web_server.cpp +++ b/esphome/components/web_server/web_server.cpp @@ -299,8 +299,10 @@ void WebServer::setup() { #endif this->base_->add_handler(this); +#ifdef USE_WEBSERVER_OTA if (this->allow_ota_) this->base_->add_ota_handler(); +#endif // doesn't need defer functionality - if the queue is full, the client JS knows it's alive because it's clearly // getting a lot of events @@ -2030,6 +2032,10 @@ void WebServer::handleRequest(AsyncWebServerRequest *request) { return; } #endif + + // No matching handler found - send 404 + ESP_LOGV(TAG, "Request for unknown URL: %s", request->url().c_str()); + request->send(404, "text/plain", "Not Found"); } bool WebServer::isRequestHandlerTrivial() const { return false; } diff --git a/esphome/components/web_server_base/web_server_base.cpp b/esphome/components/web_server_base/web_server_base.cpp index 2835585387..9ad88e09f4 100644 --- a/esphome/components/web_server_base/web_server_base.cpp +++ b/esphome/components/web_server_base/web_server_base.cpp @@ -14,11 +14,114 @@ #endif #endif +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#include +#include +#endif + namespace esphome { namespace web_server_base { static const char *const TAG = "web_server_base"; +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +// Minimal OTA backend implementation for web server +// This allows OTA updates via web server without requiring the OTA component +// TODO: In the future, this should be refactored into a common ota_base component +// that both web_server and ota components can depend on, avoiding code duplication +// while keeping the components independent. This would allow both ESP-IDF and Arduino +// implementations to share the base OTA functionality without requiring the full OTA component. +// The IDFWebServerOTABackend class is intentionally designed with the same interface +// as OTABackend to make it easy to swap to using OTABackend when the ota component +// is split into ota and ota_base in the future. +class IDFWebServerOTABackend { + public: + bool begin() { + this->partition_ = esp_ota_get_next_update_partition(nullptr); + if (this->partition_ == nullptr) { + ESP_LOGE(TAG, "No OTA partition available"); + return false; + } + +#if CONFIG_ESP_TASK_WDT_TIMEOUT_S < 15 + // The following function takes longer than the default timeout of WDT due to flash erase +#if ESP_IDF_VERSION_MAJOR >= 5 + esp_task_wdt_config_t wdtc; + wdtc.idle_core_mask = 0; +#if CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0 + wdtc.idle_core_mask |= (1 << 0); +#endif +#if CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1 + wdtc.idle_core_mask |= (1 << 1); +#endif + wdtc.timeout_ms = 15000; + wdtc.trigger_panic = false; + esp_task_wdt_reconfigure(&wdtc); +#else + esp_task_wdt_init(15, false); +#endif +#endif + + esp_err_t err = esp_ota_begin(this->partition_, 0, &this->update_handle_); + +#if CONFIG_ESP_TASK_WDT_TIMEOUT_S < 15 + // Set the WDT back to the configured timeout +#if ESP_IDF_VERSION_MAJOR >= 5 + wdtc.timeout_ms = CONFIG_ESP_TASK_WDT_TIMEOUT_S * 1000; + esp_task_wdt_reconfigure(&wdtc); +#else + esp_task_wdt_init(CONFIG_ESP_TASK_WDT_TIMEOUT_S, false); +#endif +#endif + + if (err != ESP_OK) { + esp_ota_abort(this->update_handle_); + this->update_handle_ = 0; + ESP_LOGE(TAG, "esp_ota_begin failed: %s", esp_err_to_name(err)); + return false; + } + return true; + } + + bool write(uint8_t *data, size_t len) { + esp_err_t err = esp_ota_write(this->update_handle_, data, len); + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_ota_write failed: %s", esp_err_to_name(err)); + return false; + } + return true; + } + + bool end() { + esp_err_t err = esp_ota_end(this->update_handle_); + this->update_handle_ = 0; + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_ota_end failed: %s", esp_err_to_name(err)); + return false; + } + + err = esp_ota_set_boot_partition(this->partition_); + if (err != ESP_OK) { + ESP_LOGE(TAG, "esp_ota_set_boot_partition failed: %s", esp_err_to_name(err)); + return false; + } + + return true; + } + + void abort() { + if (this->update_handle_ != 0) { + esp_ota_abort(this->update_handle_); + this->update_handle_ = 0; + } + } + + private: + esp_ota_handle_t update_handle_{0}; + const esp_partition_t *partition_{nullptr}; +}; +#endif + void WebServerBase::add_handler(AsyncWebHandler *handler) { // remove all handlers @@ -31,6 +134,33 @@ void WebServerBase::add_handler(AsyncWebHandler *handler) { } } +#ifdef USE_WEBSERVER_OTA +void OTARequestHandler::report_ota_progress_(AsyncWebServerRequest *request) { + const uint32_t now = millis(); + if (now - this->last_ota_progress_ > 1000) { + if (request->contentLength() != 0) { + float percentage = (this->ota_read_length_ * 100.0f) / request->contentLength(); + ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); + } else { + ESP_LOGD(TAG, "OTA in progress: %u bytes read", this->ota_read_length_); + } + this->last_ota_progress_ = now; + } +} + +void OTARequestHandler::schedule_ota_reboot_() { + ESP_LOGI(TAG, "OTA update successful!"); + this->parent_->set_timeout(100, []() { + ESP_LOGI(TAG, "Performing OTA reboot now"); + App.safe_reboot(); + }); +} + +void OTARequestHandler::ota_init_(const char *filename) { + ESP_LOGI(TAG, "OTA Update Start: %s", filename); + this->ota_read_length_ = 0; +} + void report_ota_error() { #ifdef USE_ARDUINO StreamString ss; @@ -44,8 +174,7 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin #ifdef USE_ARDUINO bool success; if (index == 0) { - ESP_LOGI(TAG, "OTA Update Start: %s", filename.c_str()); - this->ota_read_length_ = 0; + this->ota_init_(filename.c_str()); #ifdef USE_ESP8266 Update.runAsync(true); // NOLINTNEXTLINE(readability-static-accessed-through-instance) @@ -72,31 +201,68 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin return; } this->ota_read_length_ += len; - - const uint32_t now = millis(); - if (now - this->last_ota_progress_ > 1000) { - if (request->contentLength() != 0) { - float percentage = (this->ota_read_length_ * 100.0f) / request->contentLength(); - ESP_LOGD(TAG, "OTA in progress: %0.1f%%", percentage); - } else { - ESP_LOGD(TAG, "OTA in progress: %u bytes read", this->ota_read_length_); - } - this->last_ota_progress_ = now; - } + this->report_ota_progress_(request); if (final) { if (Update.end(true)) { - ESP_LOGI(TAG, "OTA update successful!"); - this->parent_->set_timeout(100, []() { App.safe_reboot(); }); + this->schedule_ota_reboot_(); } else { report_ota_error(); } } -#endif +#endif // USE_ARDUINO + +#ifdef USE_ESP_IDF + // ESP-IDF implementation + if (index == 0 && !this->ota_backend_) { + // Initialize OTA on first call + this->ota_init_(filename.c_str()); + this->ota_success_ = false; + + auto *backend = new IDFWebServerOTABackend(); + if (!backend->begin()) { + ESP_LOGE(TAG, "OTA begin failed"); + delete backend; + return; + } + this->ota_backend_ = backend; + } + + auto *backend = static_cast(this->ota_backend_); + if (!backend) { + return; + } + + // Process data + if (len > 0) { + if (!backend->write(data, len)) { + ESP_LOGE(TAG, "OTA write failed"); + backend->abort(); + delete backend; + this->ota_backend_ = nullptr; + return; + } + this->ota_read_length_ += len; + this->report_ota_progress_(request); + } + + // Finalize + if (final) { + this->ota_success_ = backend->end(); + if (this->ota_success_) { + this->schedule_ota_reboot_(); + } else { + ESP_LOGE(TAG, "OTA end failed"); + } + delete backend; + this->ota_backend_ = nullptr; + } +#endif // USE_ESP_IDF } + void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) { -#ifdef USE_ARDUINO AsyncWebServerResponse *response; +#ifdef USE_ARDUINO if (!Update.hasError()) { response = request->beginResponse(200, "text/plain", "Update Successful!"); } else { @@ -105,16 +271,20 @@ void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) { Update.printError(ss); response = request->beginResponse(200, "text/plain", ss); } +#endif // USE_ARDUINO +#ifdef USE_ESP_IDF + // Send response based on the OTA result + response = request->beginResponse(200, "text/plain", this->ota_success_ ? "Update Successful!" : "Update Failed!"); +#endif // USE_ESP_IDF response->addHeader("Connection", "close"); request->send(response); -#endif } void WebServerBase::add_ota_handler() { -#ifdef USE_ARDUINO this->add_handler(new OTARequestHandler(this)); // NOLINT -#endif } +#endif + float WebServerBase::get_setup_priority() const { // Before WiFi (captive portal) return setup_priority::WIFI + 2.0f; diff --git a/esphome/components/web_server_base/web_server_base.h b/esphome/components/web_server_base/web_server_base.h index 641006cb99..09a41956c9 100644 --- a/esphome/components/web_server_base/web_server_base.h +++ b/esphome/components/web_server_base/web_server_base.h @@ -110,13 +110,17 @@ class WebServerBase : public Component { void add_handler(AsyncWebHandler *handler); +#ifdef USE_WEBSERVER_OTA void add_ota_handler(); +#endif void set_port(uint16_t port) { port_ = port; } uint16_t get_port() const { return port_; } protected: +#ifdef USE_WEBSERVER_OTA friend class OTARequestHandler; +#endif int initialized_{0}; uint16_t port_{80}; @@ -125,6 +129,7 @@ class WebServerBase : public Component { internal::Credentials credentials_; }; +#ifdef USE_WEBSERVER_OTA class OTARequestHandler : public AsyncWebHandler { public: OTARequestHandler(WebServerBase *parent) : parent_(parent) {} @@ -139,10 +144,21 @@ class OTARequestHandler : public AsyncWebHandler { bool isRequestHandlerTrivial() const override { return false; } protected: + void report_ota_progress_(AsyncWebServerRequest *request); + void schedule_ota_reboot_(); + void ota_init_(const char *filename); + uint32_t last_ota_progress_{0}; uint32_t ota_read_length_{0}; WebServerBase *parent_; + + private: +#ifdef USE_ESP_IDF + void *ota_backend_{nullptr}; + bool ota_success_{false}; +#endif }; +#endif // USE_WEBSERVER_OTA } // namespace web_server_base } // namespace esphome diff --git a/esphome/components/web_server_idf/__init__.py b/esphome/components/web_server_idf/__init__.py index 506e1c5c13..fe1c6f2640 100644 --- a/esphome/components/web_server_idf/__init__.py +++ b/esphome/components/web_server_idf/__init__.py @@ -1,5 +1,7 @@ -from esphome.components.esp32 import add_idf_sdkconfig_option +from esphome.components.esp32 import add_idf_component, add_idf_sdkconfig_option import esphome.config_validation as cv +from esphome.const import CONF_OTA, CONF_WEB_SERVER +from esphome.core import CORE CODEOWNERS = ["@dentra"] @@ -12,3 +14,7 @@ CONFIG_SCHEMA = cv.All( async def to_code(config): # Increase the maximum supported size of headers section in HTTP request packet to be processed by the server add_idf_sdkconfig_option("CONFIG_HTTPD_MAX_REQ_HDR_LEN", 1024) + # Check if web_server component has OTA enabled + if CORE.config.get(CONF_WEB_SERVER, {}).get(CONF_OTA, True): + # Add multipart parser component for ESP-IDF OTA support + add_idf_component(name="zorxx/multipart-parser", ref="1.0.1") diff --git a/esphome/components/web_server_idf/multipart.cpp b/esphome/components/web_server_idf/multipart.cpp new file mode 100644 index 0000000000..8655226ab9 --- /dev/null +++ b/esphome/components/web_server_idf/multipart.cpp @@ -0,0 +1,254 @@ +#include "esphome/core/defines.h" +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) +#include "multipart.h" +#include "utils.h" +#include "esphome/core/log.h" +#include +#include "multipart_parser.h" + +namespace esphome { +namespace web_server_idf { + +static const char *const TAG = "multipart"; + +// ========== MultipartReader Implementation ========== + +MultipartReader::MultipartReader(const std::string &boundary) { + // Initialize settings with callbacks + memset(&settings_, 0, sizeof(settings_)); + settings_.on_header_field = on_header_field; + settings_.on_header_value = on_header_value; + settings_.on_part_data = on_part_data; + settings_.on_part_data_end = on_part_data_end; + + ESP_LOGV(TAG, "Initializing multipart parser with boundary: '%s' (len: %zu)", boundary.c_str(), boundary.length()); + + // Create parser with boundary + parser_ = multipart_parser_init(boundary.c_str(), &settings_); + if (parser_) { + multipart_parser_set_data(parser_, this); + } else { + ESP_LOGE(TAG, "Failed to initialize multipart parser"); + } +} + +MultipartReader::~MultipartReader() { + if (parser_) { + multipart_parser_free(parser_); + } +} + +size_t MultipartReader::parse(const char *data, size_t len) { + if (!parser_) { + ESP_LOGE(TAG, "Parser not initialized"); + return 0; + } + + size_t parsed = multipart_parser_execute(parser_, data, len); + + if (parsed != len) { + ESP_LOGW(TAG, "Parser consumed %zu of %zu bytes - possible error", parsed, len); + } + + return parsed; +} + +void MultipartReader::process_header_(const char *value, size_t length) { + // Process the completed header (field + value pair) + std::string value_str(value, length); + + if (str_startswith_case_insensitive(current_header_field_, "content-disposition")) { + // Parse name and filename from Content-Disposition + current_part_.name = extract_header_param(value_str, "name"); + current_part_.filename = extract_header_param(value_str, "filename"); + } else if (str_startswith_case_insensitive(current_header_field_, "content-type")) { + current_part_.content_type = str_trim(value_str); + } + + // Clear field for next header + current_header_field_.clear(); +} + +int MultipartReader::on_header_field(multipart_parser *parser, const char *at, size_t length) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + reader->current_header_field_.assign(at, length); + return 0; +} + +int MultipartReader::on_header_value(multipart_parser *parser, const char *at, size_t length) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + reader->process_header_(at, length); + return 0; +} + +int MultipartReader::on_part_data(multipart_parser *parser, const char *at, size_t length) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + // Only process file uploads + if (reader->has_file() && reader->data_callback_) { + // IMPORTANT: The 'at' pointer points to data within the parser's input buffer. + // This data is only valid during this callback. The callback handler MUST + // process or copy the data immediately - it cannot store the pointer for + // later use as the buffer will be overwritten. + reader->data_callback_(reinterpret_cast(at), length); + } + return 0; +} + +int MultipartReader::on_part_data_end(multipart_parser *parser) { + MultipartReader *reader = static_cast(multipart_parser_get_data(parser)); + ESP_LOGV(TAG, "Part data end"); + if (reader->part_complete_callback_) { + reader->part_complete_callback_(); + } + // Clear part info for next part + reader->current_part_ = Part{}; + return 0; +} + +// ========== Utility Functions ========== + +// Case-insensitive string prefix check +bool str_startswith_case_insensitive(const std::string &str, const std::string &prefix) { + if (str.length() < prefix.length()) { + return false; + } + return str_ncmp_ci(str.c_str(), prefix.c_str(), prefix.length()); +} + +// Extract a parameter value from a header line +// Handles both quoted and unquoted values +std::string extract_header_param(const std::string &header, const std::string ¶m) { + size_t search_pos = 0; + + while (search_pos < header.length()) { + // Look for param name + const char *found = stristr(header.c_str() + search_pos, param.c_str()); + if (!found) { + return ""; + } + size_t pos = found - header.c_str(); + + // Check if this is a word boundary (not part of another parameter) + if (pos > 0 && header[pos - 1] != ' ' && header[pos - 1] != ';' && header[pos - 1] != '\t') { + search_pos = pos + 1; + continue; + } + + // Move past param name + pos += param.length(); + + // Skip whitespace and find '=' + while (pos < header.length() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + if (pos >= header.length() || header[pos] != '=') { + search_pos = pos; + continue; + } + + pos++; // Skip '=' + + // Skip whitespace after '=' + while (pos < header.length() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + if (pos >= header.length()) { + return ""; + } + + // Check if value is quoted + if (header[pos] == '"') { + pos++; + size_t end = header.find('"', pos); + if (end != std::string::npos) { + return header.substr(pos, end - pos); + } + // Malformed - no closing quote + return ""; + } + + // Unquoted value - find the end (semicolon, comma, or end of string) + size_t end = pos; + while (end < header.length() && header[end] != ';' && header[end] != ',' && header[end] != ' ' && + header[end] != '\t') { + end++; + } + + return header.substr(pos, end - pos); + } + + return ""; +} + +// Parse boundary from Content-Type header +// Returns true if boundary found, false otherwise +// boundary_start and boundary_len will point to the boundary value +bool parse_multipart_boundary(const char *content_type, const char **boundary_start, size_t *boundary_len) { + if (!content_type) { + return false; + } + + // Check for multipart/form-data (case-insensitive) + if (!stristr(content_type, "multipart/form-data")) { + return false; + } + + // Look for boundary parameter + const char *b = stristr(content_type, "boundary="); + if (!b) { + return false; + } + + const char *start = b + 9; // Skip "boundary=" + + // Skip whitespace + while (*start == ' ' || *start == '\t') { + start++; + } + + if (!*start) { + return false; + } + + // Find end of boundary + const char *end = start; + if (*end == '"') { + // Quoted boundary + start++; + end++; + while (*end && *end != '"') { + end++; + } + *boundary_len = end - start; + } else { + // Unquoted boundary + while (*end && *end != ' ' && *end != ';' && *end != '\r' && *end != '\n' && *end != '\t') { + end++; + } + *boundary_len = end - start; + } + + if (*boundary_len == 0) { + return false; + } + + *boundary_start = start; + + return true; +} + +// Trim whitespace from both ends of a string +std::string str_trim(const std::string &str) { + size_t start = str.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) { + return ""; + } + size_t end = str.find_last_not_of(" \t\r\n"); + return str.substr(start, end - start + 1); +} + +} // namespace web_server_idf +} // namespace esphome +#endif // defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) diff --git a/esphome/components/web_server_idf/multipart.h b/esphome/components/web_server_idf/multipart.h new file mode 100644 index 0000000000..967c72ffa5 --- /dev/null +++ b/esphome/components/web_server_idf/multipart.h @@ -0,0 +1,86 @@ +#pragma once +#include "esphome/core/defines.h" +#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) + +#include +#include +#include +#include +#include +#include +#include + +namespace esphome { +namespace web_server_idf { + +// Wrapper around zorxx/multipart-parser for ESP-IDF OTA uploads +class MultipartReader { + public: + struct Part { + std::string name; + std::string filename; + std::string content_type; + }; + + // IMPORTANT: The data pointer in DataCallback is only valid during the callback! + // The multipart parser passes pointers to its internal buffer which will be + // overwritten after the callback returns. Callbacks MUST process or copy the + // data immediately - storing the pointer for deferred processing will result + // in use-after-free bugs. + using DataCallback = std::function; + using PartCompleteCallback = std::function; + + explicit MultipartReader(const std::string &boundary); + ~MultipartReader(); + + // Set callbacks for handling data + void set_data_callback(DataCallback callback) { data_callback_ = std::move(callback); } + void set_part_complete_callback(PartCompleteCallback callback) { part_complete_callback_ = std::move(callback); } + + // Parse incoming data + size_t parse(const char *data, size_t len); + + // Get current part info + const Part &get_current_part() const { return current_part_; } + + // Check if we found a file upload + bool has_file() const { return !current_part_.filename.empty(); } + + private: + static int on_header_field(multipart_parser *parser, const char *at, size_t length); + static int on_header_value(multipart_parser *parser, const char *at, size_t length); + static int on_part_data(multipart_parser *parser, const char *at, size_t length); + static int on_part_data_end(multipart_parser *parser); + + multipart_parser *parser_{nullptr}; + multipart_parser_settings settings_{}; + + Part current_part_; + std::string current_header_field_; + + DataCallback data_callback_; + PartCompleteCallback part_complete_callback_; + + void process_header_(const char *value, size_t length); +}; + +// ========== Utility Functions ========== + +// Case-insensitive string prefix check +bool str_startswith_case_insensitive(const std::string &str, const std::string &prefix); + +// Extract a parameter value from a header line +// Handles both quoted and unquoted values +std::string extract_header_param(const std::string &header, const std::string ¶m); + +// Parse boundary from Content-Type header +// Returns true if boundary found, false otherwise +// boundary_start and boundary_len will point to the boundary value +bool parse_multipart_boundary(const char *content_type, const char **boundary_start, size_t *boundary_len); + +// Trim whitespace from both ends of a string +std::string str_trim(const std::string &str); + +} // namespace web_server_idf +} // namespace esphome +#endif // defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA) diff --git a/esphome/components/web_server_idf/utils.cpp b/esphome/components/web_server_idf/utils.cpp index 349acce50d..ac5df90bb8 100644 --- a/esphome/components/web_server_idf/utils.cpp +++ b/esphome/components/web_server_idf/utils.cpp @@ -1,5 +1,7 @@ #ifdef USE_ESP_IDF #include +#include +#include #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "http_parser.h" @@ -88,6 +90,36 @@ optional query_key_value(const std::string &query_url, const std::s return {val.get()}; } +// Helper function for case-insensitive string region comparison +bool str_ncmp_ci(const char *s1, const char *s2, size_t n) { + for (size_t i = 0; i < n; i++) { + if (!char_equals_ci(s1[i], s2[i])) { + return false; + } + } + return true; +} + +// Case-insensitive string search (like strstr but case-insensitive) +const char *stristr(const char *haystack, const char *needle) { + if (!haystack) { + return nullptr; + } + + size_t needle_len = strlen(needle); + if (needle_len == 0) { + return haystack; + } + + for (const char *p = haystack; *p; p++) { + if (str_ncmp_ci(p, needle, needle_len)) { + return p; + } + } + + return nullptr; +} + } // namespace web_server_idf } // namespace esphome #endif // USE_ESP_IDF diff --git a/esphome/components/web_server_idf/utils.h b/esphome/components/web_server_idf/utils.h index 9ed17c1d50..988b962d72 100644 --- a/esphome/components/web_server_idf/utils.h +++ b/esphome/components/web_server_idf/utils.h @@ -2,6 +2,7 @@ #ifdef USE_ESP_IDF #include +#include #include "esphome/core/helpers.h" namespace esphome { @@ -12,6 +13,15 @@ optional request_get_header(httpd_req_t *req, const char *name); optional request_get_url_query(httpd_req_t *req); optional query_key_value(const std::string &query_url, const std::string &key); +// Helper function for case-insensitive character comparison +inline bool char_equals_ci(char a, char b) { return ::tolower(a) == ::tolower(b); } + +// Helper function for case-insensitive string region comparison +bool str_ncmp_ci(const char *s1, const char *s2, size_t n); + +// Case-insensitive string search (like strstr but case-insensitive) +const char *stristr(const char *haystack, const char *needle); + } // namespace web_server_idf } // namespace esphome #endif // USE_ESP_IDF diff --git a/esphome/components/web_server_idf/web_server_idf.cpp b/esphome/components/web_server_idf/web_server_idf.cpp index 409230806c..9478e4748c 100644 --- a/esphome/components/web_server_idf/web_server_idf.cpp +++ b/esphome/components/web_server_idf/web_server_idf.cpp @@ -1,16 +1,25 @@ #ifdef USE_ESP_IDF #include +#include +#include +#include #include "esphome/core/helpers.h" #include "esphome/core/log.h" #include "esp_tls_crypto.h" +#include +#include #include "utils.h" - #include "web_server_idf.h" +#ifdef USE_WEBSERVER_OTA +#include +#include "multipart.h" // For parse_multipart_boundary and other utils +#endif + #ifdef USE_WEBSERVER #include "esphome/components/web_server/web_server.h" #include "esphome/components/web_server/list_entities.h" @@ -72,18 +81,32 @@ void AsyncWebServer::begin() { esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) { ESP_LOGVV(TAG, "Enter AsyncWebServer::request_post_handler. uri=%s", r->uri); auto content_type = request_get_header(r, "Content-Type"); - if (content_type.has_value() && *content_type != "application/x-www-form-urlencoded") { - ESP_LOGW(TAG, "Only application/x-www-form-urlencoded supported for POST request"); - // fallback to get handler to support backward compatibility - return AsyncWebServer::request_handler(r); - } if (!request_has_header(r, "Content-Length")) { - ESP_LOGW(TAG, "Content length is requred for post: %s", r->uri); + ESP_LOGW(TAG, "Content length is required for post: %s", r->uri); httpd_resp_send_err(r, HTTPD_411_LENGTH_REQUIRED, nullptr); return ESP_OK; } + if (content_type.has_value()) { + const char *content_type_char = content_type.value().c_str(); + + // Check most common case first + if (stristr(content_type_char, "application/x-www-form-urlencoded") != nullptr) { + // Normal form data - proceed with regular handling +#ifdef USE_WEBSERVER_OTA + } else if (stristr(content_type_char, "multipart/form-data") != nullptr) { + auto *server = static_cast(r->user_ctx); + return server->handle_multipart_upload_(r, content_type_char); +#endif + } else { + ESP_LOGW(TAG, "Unsupported content type for POST: %s", content_type_char); + // fallback to get handler to support backward compatibility + return AsyncWebServer::request_handler(r); + } + } + + // Handle regular form data if (r->content_len > HTTPD_MAX_REQ_HDR_LEN) { ESP_LOGW(TAG, "Request size is to big: %zu", r->content_len); httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); @@ -539,6 +562,97 @@ void AsyncEventSourceResponse::deferrable_send_state(void *source, const char *e } #endif +#ifdef USE_WEBSERVER_OTA +esp_err_t AsyncWebServer::handle_multipart_upload_(httpd_req_t *r, const char *content_type) { + static constexpr size_t MULTIPART_CHUNK_SIZE = 1460; // Match Arduino AsyncWebServer buffer size + static constexpr size_t YIELD_INTERVAL_BYTES = 16 * 1024; // Yield every 16KB to prevent watchdog + + // Parse boundary and create reader + const char *boundary_start; + size_t boundary_len; + if (!parse_multipart_boundary(content_type, &boundary_start, &boundary_len)) { + ESP_LOGE(TAG, "Failed to parse multipart boundary"); + httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); + return ESP_FAIL; + } + + AsyncWebServerRequest req(r); + AsyncWebHandler *handler = nullptr; + for (auto *h : this->handlers_) { + if (h->canHandle(&req)) { + handler = h; + break; + } + } + + if (!handler) { + ESP_LOGW(TAG, "No handler found for OTA request"); + httpd_resp_send_err(r, HTTPD_404_NOT_FOUND, nullptr); + return ESP_OK; + } + + // Upload state + std::string filename; + size_t index = 0; + // Create reader on heap to reduce stack usage + auto reader = std::make_unique("--" + std::string(boundary_start, boundary_len)); + + // Configure callbacks + reader->set_data_callback([&](const uint8_t *data, size_t len) { + if (!reader->has_file() || !len) + return; + + if (filename.empty()) { + filename = reader->get_current_part().filename; + ESP_LOGV(TAG, "Processing file: '%s'", filename.c_str()); + handler->handleUpload(&req, filename, 0, nullptr, 0, false); // Start + } + + handler->handleUpload(&req, filename, index, const_cast(data), len, false); + index += len; + }); + + reader->set_part_complete_callback([&]() { + if (index > 0) { + handler->handleUpload(&req, filename, index, nullptr, 0, true); // End + filename.clear(); + index = 0; + } + }); + + // Process data + std::unique_ptr buffer(new char[MULTIPART_CHUNK_SIZE]); + size_t bytes_since_yield = 0; + + for (size_t remaining = r->content_len; remaining > 0;) { + int recv_len = httpd_req_recv(r, buffer.get(), std::min(remaining, MULTIPART_CHUNK_SIZE)); + + if (recv_len <= 0) { + httpd_resp_send_err(r, recv_len == HTTPD_SOCK_ERR_TIMEOUT ? HTTPD_408_REQ_TIMEOUT : HTTPD_400_BAD_REQUEST, + nullptr); + return recv_len == HTTPD_SOCK_ERR_TIMEOUT ? ESP_ERR_TIMEOUT : ESP_FAIL; + } + + if (reader->parse(buffer.get(), recv_len) != static_cast(recv_len)) { + ESP_LOGW(TAG, "Multipart parser error"); + httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr); + return ESP_FAIL; + } + + remaining -= recv_len; + bytes_since_yield += recv_len; + + if (bytes_since_yield > YIELD_INTERVAL_BYTES) { + vTaskDelay(1); + bytes_since_yield = 0; + } + } + + handler->handleRequest(&req); + return ESP_OK; +} +#endif // USE_WEBSERVER_OTA + } // namespace web_server_idf } // namespace esphome diff --git a/esphome/components/web_server_idf/web_server_idf.h b/esphome/components/web_server_idf/web_server_idf.h index 7547117224..8de25c8e96 100644 --- a/esphome/components/web_server_idf/web_server_idf.h +++ b/esphome/components/web_server_idf/web_server_idf.h @@ -204,6 +204,9 @@ class AsyncWebServer { static esp_err_t request_handler(httpd_req_t *r); static esp_err_t request_post_handler(httpd_req_t *r); esp_err_t request_handler_(AsyncWebServerRequest *request) const; +#ifdef USE_WEBSERVER_OTA + esp_err_t handle_multipart_upload_(httpd_req_t *r, const char *content_type); +#endif std::vector handlers_; std::function on_not_found_{}; }; diff --git a/esphome/core/defines.h b/esphome/core/defines.h index ea3c8bdc17..cfaed6fdb7 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -153,6 +153,7 @@ #define USE_SPI #define USE_VOICE_ASSISTANT #define USE_WEBSERVER +#define USE_WEBSERVER_OTA #define USE_WEBSERVER_PORT 80 // NOLINT #define USE_WEBSERVER_SORTING #define USE_WIFI_11KV_SUPPORT diff --git a/esphome/idf_component.yml b/esphome/idf_component.yml index 6299909033..c43b622684 100644 --- a/esphome/idf_component.yml +++ b/esphome/idf_component.yml @@ -17,3 +17,5 @@ dependencies: version: 2.0.11 rules: - if: "target in [esp32h2, esp32p4]" + zorxx/multipart-parser: + version: 1.0.1 diff --git a/tests/components/web_server/test_no_ota.esp32-idf.yaml b/tests/components/web_server/test_no_ota.esp32-idf.yaml new file mode 100644 index 0000000000..1f677fb948 --- /dev/null +++ b/tests/components/web_server/test_no_ota.esp32-idf.yaml @@ -0,0 +1,9 @@ +packages: + device_base: !include common.yaml + +# No OTA component defined for this test + +web_server: + port: 8080 + version: 2 + ota: false diff --git a/tests/components/web_server/test_ota.esp32-idf.yaml b/tests/components/web_server/test_ota.esp32-idf.yaml new file mode 100644 index 0000000000..294e7f862e --- /dev/null +++ b/tests/components/web_server/test_ota.esp32-idf.yaml @@ -0,0 +1,32 @@ +# Test configuration for ESP-IDF web server with OTA enabled +esphome: + name: test-web-server-ota-idf + +# Force ESP-IDF framework +esp32: + board: esp32dev + framework: + type: esp-idf + +packages: + device_base: !include common.yaml + +# Enable OTA for multipart upload testing +ota: + - platform: esphome + password: "test_ota_password" + +# Web server with OTA enabled +web_server: + port: 8080 + version: 2 + ota: true + include_internal: true + +# Enable debug logging for OTA +logger: + level: DEBUG + logs: + web_server: VERBOSE + web_server_idf: VERBOSE + diff --git a/tests/components/web_server/test_ota_disabled.esp32-idf.yaml b/tests/components/web_server/test_ota_disabled.esp32-idf.yaml new file mode 100644 index 0000000000..c7c7574e3b --- /dev/null +++ b/tests/components/web_server/test_ota_disabled.esp32-idf.yaml @@ -0,0 +1,11 @@ +packages: + device_base: !include common.yaml + +# OTA is configured but web_server OTA is disabled +ota: + - platform: esphome + +web_server: + port: 8080 + version: 2 + ota: false