[micro_wake_word] add new VPE features (#8655)

This commit is contained in:
Kevin Ahrendt 2025-04-30 19:22:48 -05:00 committed by GitHub
parent 6de6a0c82c
commit cdc77506de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 788 additions and 503 deletions

View File

@ -12,6 +12,7 @@ import esphome.config_validation as cv
from esphome.const import (
CONF_FILE,
CONF_ID,
CONF_INTERNAL,
CONF_MICROPHONE,
CONF_MODEL,
CONF_PASSWORD,
@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected"
CONF_PROBABILITY_CUTOFF = "probability_cutoff"
CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size"
CONF_SLIDING_WINDOW_SIZE = "sliding_window_size"
CONF_STOP_AFTER_DETECTION = "stop_after_detection"
CONF_TENSOR_ARENA_SIZE = "tensor_arena_size"
CONF_VAD = "vad"
@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word")
MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component)
DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action)
EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action)
StartAction = micro_wake_word_ns.class_("StartAction", automation.Action)
StopAction = micro_wake_word_ns.class_("StopAction", automation.Action)
ModelIsEnabledCondition = micro_wake_word_ns.class_(
"ModelIsEnabledCondition", automation.Condition
)
IsRunningCondition = micro_wake_word_ns.class_(
"IsRunningCondition", automation.Condition
)
WakeWordModel = micro_wake_word_ns.class_("WakeWordModel")
def _validate_json_filename(value):
value = cv.string(value)
@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest):
# Original Inception-based V1 manifest models require a minimum of 45672 bytes
v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672
# Original Inception-based V1 manifest models use a 20 ms feature step size
v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20
# Original Inception-based V1 manifest models were trained only on TTS English samples
v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"]
return v2_manifest
@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any(
MODEL_SCHEMA = cv.Schema(
{
cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel),
cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA,
cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage,
cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int,
cv.Optional(CONF_INTERNAL, default=False): cv.boolean,
cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8),
}
)
# Provide a default VAD model that could be overridden
# Provides a default VAD model that could be overridden
VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend(
cv.Schema(
{
@ -343,6 +355,7 @@ CONFIG_SCHEMA = cv.All(
single=True
),
cv.Optional(CONF_VAD): _maybe_empty_vad_schema,
cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean,
cv.Optional(CONF_MODEL): cv.invalid(
f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter."
),
@ -433,29 +446,20 @@ async def to_code(config):
mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE])
cg.add(var.set_microphone_source(mic_source))
cg.add_define("USE_MICRO_WAKE_WORD")
cg.add_define("USE_OTA_STATE_CALLBACK")
esp32.add_idf_component(
name="esp-tflite-micro",
repo="https://github.com/espressif/esp-tflite-micro",
ref="v1.3.1",
)
# add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17
# ...remove after switching to IDF 5.1.4+
esp32.add_idf_component(
name="esp-nn",
repo="https://github.com/espressif/esp-nn",
ref="v1.1.0",
ref="v1.3.3.1",
)
cg.add_build_flag("-DTF_LITE_STATIC_MEMORY")
cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON")
cg.add_build_flag("-DESP_NN")
if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED):
await automation.build_automation(
var.get_wake_word_detected_trigger(),
[(cg.std_string, "wake_word")],
on_wake_word_detection_config,
)
cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0")
if vad_model := config.get(CONF_VAD):
cg.add_define("USE_MICRO_WAKE_WORD_VAD")
@ -463,7 +467,7 @@ async def to_code(config):
# Use the general model loading code for the VAD codegen
config[CONF_MODELS].append(vad_model)
for model_parameters in config[CONF_MODELS]:
for i, model_parameters in enumerate(config[CONF_MODELS]):
model_config = model_parameters.get(CONF_MODEL)
data = []
manifest, data = _model_config_to_manifest_data(model_config)
@ -474,6 +478,8 @@ async def to_code(config):
probability_cutoff = model_parameters.get(
CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF]
)
quantized_probability_cutoff = int(probability_cutoff * 255)
sliding_window_size = model_parameters.get(
CONF_SLIDING_WINDOW_SIZE,
manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE],
@ -483,24 +489,40 @@ async def to_code(config):
cg.add(
var.add_vad_model(
prog_arr,
probability_cutoff,
quantized_probability_cutoff,
sliding_window_size,
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
)
)
else:
cg.add(
var.add_wake_word_model(
prog_arr,
probability_cutoff,
sliding_window_size,
manifest[KEY_WAKE_WORD],
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
)
# Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash
default_enabled = i == 0
wake_word_model = cg.new_Pvariable(
model_parameters[CONF_ID],
str(model_parameters[CONF_ID]),
prog_arr,
quantized_probability_cutoff,
sliding_window_size,
manifest[KEY_WAKE_WORD],
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
default_enabled,
model_parameters[CONF_INTERNAL],
)
for lang in manifest[KEY_TRAINED_LANGUAGES]:
cg.add(wake_word_model.add_trained_language(lang))
cg.add(var.add_wake_word_model(wake_word_model))
cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE]))
cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0")
cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION]))
if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED):
await automation.build_automation(
var.get_wake_word_detected_trigger(),
[(cg.std_string, "wake_word")],
on_wake_word_detection_config,
)
MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)})
@ -515,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args):
var = cg.new_Pvariable(action_id, template_arg)
await cg.register_parented(var, config[CONF_ID])
return var
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id(
{
cv.Required(CONF_ID): cv.use_id(WakeWordModel),
}
)
@register_action(
"micro_wake_word.enable_model",
EnableModelAction,
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA,
)
@register_action(
"micro_wake_word.disable_model",
DisableModelAction,
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA,
)
@register_condition(
"micro_wake_word.model_is_enabled",
ModelIsEnabledCondition,
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA,
)
async def model_action(config, action_id, template_arg, args):
parent = await cg.get_variable(config[CONF_ID])
return cg.new_Pvariable(action_id, template_arg, parent)

