mirror of
https://github.com/esphome/esphome.git
synced 2025-07-28 14:16:40 +00:00
[micro_wake_word] add new VPE features (#8655)
This commit is contained in:
parent
6de6a0c82c
commit
cdc77506de
@ -12,6 +12,7 @@ import esphome.config_validation as cv
|
||||
from esphome.const import (
|
||||
CONF_FILE,
|
||||
CONF_ID,
|
||||
CONF_INTERNAL,
|
||||
CONF_MICROPHONE,
|
||||
CONF_MODEL,
|
||||
CONF_PASSWORD,
|
||||
@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected"
|
||||
CONF_PROBABILITY_CUTOFF = "probability_cutoff"
|
||||
CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size"
|
||||
CONF_SLIDING_WINDOW_SIZE = "sliding_window_size"
|
||||
CONF_STOP_AFTER_DETECTION = "stop_after_detection"
|
||||
CONF_TENSOR_ARENA_SIZE = "tensor_arena_size"
|
||||
CONF_VAD = "vad"
|
||||
|
||||
@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word")
|
||||
|
||||
MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component)
|
||||
|
||||
DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action)
|
||||
EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action)
|
||||
StartAction = micro_wake_word_ns.class_("StartAction", automation.Action)
|
||||
StopAction = micro_wake_word_ns.class_("StopAction", automation.Action)
|
||||
|
||||
ModelIsEnabledCondition = micro_wake_word_ns.class_(
|
||||
"ModelIsEnabledCondition", automation.Condition
|
||||
)
|
||||
IsRunningCondition = micro_wake_word_ns.class_(
|
||||
"IsRunningCondition", automation.Condition
|
||||
)
|
||||
|
||||
WakeWordModel = micro_wake_word_ns.class_("WakeWordModel")
|
||||
|
||||
|
||||
def _validate_json_filename(value):
|
||||
value = cv.string(value)
|
||||
@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest):
|
||||
|
||||
# Original Inception-based V1 manifest models require a minimum of 45672 bytes
|
||||
v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672
|
||||
|
||||
# Original Inception-based V1 manifest models use a 20 ms feature step size
|
||||
v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20
|
||||
# Original Inception-based V1 manifest models were trained only on TTS English samples
|
||||
v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"]
|
||||
|
||||
return v2_manifest
|
||||
|
||||
@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any(
|
||||
|
||||
MODEL_SCHEMA = cv.Schema(
|
||||
{
|
||||
cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel),
|
||||
cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA,
|
||||
cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage,
|
||||
cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int,
|
||||
cv.Optional(CONF_INTERNAL, default=False): cv.boolean,
|
||||
cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8),
|
||||
}
|
||||
)
|
||||
|
||||
# Provide a default VAD model that could be overridden
|
||||
# Provides a default VAD model that could be overridden
|
||||
VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend(
|
||||
cv.Schema(
|
||||
{
|
||||
@ -343,6 +355,7 @@ CONFIG_SCHEMA = cv.All(
|
||||
single=True
|
||||
),
|
||||
cv.Optional(CONF_VAD): _maybe_empty_vad_schema,
|
||||
cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean,
|
||||
cv.Optional(CONF_MODEL): cv.invalid(
|
||||
f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter."
|
||||
),
|
||||
@ -433,29 +446,20 @@ async def to_code(config):
|
||||
mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE])
|
||||
cg.add(var.set_microphone_source(mic_source))
|
||||
|
||||
cg.add_define("USE_MICRO_WAKE_WORD")
|
||||
cg.add_define("USE_OTA_STATE_CALLBACK")
|
||||
|
||||
esp32.add_idf_component(
|
||||
name="esp-tflite-micro",
|
||||
repo="https://github.com/espressif/esp-tflite-micro",
|
||||
ref="v1.3.1",
|
||||
)
|
||||
# add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17
|
||||
# ...remove after switching to IDF 5.1.4+
|
||||
esp32.add_idf_component(
|
||||
name="esp-nn",
|
||||
repo="https://github.com/espressif/esp-nn",
|
||||
ref="v1.1.0",
|
||||
ref="v1.3.3.1",
|
||||
)
|
||||
|
||||
cg.add_build_flag("-DTF_LITE_STATIC_MEMORY")
|
||||
cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON")
|
||||
cg.add_build_flag("-DESP_NN")
|
||||
|
||||
if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED):
|
||||
await automation.build_automation(
|
||||
var.get_wake_word_detected_trigger(),
|
||||
[(cg.std_string, "wake_word")],
|
||||
on_wake_word_detection_config,
|
||||
)
|
||||
cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0")
|
||||
|
||||
if vad_model := config.get(CONF_VAD):
|
||||
cg.add_define("USE_MICRO_WAKE_WORD_VAD")
|
||||
@ -463,7 +467,7 @@ async def to_code(config):
|
||||
# Use the general model loading code for the VAD codegen
|
||||
config[CONF_MODELS].append(vad_model)
|
||||
|
||||
for model_parameters in config[CONF_MODELS]:
|
||||
for i, model_parameters in enumerate(config[CONF_MODELS]):
|
||||
model_config = model_parameters.get(CONF_MODEL)
|
||||
data = []
|
||||
manifest, data = _model_config_to_manifest_data(model_config)
|
||||
@ -474,6 +478,8 @@ async def to_code(config):
|
||||
probability_cutoff = model_parameters.get(
|
||||
CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF]
|
||||
)
|
||||
quantized_probability_cutoff = int(probability_cutoff * 255)
|
||||
|
||||
sliding_window_size = model_parameters.get(
|
||||
CONF_SLIDING_WINDOW_SIZE,
|
||||
manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE],
|
||||
@ -483,24 +489,40 @@ async def to_code(config):
|
||||
cg.add(
|
||||
var.add_vad_model(
|
||||
prog_arr,
|
||||
probability_cutoff,
|
||||
quantized_probability_cutoff,
|
||||
sliding_window_size,
|
||||
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
|
||||
)
|
||||
)
|
||||
else:
|
||||
cg.add(
|
||||
var.add_wake_word_model(
|
||||
prog_arr,
|
||||
probability_cutoff,
|
||||
sliding_window_size,
|
||||
manifest[KEY_WAKE_WORD],
|
||||
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
|
||||
)
|
||||
# Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash
|
||||
default_enabled = i == 0
|
||||
wake_word_model = cg.new_Pvariable(
|
||||
model_parameters[CONF_ID],
|
||||
str(model_parameters[CONF_ID]),
|
||||
prog_arr,
|
||||
quantized_probability_cutoff,
|
||||
sliding_window_size,
|
||||
manifest[KEY_WAKE_WORD],
|
||||
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
|
||||
default_enabled,
|
||||
model_parameters[CONF_INTERNAL],
|
||||
)
|
||||
|
||||
for lang in manifest[KEY_TRAINED_LANGUAGES]:
|
||||
cg.add(wake_word_model.add_trained_language(lang))
|
||||
|
||||
cg.add(var.add_wake_word_model(wake_word_model))
|
||||
|
||||
cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE]))
|
||||
cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0")
|
||||
cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION]))
|
||||
|
||||
if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED):
|
||||
await automation.build_automation(
|
||||
var.get_wake_word_detected_trigger(),
|
||||
[(cg.std_string, "wake_word")],
|
||||
on_wake_word_detection_config,
|
||||
)
|
||||
|
||||
|
||||
MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)})
|
||||
@ -515,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args):
|
||||
var = cg.new_Pvariable(action_id, template_arg)
|
||||
await cg.register_parented(var, config[CONF_ID])
|
||||
return var
|
||||
|
||||
|
||||
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id(
|
||||
{
|
||||
cv.Required(CONF_ID): cv.use_id(WakeWordModel),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@register_action(
|
||||
"micro_wake_word.enable_model",
|
||||
EnableModelAction,
|
||||
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA,
|
||||
)
|
||||
@register_action(
|
||||
"micro_wake_word.disable_model",
|
||||
DisableModelAction,
|
||||
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA,
|
||||
)
|
||||
@register_condition(
|
||||
"micro_wake_word.model_is_enabled",
|
||||
ModelIsEnabledCondition,
|
||||
MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA,
|
||||
)
|
||||
async def model_action(config, action_id, template_arg, args):
|
||||
parent = await cg.get_variable(config[CONF_ID])
|
||||
return cg.new_Pvariable(action_id, template_arg, parent)
|
||||
|
54
esphome/components/micro_wake_word/automation.h
Normal file
54
esphome/components/micro_wake_word/automation.h
Normal file
@ -0,0 +1,54 @@
|
||||
#pragma once
|
||||
|
||||
#include "micro_wake_word.h"
|
||||
#include "streaming_model.h"
|
||||
|
||||
#ifdef USE_ESP_IDF
|
||||
namespace esphome {
|
||||
namespace micro_wake_word {
|
||||
|
||||
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> {
|
||||
public:
|
||||
void play(Ts... x) override { this->parent_->start(); }
|
||||
};
|
||||
|
||||
template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> {
|
||||
public:
|
||||
void play(Ts... x) override { this->parent_->stop(); }
|
||||
};
|
||||
|
||||
template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> {
|
||||
public:
|
||||
bool check(Ts... x) override { return this->parent_->is_running(); }
|
||||
};
|
||||
|
||||
template<typename... Ts> class EnableModelAction : public Action<Ts...> {
|
||||
public:
|
||||
explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {}
|
||||
void play(Ts... x) override { this->wake_word_model_->enable(); }
|
||||
|
||||
protected:
|
||||
WakeWordModel *wake_word_model_;
|
||||
};
|
||||
|
||||
template<typename... Ts> class DisableModelAction : public Action<Ts...> {
|
||||
public:
|
||||
explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {}
|
||||
void play(Ts... x) override { this->wake_word_model_->disable(); }
|
||||
|
||||
protected:
|
||||
WakeWordModel *wake_word_model_;
|
||||
};
|
||||
|
||||
template<typename... Ts> class ModelIsEnabledCondition : public Condition<Ts...> {
|
||||
public:
|
||||
explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {}
|
||||
bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); }
|
||||
|
||||
protected:
|
||||
WakeWordModel *wake_word_model_;
|
||||
};
|
||||
|
||||
} // namespace micro_wake_word
|
||||
} // namespace esphome
|
||||
#endif
|
@ -1,5 +1,4 @@
|
||||
#include "micro_wake_word.h"
|
||||
#include "streaming_model.h"
|
||||
|
||||
#ifdef USE_ESP_IDF
|
||||
|
||||
@ -7,41 +6,57 @@
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
#include <frontend.h>
|
||||
#include <frontend_util.h>
|
||||
#include "esphome/components/audio/audio_transfer_buffer.h"
|
||||
|
||||
#include <tensorflow/lite/core/c/common.h>
|
||||
#include <tensorflow/lite/micro/micro_interpreter.h>
|
||||
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
|
||||
|
||||
#include <cmath>
|
||||
#ifdef USE_OTA
|
||||
#include "esphome/components/ota/ota_backend.h"
|
||||
#endif
|
||||
|
||||
namespace esphome {
|
||||
namespace micro_wake_word {
|
||||
|
||||
static const char *const TAG = "micro_wake_word";
|
||||
|
||||
static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz
|
||||
static const size_t BUFFER_LENGTH = 64; // 0.064 seconds
|
||||
static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH;
|
||||
static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000; // 16ms * 16kHz / 1000ms
|
||||
static const ssize_t DETECTION_QUEUE_LENGTH = 5;
|
||||
|
||||
static const size_t DATA_TIMEOUT_MS = 50;
|
||||
|
||||
static const uint32_t RING_BUFFER_DURATION_MS = 120;
|
||||
static const uint32_t RING_BUFFER_SAMPLES = RING_BUFFER_DURATION_MS * (AUDIO_SAMPLE_FREQUENCY / 1000);
|
||||
static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t);
|
||||
|
||||
static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072;
|
||||
static const UBaseType_t INFERENCE_TASK_PRIORITY = 3;
|
||||
|
||||
enum EventGroupBits : uint32_t {
|
||||
COMMAND_STOP = (1 << 0), // Signals the inference task should stop
|
||||
|
||||
TASK_STARTING = (1 << 3),
|
||||
TASK_RUNNING = (1 << 4),
|
||||
TASK_STOPPING = (1 << 5),
|
||||
TASK_STOPPED = (1 << 6),
|
||||
|
||||
ERROR_MEMORY = (1 << 9),
|
||||
ERROR_INFERENCE = (1 << 10),
|
||||
|
||||
WARNING_FULL_RING_BUFFER = (1 << 13),
|
||||
|
||||
ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE,
|
||||
ALL_BITS = 0xfffff, // 24 total bits available in an event group
|
||||
};
|
||||
|
||||
float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; }
|
||||
|
||||
static const LogString *micro_wake_word_state_to_string(State state) {
|
||||
switch (state) {
|
||||
case State::IDLE:
|
||||
return LOG_STR("IDLE");
|
||||
case State::START_MICROPHONE:
|
||||
return LOG_STR("START_MICROPHONE");
|
||||
case State::STARTING_MICROPHONE:
|
||||
return LOG_STR("STARTING_MICROPHONE");
|
||||
case State::STARTING:
|
||||
return LOG_STR("STARTING");
|
||||
case State::DETECTING_WAKE_WORD:
|
||||
return LOG_STR("DETECTING_WAKE_WORD");
|
||||
case State::STOP_MICROPHONE:
|
||||
return LOG_STR("STOP_MICROPHONE");
|
||||
case State::STOPPING_MICROPHONE:
|
||||
return LOG_STR("STOPPING_MICROPHONE");
|
||||
case State::STOPPING:
|
||||
return LOG_STR("STOPPING");
|
||||
case State::STOPPED:
|
||||
return LOG_STR("STOPPED");
|
||||
default:
|
||||
return LOG_STR("UNKNOWN");
|
||||
}
|
||||
@ -51,7 +66,7 @@ void MicroWakeWord::dump_config() {
|
||||
ESP_LOGCONFIG(TAG, "microWakeWord:");
|
||||
ESP_LOGCONFIG(TAG, " models:");
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
model.log_model_config();
|
||||
model->log_model_config();
|
||||
}
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
this->vad_model_->log_model_config();
|
||||
@ -61,108 +76,266 @@ void MicroWakeWord::dump_config() {
|
||||
void MicroWakeWord::setup() {
|
||||
ESP_LOGCONFIG(TAG, "Setting up microWakeWord...");
|
||||
|
||||
this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
|
||||
this->frontend_config_.window.step_size_ms = this->features_step_size_;
|
||||
this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
|
||||
this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT;
|
||||
this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT;
|
||||
this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS;
|
||||
this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING;
|
||||
this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING;
|
||||
this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING;
|
||||
this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN;
|
||||
this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH;
|
||||
this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET;
|
||||
this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS;
|
||||
this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG;
|
||||
this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT;
|
||||
|
||||
this->event_group_ = xEventGroupCreate();
|
||||
if (this->event_group_ == nullptr) {
|
||||
ESP_LOGE(TAG, "Failed to create event group");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent));
|
||||
if (this->detection_queue_ == nullptr) {
|
||||
ESP_LOGE(TAG, "Failed to create detection event queue");
|
||||
this->mark_failed();
|
||||
return;
|
||||
}
|
||||
|
||||
this->microphone_source_->add_data_callback([this](const std::vector<uint8_t> &data) {
|
||||
if (this->state_ != State::DETECTING_WAKE_WORD) {
|
||||
if (this->state_ == State::STOPPED) {
|
||||
return;
|
||||
}
|
||||
std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_;
|
||||
if (this->ring_buffer_.use_count() == 2) {
|
||||
// mWW still owns the ring buffer and temp_ring_buffer does as well, proceed to copy audio into ring buffer
|
||||
|
||||
std::shared_ptr<RingBuffer> temp_ring_buffer = this->ring_buffer_.lock();
|
||||
if (this->ring_buffer_.use_count() > 1) {
|
||||
size_t bytes_free = temp_ring_buffer->free();
|
||||
|
||||
if (bytes_free < data.size()) {
|
||||
ESP_LOGW(
|
||||
TAG,
|
||||
"Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). "
|
||||
"Resetting the ring buffer. Wake word detection accuracy will be reduced.",
|
||||
bytes_free, data.size());
|
||||
|
||||
xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
|
||||
temp_ring_buffer->reset();
|
||||
}
|
||||
temp_ring_buffer->write((void *) data.data(), data.size());
|
||||
}
|
||||
});
|
||||
|
||||
if (!this->register_streaming_ops_(this->streaming_op_resolver_)) {
|
||||
this->mark_failed();
|
||||
return;
|
||||
#ifdef USE_OTA
|
||||
ota::get_global_ota_callback()->add_on_state_callback(
|
||||
[this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) {
|
||||
if (state == ota::OTA_STARTED) {
|
||||
this->suspend_task_();
|
||||
} else if (state == ota::OTA_ERROR) {
|
||||
this->resume_task_();
|
||||
}
|
||||
});
|
||||
#endif
|
||||
ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
|
||||
}
|
||||
|
||||
void MicroWakeWord::inference_task(void *params) {
|
||||
MicroWakeWord *this_mww = (MicroWakeWord *) params;
|
||||
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING);
|
||||
|
||||
{ // Ensures any C++ objects fall out of scope to deallocate before deleting the task
|
||||
const size_t new_samples_to_read = this_mww->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000);
|
||||
std::unique_ptr<audio::AudioSourceTransferBuffer> audio_buffer;
|
||||
int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE];
|
||||
|
||||
if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
|
||||
// Allocate audio transfer buffer
|
||||
audio_buffer = audio::AudioSourceTransferBuffer::create(new_samples_to_read * sizeof(int16_t));
|
||||
|
||||
if (audio_buffer == nullptr) {
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
|
||||
}
|
||||
}
|
||||
|
||||
if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
|
||||
// Allocate ring buffer
|
||||
std::shared_ptr<RingBuffer> temp_ring_buffer = RingBuffer::create(RING_BUFFER_SIZE);
|
||||
if (temp_ring_buffer.use_count() == 0) {
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY);
|
||||
}
|
||||
audio_buffer->set_source(temp_ring_buffer);
|
||||
this_mww->ring_buffer_ = temp_ring_buffer;
|
||||
}
|
||||
|
||||
if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) {
|
||||
this_mww->microphone_source_->start();
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING);
|
||||
|
||||
while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) {
|
||||
audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS));
|
||||
|
||||
if (audio_buffer->available() < new_samples_to_read * sizeof(int16_t)) {
|
||||
// Insufficient data to generate new spectrogram features, read more next iteration
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generate new spectrogram features
|
||||
size_t processed_samples = this_mww->generate_features_(
|
||||
(int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer);
|
||||
audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t));
|
||||
|
||||
// Run inference using the new spectorgram features
|
||||
if (!this_mww->update_model_probabilities_(features_buffer)) {
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE);
|
||||
break;
|
||||
}
|
||||
|
||||
// Process each model's probabilities and possibly send a Detection Event to the queue
|
||||
this_mww->process_probabilities_();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ESP_LOGCONFIG(TAG, "Micro Wake Word initialized");
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING);
|
||||
|
||||
this->frontend_config_.window.size_ms = FEATURE_DURATION_MS;
|
||||
this->frontend_config_.window.step_size_ms = this->features_step_size_;
|
||||
this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE;
|
||||
this->frontend_config_.filterbank.lower_band_limit = 125.0;
|
||||
this->frontend_config_.filterbank.upper_band_limit = 7500.0;
|
||||
this->frontend_config_.noise_reduction.smoothing_bits = 10;
|
||||
this->frontend_config_.noise_reduction.even_smoothing = 0.025;
|
||||
this->frontend_config_.noise_reduction.odd_smoothing = 0.06;
|
||||
this->frontend_config_.noise_reduction.min_signal_remaining = 0.05;
|
||||
this->frontend_config_.pcan_gain_control.enable_pcan = 1;
|
||||
this->frontend_config_.pcan_gain_control.strength = 0.95;
|
||||
this->frontend_config_.pcan_gain_control.offset = 80.0;
|
||||
this->frontend_config_.pcan_gain_control.gain_bits = 21;
|
||||
this->frontend_config_.log_scale.enable_log = 1;
|
||||
this->frontend_config_.log_scale.scale_shift = 6;
|
||||
this_mww->unload_models_();
|
||||
this_mww->microphone_source_->stop();
|
||||
FrontendFreeStateContents(&this_mww->frontend_state_);
|
||||
|
||||
xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED);
|
||||
while (true) {
|
||||
// Continuously delay until the main loop deletes the task
|
||||
delay(10);
|
||||
}
|
||||
}
|
||||
|
||||
void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff,
|
||||
size_t sliding_window_average_size, const std::string &wake_word,
|
||||
size_t tensor_arena_size) {
|
||||
this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word,
|
||||
tensor_arena_size);
|
||||
std::vector<WakeWordModel *> MicroWakeWord::get_wake_words() {
|
||||
std::vector<WakeWordModel *> external_wake_word_models;
|
||||
for (auto *model : this->wake_word_models_) {
|
||||
if (!model->get_internal_only()) {
|
||||
external_wake_word_models.push_back(model);
|
||||
}
|
||||
}
|
||||
return external_wake_word_models;
|
||||
}
|
||||
|
||||
void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); }
|
||||
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size,
|
||||
void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
|
||||
size_t tensor_arena_size) {
|
||||
this->vad_model_ = make_unique<VADModel>(model_start, probability_cutoff, sliding_window_size, tensor_arena_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
void MicroWakeWord::suspend_task_() {
|
||||
if (this->inference_task_handle_ != nullptr) {
|
||||
vTaskSuspend(this->inference_task_handle_);
|
||||
}
|
||||
}
|
||||
|
||||
void MicroWakeWord::resume_task_() {
|
||||
if (this->inference_task_handle_ != nullptr) {
|
||||
vTaskResume(this->inference_task_handle_);
|
||||
}
|
||||
}
|
||||
|
||||
void MicroWakeWord::loop() {
|
||||
uint32_t event_group_bits = xEventGroupGetBits(this->event_group_);
|
||||
|
||||
if (event_group_bits & EventGroupBits::ERROR_MEMORY) {
|
||||
xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY);
|
||||
ESP_LOGE(TAG, "Encountered an error allocating buffers");
|
||||
}
|
||||
|
||||
if (event_group_bits & EventGroupBits::ERROR_INFERENCE) {
|
||||
xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE);
|
||||
ESP_LOGE(TAG, "Encountered an error while performing an inference");
|
||||
}
|
||||
|
||||
if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) {
|
||||
xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER);
|
||||
ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake "
|
||||
"word detection accuracy will temporarily be reduced.");
|
||||
}
|
||||
|
||||
if (event_group_bits & EventGroupBits::TASK_STARTING) {
|
||||
ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers");
|
||||
xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING);
|
||||
}
|
||||
|
||||
if (event_group_bits & EventGroupBits::TASK_RUNNING) {
|
||||
ESP_LOGD(TAG, "Inference task is running");
|
||||
|
||||
xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING);
|
||||
this->set_state_(State::DETECTING_WAKE_WORD);
|
||||
}
|
||||
|
||||
if (event_group_bits & EventGroupBits::TASK_STOPPING) {
|
||||
ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers");
|
||||
xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING);
|
||||
}
|
||||
|
||||
if ((event_group_bits & EventGroupBits::TASK_STOPPED)) {
|
||||
ESP_LOGD(TAG, "Inference task is finished, freeing task resources");
|
||||
vTaskDelete(this->inference_task_handle_);
|
||||
this->inference_task_handle_ = nullptr;
|
||||
xEventGroupClearBits(this->event_group_, ALL_BITS);
|
||||
xQueueReset(this->detection_queue_);
|
||||
this->set_state_(State::STOPPED);
|
||||
}
|
||||
|
||||
if ((this->pending_start_) && (this->state_ == State::STOPPED)) {
|
||||
this->set_state_(State::STARTING);
|
||||
this->pending_start_ = false;
|
||||
}
|
||||
|
||||
if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) {
|
||||
this->set_state_(State::STOPPING);
|
||||
this->pending_stop_ = false;
|
||||
}
|
||||
|
||||
switch (this->state_) {
|
||||
case State::IDLE:
|
||||
break;
|
||||
case State::START_MICROPHONE:
|
||||
ESP_LOGD(TAG, "Starting Microphone");
|
||||
this->microphone_source_->start();
|
||||
this->set_state_(State::STARTING_MICROPHONE);
|
||||
break;
|
||||
case State::STARTING_MICROPHONE:
|
||||
if (this->microphone_source_->is_running()) {
|
||||
this->set_state_(State::DETECTING_WAKE_WORD);
|
||||
}
|
||||
break;
|
||||
case State::DETECTING_WAKE_WORD:
|
||||
while (this->has_enough_samples_()) {
|
||||
this->update_model_probabilities_();
|
||||
if (this->detect_wake_words_()) {
|
||||
ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str());
|
||||
this->detected_ = true;
|
||||
this->set_state_(State::STOP_MICROPHONE);
|
||||
case State::STARTING:
|
||||
if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) {
|
||||
// Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it
|
||||
// uses floating point operations.
|
||||
if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) {
|
||||
this->status_momentary_error(
|
||||
"Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000);
|
||||
return;
|
||||
}
|
||||
|
||||
xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this,
|
||||
INFERENCE_TASK_PRIORITY, &this->inference_task_handle_);
|
||||
|
||||
if (this->inference_task_handle_ == nullptr) {
|
||||
FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state
|
||||
this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case State::STOP_MICROPHONE:
|
||||
ESP_LOGD(TAG, "Stopping Microphone");
|
||||
this->microphone_source_->stop();
|
||||
this->set_state_(State::STOPPING_MICROPHONE);
|
||||
this->unload_models_();
|
||||
this->deallocate_buffers_();
|
||||
break;
|
||||
case State::STOPPING_MICROPHONE:
|
||||
if (this->microphone_source_->is_stopped()) {
|
||||
this->set_state_(State::IDLE);
|
||||
if (this->detected_) {
|
||||
this->wake_word_detected_trigger_->trigger(this->detected_wake_word_);
|
||||
this->detected_ = false;
|
||||
this->detected_wake_word_ = "";
|
||||
case State::DETECTING_WAKE_WORD: {
|
||||
DetectionEvent detection_event;
|
||||
while (xQueueReceive(this->detection_queue_, &detection_event, 0)) {
|
||||
if (detection_event.blocked_by_vad) {
|
||||
ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str());
|
||||
} else {
|
||||
constexpr float uint8_to_float_divisor =
|
||||
255.0f; // Converting a quantized uint8 probability to floating point
|
||||
ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f",
|
||||
detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor),
|
||||
(detection_event.max_probability / uint8_to_float_divisor));
|
||||
this->wake_word_detected_trigger_->trigger(*detection_event.wake_word);
|
||||
if (this->stop_after_detection_) {
|
||||
this->stop();
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case State::STOPPING:
|
||||
xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP);
|
||||
break;
|
||||
case State::STOPPED:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -177,199 +350,40 @@ void MicroWakeWord::start() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->state_ != State::IDLE) {
|
||||
ESP_LOGW(TAG, "Wake word is already running");
|
||||
if (this->is_running()) {
|
||||
ESP_LOGW(TAG, "Wake word detection is already running");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this->load_models_() || !this->allocate_buffers_()) {
|
||||
ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers");
|
||||
this->status_set_error();
|
||||
} else {
|
||||
this->status_clear_error();
|
||||
}
|
||||
ESP_LOGD(TAG, "Starting wake word detection");
|
||||
|
||||
if (this->status_has_error()) {
|
||||
ESP_LOGW(TAG, "Wake word component has an error. Please check logs");
|
||||
return;
|
||||
}
|
||||
|
||||
this->reset_states_();
|
||||
this->set_state_(State::START_MICROPHONE);
|
||||
this->pending_start_ = true;
|
||||
this->pending_stop_ = false;
|
||||
}
|
||||
|
||||
void MicroWakeWord::stop() {
|
||||
if (this->state_ == State::IDLE) {
|
||||
ESP_LOGW(TAG, "Wake word is already stopped");
|
||||
if (this->state_ == STOPPED)
|
||||
return;
|
||||
}
|
||||
if (this->state_ == State::STOPPING_MICROPHONE) {
|
||||
ESP_LOGW(TAG, "Wake word is already stopping");
|
||||
return;
|
||||
}
|
||||
this->set_state_(State::STOP_MICROPHONE);
|
||||
|
||||
ESP_LOGD(TAG, "Stopping wake word detection");
|
||||
|
||||
this->pending_start_ = false;
|
||||
this->pending_stop_ = true;
|
||||
}
|
||||
|
||||
void MicroWakeWord::set_state_(State state) {
|
||||
ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
|
||||
LOG_STR_ARG(micro_wake_word_state_to_string(state)));
|
||||
this->state_ = state;
|
||||
if (this->state_ != state) {
|
||||
ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)),
|
||||
LOG_STR_ARG(micro_wake_word_state_to_string(state)));
|
||||
this->state_ = state;
|
||||
}
|
||||
}
|
||||
|
||||
bool MicroWakeWord::allocate_buffers_() {
|
||||
ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
|
||||
|
||||
if (this->input_buffer_ == nullptr) {
|
||||
this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t));
|
||||
if (this->input_buffer_ == nullptr) {
|
||||
ESP_LOGE(TAG, "Could not allocate input buffer");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (this->preprocessor_audio_buffer_ == nullptr) {
|
||||
this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_());
|
||||
if (this->preprocessor_audio_buffer_ == nullptr) {
|
||||
ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (this->ring_buffer_.use_count() == 0) {
|
||||
this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t));
|
||||
if (this->ring_buffer_.use_count() == 0) {
|
||||
ESP_LOGE(TAG, "Could not allocate ring buffer");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void MicroWakeWord::deallocate_buffers_() {
|
||||
ExternalRAMAllocator<int16_t> audio_samples_allocator(ExternalRAMAllocator<int16_t>::ALLOW_FAILURE);
|
||||
if (this->input_buffer_ != nullptr) {
|
||||
audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t));
|
||||
this->input_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
if (this->preprocessor_audio_buffer_ != nullptr) {
|
||||
audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_());
|
||||
this->preprocessor_audio_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
this->ring_buffer_.reset();
|
||||
}
|
||||
|
||||
bool MicroWakeWord::load_models_() {
|
||||
// Setup preprocesor feature generator
|
||||
if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) {
|
||||
ESP_LOGD(TAG, "Failed to populate frontend state");
|
||||
FrontendFreeStateContents(&this->frontend_state_);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Setup streaming models
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
if (!model.load_model(this->streaming_op_resolver_)) {
|
||||
ESP_LOGE(TAG, "Failed to initialize a wake word model.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
if (!this->vad_model_->load_model(this->streaming_op_resolver_)) {
|
||||
ESP_LOGE(TAG, "Failed to initialize VAD model.");
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void MicroWakeWord::unload_models_() {
|
||||
FrontendFreeStateContents(&this->frontend_state_);
|
||||
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
model.unload_model();
|
||||
}
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
this->vad_model_->unload_model();
|
||||
#endif
|
||||
}
|
||||
|
||||
void MicroWakeWord::update_model_probabilities_() {
|
||||
int8_t audio_features[PREPROCESSOR_FEATURE_SIZE];
|
||||
|
||||
if (!this->generate_features_for_window_(audio_features)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Increase the counter since the last positive detection
|
||||
this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
|
||||
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
// Perform inference
|
||||
model.perform_streaming_inference(audio_features);
|
||||
}
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
this->vad_model_->perform_streaming_inference(audio_features);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool MicroWakeWord::detect_wake_words_() {
|
||||
// Verify we have processed samples since the last positive detection
|
||||
if (this->ignore_windows_ < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
bool vad_state = this->vad_model_->determine_detected();
|
||||
#endif
|
||||
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
if (model.determine_detected()) {
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
if (vad_state) {
|
||||
#endif
|
||||
this->detected_wake_word_ = model.get_wake_word();
|
||||
return true;
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
} else {
|
||||
ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MicroWakeWord::has_enough_samples_() {
|
||||
return this->ring_buffer_->available() >=
|
||||
(this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t);
|
||||
}
|
||||
|
||||
bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
|
||||
// Ensure we have enough new audio samples in the ring buffer for a full window
|
||||
if (!this->has_enough_samples_()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_),
|
||||
this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200));
|
||||
|
||||
if (bytes_read == 0) {
|
||||
ESP_LOGE(TAG, "Could not read data from Ring Buffer");
|
||||
} else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) {
|
||||
ESP_LOGD(TAG, "Partial Read of Data by Model");
|
||||
ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read,
|
||||
(int) (this->new_samples_to_get_() * sizeof(int16_t)));
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t num_samples_read;
|
||||
struct FrontendOutput frontend_output = FrontendProcessSamples(
|
||||
&this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read);
|
||||
size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available,
|
||||
int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) {
|
||||
size_t processed_samples = 0;
|
||||
struct FrontendOutput frontend_output =
|
||||
FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples);
|
||||
|
||||
for (size_t i = 0; i < frontend_output.size; ++i) {
|
||||
// These scaling values are set to match the TFLite audio frontend int8 output.
|
||||
@ -379,8 +393,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F
|
||||
// for historical reasons, to match up with the output of other feature
|
||||
// generators.
|
||||
// The process is then further complicated when we quantize the model. This
|
||||
// means we have to scale the 0.0 to 26.0 real values to the -128 to 127
|
||||
// signed integer numbers.
|
||||
// means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN)
|
||||
// to 127 (INT8_MAX) signed integer numbers.
|
||||
// All this means that to get matching values from our integer feature
|
||||
// output into the tensor input, we have to perform:
|
||||
// input = (((feature / 25.6) / 26.0) * 256) - 128
|
||||
@ -389,74 +403,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F
|
||||
constexpr int32_t value_scale = 256;
|
||||
constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding
|
||||
int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div;
|
||||
value -= 128;
|
||||
if (value < -128) {
|
||||
value = -128;
|
||||
}
|
||||
if (value > 127) {
|
||||
value = 127;
|
||||
}
|
||||
features[i] = value;
|
||||
|
||||
value -= INT8_MIN;
|
||||
features_buffer[i] = clamp<int8_t>(value, INT8_MIN, INT8_MAX);
|
||||
}
|
||||
|
||||
return true;
|
||||
return processed_samples;
|
||||
}
|
||||
|
||||
void MicroWakeWord::reset_states_() {
|
||||
ESP_LOGD(TAG, "Resetting buffers and probabilities");
|
||||
this->ring_buffer_->reset();
|
||||
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
|
||||
void MicroWakeWord::process_probabilities_() {
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
DetectionEvent vad_state = this->vad_model_->determine_detected();
|
||||
|
||||
this->vad_state_ = vad_state.detected; // atomic write, so thread safe
|
||||
#endif
|
||||
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
model.reset_probabilities();
|
||||
if (model->get_unprocessed_probability_status()) {
|
||||
// Only detect wake words if there is a new probability since the last check
|
||||
DetectionEvent wake_word_state = model->determine_detected();
|
||||
if (wake_word_state.detected) {
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
if (vad_state.detected) {
|
||||
#endif
|
||||
xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
|
||||
model->reset_probabilities();
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
} else {
|
||||
wake_word_state.blocked_by_vad = true;
|
||||
xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MicroWakeWord::unload_models_() {
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
model->unload_model();
|
||||
}
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
this->vad_model_->reset_probabilities();
|
||||
this->vad_model_->unload_model();
|
||||
#endif
|
||||
}
|
||||
|
||||
bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
|
||||
if (op_resolver.AddCallOnce() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddVarHandle() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddReshape() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddReadVariable() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddStridedSlice() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddConcatenation() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddAssignVariable() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddConv2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddMul() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddAdd() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddMean() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddFullyConnected() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddLogistic() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddQuantize() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddAveragePool2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddMaxPool2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddPad() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddPack() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddSplitV() != kTfLiteOk)
|
||||
return false;
|
||||
bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) {
|
||||
bool success = true;
|
||||
|
||||
return true;
|
||||
for (auto &model : this->wake_word_models_) {
|
||||
// Perform inference
|
||||
success = success & model->perform_streaming_inference(audio_features);
|
||||
}
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
success = success & this->vad_model_->perform_streaming_inference(audio_features);
|
||||
#endif
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
} // namespace micro_wake_word
|
||||
|
@ -5,33 +5,27 @@
|
||||
#include "preprocessor_settings.h"
|
||||
#include "streaming_model.h"
|
||||
|
||||
#include "esphome/components/microphone/microphone_source.h"
|
||||
|
||||
#include "esphome/core/automation.h"
|
||||
#include "esphome/core/component.h"
|
||||
#include "esphome/core/ring_buffer.h"
|
||||
|
||||
#include "esphome/components/microphone/microphone_source.h"
|
||||
#include <freertos/event_groups.h>
|
||||
|
||||
#include <frontend.h>
|
||||
#include <frontend_util.h>
|
||||
|
||||
#include <tensorflow/lite/core/c/common.h>
|
||||
#include <tensorflow/lite/micro/micro_interpreter.h>
|
||||
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
|
||||
|
||||
namespace esphome {
|
||||
namespace micro_wake_word {
|
||||
|
||||
enum State {
|
||||
IDLE,
|
||||
START_MICROPHONE,
|
||||
STARTING_MICROPHONE,
|
||||
STARTING,
|
||||
DETECTING_WAKE_WORD,
|
||||
STOP_MICROPHONE,
|
||||
STOPPING_MICROPHONE,
|
||||
STOPPING,
|
||||
STOPPED,
|
||||
};
|
||||
|
||||
// The number of audio slices to process before accepting a positive detection
|
||||
static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74;
|
||||
|
||||
class MicroWakeWord : public Component {
|
||||
public:
|
||||
void setup() override;
|
||||
@ -42,7 +36,7 @@ class MicroWakeWord : public Component {
|
||||
void start();
|
||||
void stop();
|
||||
|
||||
bool is_running() const { return this->state_ != State::IDLE; }
|
||||
bool is_running() const { return this->state_ != State::STOPPED; }
|
||||
|
||||
void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; }
|
||||
|
||||
@ -50,118 +44,87 @@ class MicroWakeWord : public Component {
|
||||
this->microphone_source_ = microphone_source;
|
||||
}
|
||||
|
||||
void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; }
|
||||
|
||||
Trigger<std::string> *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; }
|
||||
|
||||
void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
|
||||
const std::string &wake_word, size_t tensor_arena_size);
|
||||
void add_wake_word_model(WakeWordModel *model);
|
||||
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size,
|
||||
void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
|
||||
size_t tensor_arena_size);
|
||||
|
||||
// Intended for the voice assistant component to fetch VAD status
|
||||
bool get_vad_state() { return this->vad_state_; }
|
||||
#endif
|
||||
|
||||
// Intended for the voice assistant component to access which wake words are available
|
||||
// Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them
|
||||
std::vector<WakeWordModel *> get_wake_words();
|
||||
|
||||
protected:
|
||||
microphone::MicrophoneSource *microphone_source_{nullptr};
|
||||
Trigger<std::string> *wake_word_detected_trigger_ = new Trigger<std::string>();
|
||||
State state_{State::IDLE};
|
||||
State state_{State::STOPPED};
|
||||
|
||||
std::shared_ptr<RingBuffer> ring_buffer_;
|
||||
|
||||
std::vector<WakeWordModel> wake_word_models_;
|
||||
std::weak_ptr<RingBuffer> ring_buffer_;
|
||||
std::vector<WakeWordModel *> wake_word_models_;
|
||||
|
||||
#ifdef USE_MICRO_WAKE_WORD_VAD
|
||||
std::unique_ptr<VADModel> vad_model_;
|
||||
bool vad_state_{false};
|
||||
#endif
|
||||
|
||||
tflite::MicroMutableOpResolver<20> streaming_op_resolver_;
|
||||
bool pending_start_{false};
|
||||
bool pending_stop_{false};
|
||||
|
||||
bool stop_after_detection_;
|
||||
|
||||
uint8_t features_step_size_;
|
||||
|
||||
// Audio frontend handles generating spectrogram features
|
||||
struct FrontendConfig frontend_config_;
|
||||
struct FrontendState frontend_state_;
|
||||
|
||||
// When the wake word detection first starts, we ignore this many audio
|
||||
// feature slices before accepting a positive detection
|
||||
int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
|
||||
// Handles managing the stop/state of the inference task
|
||||
EventGroupHandle_t event_group_;
|
||||
|
||||
uint8_t features_step_size_;
|
||||
// Used to send messages about the models' states to the main loop
|
||||
QueueHandle_t detection_queue_;
|
||||
|
||||
// Stores audio read from the microphone before being added to the ring buffer.
|
||||
int16_t *input_buffer_{nullptr};
|
||||
// Stores audio to be fed into the audio frontend for generating features.
|
||||
int16_t *preprocessor_audio_buffer_{nullptr};
|
||||
static void inference_task(void *params);
|
||||
TaskHandle_t inference_task_handle_{nullptr};
|
||||
|
||||
bool detected_{false};
|
||||
std::string detected_wake_word_{""};
|
||||
/// @brief Suspends the inference task
|
||||
void suspend_task_();
|
||||
/// @brief Resumes the inference task
|
||||
void resume_task_();
|
||||
|
||||
void set_state_(State state);
|
||||
|
||||
/// @brief Tests if there are enough samples in the ring buffer to generate new features.
|
||||
/// @return True if enough samples, false otherwise.
|
||||
bool has_enough_samples_();
|
||||
/// @brief Generates spectrogram features from an input buffer of audio samples
|
||||
/// @param audio_buffer (int16_t *) Buffer containing input audio samples
|
||||
/// @param samples_available (size_t) Number of samples avaiable in the input buffer
|
||||
/// @param features_buffer (int8_t *) Buffer to store generated features
|
||||
/// @return (size_t) Number of samples processed from the input buffer
|
||||
size_t generate_features_(int16_t *audio_buffer, size_t samples_available,
|
||||
int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]);
|
||||
|
||||
/// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_
|
||||
/// @return True if successful, false otherwise
|
||||
bool allocate_buffers_();
|
||||
/// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent
|
||||
/// to the detection_queue_.
|
||||
void process_probabilities_();
|
||||
|
||||
/// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_
|
||||
void deallocate_buffers_();
|
||||
|
||||
/// @brief Loads streaming models and prepares the feature generation frontend
|
||||
/// @return True if successful, false otherwise
|
||||
bool load_models_();
|
||||
|
||||
/// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature
|
||||
/// generation frontend.
|
||||
/// @brief Deletes each model's TFLite interpreters and frees tensor arena memory.
|
||||
void unload_models_();
|
||||
|
||||
/** Performs inference with each configured model
|
||||
*
|
||||
* If enough audio samples are available, it will generate one slice of new features.
|
||||
* It then loops through and performs inference with each of the loaded models.
|
||||
*/
|
||||
void update_model_probabilities_();
|
||||
|
||||
/** Checks every model's recent probabilities to determine if the wake word has been predicted
|
||||
*
|
||||
* Verifies the models have processed enough new samples for accurate predictions.
|
||||
* Sets detected_wake_word_ to the wake word, if one is detected.
|
||||
* @return True if a wake word is predicted, false otherwise
|
||||
*/
|
||||
bool detect_wake_words_();
|
||||
|
||||
/** Generates features for a window of audio samples
|
||||
*
|
||||
* Reads samples from the ring buffer and feeds them into the preprocessor frontend.
|
||||
* Adapted from TFLite microspeech frontend.
|
||||
* @param features int8_t array to store the audio features
|
||||
* @return True if successful, false otherwise.
|
||||
*/
|
||||
bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]);
|
||||
|
||||
/// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities
|
||||
void reset_states_();
|
||||
|
||||
/// @brief Returns true if successfully registered the streaming model's TensorFlow operations
|
||||
bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver);
|
||||
/// @brief Runs an inference with each model using the new spectrogram features
|
||||
/// @param audio_features (int8_t *) Buffer containing new spectrogram features
|
||||
/// @return True if successful, false if any errors were encountered
|
||||
bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]);
|
||||
|
||||
inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); }
|
||||
};
|
||||
|
||||
template<typename... Ts> class StartAction : public Action<Ts...>, public Parented<MicroWakeWord> {
|
||||
public:
|
||||
void play(Ts... x) override { this->parent_->start(); }
|
||||
};
|
||||
|
||||
template<typename... Ts> class StopAction : public Action<Ts...>, public Parented<MicroWakeWord> {
|
||||
public:
|
||||
void play(Ts... x) override { this->parent_->stop(); }
|
||||
};
|
||||
|
||||
template<typename... Ts> class IsRunningCondition : public Condition<Ts...>, public Parented<MicroWakeWord> {
|
||||
public:
|
||||
bool check(Ts... x) override { return this->parent_->is_running(); }
|
||||
};
|
||||
|
||||
} // namespace micro_wake_word
|
||||
} // namespace esphome
|
||||
|
||||
|
@ -7,6 +7,10 @@
|
||||
namespace esphome {
|
||||
namespace micro_wake_word {
|
||||
|
||||
// Settings for controlling the spectrogram feature generation by the preprocessor.
|
||||
// These must match the settings used when training a particular model.
|
||||
// All microWakeWord models have been trained with these specific paramters.
|
||||
|
||||
// The number of features the audio preprocessor generates per slice
|
||||
static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40;
|
||||
// Duration of each slice used as input into the preprocessor
|
||||
@ -14,6 +18,21 @@ static const uint8_t FEATURE_DURATION_MS = 30;
|
||||
// Audio sample frequency in hertz
|
||||
static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000;
|
||||
|
||||
static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0;
|
||||
static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0;
|
||||
|
||||
static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10;
|
||||
static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025;
|
||||
static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06;
|
||||
static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05;
|
||||
|
||||
static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true;
|
||||
static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95;
|
||||
static const float PCAN_GAIN_CONTROL_OFFSET = 80.0;
|
||||
static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21;
|
||||
|
||||
static const bool LOG_SCALE_ENABLE_LOG = true;
|
||||
static const uint8_t LOG_SCALE_SCALE_SHIFT = 6;
|
||||
} // namespace micro_wake_word
|
||||
} // namespace esphome
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
#ifdef USE_ESP_IDF
|
||||
|
||||
#include "streaming_model.h"
|
||||
|
||||
#include "esphome/core/hal.h"
|
||||
#ifdef USE_ESP_IDF
|
||||
|
||||
#include "esphome/core/helpers.h"
|
||||
#include "esphome/core/log.h"
|
||||
|
||||
@ -13,18 +12,18 @@ namespace micro_wake_word {
|
||||
|
||||
void WakeWordModel::log_model_config() {
|
||||
ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str());
|
||||
ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
|
||||
ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f);
|
||||
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
|
||||
}
|
||||
|
||||
void VADModel::log_model_config() {
|
||||
ESP_LOGCONFIG(TAG, " - VAD Model");
|
||||
ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_);
|
||||
ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f);
|
||||
ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_);
|
||||
}
|
||||
|
||||
bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) {
|
||||
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
|
||||
bool StreamingModel::load_model_() {
|
||||
RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE);
|
||||
|
||||
if (this->tensor_arena_ == nullptr) {
|
||||
this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_);
|
||||
@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver)
|
||||
}
|
||||
|
||||
if (this->interpreter_ == nullptr) {
|
||||
this->interpreter_ = make_unique<tflite::MicroInterpreter>(
|
||||
tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
|
||||
this->interpreter_ =
|
||||
make_unique<tflite::MicroInterpreter>(tflite::GetModel(this->model_start_), this->streaming_op_resolver_,
|
||||
this->tensor_arena_, this->tensor_arena_size_, this->mrv_);
|
||||
if (this->interpreter_->AllocateTensors() != kTfLiteOk) {
|
||||
ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model");
|
||||
return false;
|
||||
@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver)
|
||||
}
|
||||
}
|
||||
|
||||
this->loaded_ = true;
|
||||
this->reset_probabilities();
|
||||
return true;
|
||||
}
|
||||
|
||||
void StreamingModel::unload_model() {
|
||||
this->interpreter_.reset();
|
||||
|
||||
ExternalRAMAllocator<uint8_t> arena_allocator(ExternalRAMAllocator<uint8_t>::ALLOW_FAILURE);
|
||||
RAMAllocator<uint8_t> arena_allocator(RAMAllocator<uint8_t>::ALLOW_FAILURE);
|
||||
|
||||
arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
|
||||
this->tensor_arena_ = nullptr;
|
||||
arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
|
||||
this->var_arena_ = nullptr;
|
||||
if (this->tensor_arena_ != nullptr) {
|
||||
arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_);
|
||||
this->tensor_arena_ = nullptr;
|
||||
}
|
||||
|
||||
if (this->var_arena_ != nullptr) {
|
||||
arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE);
|
||||
this->var_arena_ = nullptr;
|
||||
}
|
||||
|
||||
this->loaded_ = false;
|
||||
}
|
||||
|
||||
bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) {
|
||||
if (this->interpreter_ != nullptr) {
|
||||
if (this->enabled_ && !this->loaded_) {
|
||||
// Model is enabled but isn't loaded
|
||||
if (!this->load_model_()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!this->enabled_ && this->loaded_) {
|
||||
// Model is disabled but still loaded
|
||||
this->unload_model();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this->loaded_) {
|
||||
TfLiteTensor *input = this->interpreter_->input(0);
|
||||
|
||||
uint8_t stride = this->interpreter_->input(0)->dims->data[1];
|
||||
this->current_stride_step_ = this->current_stride_step_ % stride;
|
||||
|
||||
std::memmove(
|
||||
(int8_t *) (tflite::GetTensorData<int8_t>(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_,
|
||||
features, PREPROCESSOR_FEATURE_SIZE);
|
||||
++this->current_stride_step_;
|
||||
|
||||
uint8_t stride = this->interpreter_->input(0)->dims->data[1];
|
||||
|
||||
if (this->current_stride_step_ >= stride) {
|
||||
this->current_stride_step_ = 0;
|
||||
|
||||
TfLiteStatus invoke_status = this->interpreter_->Invoke();
|
||||
if (invoke_status != kTfLiteOk) {
|
||||
ESP_LOGW(TAG, "Streaming interpreter invoke failed");
|
||||
@ -124,65 +145,159 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES
|
||||
if (this->last_n_index_ == this->sliding_window_size_)
|
||||
this->last_n_index_ = 0;
|
||||
this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability;
|
||||
this->unprocessed_probability_status_ = true;
|
||||
}
|
||||
return true;
|
||||
this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0);
|
||||
}
|
||||
ESP_LOGE(TAG, "Streaming interpreter is not initialized.");
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void StreamingModel::reset_probabilities() {
|
||||
for (auto &prob : this->recent_streaming_probabilities_) {
|
||||
prob = 0;
|
||||
}
|
||||
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
|
||||
}
|
||||
|
||||
WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
|
||||
const std::string &wake_word, size_t tensor_arena_size) {
|
||||
WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff,
|
||||
size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
|
||||
bool default_enabled, bool internal_only) {
|
||||
this->id_ = id;
|
||||
this->model_start_ = model_start;
|
||||
this->probability_cutoff_ = probability_cutoff;
|
||||
this->sliding_window_size_ = sliding_window_average_size;
|
||||
this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0);
|
||||
this->wake_word_ = wake_word;
|
||||
this->tensor_arena_size_ = tensor_arena_size;
|
||||
this->register_streaming_ops_(this->streaming_op_resolver_);
|
||||
this->current_stride_step_ = 0;
|
||||
this->internal_only_ = internal_only;
|
||||
|
||||
this->pref_ = global_preferences->make_preference<bool>(fnv1_hash(id));
|
||||
bool enabled;
|
||||
if (this->pref_.load(&enabled)) {
|
||||
// Use the enabled state loaded from flash
|
||||
this->enabled_ = enabled;
|
||||
} else {
|
||||
// If no state saved, then use the default
|
||||
this->enabled_ = default_enabled;
|
||||
}
|
||||
};
|
||||
|
||||
bool WakeWordModel::determine_detected() {
|
||||
void WakeWordModel::enable() {
|
||||
this->enabled_ = true;
|
||||
if (!this->internal_only_) {
|
||||
this->pref_.save(&this->enabled_);
|
||||
}
|
||||
}
|
||||
|
||||
void WakeWordModel::disable() {
|
||||
this->enabled_ = false;
|
||||
if (!this->internal_only_) {
|
||||
this->pref_.save(&this->enabled_);
|
||||
}
|
||||
}
|
||||
|
||||
DetectionEvent WakeWordModel::determine_detected() {
|
||||
DetectionEvent detection_event;
|
||||
detection_event.wake_word = &this->wake_word_;
|
||||
detection_event.max_probability = 0;
|
||||
detection_event.average_probability = 0;
|
||||
|
||||
if ((this->ignore_windows_ < 0) || !this->enabled_) {
|
||||
detection_event.detected = false;
|
||||
return detection_event;
|
||||
}
|
||||
|
||||
uint32_t sum = 0;
|
||||
for (auto &prob : this->recent_streaming_probabilities_) {
|
||||
detection_event.max_probability = std::max(detection_event.max_probability, prob);
|
||||
sum += prob;
|
||||
}
|
||||
|
||||
float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_);
|
||||
detection_event.average_probability = sum / this->sliding_window_size_;
|
||||
detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_;
|
||||
|
||||
// Detect the wake word if the sliding window average is above the cutoff
|
||||
if (sliding_window_average > this->probability_cutoff_) {
|
||||
ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f",
|
||||
this->wake_word_.c_str(), sliding_window_average,
|
||||
this->recent_streaming_probabilities_[this->last_n_index_] / (255.0));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
this->unprocessed_probability_status_ = false;
|
||||
return detection_event;
|
||||
}
|
||||
|
||||
VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size,
|
||||
VADModel::VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
|
||||
size_t tensor_arena_size) {
|
||||
this->model_start_ = model_start;
|
||||
this->probability_cutoff_ = probability_cutoff;
|
||||
this->sliding_window_size_ = sliding_window_size;
|
||||
this->recent_streaming_probabilities_.resize(sliding_window_size, 0);
|
||||
this->tensor_arena_size_ = tensor_arena_size;
|
||||
};
|
||||
this->register_streaming_ops_(this->streaming_op_resolver_);
|
||||
}
|
||||
|
||||
DetectionEvent VADModel::determine_detected() {
|
||||
DetectionEvent detection_event;
|
||||
detection_event.max_probability = 0;
|
||||
detection_event.average_probability = 0;
|
||||
|
||||
if (!this->enabled_) {
|
||||
// We disabled the VAD model for some reason... so we shouldn't block wake words from being detected
|
||||
detection_event.detected = true;
|
||||
return detection_event;
|
||||
}
|
||||
|
||||
bool VADModel::determine_detected() {
|
||||
uint32_t sum = 0;
|
||||
for (auto &prob : this->recent_streaming_probabilities_) {
|
||||
detection_event.max_probability = std::max(detection_event.max_probability, prob);
|
||||
sum += prob;
|
||||
}
|
||||
|
||||
float sliding_window_average = static_cast<float>(sum) / static_cast<float>(255 * this->sliding_window_size_);
|
||||
detection_event.average_probability = sum / this->sliding_window_size_;
|
||||
detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_);
|
||||
|
||||
return sliding_window_average > this->probability_cutoff_;
|
||||
return detection_event;
|
||||
}
|
||||
|
||||
bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) {
|
||||
if (op_resolver.AddCallOnce() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddVarHandle() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddReshape() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddReadVariable() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddStridedSlice() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddConcatenation() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddAssignVariable() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddConv2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddMul() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddAdd() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddMean() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddFullyConnected() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddLogistic() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddQuantize() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddAveragePool2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddMaxPool2D() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddPad() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddPack() != kTfLiteOk)
|
||||
return false;
|
||||
if (op_resolver.AddSplitV() != kTfLiteOk)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace micro_wake_word
|
||||
|
@ -4,6 +4,8 @@
|
||||
|
||||
#include "preprocessor_settings.h"
|
||||
|
||||
#include "esphome/core/preferences.h"
|
||||
|
||||
#include <tensorflow/lite/core/c/common.h>
|
||||
#include <tensorflow/lite/micro/micro_interpreter.h>
|
||||
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
|
||||
@ -11,30 +13,63 @@
|
||||
namespace esphome {
|
||||
namespace micro_wake_word {
|
||||
|
||||
static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100;
|
||||
static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024;
|
||||
|
||||
struct DetectionEvent {
|
||||
std::string *wake_word;
|
||||
bool detected;
|
||||
bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average
|
||||
// hasn't yet
|
||||
uint8_t max_probability;
|
||||
uint8_t average_probability;
|
||||
bool blocked_by_vad = false;
|
||||
};
|
||||
|
||||
class StreamingModel {
|
||||
public:
|
||||
virtual void log_model_config() = 0;
|
||||
virtual bool determine_detected() = 0;
|
||||
virtual DetectionEvent determine_detected() = 0;
|
||||
|
||||
// Performs inference on the given features.
|
||||
// - If the model is enabled but not loaded, it will load it
|
||||
// - If the model is disabled but loaded, it will unload it
|
||||
// Returns true if sucessful or false if there is an error
|
||||
bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]);
|
||||
|
||||
/// @brief Sets all recent_streaming_probabilities to 0
|
||||
/// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count
|
||||
void reset_probabilities();
|
||||
|
||||
/// @brief Allocates tensor and variable arenas and sets up the model interpreter
|
||||
/// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded
|
||||
/// @return True if successful, false otherwise
|
||||
bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver);
|
||||
|
||||
/// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory
|
||||
void unload_model();
|
||||
|
||||
protected:
|
||||
uint8_t current_stride_step_{0};
|
||||
/// @brief Enable the model. The next performing_streaming_inference call will load it.
|
||||
virtual void enable() { this->enabled_ = true; }
|
||||
|
||||
float probability_cutoff_;
|
||||
/// @brief Disable the model. The next performing_streaming_inference call will unload it.
|
||||
virtual void disable() { this->enabled_ = false; }
|
||||
|
||||
/// @brief Return true if the model is enabled.
|
||||
bool is_enabled() { return this->enabled_; }
|
||||
|
||||
bool get_unprocessed_probability_status() { return this->unprocessed_probability_status_; }
|
||||
|
||||
protected:
|
||||
/// @brief Allocates tensor and variable arenas and sets up the model interpreter
|
||||
/// @return True if successful, false otherwise
|
||||
bool load_model_();
|
||||
/// @brief Returns true if successfully registered the streaming model's TensorFlow operations
|
||||
bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver);
|
||||
|
||||
tflite::MicroMutableOpResolver<20> streaming_op_resolver_;
|
||||
|
||||
bool loaded_{false};
|
||||
bool enabled_{true};
|
||||
bool unprocessed_probability_status_{false};
|
||||
uint8_t current_stride_step_{0};
|
||||
int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};
|
||||
|
||||
uint8_t probability_cutoff_; // Quantized probability cutoff mapping 0.0 - 1.0 to 0 - 255
|
||||
size_t sliding_window_size_;
|
||||
size_t last_n_index_{0};
|
||||
size_t tensor_arena_size_;
|
||||
@ -50,32 +85,62 @@ class StreamingModel {
|
||||
|
||||
class WakeWordModel final : public StreamingModel {
|
||||
public:
|
||||
WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size,
|
||||
const std::string &wake_word, size_t tensor_arena_size);
|
||||
/// @brief Constructs a wake word model object
|
||||
/// @param id (std::string) identifier for this model
|
||||
/// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer
|
||||
/// @param probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said
|
||||
/// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling
|
||||
/// probability
|
||||
/// @param wake_word (std::string) Friendly name of the wake word
|
||||
/// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena
|
||||
/// @param default_enabled (bool) If true, it will be enabled by default on first boot
|
||||
/// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model
|
||||
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff,
|
||||
size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
|
||||
bool default_enabled, bool internal_only);
|
||||
|
||||
void log_model_config() override;
|
||||
|
||||
/// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability
|
||||
/// cutoff
|
||||
/// @return True if wake word is detected, false otherwise
|
||||
bool determine_detected() override;
|
||||
DetectionEvent determine_detected() override;
|
||||
|
||||
const std::string &get_id() const { return this->id_; }
|
||||
const std::string &get_wake_word() const { return this->wake_word_; }
|
||||
|
||||
void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); }
|
||||
const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; }
|
||||
|
||||
/// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it.
|
||||
void enable() override;
|
||||
|
||||
/// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it.
|
||||
void disable() override;
|
||||
|
||||
bool get_internal_only() { return this->internal_only_; }
|
||||
|
||||
protected:
|
||||
std::string id_;
|
||||
std::string wake_word_;
|
||||
std::vector<std::string> trained_languages_;
|
||||
|
||||
bool internal_only_;
|
||||
|
||||
ESPPreferenceObject pref_;
|
||||
};
|
||||
|
||||
class VADModel final : public StreamingModel {
|
||||
public:
|
||||
VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size);
|
||||
VADModel(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size,
|
||||
size_t tensor_arena_size);
|
||||
|
||||
void log_model_config() override;
|
||||
|
||||
/// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability
|
||||
/// cutoff
|
||||
/// @return True if voice activity is detected, false otherwise
|
||||
bool determine_detected() override;
|
||||
DetectionEvent determine_detected() override;
|
||||
};
|
||||
|
||||
} // namespace micro_wake_word
|
||||
|
@ -79,6 +79,7 @@
|
||||
#define USE_LVGL_TEXTAREA
|
||||
#define USE_LVGL_TILEVIEW
|
||||
#define USE_LVGL_TOUCHSCREEN
|
||||
#define USE_MICRO_WAKE_WORD
|
||||
#define USE_MD5
|
||||
#define USE_MDNS
|
||||
#define USE_MEDIA_PLAYER
|
||||
|
@ -14,8 +14,24 @@ micro_wake_word:
|
||||
microphone: echo_microphone
|
||||
on_wake_word_detected:
|
||||
- logger.log: "Wake word detected"
|
||||
- micro_wake_word.stop:
|
||||
- if:
|
||||
condition:
|
||||
- micro_wake_word.model_is_enabled: hey_jarvis_model
|
||||
then:
|
||||
- micro_wake_word.disable_model: hey_jarvis_model
|
||||
else:
|
||||
- micro_wake_word.enable_model: hey_jarvis_model
|
||||
- if:
|
||||
condition:
|
||||
- not:
|
||||
- micro_wake_word.is_running:
|
||||
then:
|
||||
micro_wake_word.start:
|
||||
stop_after_detection: false
|
||||
models:
|
||||
- model: hey_jarvis
|
||||
probability_cutoff: 0.7
|
||||
id: hey_jarvis_model
|
||||
- model: okay_nabu
|
||||
sliding_window_size: 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user