This commit is contained in:
J. Nick Koston 2025-06-29 17:22:33 -05:00
parent 2f5db85997
commit 3fca3df756
No known key found for this signature in database
12 changed files with 740 additions and 28 deletions

View File

@ -17,6 +17,10 @@ namespace ota {
std::unique_ptr<ota::OTABackend> make_ota_backend() { return make_unique<ota::IDFOTABackend>(); }
OTAResponseTypes IDFOTABackend::begin(size_t image_size) {
// Reset MD5 validation state
this->md5_set_ = false;
memset(this->expected_bin_md5_, 0, sizeof(this->expected_bin_md5_));
this->partition_ = esp_ota_get_next_update_partition(nullptr);
if (this->partition_ == nullptr) {
return OTA_RESPONSE_ERROR_NO_UPDATE_PARTITION;
@ -67,7 +71,10 @@ OTAResponseTypes IDFOTABackend::begin(size_t image_size) {
return OTA_RESPONSE_OK;
}
void IDFOTABackend::set_update_md5(const char *expected_md5) { memcpy(this->expected_bin_md5_, expected_md5, 32); }
void IDFOTABackend::set_update_md5(const char *expected_md5) {
memcpy(this->expected_bin_md5_, expected_md5, 32);
this->md5_set_ = true;
}
OTAResponseTypes IDFOTABackend::write(uint8_t *data, size_t len) {
esp_err_t err = esp_ota_write(this->update_handle_, data, len);
@ -85,10 +92,15 @@ OTAResponseTypes IDFOTABackend::write(uint8_t *data, size_t len) {
OTAResponseTypes IDFOTABackend::end() {
this->md5_.calculate();
if (!this->md5_.equals_hex(this->expected_bin_md5_)) {
this->abort();
return OTA_RESPONSE_ERROR_MD5_MISMATCH;
// Only validate MD5 if one was provided
if (this->md5_set_) {
if (!this->md5_.equals_hex(this->expected_bin_md5_)) {
this->abort();
return OTA_RESPONSE_ERROR_MD5_MISMATCH;
}
}
esp_err_t err = esp_ota_end(this->update_handle_);
this->update_handle_ = 0;
if (err == ESP_OK) {

View File

@ -6,12 +6,14 @@
#include "esphome/core/defines.h"
#include <esp_ota_ops.h>
#include <cstring>
namespace esphome {
namespace ota {
class IDFOTABackend : public OTABackend {
public:
IDFOTABackend() : md5_set_(false) { memset(expected_bin_md5_, 0, sizeof(expected_bin_md5_)); }
OTAResponseTypes begin(size_t image_size) override;
void set_update_md5(const char *md5) override;
OTAResponseTypes write(uint8_t *data, size_t len) override;
@ -24,6 +26,7 @@ class IDFOTABackend : public OTABackend {
const esp_partition_t *partition_;
md5::MD5Digest md5_{};
char expected_bin_md5_[32];
bool md5_set_;
};
} // namespace ota

View File

@ -4,6 +4,11 @@
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
#ifdef USE_ESP_IDF
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#endif
#ifdef USE_ARDUINO
#include <StreamString.h>
#if defined(USE_ESP32) || defined(USE_LIBRETINY)
@ -117,6 +122,7 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
if (index == 0) {
this->ota_init_(filename.c_str());
this->ota_started_ = false;
this->ota_success_ = false;
// Create OTA backend
auto backend = ota::make_ota_backend();
@ -125,12 +131,14 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
auto result = backend->begin(0);
if (result != ota::OTA_RESPONSE_OK) {
ESP_LOGE(TAG, "OTA begin failed: %d", result);
this->ota_success_ = false;
return;
}
// Store the backend pointer
this->ota_backend_ = backend.release();
this->ota_started_ = true;
this->ota_success_ = false; // Will be set to true only on successful completion
} else if (!this->ota_started_ || !this->ota_backend_) {
// Begin failed or was aborted
return;
@ -139,6 +147,29 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
// Write data
if (len > 0) {
auto *backend = static_cast<ota::OTABackend *>(this->ota_backend_);
// Log first chunk of data received by OTA handler
if (this->ota_read_length_ == 0 && len >= 8) {
ESP_LOGD(TAG, "First data received by OTA handler: %02x %02x %02x %02x %02x %02x %02x %02x", data[0], data[1],
data[2], data[3], data[4], data[5], data[6], data[7]);
ESP_LOGD(TAG, "Data pointer in OTA handler: %p, len: %zu, index: %zu", data, len, index);
}
// Feed watchdog and yield periodically to prevent timeout during OTA
// Flash writes can be slow, especially for large chunks
static uint32_t last_ota_yield = 0;
static uint32_t ota_chunks_written = 0;
uint32_t now = millis();
ota_chunks_written++;
// Yield more frequently during OTA - every 25ms or every 2 chunks
if (now - last_ota_yield > 25 || ota_chunks_written >= 2) {
// Don't log during yield - logging itself can cause delays
vTaskDelay(2); // Let other tasks run for 2 ticks
last_ota_yield = now;
ota_chunks_written = 0;
}
auto result = backend->write(data, len);
if (result != ota::OTA_RESPONSE_OK) {
ESP_LOGE(TAG, "OTA write failed: %d", result);
@ -146,6 +177,7 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
delete backend;
this->ota_backend_ = nullptr;
this->ota_started_ = false;
this->ota_success_ = false;
return;
}
@ -157,9 +189,11 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
auto *backend = static_cast<ota::OTABackend *>(this->ota_backend_);
auto result = backend->end();
if (result == ota::OTA_RESPONSE_OK) {
this->ota_success_ = true;
this->schedule_ota_reboot_();
} else {
ESP_LOGE(TAG, "OTA end failed: %d", result);
this->ota_success_ = false;
}
delete backend;
this->ota_backend_ = nullptr;
@ -170,6 +204,7 @@ void OTARequestHandler::handleUpload(AsyncWebServerRequest *request, const Strin
}
void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) {
#ifdef USE_WEBSERVER_OTA
ESP_LOGD(TAG, "OTA handleRequest called");
AsyncWebServerResponse *response;
#ifdef USE_ARDUINO
if (!Update.hasError()) {
@ -182,7 +217,12 @@ void OTARequestHandler::handleRequest(AsyncWebServerRequest *request) {
}
#endif // USE_ARDUINO
#ifdef USE_ESP_IDF
response = request->beginResponse(200, "text/plain", this->ota_started_ ? "Update Successful!" : "Update Failed!");
if (this->ota_success_) {
request->send(200, "text/plain", "Update Successful!");
} else {
request->send(200, "text/plain", "Update Failed!");
}
return;
#endif // USE_ESP_IDF
response->addHeader("Connection", "close");
request->send(response);

View File

@ -127,7 +127,13 @@ class WebServerBase : public Component {
class OTARequestHandler : public AsyncWebHandler {
public:
OTARequestHandler(WebServerBase *parent) : parent_(parent) {}
OTARequestHandler(WebServerBase *parent) : parent_(parent) {
#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA)
this->ota_backend_ = nullptr;
this->ota_started_ = false;
this->ota_success_ = false;
#endif
}
void handleRequest(AsyncWebServerRequest *request) override;
void handleUpload(AsyncWebServerRequest *request, const String &filename, size_t index, uint8_t *data, size_t len,
bool final) override;
@ -153,6 +159,7 @@ class OTARequestHandler : public AsyncWebHandler {
#if defined(USE_ESP_IDF) && defined(USE_WEBSERVER_OTA)
void *ota_backend_{nullptr}; // Actually ota::OTABackend*, stored as void* to avoid incomplete type issues
bool ota_started_{false};
bool ota_success_{false};
#endif
};

View File

@ -2,6 +2,7 @@
#ifdef USE_ESP_IDF
#ifdef USE_WEBSERVER_OTA
#include "multipart_parser_utils.h"
#include "esphome/core/log.h"
namespace esphome {
namespace web_server_idf {
@ -181,6 +182,10 @@ bool parse_multipart_boundary(const char *content_type, const char **boundary_st
}
*boundary_start = start;
// Debug log the extracted boundary
ESP_LOGD("multipart_utils", "Extracted boundary: '%.*s' (len: %zu)", (int) *boundary_len, start, *boundary_len);
return true;
}

View File

@ -12,7 +12,7 @@ namespace web_server_idf {
static const char *const TAG = "multipart_reader";
MultipartReader::MultipartReader(const std::string &boundary) {
MultipartReader::MultipartReader(const std::string &boundary) : first_data_logged_(false) {
// Initialize settings with callbacks
memset(&settings_, 0, sizeof(settings_));
settings_.on_header_field = on_header_field;
@ -22,10 +22,14 @@ MultipartReader::MultipartReader(const std::string &boundary) {
settings_.on_part_data_end = on_part_data_end;
settings_.on_headers_complete = on_headers_complete;
ESP_LOGV(TAG, "Initializing multipart parser with boundary: '%s' (len: %zu)", boundary.c_str(), boundary.length());
// Create parser with boundary
parser_ = multipart_parser_init(boundary.c_str(), &settings_);
if (parser_) {
multipart_parser_set_data(parser_, this);
} else {
ESP_LOGE(TAG, "Failed to initialize multipart parser");
}
}
@ -37,9 +41,26 @@ MultipartReader::~MultipartReader() {
size_t MultipartReader::parse(const char *data, size_t len) {
if (!parser_) {
ESP_LOGE(TAG, "Parser not initialized");
return 0;
}
return multipart_parser_execute(parser_, data, len);
size_t parsed = multipart_parser_execute(parser_, data, len);
if (parsed != len) {
ESP_LOGD(TAG, "Parser consumed %zu of %zu bytes", parsed, len);
// Log the data around the error point
if (parsed < len && parsed < 32) {
ESP_LOGD(TAG, "Data at error point (offset %zu): %02x %02x %02x %02x", parsed,
parsed > 0 ? (uint8_t) data[parsed - 1] : 0, (uint8_t) data[parsed],
parsed + 1 < len ? (uint8_t) data[parsed + 1] : 0, parsed + 2 < len ? (uint8_t) data[parsed + 2] : 0);
// Log what we have vs what parser expects
ESP_LOGD(TAG, "Parser error at position %zu: got '%c' (0x%02x)", parsed, data[parsed], (uint8_t) data[parsed]);
}
}
return parsed;
}
void MultipartReader::process_header_() {
@ -95,7 +116,7 @@ int MultipartReader::on_headers_complete(multipart_parser *parser) {
int MultipartReader::on_part_data_begin(multipart_parser *parser) {
MultipartReader *reader = static_cast<MultipartReader *>(multipart_parser_get_data(parser));
ESP_LOGD(TAG, "Part data begin");
ESP_LOGV(TAG, "Part data begin");
return 0;
}
@ -104,6 +125,18 @@ int MultipartReader::on_part_data(multipart_parser *parser, const char *at, size
// Only process file uploads
if (reader->has_file() && reader->data_callback_) {
// IMPORTANT: The 'at' pointer points to data within the parser's input buffer.
// This data is only valid during this callback. The callback handler MUST
// process or copy the data immediately - it cannot store the pointer for
// later use as the buffer will be overwritten.
// Log first data bytes from multipart parser
if (!reader->first_data_logged_ && length >= 8) {
ESP_LOGD(TAG, "First part data from parser: %02x %02x %02x %02x %02x %02x %02x %02x", (uint8_t) at[0],
(uint8_t) at[1], (uint8_t) at[2], (uint8_t) at[3], (uint8_t) at[4], (uint8_t) at[5], (uint8_t) at[6],
(uint8_t) at[7]);
reader->first_data_logged_ = true;
}
reader->data_callback_(reinterpret_cast<const uint8_t *>(at), length);
}
@ -113,7 +146,7 @@ int MultipartReader::on_part_data(multipart_parser *parser, const char *at, size
int MultipartReader::on_part_data_end(multipart_parser *parser) {
MultipartReader *reader = static_cast<MultipartReader *>(multipart_parser_get_data(parser));
ESP_LOGD(TAG, "Part data end");
ESP_LOGV(TAG, "Part data end");
if (reader->part_complete_callback_) {
reader->part_complete_callback_();
@ -122,6 +155,9 @@ int MultipartReader::on_part_data_end(multipart_parser *parser) {
// Clear part info for next part
reader->current_part_ = Part{};
// Reset first_data flag for next upload
reader->first_data_logged_ = false;
return 0;
}

View File

@ -20,6 +20,11 @@ class MultipartReader {
std::string content_type;
};
// IMPORTANT: The data pointer in DataCallback is only valid during the callback!
// The multipart parser passes pointers to its internal buffer which will be
// overwritten after the callback returns. Callbacks MUST process or copy the
// data immediately - storing the pointer for deferred processing will result
// in use-after-free bugs.
using DataCallback = std::function<void(const uint8_t *data, size_t len)>;
using PartCompleteCallback = std::function<void()>;
@ -58,6 +63,7 @@ class MultipartReader {
PartCompleteCallback part_complete_callback_;
bool in_headers_{false};
bool first_data_logged_{false};
void process_header_();
};

View File

@ -7,6 +7,8 @@
#include "esphome/core/log.h"
#include "esp_tls_crypto.h"
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "utils.h"
#include "web_server_idf.h"
@ -75,7 +77,7 @@ void AsyncWebServer::begin() {
}
esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) {
ESP_LOGVV(TAG, "Enter AsyncWebServer::request_post_handler. uri=%s", r->uri);
ESP_LOGD(TAG, "Enter AsyncWebServer::request_post_handler. uri=%s", r->uri);
auto content_type = request_get_header(r, "Content-Type");
#ifdef USE_WEBSERVER_OTA
@ -91,6 +93,7 @@ esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) {
if (parse_multipart_boundary(ct.c_str(), &boundary_start, &boundary_len)) {
boundary.assign(boundary_start, boundary_len);
is_multipart = true;
ESP_LOGD(TAG, "Multipart upload detected, boundary: '%s' (len: %zu)", boundary.c_str(), boundary_len);
} else if (!is_form_urlencoded(ct.c_str())) {
ESP_LOGW(TAG, "Unsupported content type for POST: %s", ct.c_str());
// fallback to get handler to support backward compatibility
@ -123,42 +126,93 @@ esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) {
for (auto *handler : server->handlers_) {
if (handler->canHandle(&req)) {
found_handler = handler;
ESP_LOGD(TAG, "Found handler for OTA request");
break;
}
}
if (!found_handler) {
ESP_LOGW(TAG, "No handler found for OTA request");
httpd_resp_send_err(r, HTTPD_404_NOT_FOUND, nullptr);
return ESP_OK;
}
// Handle multipart upload using the multipart-parser library
MultipartReader reader(boundary);
// The multipart data starts with "--" + boundary, so we need to prepend it
std::string full_boundary = "--" + boundary;
ESP_LOGV(TAG, "Initializing multipart reader with full boundary: '%s'", full_boundary.c_str());
MultipartReader reader(full_boundary);
static constexpr size_t CHUNK_SIZE = 1024;
// IMPORTANT: chunk_buf is reused for each chunk read from the socket.
// The multipart parser will pass pointers into this buffer to callbacks.
// Those pointers are only valid during the callback execution!
std::unique_ptr<char[]> chunk_buf(new char[CHUNK_SIZE]);
size_t total_len = r->content_len;
size_t remaining = total_len;
std::string current_filename;
bool upload_started = false;
// Track if we've started the upload
bool file_started = false;
// Set up callbacks for the multipart reader
reader.set_data_callback([&](const uint8_t *data, size_t len) {
if (!current_filename.empty()) {
found_handler->handleUpload(&req, current_filename, upload_started ? 1 : 0, const_cast<uint8_t *>(data), len,
false);
upload_started = true;
// CRITICAL: The data pointer is only valid during this callback!
// The multipart parser passes pointers into the chunk_buf buffer, which will be
// overwritten when we read the next chunk. We MUST process the data immediately
// within this callback - any deferred processing will result in use-after-free bugs
// where the data pointer points to corrupted/overwritten memory.
// By the time on_part_data is called, on_headers_complete has already been called
// so we can check for filename
if (reader.has_file()) {
if (current_filename.empty()) {
// First time we see data for this file
current_filename = reader.get_current_part().filename;
ESP_LOGD(TAG, "Processing file part: '%s'", current_filename.c_str());
}
// Log first few bytes of firmware data (only once)
static bool firmware_data_logged = false;
if (!firmware_data_logged && len >= 8) {
ESP_LOGD(TAG, "First firmware bytes from callback: %02x %02x %02x %02x %02x %02x %02x %02x", data[0], data[1],
data[2], data[3], data[4], data[5], data[6], data[7]);
firmware_data_logged = true;
}
if (!file_started) {
// Initialize the upload with index=0
ESP_LOGD(TAG, "Starting upload for: '%s'", current_filename.c_str());
found_handler->handleUpload(&req, current_filename, 0, nullptr, 0, false);
file_started = true;
upload_started = true;
}
// Process the data chunk immediately - the pointer won't be valid after this callback returns!
// DO NOT store the data pointer for later use or pass it to any async/deferred operations.
if (len > 0) {
found_handler->handleUpload(&req, current_filename, 1, const_cast<uint8_t *>(data), len, false);
}
}
});
reader.set_part_complete_callback([&]() {
if (!current_filename.empty() && upload_started) {
// Signal end of this part
found_handler->handleUpload(&req, current_filename, 2, nullptr, 0, false);
ESP_LOGD(TAG, "Part complete callback called for: '%s'", current_filename.c_str());
// Signal end of this part - final=true signals completion
found_handler->handleUpload(&req, current_filename, 2, nullptr, 0, true);
current_filename.clear();
upload_started = false;
file_started = false;
}
});
// Track time to yield periodically
uint32_t last_yield = millis();
static constexpr uint32_t YIELD_INTERVAL_MS = 50; // Yield every 50ms
uint32_t chunks_processed = 0;
static constexpr uint32_t CHUNKS_PER_YIELD = 5; // Also yield every 5 chunks
while (remaining > 0) {
size_t to_read = std::min(remaining, CHUNK_SIZE);
int recv_len = httpd_req_recv(r, chunk_buf.get(), to_read);
@ -172,29 +226,69 @@ esp_err_t AsyncWebServer::request_post_handler(httpd_req_t *r) {
return ESP_FAIL;
}
// Parse multipart data
size_t parsed = reader.parse(chunk_buf.get(), recv_len);
if (parsed != recv_len) {
ESP_LOGW(TAG, "Multipart parser error at byte %zu", total_len - remaining + parsed);
httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr);
return ESP_FAIL;
// Yield periodically to prevent watchdog timeout
chunks_processed++;
uint32_t now = millis();
if (now - last_yield > YIELD_INTERVAL_MS || chunks_processed >= CHUNKS_PER_YIELD) {
// Don't log during yield - logging itself can cause delays
vTaskDelay(2); // Yield for 2 ticks to give more time to other tasks
last_yield = now;
chunks_processed = 0;
}
// Check if we found a new file part
if (reader.has_file() && current_filename.empty()) {
current_filename = reader.get_current_part().filename;
// Log received vs requested - only log every 100KB to reduce overhead
static size_t bytes_logged = 0;
bytes_logged += recv_len;
if (bytes_logged > 100000) {
ESP_LOGD(TAG, "OTA progress: %zu bytes remaining", remaining);
bytes_logged = 0;
}
// Log first few bytes for debugging
if (total_len == remaining) {
ESP_LOGD(TAG, "First chunk data (hex): %02x %02x %02x %02x %02x %02x %02x %02x", (uint8_t) chunk_buf[0],
(uint8_t) chunk_buf[1], (uint8_t) chunk_buf[2], (uint8_t) chunk_buf[3], (uint8_t) chunk_buf[4],
(uint8_t) chunk_buf[5], (uint8_t) chunk_buf[6], (uint8_t) chunk_buf[7]);
ESP_LOGD(TAG, "First chunk data (ascii): %.8s", chunk_buf.get());
ESP_LOGD(TAG, "Expected boundary start: %.8s", full_boundary.c_str());
// Log more of the first chunk to see the headers
ESP_LOGD(TAG, "First 256 bytes of upload:");
for (int i = 0; i < std::min(recv_len, 256); i += 16) {
char hex_buf[50];
char ascii_buf[17];
int n = std::min(16, recv_len - i);
for (int j = 0; j < n; j++) {
sprintf(hex_buf + j * 3, "%02x ", (uint8_t) chunk_buf[i + j]);
ascii_buf[j] = isprint(chunk_buf[i + j]) ? chunk_buf[i + j] : '.';
}
ascii_buf[n] = '\0';
ESP_LOGD(TAG, "%04x: %-48s %s", i, hex_buf, ascii_buf);
}
}
size_t parsed = reader.parse(chunk_buf.get(), recv_len);
if (parsed != recv_len) {
ESP_LOGW(TAG, "Multipart parser error at byte %zu (parsed %zu of %d bytes)", total_len - remaining + parsed,
parsed, recv_len);
httpd_resp_send_err(r, HTTPD_400_BAD_REQUEST, nullptr);
return ESP_FAIL;
}
remaining -= recv_len;
}
// Final cleanup - send final signal if upload was in progress
// This should not be needed as part_complete_callback should handle it
if (!current_filename.empty() && upload_started) {
ESP_LOGW(TAG, "Upload was not properly closed by part_complete_callback");
found_handler->handleUpload(&req, current_filename, 2, nullptr, 0, true);
file_started = false;
}
// Let handler send response
ESP_LOGD(TAG, "Calling handleRequest for OTA response");
found_handler->handleRequest(&req);
ESP_LOGD(TAG, "handleRequest completed");
return ESP_OK;
}
#endif // USE_WEBSERVER_OTA

View File

@ -0,0 +1,236 @@
import asyncio
import os
import tempfile
import aiohttp
import pytest
@pytest.fixture
async def web_server_fixture(event_loop):
"""Start the test device with web server"""
# This would be replaced with actual device setup in a real test environment
# For now, we'll assume the device is running at a specific address
base_url = "http://localhost:8080"
# Wait a bit for server to be ready
await asyncio.sleep(2)
yield base_url
async def create_test_firmware():
"""Create a dummy firmware file for testing"""
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f:
# Write some dummy data that looks like a firmware file
# ESP32 firmware files typically start with these magic bytes
f.write(b"\xe9\x08\x02\x20") # ESP32 magic bytes
# Add some padding to make it look like a real firmware
f.write(b"\x00" * 1024) # 1KB of zeros
f.write(b"TEST_FIRMWARE_CONTENT")
f.write(b"\x00" * 1024) # More padding
return f.name
@pytest.mark.asyncio
async def test_ota_upload_multipart(web_server_fixture):
"""Test OTA firmware upload using multipart/form-data"""
base_url = web_server_fixture
firmware_path = await create_test_firmware()
try:
# Create multipart form data
async with aiohttp.ClientSession() as session:
# First, check if OTA endpoint is available
async with session.get(f"{base_url}/") as resp:
assert resp.status == 200
content = await resp.text()
assert "ota" in content or "OTA" in content
# Prepare multipart upload
with open(firmware_path, "rb") as f:
data = aiohttp.FormData()
data.add_field(
"firmware",
f,
filename="firmware.bin",
content_type="application/octet-stream",
)
# Send OTA update request
async with session.post(f"{base_url}/ota/upload", data=data) as resp:
assert resp.status in [200, 201, 204], (
f"OTA upload failed with status {resp.status}"
)
# Check response
if resp.status == 200:
response_text = await resp.text()
# The response might be JSON or plain text depending on implementation
assert (
"success" in response_text.lower()
or "ok" in response_text.lower()
)
finally:
# Clean up
os.unlink(firmware_path)
@pytest.mark.asyncio
async def test_ota_upload_wrong_content_type(web_server_fixture):
"""Test that OTA upload fails with wrong content type"""
base_url = web_server_fixture
async with aiohttp.ClientSession() as session:
# Try to upload with wrong content type
data = b"not a firmware file"
headers = {"Content-Type": "text/plain"}
async with session.post(
f"{base_url}/ota/upload", data=data, headers=headers
) as resp:
# Should fail with bad request or similar
assert resp.status >= 400, f"Expected error status, got {resp.status}"
@pytest.mark.asyncio
async def test_ota_upload_empty_file(web_server_fixture):
"""Test that OTA upload fails with empty file"""
base_url = web_server_fixture
async with aiohttp.ClientSession() as session:
# Create empty multipart upload
data = aiohttp.FormData()
data.add_field(
"firmware",
b"",
filename="empty.bin",
content_type="application/octet-stream",
)
async with session.post(f"{base_url}/ota/upload", data=data) as resp:
# Should fail with bad request
assert resp.status >= 400, (
f"Expected error status for empty file, got {resp.status}"
)
@pytest.mark.asyncio
async def test_ota_multipart_boundary_parsing(web_server_fixture):
"""Test multipart boundary parsing edge cases"""
base_url = web_server_fixture
firmware_path = await create_test_firmware()
try:
async with aiohttp.ClientSession() as session:
# Test with custom boundary
with open(firmware_path, "rb") as f:
# Create multipart manually with specific boundary
boundary = "----WebKitFormBoundaryCustomTest123"
body = (
f"--{boundary}\r\n"
f'Content-Disposition: form-data; name="firmware"; filename="test.bin"\r\n'
f"Content-Type: application/octet-stream\r\n"
f"\r\n"
).encode()
body += f.read()
body += f"\r\n--{boundary}--\r\n".encode()
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
"Content-Length": str(len(body)),
}
async with session.post(
f"{base_url}/ota/upload", data=body, headers=headers
) as resp:
assert resp.status in [200, 201, 204], (
f"Custom boundary upload failed with status {resp.status}"
)
finally:
os.unlink(firmware_path)
@pytest.mark.asyncio
async def test_ota_concurrent_uploads(web_server_fixture):
"""Test that concurrent OTA uploads are properly handled"""
base_url = web_server_fixture
firmware_path = await create_test_firmware()
try:
async with aiohttp.ClientSession() as session:
# Create two concurrent upload tasks
async def upload_firmware():
with open(firmware_path, "rb") as f:
data = aiohttp.FormData()
data.add_field(
"firmware",
f.read(), # Read to bytes to avoid file conflicts
filename="firmware.bin",
content_type="application/octet-stream",
)
async with session.post(
f"{base_url}/ota/upload", data=data
) as resp:
return resp.status
# Start two uploads concurrently
results = await asyncio.gather(
upload_firmware(), upload_firmware(), return_exceptions=True
)
# One should succeed, the other should fail with conflict
statuses = [r for r in results if isinstance(r, int)]
assert len(statuses) == 2
assert 200 in statuses or 201 in statuses or 204 in statuses
# The other might be 409 Conflict or similar
finally:
os.unlink(firmware_path)
@pytest.mark.asyncio
async def test_ota_large_file_upload(web_server_fixture):
"""Test OTA upload with a larger file to test chunked processing"""
base_url = web_server_fixture
# Create a larger test firmware (1MB)
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f:
# ESP32 magic bytes
f.write(b"\xe9\x08\x02\x20")
# Write 1MB of data in chunks
chunk_size = 4096
for _ in range(256): # 256 * 4KB = 1MB
f.write(b"A" * chunk_size)
firmware_path = f.name
try:
async with aiohttp.ClientSession() as session:
with open(firmware_path, "rb") as f:
data = aiohttp.FormData()
data.add_field(
"firmware",
f,
filename="large_firmware.bin",
content_type="application/octet-stream",
)
# Use a longer timeout for large file
timeout = aiohttp.ClientTimeout(total=60)
async with session.post(
f"{base_url}/ota/upload", data=data, timeout=timeout
) as resp:
assert resp.status in [200, 201, 204], (
f"Large file OTA upload failed with status {resp.status}"
)
finally:
os.unlink(firmware_path)
if __name__ == "__main__":
# For manual testing
asyncio.run(test_ota_upload_multipart(asyncio.Event()))

View File

@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""
Test script for ESP-IDF web server multipart OTA upload functionality.
This script can be run manually to test OTA uploads to a running device.
"""
import argparse
from pathlib import Path
import sys
import time
import requests
def test_multipart_ota_upload(host, port, firmware_path):
"""Test OTA firmware upload using multipart/form-data"""
base_url = f"http://{host}:{port}"
print(f"Testing OTA upload to {base_url}")
# First check if server is reachable
try:
resp = requests.get(f"{base_url}/", timeout=5)
if resp.status_code != 200:
print(f"Error: Server returned status {resp.status_code}")
return False
print("✓ Server is reachable")
except requests.exceptions.RequestException as e:
print(f"Error: Cannot reach server - {e}")
return False
# Check if firmware file exists
if not Path(firmware_path).exists():
print(f"Error: Firmware file not found: {firmware_path}")
return False
# Prepare multipart upload
print(f"Uploading firmware: {firmware_path}")
print(f"File size: {Path(firmware_path).stat().st_size} bytes")
try:
with open(firmware_path, "rb") as f:
files = {"firmware": ("firmware.bin", f, "application/octet-stream")}
# Send OTA update request
resp = requests.post(f"{base_url}/ota/upload", files=files, timeout=60)
if resp.status_code in [200, 201, 204]:
print(f"✓ OTA upload successful (status: {resp.status_code})")
if resp.text:
print(f"Response: {resp.text}")
return True
else:
print(f"✗ OTA upload failed with status {resp.status_code}")
print(f"Response: {resp.text}")
return False
except requests.exceptions.RequestException as e:
print(f"Error during upload: {e}")
return False
def test_ota_with_wrong_content_type(host, port):
"""Test that OTA upload fails gracefully with wrong content type"""
base_url = f"http://{host}:{port}"
print("\nTesting OTA with wrong content type...")
try:
# Send plain text instead of multipart
headers = {"Content-Type": "text/plain"}
resp = requests.post(
f"{base_url}/ota/upload",
data="This is not a firmware file",
headers=headers,
timeout=10,
)
if resp.status_code >= 400:
print(
f"✓ Server correctly rejected wrong content type (status: {resp.status_code})"
)
return True
else:
print(f"✗ Server accepted wrong content type (status: {resp.status_code})")
return False
except requests.exceptions.RequestException as e:
print(f"Error: {e}")
return False
def test_ota_with_empty_file(host, port):
"""Test that OTA upload fails gracefully with empty file"""
base_url = f"http://{host}:{port}"
print("\nTesting OTA with empty file...")
try:
# Send empty file
files = {"firmware": ("empty.bin", b"", "application/octet-stream")}
resp = requests.post(f"{base_url}/ota/upload", files=files, timeout=10)
if resp.status_code >= 400:
print(
f"✓ Server correctly rejected empty file (status: {resp.status_code})"
)
return True
else:
print(f"✗ Server accepted empty file (status: {resp.status_code})")
return False
except requests.exceptions.RequestException as e:
print(f"Error: {e}")
return False
def create_test_firmware(size_kb=10):
"""Create a dummy firmware file for testing"""
import tempfile
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f:
# ESP32 firmware magic bytes
f.write(b"\xe9\x08\x02\x20")
# Add padding
f.write(b"\x00" * (size_kb * 1024 - 4))
return f.name
def main():
parser = argparse.ArgumentParser(
description="Test ESP-IDF web server OTA functionality"
)
parser.add_argument("--host", default="localhost", help="Device hostname or IP")
parser.add_argument("--port", type=int, default=8080, help="Web server port")
parser.add_argument(
"--firmware", help="Path to firmware file (if not specified, creates test file)"
)
parser.add_argument(
"--skip-error-tests", action="store_true", help="Skip error condition tests"
)
args = parser.parse_args()
# Create test firmware if not specified
firmware_path = args.firmware
if not firmware_path:
print("Creating test firmware file...")
firmware_path = create_test_firmware(100) # 100KB test file
print(f"Created test firmware: {firmware_path}")
all_passed = True
# Test successful OTA upload
if not test_multipart_ota_upload(args.host, args.port, firmware_path):
all_passed = False
# Test error conditions
if not args.skip_error_tests:
time.sleep(1) # Small delay between tests
if not test_ota_with_wrong_content_type(args.host, args.port):
all_passed = False
time.sleep(1)
if not test_ota_with_empty_file(args.host, args.port):
all_passed = False
# Clean up test firmware if we created it
if not args.firmware:
import os
os.unlink(firmware_path)
print("\nCleaned up test firmware")
print(f"\n{'All tests passed!' if all_passed else 'Some tests failed!'}")
return 0 if all_passed else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,12 +1,33 @@
# Test configuration for ESP-IDF web server with OTA enabled
esphome:
name: test-web-server-ota-idf
# Force ESP-IDF framework
esp32:
board: esp32dev
framework:
type: esp-idf
packages:
device_base: !include common.yaml
# Enable OTA for this test
# Enable OTA for multipart upload testing
ota:
- platform: esphome
safe_mode: true
password: "test_ota_password"
# Web server with OTA enabled
web_server:
port: 8080
version: 2
ota: true
include_internal: true
# Enable debug logging for OTA
logger:
level: DEBUG
logs:
web_server: VERBOSE
web_server_idf: VERBOSE

View File

@ -0,0 +1,70 @@
# Testing ESP-IDF Web Server OTA Functionality
This directory contains tests for the ESP-IDF web server OTA (Over-The-Air) update functionality using multipart form uploads.
## Test Files
- `test_ota.esp32-idf.yaml` - ESPHome configuration with OTA enabled for ESP-IDF
- `test_no_ota.esp32-idf.yaml` - ESPHome configuration with OTA disabled
- `test_ota_disabled.esp32-idf.yaml` - ESPHome configuration with web_server ota: false
- `test_multipart_ota.py` - Manual test script for OTA functionality
- `test_esp_idf_ota.py` - Automated pytest for OTA functionality
## Running the Tests
### 1. Compile and Flash Test Device
```bash
# Compile the OTA-enabled configuration
esphome compile tests/components/web_server/test_ota.esp32-idf.yaml
# Flash to device
esphome upload tests/components/web_server/test_ota.esp32-idf.yaml
```
### 2. Run Manual Tests
Once the device is running, you can test OTA functionality:
```bash
# Test with default settings (creates test firmware)
python tests/components/web_server/test_multipart_ota.py --host <device-ip>
# Test with real firmware file
python tests/components/web_server/test_multipart_ota.py --host <device-ip> --firmware <path-to-firmware.bin>
# Skip error condition tests (useful for production devices)
python tests/components/web_server/test_multipart_ota.py --host <device-ip> --skip-error-tests
```
### 3. Run Automated Tests
```bash
# Run pytest suite
pytest tests/component_tests/web_server/test_esp_idf_ota.py
```
## What's Being Tested
1. **Multipart Upload**: Tests that firmware can be uploaded using multipart/form-data
2. **Error Handling**:
- Wrong content type rejection
- Empty file rejection
- Concurrent upload handling
3. **Large Files**: Tests chunked processing of larger firmware files
4. **Boundary Parsing**: Tests various multipart boundary formats
## Implementation Details
The ESP-IDF web server uses the `multipart-parser` library to handle multipart uploads. Key components:
- `MultipartReader` class for parsing multipart data
- Chunked processing to handle large files without excessive memory use
- Integration with ESPHome's OTA component for actual firmware updates
## Troubleshooting
1. **Connection Refused**: Make sure the device is on the network and the IP is correct
2. **404 Not Found**: Ensure OTA is enabled in the configuration (`ota: true` in web_server)
3. **Upload Fails**: Check device logs for detailed error messages
4. **Timeout**: Large firmware files may take time, increase timeout if needed