View File

@ -0,0 +1,54 @@
#pragma once
#include "micro_wake_word.h"
#include "streaming_model.h"
#ifdef USE_ESP_IDF
namespace esphome {
namespace micro_wake_word {
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> {
public:
void play(Ts... x) override { this->parent_->start(); }
};
template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> {
public:
void play(Ts... x) override { this->parent_->stop(); }
};
template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> {
public:
bool check(Ts... x) override { return this->parent_->is_running(); }
};
template<typename... Ts> class EnableModelAction : public Action<Ts...> {
public:
explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {}
void play(Ts... x) override { this->wake_word_model_->enable(); }
protected:
WakeWordModel *wake_word_model_;
};
template<typename... Ts> class DisableModelAction : public Action<Ts...> {
public:
explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {}
void play(Ts... x) override { this->wake_word_model_->disable(); }
protected:
WakeWordModel *wake_word_model_;
};
template<typename... Ts> class ModelIsEnabledCondition : public Condition<Ts...> {
public:
explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {}
bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); }
protected:
WakeWordModel *wake_word_model_;
};
} // namespace micro_wake_word
} // namespace esphome
#endif

View File

@ -1,5 +1,4 @@
#include "micro_wake_word.h"
#include "streaming_model.h"
#ifdef USE_ESP_IDF
@ -7,41 +6,57 @@
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
#include <frontend.h>
#include <frontend_util.h>
#include "esphome/components/audio/audio_transfer_buffer.h"
#include <tensorflow/lite/core/c/common.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
#include <cmath>
#ifdef USE_OTA
#include "esphome/components/ota/ota_backend.h"
#endif
namespace esphome {
namespace micro_wake_word {
static const char *const TAG = "micro_wake_word";
static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz
static const size_t BUFFER_LENGTH = 64; // 0.064 seconds
static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH;
static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000; // 16ms * 16kHz / 1000ms
static const ssize_t DETECTION_QUEUE_LENGTH = 5;
static const size_t DATA_TIMEOUT_MS = 50;
static const uint32_t RING_BUFFER_DURATION_MS = 120;
static const uint32_t RING_BUFFER_SAMPLES = RING_BUFFER_DURATION_MS * (AUDIO_SAMPLE_FREQUENCY / 1000);
static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t);
static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
enum EventGroupBits : uint32_t {
COMMAND_STOP = (1 << 0), // Signals the inference task should stop
TASK_STARTING = (1 << 3),
TASK_RUNNING = (1 << 4),
TASK_STOPPING = (1 << 5),
TASK_STOPPED = (1 << 6),
ERROR_MEMORY = (1 << 9),
ERROR_INFERENCE = (1 << 10),
WARNING_FULL_RING_BUFFER = (1 << 13),
ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE,
ALL_BITS = 0xfffff, // 24 total bits available in an event group
};
float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; }
static const LogString *micro_wake_word_state_to_string(State state) {
switch (state) {
case State::IDLE:
return LOG_STR("IDLE");
case State::START_MICROPHONE:
return LOG_STR("START_MICROPHONE");
case State::STARTING_MICROPHONE:
return LOG_STR("STARTING_MICROPHONE");
case State::STARTING:
return LOG_STR("STARTING");
case State::DETECTING_WAKE_WORD:
return LOG_STR("DETECTING_WAKE_WORD");
case State::STOP_MICROPHONE:
return LOG_STR("STOP_MICROPHONE");
case State::STOPPING_MICROPHONE:
return LOG_STR("STOPPING_MICROPHONE");
case State::STOPPING:
return LOG_STR("STOPPING");
case State::STOPPED:
return LOG_STR("STOPPED");
default:
return LOG_STR("UNKNOWN");
}
@ -51,7 +66,7 @@ void MicroWakeWord::dump_config() {
ESP_LOGCONFIG(TAG, "microWakeWord:");
ESP_LOGCONFIG(TAG, " models:");
for (auto &model : this->wake_word_models_) {
model.log_model_config();
model->log_model_config();
}
#ifdef USE_MICRO_WAKE_WORD_VAD
this->vad_model_->log_model_config();
@ -61,108 +76,266 @@ void MicroWakeWord::dump_config() {
void MicroWakeWord::setup() {
ESP_LOGCONFIG(TAG, "Setting up microWakeWord...");
this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
this->frontend_config_.window.step_size_ms = this->features_step_size_;
this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
this->event_group_ = xEventGroupCreate();
if (this->event_group_ == nullptr) {
ESP_LOGE(TAG, "Failed to create event group");
this->mark_failed();
return;
}
this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
if (this->detection_queue_ == nullptr) {
ESP_LOGE(TAG, "Failed to create detection event queue");
this->mark_failed();
return;
}
this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
if (this->state_ != State::DETECTING_WAKE_WORD) {
if (this->state_ == State::STOPPED) {
return;
}
std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_;
if (this->ring_buffer_.use_count() == 2) {
// mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer
std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
if (this->ring_buffer_.use_count() > 1) {
size_t bytes_free = temp_ring_buffer->free();
if (bytes_free < data.size()) {
ESP_LOGW(
TAG,
"Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). "
"Resetting the ring buffer. Wake word detection accuracy will be reduced.",
bytes_free, data.size());
xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
temp_ring_buffer->reset();
}
temp_ring_buffer->write((void *) data.data(), data.size());
}
});
if (!this->register_streaming_ops_(this->streaming_op_resolver_)) {
this->mark_failed();
return;
#ifdef USE_OTA
ota::get_global_ota_callback()->add_on_state_callback(
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
if (state == ota::OTA_STARTED) {
this->suspend_task_();
} else if (state == ota::OTA_ERROR) {
this->resume_task_();
}
});
#endif
ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
}
void MicroWakeWord::inference_task(void *params) {
MicroWakeWord *this_mww = (MicroWakeWord *) params;
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
{ // Ensures any C++ objects fall out of scope to deallocate before deleting the task
const size_t new_samples_to_read = this_mww->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000);
std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer;
int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
// Allocate audio transfer buffer
audio_buffer = audio::AudioSourceTransferBuffer::create(new_samples_to_read * sizeof(int16_t));
if (audio_buffer == nullptr) {
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
}
}
if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
// Allocate ring buffer
std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(RING_BUFFER_SIZE);
if (temp_ring_buffer.use_count() == 0) {
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
}
audio_buffer->set_source(temp_ring_buffer);
this_mww->ring_buffer_ = temp_ring_buffer;
}
if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
this_mww->microphone_source_->start();
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) {
audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS));
if (audio_buffer->available() < new_samples_to_read * sizeof(int16_t)) {
// Insufficient data to generate new spectrogram features, read more next iteration
continue;
}
// Generate new spectrogram features
size_t processed_samples = this_mww->generate_features_(
(int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer);
audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t));
// Run inference using the new spectorgram features
if (!this_mww->update_model_probabilities_(features_buffer)) {
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
break;
}
// Process each model's probabilities and possibly send a Detection Event to the queue
this_mww->process_probabilities_();
}
}
}
ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
this->frontend_config_.window.step_size_ms = this->features_step_size_;
this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
this->frontend_config_.filterbank.lower_band_limit = 125.0;
this->frontend_config_.filterbank.upper_band_limit = 7500.0;
this->frontend_config_.noise_reduction.smoothing_bits = 10;
this->frontend_config_.noise_reduction.even_smoothing = 0.025;
this->frontend_config_.noise_reduction.odd_smoothing = 0.06;
this->frontend_config_.noise_reduction.min_signal_remaining = 0.05;
this->frontend_config_.pcan_gain_control.enable_pcan = 1;
this->frontend_config_.pcan_gain_control.strength = 0.95;
this->frontend_config_.pcan_gain_control.offset = 80.0;
this->frontend_config_.pcan_gain_control.gain_bits = 21;
this->frontend_config_.log_scale.enable_log = 1;
this->frontend_config_.log_scale.scale_shift = 6;
this_mww->unload_models_();
this_mww->microphone_source_->stop();
FrontendFreeStateContents(&this_mww->frontend_state_);
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
while (true) {
// Continuously delay until the main loop deletes the task
delay(10);
}
}
void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff,
size_t sliding_window_average_size, const std::string &wake_word,
size_t tensor_arena_size) {
this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word,
tensor_arena_size);
std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
std::vector<WakeWordModel *> external_wake_word_models;
for (auto *model : this->wake_word_models_) {
if (!model->get_internal_only()) {
external_wake_word_models.push_back(model);
}
}
return external_wake_word_models;
}
void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); }
#ifdef USE_MICRO_WAKE_WORD_VAD
void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size,
void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
size_t tensor_arena_size) {
this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
}
#endif
void MicroWakeWord::suspend_task_() {
if (this->inference_task_handle_ != nullptr) {
vTaskSuspend(this->inference_task_handle_);
}
}
void MicroWakeWord::resume_task_() {
if (this->inference_task_handle_ != nullptr) {
vTaskResume(this->inference_task_handle_);
}
}
void MicroWakeWord::loop() {
uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
ESP_LOGE(TAG, "Encountered an error allocating buffers");
}
if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
ESP_LOGE(TAG, "Encountered an error while performing an inference");
}
if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
"word detection accuracy will temporarily be reduced.");
}
if (event_group_bits & EventGroupBits::TASK_STARTING) {
ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
}
if (event_group_bits & EventGroupBits::TASK_RUNNING) {
ESP_LOGD(TAG, "Inference task is running");
xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
this->set_state_(State::DETECTING_WAKE_WORD);
}
if (event_group_bits & EventGroupBits::TASK_STOPPING) {
ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
}
if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
vTaskDelete(this->inference_task_handle_);
this->inference_task_handle_ = nullptr;
xEventGroupClearBits(this->event_group_, ALL_BITS);
xQueueReset(this->detection_queue_);
this->set_state_(State::STOPPED);
}
if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
this->set_state_(State::STARTING);
this->pending_start_ = false;
}
if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
this->set_state_(State::STOPPING);
this->pending_stop_ = false;
}
switch (this->state_) {
case State::IDLE:
break;
case State::START_MICROPHONE:
ESP_LOGD(TAG, "Starting Microphone");
this->microphone_source_->start();
this->set_state_(State::STARTING_MICROPHONE);
break;
case State::STARTING_MICROPHONE:
if (this->microphone_source_->is_running()) {
this->set_state_(State::DETECTING_WAKE_WORD);
}
break;
case State::DETECTING_WAKE_WORD:
while (this->has_enough_samples_()) {
this->update_model_probabilities_();
if (this->detect_wake_words_()) {
ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str());
this->detected_ = true;
this->set_state_(State::STOP_MICROPHONE);
case State::STARTING:
if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) {
// Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
// uses floating point operations.
if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) {
this->status_momentary_error(
"Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000);
return;
}
xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this,
INFERENCE_TASK_PRIORITY, &this->inference_task_handle_);
if (this->inference_task_handle_ == nullptr) {
FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000);
}
}
break;
case State::STOP_MICROPHONE:
ESP_LOGD(TAG, "Stopping Microphone");
this->microphone_source_->stop();
this->set_state_(State::STOPPING_MICROPHONE);
this->unload_models_();
this->deallocate_buffers_();
break;
case State::STOPPING_MICROPHONE:
if (this->microphone_source_->is_stopped()) {
this->set_state_(State::IDLE);
if (this->detected_) {
this->wake_word_detected_trigger_->trigger(this->detected_wake_word_);
this->detected_ = false;
this->detected_wake_word_ = "";
case State::DETECTING_WAKE_WORD: {
DetectionEvent detection_event;
while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
if (detection_event.blocked_by_vad) {
ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
} else {
constexpr float uint8_to_float_divisor =
255.0f; // Converting a quantized uint8 probability to floating point
ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
(detection_event.max_probability / uint8_to_float_divisor));
this->wake_word_detected_trigger_->trigger(*detection_event.wake_word);
if (this->stop_after_detection_) {
this->stop();
}
}
}
break;
}
case State::STOPPING:
xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
break;
case State::STOPPED:
break;
}
}
@ -177,199 +350,40 @@ void MicroWakeWord::start() {
return;
}
if (this->state_ != State::IDLE) {
ESP_LOGW(TAG, "Wake word is already running");
if (this->is_running()) {
ESP_LOGW(TAG, "Wake word detection is already running");
return;
}
if (!this->load_models_() || !this->allocate_buffers_()) {
ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers");
this->status_set_error();
} else {
this->status_clear_error();
}
ESP_LOGD(TAG, "Starting wake word detection");
if (this->status_has_error()) {
ESP_LOGW(TAG, "Wake word component has an error. Please check logs");
return;
}
this->reset_states_();
this->set_state_(State::START_MICROPHONE);
this->pending_start_ = true;
this->pending_stop_ = false;
}
void MicroWakeWord::stop() {
if (this->state_ == State::IDLE) {
ESP_LOGW(TAG, "Wake word is already stopped");
if (this->state_ == STOPPED)
return;
}
if (this->state_ == State::STOPPING_MICROPHONE) {
ESP_LOGW(TAG, "Wake word is already stopping");
return;
}
this->set_state_(State::STOP_MICROPHONE);
ESP_LOGD(TAG, "Stopping wake word detection");
this->pending_start_ = false;
this->pending_stop_ = true;
}
void MicroWakeWord::set_state_(State state) {
ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
LOG_STR_ARG(micro_wake_word_state_to_string(state)));
this->state_ = state;
if (this->state_ != state) {
ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
LOG_STR_ARG(micro_wake_word_state_to_string(state)));
this->state_ = state;
}
}
bool MicroWakeWord::allocate_buffers_() {
ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
if (this->input_buffer_ == nullptr) {
this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t));
if (this->input_buffer_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate input buffer");
return false;
}
}
if (this->preprocessor_audio_buffer_ == nullptr) {
this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_());
if (this->preprocessor_audio_buffer_ == nullptr) {
ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer.");
return false;
}
}
if (this->ring_buffer_.use_count() == 0) {
this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t));
if (this->ring_buffer_.use_count() == 0) {
ESP_LOGE(TAG, "Could not allocate ring buffer");
return false;
}
}
return true;
}
void MicroWakeWord::deallocate_buffers_() {
ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
if (this->input_buffer_ != nullptr) {
audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t));
this->input_buffer_ = nullptr;
}
if (this->preprocessor_audio_buffer_ != nullptr) {
audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_());
this->preprocessor_audio_buffer_ = nullptr;
}
this->ring_buffer_.reset();
}
bool MicroWakeWord::load_models_() {
// Setup preprocesor feature generator
if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) {
ESP_LOGD(TAG, "Failed to populate frontend state");
FrontendFreeStateContents(&this->frontend_state_);
return false;
}
// Setup streaming models
for (auto &model : this->wake_word_models_) {
if (!model.load_model(this->streaming_op_resolver_)) {
ESP_LOGE(TAG, "Failed to initialize a wake word model.");
return false;
}
}
#ifdef USE_MICRO_WAKE_WORD_VAD
if (!this->vad_model_->load_model(this->streaming_op_resolver_)) {
ESP_LOGE(TAG, "Failed to initialize VAD model.");
return false;
}
#endif
return true;
}
void MicroWakeWord::unload_models_() {
FrontendFreeStateContents(&this->frontend_state_);
for (auto &model : this->wake_word_models_) {
model.unload_model();
}
#ifdef USE_MICRO_WAKE_WORD_VAD
this->vad_model_->unload_model();
#endif
}
void MicroWakeWord::update_model_probabilities_() {
int8_t audio_features[PREPROCESSOR_FEATURE_SIZE];
if (!this->generate_features_for_window_(audio_features)) {
return;
}
// Increase the counter since the last positive detection
this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
for (auto &model : this->wake_word_models_) {
// Perform inference
model.perform_streaming_inference(audio_features);
}
#ifdef USE_MICRO_WAKE_WORD_VAD
this->vad_model_->perform_streaming_inference(audio_features);
#endif
}
bool MicroWakeWord::detect_wake_words_() {
// Verify we have processed samples since the last positive detection
if (this->ignore_windows_ < 0) {
return false;
}
#ifdef USE_MICRO_WAKE_WORD_VAD
bool vad_state = this->vad_model_->determine_detected();
#endif
for (auto &model : this->wake_word_models_) {
if (model.determine_detected()) {
#ifdef USE_MICRO_WAKE_WORD_VAD
if (vad_state) {
#endif
this->detected_wake_word_ = model.get_wake_word();
return true;
#ifdef USE_MICRO_WAKE_WORD_VAD
} else {
ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str());
}
#endif
}
}
return false;
}
bool MicroWakeWord::has_enough_samples_() {
return this->ring_buffer_->available() >=
(this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t);
}
bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
// Ensure we have enough new audio samples in the ring buffer for a full window
if (!this->has_enough_samples_()) {
return false;
}
size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_),
this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200));
if (bytes_read == 0) {
ESP_LOGE(TAG, "Could not read data from Ring Buffer");
} else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) {
ESP_LOGD(TAG, "Partial Read of Data by Model");
ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read,
(int) (this->new_samples_to_get_() * sizeof(int16_t)));
return false;
}
size_t num_samples_read;
struct FrontendOutput frontend_output = FrontendProcessSamples(
&this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read);
size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available,
int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) {
size_t processed_samples = 0;
struct FrontendOutput frontend_output =
FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples);
for (size_t i = 0; i < frontend_output.size; ++i) {
// These scaling values are set to match the TFLite audio frontend int8 output.
@ -379,8 +393,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F
// for historical reasons, to match up with the output of other feature
// generators.
// The process is then further complicated when we quantize the model. This
// means we have to scale the 0.0 to 26.0 real values to the -128 to 127
// signed integer numbers.
// means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
// to 127 (INT8_MAX) signed integer numbers.
// All this means that to get matching values from our integer feature
// output into the tensor input, we have to perform:
// input = (((feature / 25.6) / 26.0) * 256) - 128
@ -389,74 +403,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F
constexpr int32_t value_scale = 256;
constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
value -= 128;
if (value < -128) {
value = -128;
}
if (value > 127) {
value = 127;
}
features[i] = value;
value -= INT8_MIN;
features_buffer[i] = clamp<int8_t>(value, INT8_MIN, INT8_MAX);
}
return true;
return processed_samples;
}
void MicroWakeWord::reset_states_() {
ESP_LOGD(TAG, "Resetting buffers and probabilities");
this->ring_buffer_->reset();
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
void MicroWakeWord::process_probabilities_() {
#ifdef USE_MICRO_WAKE_WORD_VAD
DetectionEvent vad_state = this->vad_model_->determine_detected();
this->vad_state_ = vad_state.detected; // atomic write, so thread safe
#endif
for (auto &model : this->wake_word_models_) {
model.reset_probabilities();
if (model->get_unprocessed_probability_status()) {
// Only detect wake words if there is a new probability since the last check
DetectionEvent wake_word_state = model->determine_detected();
if (wake_word_state.detected) {
#ifdef USE_MICRO_WAKE_WORD_VAD
if (vad_state.detected) {
#endif
xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
model->reset_probabilities();
#ifdef USE_MICRO_WAKE_WORD_VAD
} else {
wake_word_state.blocked_by_vad = true;
xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
}
#endif
}
}
}
}
void MicroWakeWord::unload_models_() {
for (auto &model : this->wake_word_models_) {
model->unload_model();
}
#ifdef USE_MICRO_WAKE_WORD_VAD
this->vad_model_->reset_probabilities();
this->vad_model_->unload_model();
#endif
}
bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
if (op_resolver.AddCallOnce() != kTfLiteOk)
return false;
if (op_resolver.AddVarHandle() != kTfLiteOk)
return false;
if (op_resolver.AddReshape() != kTfLiteOk)
return false;
if (op_resolver.AddReadVariable() != kTfLiteOk)
return false;
if (op_resolver.AddStridedSlice() != kTfLiteOk)
return false;
if (op_resolver.AddConcatenation() != kTfLiteOk)
return false;
if (op_resolver.AddAssignVariable() != kTfLiteOk)
return false;
if (op_resolver.AddConv2D() != kTfLiteOk)
return false;
if (op_resolver.AddMul() != kTfLiteOk)
return false;
if (op_resolver.AddAdd() != kTfLiteOk)
return false;
if (op_resolver.AddMean() != kTfLiteOk)
return false;
if (op_resolver.AddFullyConnected() != kTfLiteOk)
return false;
if (op_resolver.AddLogistic() != kTfLiteOk)
return false;
if (op_resolver.AddQuantize() != kTfLiteOk)
return false;
if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
return false;
if (op_resolver.AddAveragePool2D() != kTfLiteOk)
return false;
if (op_resolver.AddMaxPool2D() != kTfLiteOk)
return false;
if (op_resolver.AddPad() != kTfLiteOk)
return false;
if (op_resolver.AddPack() != kTfLiteOk)
return false;
if (op_resolver.AddSplitV() != kTfLiteOk)
return false;
bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
bool success = true;
return true;
for (auto &model : this->wake_word_models_) {
// Perform inference
success = success & model->perform_streaming_inference(audio_features);
}
#ifdef USE_MICRO_WAKE_WORD_VAD
success = success & this->vad_model_->perform_streaming_inference(audio_features);
#endif
return success;
}
} // namespace micro_wake_word

View File

@ -5,33 +5,27 @@
#include "preprocessor_settings.h"
#include "streaming_model.h"
#include "esphome/components/microphone/microphone_source.h"
#include "esphome/core/automation.h"
#include "esphome/core/component.h"
#include "esphome/core/ring_buffer.h"
#include "esphome/components/microphone/microphone_source.h"
#include <freertos/event_groups.h>
#include <frontend.h>
#include <frontend_util.h>
#include <tensorflow/lite/core/c/common.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
namespace esphome {
namespace micro_wake_word {
enum State {
IDLE,
START_MICROPHONE,
STARTING_MICROPHONE,
STARTING,
DETECTING_WAKE_WORD,
STOP_MICROPHONE,
STOPPING_MICROPHONE,
STOPPING,
STOPPED,
};
// The number of audio slices to process before accepting a positive detection
static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74;
class MicroWakeWord : public Component {
public:
void setup() override;
@ -42,7 +36,7 @@ class MicroWakeWord : public Component {
void start();
void stop();
bool is_running() const { return this->state_ != State::IDLE; }
bool is_running() const { return this->state_ != State::STOPPED; }
void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; }
@ -50,118 +44,87 @@ class MicroWakeWord : public Component {
this->microphone_source_ = microphone_source;
}
void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; }
Trigger<std::string> *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; }
void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size);
void add_wake_word_model(WakeWordModel *model);
#ifdef USE_MICRO_WAKE_WORD_VAD
void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size,
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
size_t tensor_arena_size);
// Intended for the voice assistant component to fetch VAD status
bool get_vad_state() { return this->vad_state_; }
#endif
// Intended for the voice assistant component to access which wake words are available
// Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them
std::vector<WakeWordModel *> get_wake_words();
protected:
microphone::MicrophoneSource *microphone_source_{nullptr};
Trigger<std::string> *wake_word_detected_trigger_ = new Trigger<std::string>();
State state_{State::IDLE};
State state_{State::STOPPED};
std::shared_ptr<RingBuffer> ring_buffer_;
std::vector<WakeWordModel> wake_word_models_;
std::weak_ptr<RingBuffer> ring_buffer_;
std::vector<WakeWordModel *> wake_word_models_;
#ifdef USE_MICRO_WAKE_WORD_VAD
std::unique_ptr<VADModel> vad_model_;
bool vad_state_{false};
#endif
tflite::MicroMutableOpResolver<20> streaming_op_resolver_;
bool pending_start_{false};
bool pending_stop_{false};
bool stop_after_detection_;
uint8_t features_step_size_;
// Audio frontend handles generating spectrogram features
struct FrontendConfig frontend_config_;
struct FrontendState frontend_state_;
// When the wake word detection first starts, we ignore this many audio
// feature slices before accepting a positive detection
int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
// Handles managing the stop/state of the inference task
EventGroupHandle_t event_group_;
uint8_t features_step_size_;
// Used to send messages about the models' states to the main loop
QueueHandle_t detection_queue_;
// Stores audio read from the microphone before being added to the ring buffer.
int16_t *input_buffer_{nullptr};
// Stores audio to be fed into the audio frontend for generating features.
int16_t *preprocessor_audio_buffer_{nullptr};
static void inference_task(void *params);
TaskHandle_t inference_task_handle_{nullptr};
bool detected_{false};
std::string detected_wake_word_{""};
/// @brief Suspends the inference task
void suspend_task_();
/// @brief Resumes the inference task
void resume_task_();
void set_state_(State state);
/// @brief Tests if there are enough samples in the ring buffer to generate new features.
/// @return True if enough samples, false otherwise.
bool has_enough_samples_();
/// @brief Generates spectrogram features from an input buffer of audio samples
/// @param audio_buffer (int16_t *) Buffer containing input audio samples
/// @param samples_available (size_t) Number of samples avaiable in the input buffer
/// @param features_buffer (int8_t *) Buffer to store generated features
/// @return (size_t) Number of samples processed from the input buffer
size_t generate_features_(int16_t *audio_buffer, size_t samples_available,
int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]);
/// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_
/// @return True if successful, false otherwise
bool allocate_buffers_();
/// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent
/// to the detection_queue_.
void process_probabilities_();
/// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_
void deallocate_buffers_();
/// @brief Loads streaming models and prepares the feature generation frontend
/// @return True if successful, false otherwise
bool load_models_();
/// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature
/// generation frontend.
/// @brief Deletes each model's TFLite interpreters and frees tensor arena memory.
void unload_models_();
/** Performs inference with each configured model
*
* If enough audio samples are available, it will generate one slice of new features.
* It then loops through and performs inference with each of the loaded models.
*/
void update_model_probabilities_();
/** Checks every model's recent probabilities to determine if the wake word has been predicted
*
* Verifies the models have processed enough new samples for accurate predictions.
* Sets detected_wake_word_ to the wake word, if one is detected.
* @return True if a wake word is predicted, false otherwise
*/
bool detect_wake_words_();
/** Generates features for a window of audio samples
*
* Reads samples from the ring buffer and feeds them into the preprocessor frontend.
* Adapted from TFLite microspeech frontend.
* @param features int8_t array to store the audio features
* @return True if successful, false otherwise.
*/
bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]);
/// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities
void reset_states_();
/// @brief Returns true if successfully registered the streaming model's TensorFlow operations
bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver);
/// @brief Runs an inference with each model using the new spectrogram features
/// @param audio_features (int8_t *) Buffer containing new spectrogram features
/// @return True if successful, false if any errors were encountered
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]);
inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); }
};
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> {
public:
void play(Ts... x) override { this->parent_->start(); }
};
template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> {
public:
void play(Ts... x) override { this->parent_->stop(); }
};
template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> {
public:
bool check(Ts... x) override { return this->parent_->is_running(); }
};
} // namespace micro_wake_word
} // namespace esphome

View File

@ -7,6 +7,10 @@
namespace esphome {
namespace micro_wake_word {
// Settings for controlling the spectrogram feature generation by the preprocessor.
// These must match the settings used when training a particular model.
// All microWakeWord models have been trained with these specific paramters.
// The number of features the audio preprocessor generates per slice
static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40;
// Duration of each slice used as input into the preprocessor
@ -14,6 +18,21 @@ static const uint8_t FEATURE_DURATION_MS = 30;
// Audio sample frequency in hertz
static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000;
static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0;
static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0;
static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10;
static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025;
static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06;
static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05;
static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true;
static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95;
static const float PCAN_GAIN_CONTROL_OFFSET = 80.0;
static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21;
static const bool LOG_SCALE_ENABLE_LOG = true;
static const uint8_t LOG_SCALE_SCALE_SHIFT = 6;
} // namespace micro_wake_word
} // namespace esphome

View File

@ -1,8 +1,7 @@
#ifdef USE_ESP_IDF
#include "streaming_model.h"
#include "esphome/core/hal.h"
#ifdef USE_ESP_IDF
#include "esphome/core/helpers.h"
#include "esphome/core/log.h"
@ -13,18 +12,18 @@ namespace micro_wake_word {
void WakeWordModel::log_model_config() {
ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str());
ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f);
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
}
void VADModel::log_model_config() {
ESP_LOGCONFIG(TAG, " - VAD Model");
ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f);
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
}
bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) {
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
bool StreamingModel::load_model_() {
RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE);
if (this->tensor_arena_ == nullptr) {
this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver)
}
if (this->interpreter_ == nullptr) {
this->interpreter_ = make_unique<tflite::MicroInterpreter>(
tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
this->interpreter_ =
make_unique<tflite::MicroInterpreter>(tflite::GetModel(this->model_start_), this->streaming_op_resolver_,
this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
return false;
@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver)
}
}
this->loaded_ = true;
this->reset_probabilities();
return true;
}
void StreamingModel::unload_model() {
this->interpreter_.reset();
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE);
arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
this->tensor_arena_ = nullptr;
arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
this->var_arena_ = nullptr;
if (this->tensor_arena_ != nullptr) {
arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
this->tensor_arena_ = nullptr;
}
if (this->var_arena_ != nullptr) {
arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
this->var_arena_ = nullptr;
}
this->loaded_ = false;
}
bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
if (this->interpreter_ != nullptr) {
if (this->enabled_ && !this->loaded_) {
// Model is enabled but isn't loaded
if (!this->load_model_()) {
return false;
}
}
if (!this->enabled_ && this->loaded_) {
// Model is disabled but still loaded
this->unload_model();
return true;
}
if (this->loaded_) {
TfLiteTensor *input = this->interpreter_->input(0);
uint8_t stride = this->interpreter_->input(0)->dims->data[1];
this->current_stride_step_ = this->current_stride_step_ % stride;
std::memmove(
(int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_,
features, PREPROCESSOR_FEATURE_SIZE);
++this->current_stride_step_;
uint8_t stride = this->interpreter_->input(0)->dims->data[1];
if (this->current_stride_step_ >= stride) {
this->current_stride_step_ = 0;
TfLiteStatus invoke_status = this->interpreter_->Invoke();
if (invoke_status != kTfLiteOk) {
ESP_LOGW(TAG, "Streaming interpreter invoke failed");
@ -124,65 +145,159 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES
if (this->last_n_index_ == this->sliding_window_size_)
this->last_n_index_ = 0;
this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
this->unprocessed_probability_status_ = true;
}
return true;
this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
}
ESP_LOGE(TAG, "Streaming interpreter is not initialized.");
return false;
return true;
}
void StreamingModel::reset_probabilities() {
for (auto &prob : this->recent_streaming_probabilities_) {
prob = 0;
}
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
}
WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size) {
WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff,
size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
bool default_enabled, bool internal_only) {
this->id_ = id;
this->model_start_ = model_start;
this->probability_cutoff_ = probability_cutoff;
this->sliding_window_size_ = sliding_window_average_size;
this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
this->wake_word_ = wake_word;
this->tensor_arena_size_ = tensor_arena_size;
this->register_streaming_ops_(this->streaming_op_resolver_);
this->current_stride_step_ = 0;
this->internal_only_ = internal_only;
this->pref_ = global_preferences->make_preference<bool>(fnv1_hash(id));
bool enabled;
if (this->pref_.load(&enabled)) {
// Use the enabled state loaded from flash
this->enabled_ = enabled;
} else {
// If no state saved, then use the default
this->enabled_ = default_enabled;
}
};
bool WakeWordModel::determine_detected() {
void WakeWordModel::enable() {
this->enabled_ = true;
if (!this->internal_only_) {
this->pref_.save(&this->enabled_);
}
}
void WakeWordModel::disable() {
this->enabled_ = false;
if (!this->internal_only_) {
this->pref_.save(&this->enabled_);
}
}
DetectionEvent WakeWordModel::determine_detected() {
DetectionEvent detection_event;
detection_event.wake_word = &this->wake_word_;
detection_event.max_probability = 0;
detection_event.average_probability = 0;
if ((this->ignore_windows_ < 0) || !this->enabled_) {
detection_event.detected = false;
return detection_event;
}
uint32_t sum = 0;
for (auto &prob : this->recent_streaming_probabilities_) {
detection_event.max_probability = std::max(detection_event.max_probability, prob);
sum += prob;
}
float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_);
detection_event.average_probability = sum / this->sliding_window_size_;
detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_;
// Detect the wake word if the sliding window average is above the cutoff
if (sliding_window_average > this->probability_cutoff_) {
ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f",
this->wake_word_.c_str(), sliding_window_average,
this->recent_streaming_probabilities_[this->last_n_index_] / (255.0));
return true;
}
return false;
this->unprocessed_probability_status_ = false;
return detection_event;
}
VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size,
VADModel::VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
size_t tensor_arena_size) {
this->model_start_ = model_start;
this->probability_cutoff_ = probability_cutoff;
this->sliding_window_size_ = sliding_window_size;
this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
this->tensor_arena_size_ = tensor_arena_size;
};
this->register_streaming_ops_(this->streaming_op_resolver_);
}
DetectionEvent VADModel::determine_detected() {
DetectionEvent detection_event;
detection_event.max_probability = 0;
detection_event.average_probability = 0;
if (!this->enabled_) {
// We disabled the VAD model for some reason... so we shouldn't block wake words from being detected
detection_event.detected = true;
return detection_event;
}
bool VADModel::determine_detected() {
uint32_t sum = 0;
for (auto &prob : this->recent_streaming_probabilities_) {
detection_event.max_probability = std::max(detection_event.max_probability, prob);
sum += prob;
}
float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_);
detection_event.average_probability = sum / this->sliding_window_size_;
detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_);
return sliding_window_average > this->probability_cutoff_;
return detection_event;
}
bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
if (op_resolver.AddCallOnce() != kTfLiteOk)
return false;
if (op_resolver.AddVarHandle() != kTfLiteOk)
return false;
if (op_resolver.AddReshape() != kTfLiteOk)
return false;
if (op_resolver.AddReadVariable() != kTfLiteOk)
return false;
if (op_resolver.AddStridedSlice() != kTfLiteOk)
return false;
if (op_resolver.AddConcatenation() != kTfLiteOk)
return false;
if (op_resolver.AddAssignVariable() != kTfLiteOk)
return false;
if (op_resolver.AddConv2D() != kTfLiteOk)
return false;
if (op_resolver.AddMul() != kTfLiteOk)
return false;
if (op_resolver.AddAdd() != kTfLiteOk)
return false;
if (op_resolver.AddMean() != kTfLiteOk)
return false;
if (op_resolver.AddFullyConnected() != kTfLiteOk)
return false;
if (op_resolver.AddLogistic() != kTfLiteOk)
return false;
if (op_resolver.AddQuantize() != kTfLiteOk)
return false;
if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
return false;
if (op_resolver.AddAveragePool2D() != kTfLiteOk)
return false;
if (op_resolver.AddMaxPool2D() != kTfLiteOk)
return false;
if (op_resolver.AddPad() != kTfLiteOk)
return false;
if (op_resolver.AddPack() != kTfLiteOk)
return false;
if (op_resolver.AddSplitV() != kTfLiteOk)
return false;
return true;
}
} // namespace micro_wake_word

View File

@ -4,6 +4,8 @@
#include "preprocessor_settings.h"
#include "esphome/core/preferences.h"
#include <tensorflow/lite/core/c/common.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
@ -11,30 +13,63 @@
namespace esphome {
namespace micro_wake_word {
static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100;
static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
struct DetectionEvent {
std::string *wake_word;
bool detected;
bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average
// hasn't yet
uint8_t max_probability;
uint8_t average_probability;
bool blocked_by_vad = false;
};
class StreamingModel {
public:
virtual void log_model_config() = 0;
virtual bool determine_detected() = 0;
virtual DetectionEvent determine_detected() = 0;
// Performs inference on the given features.
// - If the model is enabled but not loaded, it will load it
// - If the model is disabled but loaded, it will unload it
// Returns true if sucessful or false if there is an error
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]);
/// @brief Sets all recent_streaming_probabilities to 0
/// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count
void reset_probabilities();
/// @brief Allocates tensor and variable arenas and sets up the model interpreter
/// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded
/// @return True if successful, false otherwise
bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver);
/// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory
void unload_model();
protected:
uint8_t current_stride_step_{0};
/// @brief Enable the model. The next performing_streaming_inference call will load it.
virtual void enable() { this->enabled_ = true; }
float probability_cutoff_;
/// @brief Disable the model. The next performing_streaming_inference call will unload it.
virtual void disable() { this->enabled_ = false; }
/// @brief Return true if the model is enabled.
bool is_enabled() { return this->enabled_; }
bool get_unprocessed_probability_status() { return this->unprocessed_probability_status_; }
protected:
/// @brief Allocates tensor and variable arenas and sets up the model interpreter
/// @return True if successful, false otherwise
bool load_model_();
/// @brief Returns true if successfully registered the streaming model's TensorFlow operations
bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver);
tflite::MicroMutableOpResolver<20> streaming_op_resolver_;
bool loaded_{false};
bool enabled_{true};
bool unprocessed_probability_status_{false};
uint8_t current_stride_step_{0};
int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
uint8_t probability_cutoff_; // Quantized probability cutoff mapping 0.0 - 1.0 to 0 - 255
size_t sliding_window_size_;
size_t last_n_index_{0};
size_t tensor_arena_size_;
@ -50,32 +85,62 @@ class StreamingModel {
class WakeWordModel final : public StreamingModel {
public:
WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size);
/// @brief Constructs a wake word model object
/// @param id (std::string) identifier for this model
/// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer
/// @param probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said
/// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling
/// probability
/// @param wake_word (std::string) Friendly name of the wake word
/// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena
/// @param default_enabled (bool) If true, it will be enabled by default on first boot
/// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff,
size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
bool default_enabled, bool internal_only);
void log_model_config() override;
/// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability
/// cutoff
/// @return True if wake word is detected, false otherwise
bool determine_detected() override;
DetectionEvent determine_detected() override;
const std::string &get_id() const { return this->id_; }
const std::string &get_wake_word() const { return this->wake_word_; }
void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); }
const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; }
/// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it.
void enable() override;
/// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it.
void disable() override;
bool get_internal_only() { return this->internal_only_; }
protected:
std::string id_;
std::string wake_word_;
std::vector<std::string> trained_languages_;
bool internal_only_;
ESPPreferenceObject pref_;
};
class VADModel final : public StreamingModel {
public:
VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size);
VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
size_t tensor_arena_size);
void log_model_config() override;
/// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability
/// cutoff
/// @return True if voice activity is detected, false otherwise
bool determine_detected() override;
DetectionEvent determine_detected() override;
};
} // namespace micro_wake_word

View File

@ -79,6 +79,7 @@
#define USE_LVGL_TEXTAREA
#define USE_LVGL_TILEVIEW
#define USE_LVGL_TOUCHSCREEN
#define USE_MICRO_WAKE_WORD
#define USE_MD5
#define USE_MDNS
#define USE_MEDIA_PLAYER

View File

@ -14,8 +14,24 @@ micro_wake_word:
microphone: echo_microphone
on_wake_word_detected:
- logger.log: "Wake word detected"
- micro_wake_word.stop:
- if:
condition:
- micro_wake_word.model_is_enabled: hey_jarvis_model
then:
- micro_wake_word.disable_model: hey_jarvis_model
else:
- micro_wake_word.enable_model: hey_jarvis_model
- if:
condition:
- not:
- micro_wake_word.is_running:
then:
micro_wake_word.start:
stop_after_detection: false
models:
- model: hey_jarvis
probability_cutoff: 0.7
id: hey_jarvis_model
- model: okay_nabu
sliding_window_size: 5