diff --git a/.dockerignore b/.dockerignore index 7998ff877f..ccd466d8cb 100644 --- a/.dockerignore +++ b/.dockerignore @@ -114,4 +114,5 @@ config/ examples/ Dockerfile .git/ -tests/build/ +tests/ +.* diff --git a/.github/actions/build-image/action.yaml b/.github/actions/build-image/action.yaml index 86a6b3f4d2..3d6de54f42 100644 --- a/.github/actions/build-image/action.yaml +++ b/.github/actions/build-image/action.yaml @@ -1,15 +1,11 @@ name: Build Image inputs: - platform: - description: "Platform to build for" - required: true - example: "linux/amd64" target: description: "Target to build" required: true example: "docker" - baseimg: - description: "Base image type" + build_type: + description: "Build type" required: true example: "docker" suffix: @@ -19,6 +15,11 @@ inputs: description: "Version to build" required: true example: "2023.12.0" + base_os: + description: "Base OS to use" + required: false + default: "debian" + example: "debian" runs: using: "composite" steps: @@ -46,52 +47,52 @@ runs: - name: Build and push to ghcr by digest id: build-ghcr - uses: docker/build-push-action@v6.15.0 + uses: docker/build-push-action@v6.16.0 env: DOCKER_BUILD_SUMMARY: false DOCKER_BUILD_RECORD_UPLOAD: false with: context: . file: ./docker/Dockerfile - platforms: ${{ inputs.platform }} target: ${{ inputs.target }} cache-from: type=gha cache-to: ${{ steps.cache-to.outputs.value }} build-args: | - BASEIMGTYPE=${{ inputs.baseimg }} + BUILD_TYPE=${{ inputs.build_type }} BUILD_VERSION=${{ inputs.version }} + BUILD_OS=${{ inputs.base_os }} outputs: | type=image,name=ghcr.io/${{ steps.tags.outputs.image_name }},push-by-digest=true,name-canonical=true,push=true - name: Export ghcr digests shell: bash run: | - mkdir -p /tmp/digests/${{ inputs.target }}/ghcr + mkdir -p /tmp/digests/${{ inputs.build_type }}/ghcr digest="${{ steps.build-ghcr.outputs.digest }}" - touch "/tmp/digests/${{ inputs.target }}/ghcr/${digest#sha256:}" + touch "/tmp/digests/${{ inputs.build_type }}/ghcr/${digest#sha256:}" - name: Build and push to dockerhub by digest id: build-dockerhub - uses: docker/build-push-action@v6.15.0 + uses: docker/build-push-action@v6.16.0 env: DOCKER_BUILD_SUMMARY: false DOCKER_BUILD_RECORD_UPLOAD: false with: context: . file: ./docker/Dockerfile - platforms: ${{ inputs.platform }} target: ${{ inputs.target }} cache-from: type=gha cache-to: ${{ steps.cache-to.outputs.value }} build-args: | - BASEIMGTYPE=${{ inputs.baseimg }} + BUILD_TYPE=${{ inputs.build_type }} BUILD_VERSION=${{ inputs.version }} + BUILD_OS=${{ inputs.base_os }} outputs: | type=image,name=docker.io/${{ steps.tags.outputs.image_name }},push-by-digest=true,name-canonical=true,push=true - name: Export dockerhub digests shell: bash run: | - mkdir -p /tmp/digests/${{ inputs.target }}/dockerhub + mkdir -p /tmp/digests/${{ inputs.build_type }}/dockerhub digest="${{ steps.build-dockerhub.outputs.digest }}" - touch "/tmp/digests/${{ inputs.target }}/dockerhub/${digest#sha256:}" + touch "/tmp/digests/${{ inputs.build_type }}/dockerhub/${digest#sha256:}" diff --git a/.github/actions/restore-python/action.yml b/.github/actions/restore-python/action.yml index 3ac91f8ea2..082539adaa 100644 --- a/.github/actions/restore-python/action.yml +++ b/.github/actions/restore-python/action.yml @@ -17,7 +17,7 @@ runs: steps: - name: Set up Python ${{ inputs.python-version }} id: python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: ${{ inputs.python-version }} - name: Restore Python virtual environment @@ -34,7 +34,7 @@ runs: python -m venv venv source venv/bin/activate python --version - pip install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt + pip install -r requirements.txt -r requirements_test.txt pip install -e . - name: Create Python virtual environment if: steps.cache-venv.outputs.cache-hit != 'true' && runner.os == 'Windows' @@ -43,5 +43,5 @@ runs: python -m venv venv ./venv/Scripts/activate python --version - pip install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt + pip install -r requirements.txt -r requirements_test.txt pip install -e . diff --git a/.github/dependabot.yml b/.github/dependabot.yml index bb35f16048..cf507bbaa6 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -17,7 +17,6 @@ updates: docker-actions: applies-to: version-updates patterns: - - "docker/setup-qemu-action" - "docker/login-action" - "docker/setup-buildx-action" - package-ecosystem: github-actions diff --git a/.github/workflows/ci-api-proto.yml b/.github/workflows/ci-api-proto.yml index 233fb64693..92d209cc34 100644 --- a/.github/workflows/ci-api-proto.yml +++ b/.github/workflows/ci-api-proto.yml @@ -23,7 +23,7 @@ jobs: - name: Checkout uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.11" @@ -57,6 +57,17 @@ jobs: event: 'REQUEST_CHANGES', body: 'You have altered the generated proto files but they do not match what is expected.\nPlease run "script/api_protobuf/api_protobuf.py" and commit the changes.' }) + - if: failure() + name: Show changes + run: git diff + - if: failure() + name: Archive artifacts + uses: actions/upload-artifact@v4.6.2 + with: + name: generated-proto-files + path: | + esphome/components/api/api_pb2.* + esphome/components/api/api_pb2_service.* - if: success() name: Dismiss review uses: actions/github-script@v7.0.1 diff --git a/.github/workflows/ci-docker.yml b/.github/workflows/ci-docker.yml index 0a08e6ffad..511ec55f3e 100644 --- a/.github/workflows/ci-docker.yml +++ b/.github/workflows/ci-docker.yml @@ -37,12 +37,15 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest", "ubuntu-24.04-arm"] - build_type: ["ha-addon", "docker", "lint"] + os: ["ubuntu-24.04", "ubuntu-24.04-arm"] + build_type: + - "ha-addon" + - "docker" + # - "lint" steps: - uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.9" - name: Set up Docker Buildx diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 91c40d37c4..8d2ec68010 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,10 +39,10 @@ jobs: uses: actions/checkout@v4.1.7 - name: Generate cache-key id: cache-key - run: echo key="${{ hashFiles('requirements.txt', 'requirements_optional.txt', 'requirements_test.txt') }}" >> $GITHUB_OUTPUT + run: echo key="${{ hashFiles('requirements.txt', 'requirements_test.txt') }}" >> $GITHUB_OUTPUT - name: Set up Python ${{ env.DEFAULT_PYTHON }} id: python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: ${{ env.DEFAULT_PYTHON }} - name: Restore Python virtual environment @@ -58,7 +58,7 @@ jobs: python -m venv venv . venv/bin/activate python --version - pip install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt + pip install -r requirements.txt -r requirements_test.txt pip install -e . ruff: @@ -165,6 +165,7 @@ jobs: . venv/bin/activate script/ci-custom.py script/build_codeowners.py --check + script/build_language_schema.py --check pytest: name: Run pytest @@ -220,7 +221,7 @@ jobs: . venv/bin/activate pytest -vv --cov-report=xml --tb=native tests - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v5.4.2 with: token: ${{ secrets.CODECOV_TOKEN }} @@ -291,6 +292,11 @@ jobs: name: Run script/clang-tidy for ESP32 IDF options: --environment esp32-idf-tidy --grep USE_ESP_IDF pio_cache_key: tidyesp32-idf + - id: clang-tidy + name: Run script/clang-tidy for ZEPHYR + options: --environment nrf52-tidy --grep USE_ZEPHYR + pio_cache_key: tidy-zephyr + ignore_errors: true steps: - name: Check out code from GitHub @@ -330,13 +336,13 @@ jobs: - name: Run clang-tidy run: | . venv/bin/activate - script/clang-tidy --all-headers --fix ${{ matrix.options }} + script/clang-tidy --all-headers --fix ${{ matrix.options }} ${{ matrix.ignore_errors && '|| true' || '' }} env: # Also cache libdeps, store them in a ~/.platformio subfolder PLATFORMIO_LIBDEPS_DIR: ~/.platformio/libdeps - name: Suggested changes - run: script/ci-suggest-changes + run: script/ci-suggest-changes ${{ matrix.ignore_errors && '|| true' || '' }} # yamllint disable-line rule:line-length if: always() diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7793c574fe..88704953ce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: steps: - uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.x" - name: Set up python environment @@ -68,31 +68,31 @@ jobs: uses: pypa/gh-action-pypi-publish@v1.12.4 deploy-docker: - name: Build ESPHome ${{ matrix.platform }} + name: Build ESPHome ${{ matrix.platform.arch }} if: github.repository == 'esphome/esphome' permissions: contents: read packages: write - runs-on: ubuntu-latest + runs-on: ${{ matrix.platform.os }} needs: [init] strategy: fail-fast: false matrix: platform: - - linux/amd64 - - linux/arm64 + - arch: amd64 + os: "ubuntu-24.04" + - arch: arm64 + os: "ubuntu-24.04-arm" + steps: - uses: actions/checkout@v4.1.7 - name: Set up Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: "3.9" - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3.10.0 - - name: Set up QEMU - if: matrix.platform != 'linux/amd64' - uses: docker/setup-qemu-action@v3.6.0 - name: Log in to docker hub uses: docker/login-action@v3.4.0 @@ -109,45 +109,36 @@ jobs: - name: Build docker uses: ./.github/actions/build-image with: - platform: ${{ matrix.platform }} - target: docker - baseimg: docker + target: final + build_type: docker suffix: "" version: ${{ needs.init.outputs.tag }} - name: Build ha-addon uses: ./.github/actions/build-image with: - platform: ${{ matrix.platform }} - target: hassio - baseimg: hassio + target: final + build_type: ha-addon suffix: "hassio" version: ${{ needs.init.outputs.tag }} - - name: Build lint - uses: ./.github/actions/build-image - with: - platform: ${{ matrix.platform }} - target: lint - baseimg: docker - suffix: lint - version: ${{ needs.init.outputs.tag }} - - - name: Sanitize platform name - id: sanitize - run: | - echo "${{ matrix.platform }}" | sed 's|/|-|g' > /tmp/platform - echo name=$(cat /tmp/platform) >> $GITHUB_OUTPUT + # - name: Build lint + # uses: ./.github/actions/build-image + # with: + # target: lint + # build_type: lint + # suffix: lint + # version: ${{ needs.init.outputs.tag }} - name: Upload digests uses: actions/upload-artifact@v4.6.2 with: - name: digests-${{ steps.sanitize.outputs.name }} + name: digests-${{ matrix.platform.arch }} path: /tmp/digests retention-days: 1 deploy-manifest: - name: Publish ESPHome ${{ matrix.image.title }} to ${{ matrix.registry }} + name: Publish ESPHome ${{ matrix.image.build_type }} to ${{ matrix.registry }} runs-on: ubuntu-latest needs: - init @@ -160,15 +151,12 @@ jobs: fail-fast: false matrix: image: - - title: "ha-addon" - target: "hassio" - suffix: "hassio" - - title: "docker" - target: "docker" + - build_type: "docker" suffix: "" - - title: "lint" - target: "lint" - suffix: "lint" + - build_type: "ha-addon" + suffix: "hassio" + # - build_type: "lint" + # suffix: "lint" registry: - ghcr - dockerhub @@ -176,7 +164,7 @@ jobs: - uses: actions/checkout@v4.1.7 - name: Download digests - uses: actions/download-artifact@v4.2.1 + uses: actions/download-artifact@v4.3.0 with: pattern: digests-* path: /tmp/digests @@ -212,7 +200,7 @@ jobs: done - name: Create manifest list and push - working-directory: /tmp/digests/${{ matrix.image.target }}/${{ matrix.registry }} + working-directory: /tmp/digests/${{ matrix.image.build_type }}/${{ matrix.registry }} run: | docker buildx imagetools create $(jq -Rcnr 'inputs | . / "," | map("-t " + .) | join(" ")' <<< "${{ steps.tags.outputs.tags}}") \ $(printf '${{ steps.tags.outputs.image }}@sha256:%s ' *) @@ -243,3 +231,25 @@ jobs: content: description } }) + + deploy-esphome-schema: + if: github.repository == 'esphome/esphome' && needs.init.outputs.branch_build == 'false' + runs-on: ubuntu-latest + needs: + - init + - deploy-manifest + steps: + - name: Trigger Workflow + uses: actions/github-script@v7.0.1 + with: + github-token: ${{ secrets.DEPLOY_ESPHOME_SCHEMA_REPO_TOKEN }} + script: | + github.rest.actions.createWorkflowDispatch({ + owner: "esphome", + repo: "esphome-schema", + workflow_id: "generate-schemas.yml", + ref: "main", + inputs: { + version: "${{ needs.init.outputs.tag }}", + } + }) diff --git a/.github/workflows/sync-device-classes.yml b/.github/workflows/sync-device-classes.yml index 0a0c834a71..b262a9f9c1 100644 --- a/.github/workflows/sync-device-classes.yml +++ b/.github/workflows/sync-device-classes.yml @@ -22,7 +22,7 @@ jobs: path: lib/home-assistant - name: Setup Python - uses: actions/setup-python@v5.5.0 + uses: actions/setup-python@v5.6.0 with: python-version: 3.12 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d11aa067bf..c3d5b9c783 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.11.0 + rev: v0.11.9 hooks: # Run the linter. - id: ruff @@ -33,7 +33,7 @@ repos: - id: pyupgrade args: [--py39-plus] - repo: https://github.com/adrienverge/yamllint.git - rev: v1.35.1 + rev: v1.37.1 hooks: - id: yamllint - repo: https://github.com/pre-commit/mirrors-clang-format diff --git a/CODEOWNERS b/CODEOWNERS index f6f7ac6f9c..ddd0494a3c 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -98,6 +98,7 @@ esphome/components/climate/* @esphome/core esphome/components/climate_ir/* @glmnet esphome/components/color_temperature/* @jesserockz esphome/components/combination/* @Cat-Ion @kahrendt +esphome/components/const/* @esphome/core esphome/components/coolix/* @glmnet esphome/components/copy/* @OttoWinter esphome/components/cover/* @esphome/core @@ -250,6 +251,7 @@ esphome/components/ltr501/* @latonita esphome/components/ltr_als_ps/* @latonita esphome/components/lvgl/* @clydebarrow esphome/components/m5stack_8angle/* @rnauber +esphome/components/mapping/* @clydebarrow esphome/components/matrix_keypad/* @ssieb esphome/components/max17043/* @blacknell esphome/components/max31865/* @DAVe3283 @@ -276,10 +278,11 @@ esphome/components/mdns/* @esphome/core esphome/components/media_player/* @jesserockz esphome/components/micro_wake_word/* @jesserockz @kahrendt esphome/components/micronova/* @jorre05 -esphome/components/microphone/* @jesserockz +esphome/components/microphone/* @jesserockz @kahrendt esphome/components/mics_4514/* @jesserockz esphome/components/midea/* @dudanov esphome/components/midea_ir/* @dudanov +esphome/components/mipi_spi/* @clydebarrow esphome/components/mitsubishi/* @RubyBailey esphome/components/mixer/speaker/* @kahrendt esphome/components/mlx90393/* @functionpointer @@ -317,6 +320,7 @@ esphome/components/online_image/* @clydebarrow @guillempages esphome/components/opentherm/* @olegtarasov esphome/components/ota/* @esphome/core esphome/components/output/* @esphome/core +esphome/components/packet_transport/* @clydebarrow esphome/components/pca6416a/* @Mat931 esphome/components/pca9554/* @clydebarrow @hwstar esphome/components/pcf85063/* @brogon @@ -324,7 +328,9 @@ esphome/components/pcf8563/* @KoenBreeman esphome/components/pid/* @OttoWinter esphome/components/pipsolar/* @andreashergert1984 esphome/components/pm1006/* @habbie +esphome/components/pm2005/* @andrewjswan esphome/components/pmsa003i/* @sjtrny +esphome/components/pmsx003/* @ximex esphome/components/pmwcs3/* @SeByDocKy esphome/components/pn532/* @OttoWinter @jesserockz esphome/components/pn532_i2c/* @OttoWinter @jesserockz @@ -393,6 +399,7 @@ esphome/components/smt100/* @piechade esphome/components/sn74hc165/* @jesserockz esphome/components/socket/* @esphome/core esphome/components/sonoff_d1/* @anatoly-savchenkov +esphome/components/sound_level/* @kahrendt esphome/components/speaker/* @jesserockz @kahrendt esphome/components/speaker/media_player/* @kahrendt @synesthesiam esphome/components/spi/* @clydebarrow @esphome/core @@ -424,6 +431,7 @@ esphome/components/sun/* @OttoWinter esphome/components/sun_gtil2/* @Mat931 esphome/components/switch/* @esphome/core esphome/components/switch/binary_sensor/* @ssieb +esphome/components/syslog/* @clydebarrow esphome/components/t6615/* @tylermenezes esphome/components/tc74/* @sethgirvan esphome/components/tca9548a/* @andreashergert1984 @@ -463,6 +471,7 @@ esphome/components/tuya/switch/* @jesserockz esphome/components/tuya/text_sensor/* @dentra esphome/components/uart/* @esphome/core esphome/components/uart/button/* @ssieb +esphome/components/uart/packet_transport/* @clydebarrow esphome/components/udp/* @clydebarrow esphome/components/ufire_ec/* @pvizeli esphome/components/ufire_ise/* @pvizeli diff --git a/docker/Dockerfile b/docker/Dockerfile index 117ec17ae4..39dc1c7f28 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,131 +1,54 @@ -# Build these with the build.py script -# Example: -# python3 docker/build.py --tag dev --arch amd64 --build-type docker build +ARG BUILD_VERSION=dev +ARG BUILD_OS=alpine +ARG BUILD_BASE_VERSION=2025.04.0 +ARG BUILD_TYPE=docker -# One of "docker", "hassio" -ARG BASEIMGTYPE=docker +FROM ghcr.io/esphome/docker-base:${BUILD_OS}-${BUILD_BASE_VERSION} AS base-source-docker +FROM ghcr.io/esphome/docker-base:${BUILD_OS}-ha-addon-${BUILD_BASE_VERSION} AS base-source-ha-addon +ARG BUILD_TYPE +FROM base-source-${BUILD_TYPE} AS base -# https://github.com/hassio-addons/addon-debian-base/releases -FROM ghcr.io/hassio-addons/debian-base:7.2.0 AS base-hassio -# https://hub.docker.com/_/debian?tab=tags&page=1&name=bookworm -FROM debian:12.2-slim AS base-docker +RUN git config --system --add safe.directory "*" -FROM base-${BASEIMGTYPE} AS base +RUN pip install uv==0.6.14 - -ARG TARGETARCH -ARG TARGETVARIANT - - -# Note that --break-system-packages is used below because -# https://peps.python.org/pep-0668/ added a safety check that prevents -# installing packages with the same name as a system package. This is -# not a problem for us because we are not concerned about overwriting -# system packages because we are running in an isolated container. +COPY requirements.txt / RUN \ - apt-get update \ - # Use pinned versions so that we get updates with build caching - && apt-get install -y --no-install-recommends \ - python3-pip=23.0.1+dfsg-1 \ - python3-setuptools=66.1.1-1+deb12u1 \ - python3-venv=3.11.2-1+b1 \ - python3-wheel=0.38.4-2 \ - iputils-ping=3:20221126-1+deb12u1 \ - git=1:2.39.5-0+deb12u2 \ - curl=7.88.1-10+deb12u12 \ - openssh-client=1:9.2p1-2+deb12u5 \ - python3-cffi=1.15.1-5 \ - libcairo2=1.16.0-7 \ - libmagic1=1:5.44-3 \ - patch=2.7.6-7 \ - && rm -rf \ - /tmp/* \ - /var/{cache,log}/* \ - /var/lib/apt/lists/* - -ENV \ - # Fix click python3 lang warning https://click.palletsprojects.com/en/7.x/python3/ - LANG=C.UTF-8 LC_ALL=C.UTF-8 \ - # Store globally installed pio libs in /piolibs - PLATFORMIO_GLOBALLIB_DIR=/piolibs + uv pip install --no-cache-dir \ + -r /requirements.txt RUN \ - pip3 install \ - --break-system-packages --no-cache-dir \ - # Keep platformio version in sync with requirements.txt - platformio==6.1.18 \ - # Change some platformio settings - && platformio settings set enable_telemetry No \ + platformio settings set enable_telemetry No \ && platformio settings set check_platformio_interval 1000000 \ && mkdir -p /piolibs - -# First install requirements to leverage caching when requirements don't change -# tmpfs is for https://github.com/rust-lang/cargo/issues/8719 - -COPY requirements.txt requirements_optional.txt / -RUN --mount=type=tmpfs,target=/root/.cargo < /etc/apt/sources.list.d/llvm.sources.list \ - && apt-get update \ - # Use pinned versions so that we get updates with build caching - && apt-get install -y --no-install-recommends \ - clang-format-13=1:13.0.1-11+b2 \ - patch=2.7.6-7 \ - software-properties-common=0.99.30-4.1~deb12u1 \ - nano=7.2-1+deb12u1 \ - build-essential=12.9 \ - python3-dev=3.11.2-1+b1 \ - clang-tidy-18=1:18.1.8~++20240731024826+3b5b5c1ec4a3-1~exp1~20240731144843.145 \ - && rm -rf \ - /tmp/* \ - /var/{cache,log}/* \ - /var/lib/apt/lists/* - -COPY requirements_test.txt / -RUN pip3 install --break-system-packages --no-cache-dir -r /requirements_test.txt - -VOLUME ["/esphome"] -WORKDIR /esphome +# Copy esphome and install +COPY . /esphome +RUN uv pip install --no-cache-dir -e /esphome diff --git a/docker/build.py b/docker/build.py index cdc25df340..921adac7ab 100755 --- a/docker/build.py +++ b/docker/build.py @@ -54,7 +54,7 @@ manifest_parser = subparsers.add_parser( class DockerParams: build_to: str manifest_to: str - baseimgtype: str + build_type: str platform: str target: str @@ -66,24 +66,19 @@ class DockerParams: TYPE_LINT: "esphome/esphome-lint", }[build_type] build_to = f"{prefix}-{arch}" - baseimgtype = { - TYPE_DOCKER: "docker", - TYPE_HA_ADDON: "hassio", - TYPE_LINT: "docker", - }[build_type] platform = { ARCH_AMD64: "linux/amd64", ARCH_AARCH64: "linux/arm64", }[arch] target = { - TYPE_DOCKER: "docker", - TYPE_HA_ADDON: "hassio", + TYPE_DOCKER: "final", + TYPE_HA_ADDON: "final", TYPE_LINT: "lint", }[build_type] return cls( build_to=build_to, manifest_to=prefix, - baseimgtype=baseimgtype, + build_type=build_type, platform=platform, target=target, ) @@ -145,7 +140,7 @@ def main(): "buildx", "build", "--build-arg", - f"BASEIMGTYPE={params.baseimgtype}", + f"BUILD_TYPE={params.build_type}", "--build-arg", f"BUILD_VERSION={args.tag}", "--cache-from", diff --git a/esphome/components/ac_dimmer/ac_dimmer.cpp b/esphome/components/ac_dimmer/ac_dimmer.cpp index 16101a1c2c..4901719b32 100644 --- a/esphome/components/ac_dimmer/ac_dimmer.cpp +++ b/esphome/components/ac_dimmer/ac_dimmer.cpp @@ -114,13 +114,14 @@ void IRAM_ATTR HOT AcDimmerDataStore::gpio_intr() { // fully off, disable output immediately this->gate_pin.digital_write(false); } else { + auto min_us = this->cycle_time_us * this->min_power / 1000; if (this->method == DIM_METHOD_TRAILING) { this->enable_time_us = 1; // cannot be 0 - this->disable_time_us = std::max((uint32_t) 10, this->value * this->cycle_time_us / 65535); + // calculate time until disable in µs with integer arithmetic and take into account min_power + this->disable_time_us = std::max((uint32_t) 10, this->value * (this->cycle_time_us - min_us) / 65535 + min_us); } else { // calculate time until enable in µs: (1.0-value)*cycle_time, but with integer arithmetic // also take into account min_power - auto min_us = this->cycle_time_us * this->min_power / 1000; this->enable_time_us = std::max((uint32_t) 1, ((65535 - this->value) * (this->cycle_time_us - min_us)) / 65535); if (this->method == DIM_METHOD_LEADING_PULSE) { diff --git a/esphome/components/adc/__init__.py b/esphome/components/adc/__init__.py index be420475fb..5f94c61a08 100644 --- a/esphome/components/adc/__init__.py +++ b/esphome/components/adc/__init__.py @@ -47,9 +47,10 @@ SAMPLING_MODES = { adc1_channel_t = cg.global_ns.enum("adc1_channel_t") adc2_channel_t = cg.global_ns.enum("adc2_channel_t") -# From https://github.com/espressif/esp-idf/blob/master/components/driver/include/driver/adc_common.h # pin to adc1 channel mapping +# https://github.com/espressif/esp-idf/blob/v4.4.8/components/driver/include/driver/adc.h ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32/include/soc/adc_channel.h VARIANT_ESP32: { 36: adc1_channel_t.ADC1_CHANNEL_0, 37: adc1_channel_t.ADC1_CHANNEL_1, @@ -60,6 +61,41 @@ ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { 34: adc1_channel_t.ADC1_CHANNEL_6, 35: adc1_channel_t.ADC1_CHANNEL_7, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c2/include/soc/adc_channel.h + VARIANT_ESP32C2: { + 0: adc1_channel_t.ADC1_CHANNEL_0, + 1: adc1_channel_t.ADC1_CHANNEL_1, + 2: adc1_channel_t.ADC1_CHANNEL_2, + 3: adc1_channel_t.ADC1_CHANNEL_3, + 4: adc1_channel_t.ADC1_CHANNEL_4, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c3/include/soc/adc_channel.h + VARIANT_ESP32C3: { + 0: adc1_channel_t.ADC1_CHANNEL_0, + 1: adc1_channel_t.ADC1_CHANNEL_1, + 2: adc1_channel_t.ADC1_CHANNEL_2, + 3: adc1_channel_t.ADC1_CHANNEL_3, + 4: adc1_channel_t.ADC1_CHANNEL_4, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c6/include/soc/adc_channel.h + VARIANT_ESP32C6: { + 0: adc1_channel_t.ADC1_CHANNEL_0, + 1: adc1_channel_t.ADC1_CHANNEL_1, + 2: adc1_channel_t.ADC1_CHANNEL_2, + 3: adc1_channel_t.ADC1_CHANNEL_3, + 4: adc1_channel_t.ADC1_CHANNEL_4, + 5: adc1_channel_t.ADC1_CHANNEL_5, + 6: adc1_channel_t.ADC1_CHANNEL_6, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32h2/include/soc/adc_channel.h + VARIANT_ESP32H2: { + 1: adc1_channel_t.ADC1_CHANNEL_0, + 2: adc1_channel_t.ADC1_CHANNEL_1, + 3: adc1_channel_t.ADC1_CHANNEL_2, + 4: adc1_channel_t.ADC1_CHANNEL_3, + 5: adc1_channel_t.ADC1_CHANNEL_4, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s2/include/soc/adc_channel.h VARIANT_ESP32S2: { 1: adc1_channel_t.ADC1_CHANNEL_0, 2: adc1_channel_t.ADC1_CHANNEL_1, @@ -72,6 +108,7 @@ ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { 9: adc1_channel_t.ADC1_CHANNEL_8, 10: adc1_channel_t.ADC1_CHANNEL_9, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s3/include/soc/adc_channel.h VARIANT_ESP32S3: { 1: adc1_channel_t.ADC1_CHANNEL_0, 2: adc1_channel_t.ADC1_CHANNEL_1, @@ -84,40 +121,12 @@ ESP32_VARIANT_ADC1_PIN_TO_CHANNEL = { 9: adc1_channel_t.ADC1_CHANNEL_8, 10: adc1_channel_t.ADC1_CHANNEL_9, }, - VARIANT_ESP32C3: { - 0: adc1_channel_t.ADC1_CHANNEL_0, - 1: adc1_channel_t.ADC1_CHANNEL_1, - 2: adc1_channel_t.ADC1_CHANNEL_2, - 3: adc1_channel_t.ADC1_CHANNEL_3, - 4: adc1_channel_t.ADC1_CHANNEL_4, - }, - VARIANT_ESP32C2: { - 0: adc1_channel_t.ADC1_CHANNEL_0, - 1: adc1_channel_t.ADC1_CHANNEL_1, - 2: adc1_channel_t.ADC1_CHANNEL_2, - 3: adc1_channel_t.ADC1_CHANNEL_3, - 4: adc1_channel_t.ADC1_CHANNEL_4, - }, - VARIANT_ESP32C6: { - 0: adc1_channel_t.ADC1_CHANNEL_0, - 1: adc1_channel_t.ADC1_CHANNEL_1, - 2: adc1_channel_t.ADC1_CHANNEL_2, - 3: adc1_channel_t.ADC1_CHANNEL_3, - 4: adc1_channel_t.ADC1_CHANNEL_4, - 5: adc1_channel_t.ADC1_CHANNEL_5, - 6: adc1_channel_t.ADC1_CHANNEL_6, - }, - VARIANT_ESP32H2: { - 1: adc1_channel_t.ADC1_CHANNEL_0, - 2: adc1_channel_t.ADC1_CHANNEL_1, - 3: adc1_channel_t.ADC1_CHANNEL_2, - 4: adc1_channel_t.ADC1_CHANNEL_3, - 5: adc1_channel_t.ADC1_CHANNEL_4, - }, } +# pin to adc2 channel mapping +# https://github.com/espressif/esp-idf/blob/v4.4.8/components/driver/include/driver/adc.h ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { - # TODO: add other variants + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32/include/soc/adc_channel.h VARIANT_ESP32: { 4: adc2_channel_t.ADC2_CHANNEL_0, 0: adc2_channel_t.ADC2_CHANNEL_1, @@ -130,6 +139,19 @@ ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { 25: adc2_channel_t.ADC2_CHANNEL_8, 26: adc2_channel_t.ADC2_CHANNEL_9, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c2/include/soc/adc_channel.h + VARIANT_ESP32C2: { + 5: adc2_channel_t.ADC2_CHANNEL_0, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c3/include/soc/adc_channel.h + VARIANT_ESP32C3: { + 5: adc2_channel_t.ADC2_CHANNEL_0, + }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32c6/include/soc/adc_channel.h + VARIANT_ESP32C6: {}, # no ADC2 + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32h2/include/soc/adc_channel.h + VARIANT_ESP32H2: {}, # no ADC2 + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s2/include/soc/adc_channel.h VARIANT_ESP32S2: { 11: adc2_channel_t.ADC2_CHANNEL_0, 12: adc2_channel_t.ADC2_CHANNEL_1, @@ -142,6 +164,7 @@ ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { 19: adc2_channel_t.ADC2_CHANNEL_8, 20: adc2_channel_t.ADC2_CHANNEL_9, }, + # https://github.com/espressif/esp-idf/blob/master/components/soc/esp32s3/include/soc/adc_channel.h VARIANT_ESP32S3: { 11: adc2_channel_t.ADC2_CHANNEL_0, 12: adc2_channel_t.ADC2_CHANNEL_1, @@ -154,12 +177,6 @@ ESP32_VARIANT_ADC2_PIN_TO_CHANNEL = { 19: adc2_channel_t.ADC2_CHANNEL_8, 20: adc2_channel_t.ADC2_CHANNEL_9, }, - VARIANT_ESP32C3: { - 5: adc2_channel_t.ADC2_CHANNEL_0, - }, - VARIANT_ESP32C2: {}, - VARIANT_ESP32C6: {}, - VARIANT_ESP32H2: {}, } diff --git a/esphome/components/airthings_wave_base/__init__.py b/esphome/components/airthings_wave_base/__init__.py index 6a29683ced..c3f3b8f199 100644 --- a/esphome/components/airthings_wave_base/__init__.py +++ b/esphome/components/airthings_wave_base/__init__.py @@ -34,7 +34,7 @@ AirthingsWaveBase = airthings_wave_base_ns.class_( BASE_SCHEMA = ( - sensor.SENSOR_SCHEMA.extend( + cv.Schema( { cv.Optional(CONF_HUMIDITY): sensor.sensor_schema( unit_of_measurement=UNIT_PERCENT, diff --git a/esphome/components/alarm_control_panel/__init__.py b/esphome/components/alarm_control_panel/__init__.py index 379fbf32f9..1bcb83bce7 100644 --- a/esphome/components/alarm_control_panel/__init__.py +++ b/esphome/components/alarm_control_panel/__init__.py @@ -5,6 +5,8 @@ from esphome.components import mqtt, web_server import esphome.config_validation as cv from esphome.const import ( CONF_CODE, + CONF_ENTITY_CATEGORY, + CONF_ICON, CONF_ID, CONF_MQTT_ID, CONF_ON_STATE, @@ -12,6 +14,7 @@ from esphome.const import ( CONF_WEB_SERVER, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity CODEOWNERS = ["@grahambrown11", "@hwstar"] @@ -78,12 +81,11 @@ AlarmControlPanelCondition = alarm_control_panel_ns.class_( "AlarmControlPanelCondition", automation.Condition ) -ALARM_CONTROL_PANEL_SCHEMA = ( +_ALARM_CONTROL_PANEL_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( { - cv.GenerateID(): cv.declare_id(AlarmControlPanel), cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id( mqtt.MQTTAlarmControlPanelComponent ), @@ -146,6 +148,33 @@ ALARM_CONTROL_PANEL_SCHEMA = ( ) ) + +def alarm_control_panel_schema( + class_: MockObjClass, + *, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, +) -> cv.Schema: + schema = { + cv.GenerateID(): cv.declare_id(class_), + } + + for key, default, validator in [ + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _ALARM_CONTROL_PANEL_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +ALARM_CONTROL_PANEL_SCHEMA = alarm_control_panel_schema(AlarmControlPanel) +ALARM_CONTROL_PANEL_SCHEMA.add_extra( + cv.deprecated_schema_constant("alarm_control_panel") +) + ALARM_CONTROL_PANEL_ACTION_SCHEMA = maybe_simple_id( { cv.GenerateID(): cv.use_id(AlarmControlPanel), @@ -209,6 +238,12 @@ async def register_alarm_control_panel(var, config): await setup_alarm_control_panel_core_(var, config) +async def new_alarm_control_panel(config, *args): + var = cg.new_Pvariable(config[CONF_ID], *args) + await register_alarm_control_panel(var, config) + return var + + @automation.register_action( "alarm_control_panel.arm_away", ArmAwayAction, ALARM_CONTROL_PANEL_ACTION_SCHEMA ) diff --git a/esphome/components/am43/cover/__init__.py b/esphome/components/am43/cover/__init__.py index d60f9cd4e7..e4ecf1444f 100644 --- a/esphome/components/am43/cover/__init__.py +++ b/esphome/components/am43/cover/__init__.py @@ -1,7 +1,7 @@ import esphome.codegen as cg from esphome.components import ble_client, cover import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_PIN +from esphome.const import CONF_PIN CODEOWNERS = ["@buxtronix"] DEPENDENCIES = ["ble_client"] @@ -15,9 +15,9 @@ Am43Component = am43_ns.class_( ) CONFIG_SCHEMA = ( - cover.COVER_SCHEMA.extend( + cover.cover_schema(Am43Component) + .extend( { - cv.GenerateID(): cv.declare_id(Am43Component), cv.Optional(CONF_PIN, default=8888): cv.int_range(min=0, max=0xFFFF), cv.Optional(CONF_INVERT_POSITION, default=False): cv.boolean, } @@ -28,9 +28,8 @@ CONFIG_SCHEMA = ( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) cg.add(var.set_pin(config[CONF_PIN])) cg.add(var.set_invert_position(config[CONF_INVERT_POSITION])) await cg.register_component(var, config) - await cover.register_cover(var, config) await ble_client.register_ble_node(var, config) diff --git a/esphome/components/analog_threshold/analog_threshold_binary_sensor.cpp b/esphome/components/analog_threshold/analog_threshold_binary_sensor.cpp index f679b9994f..8dcbb2ac4b 100644 --- a/esphome/components/analog_threshold/analog_threshold_binary_sensor.cpp +++ b/esphome/components/analog_threshold/analog_threshold_binary_sensor.cpp @@ -14,7 +14,8 @@ void AnalogThresholdBinarySensor::setup() { if (std::isnan(sensor_value)) { this->publish_initial_state(false); } else { - this->publish_initial_state(sensor_value >= (this->lower_threshold_ + this->upper_threshold_) / 2.0f); + this->publish_initial_state(sensor_value >= + (this->lower_threshold_.value() + this->upper_threshold_.value()) / 2.0f); } } @@ -24,7 +25,8 @@ void AnalogThresholdBinarySensor::set_sensor(sensor::Sensor *analog_sensor) { this->sensor_->add_on_state_callback([this](float sensor_value) { // if there is an invalid sensor reading, ignore the change and keep the current state if (!std::isnan(sensor_value)) { - this->publish_state(sensor_value >= (this->state ? this->lower_threshold_ : this->upper_threshold_)); + this->publish_state(sensor_value >= + (this->state ? this->lower_threshold_.value() : this->upper_threshold_.value())); } }); } @@ -32,8 +34,8 @@ void AnalogThresholdBinarySensor::set_sensor(sensor::Sensor *analog_sensor) { void AnalogThresholdBinarySensor::dump_config() { LOG_BINARY_SENSOR("", "Analog Threshold Binary Sensor", this); LOG_SENSOR(" ", "Sensor", this->sensor_); - ESP_LOGCONFIG(TAG, " Upper threshold: %.11f", this->upper_threshold_); - ESP_LOGCONFIG(TAG, " Lower threshold: %.11f", this->lower_threshold_); + ESP_LOGCONFIG(TAG, " Upper threshold: %.11f", this->upper_threshold_.value()); + ESP_LOGCONFIG(TAG, " Lower threshold: %.11f", this->lower_threshold_.value()); } } // namespace analog_threshold diff --git a/esphome/components/analog_threshold/analog_threshold_binary_sensor.h b/esphome/components/analog_threshold/analog_threshold_binary_sensor.h index 619aef1075..efb8e3c90c 100644 --- a/esphome/components/analog_threshold/analog_threshold_binary_sensor.h +++ b/esphome/components/analog_threshold/analog_threshold_binary_sensor.h @@ -15,14 +15,13 @@ class AnalogThresholdBinarySensor : public Component, public binary_sensor::Bina float get_setup_priority() const override { return setup_priority::DATA; } void set_sensor(sensor::Sensor *analog_sensor); - void set_upper_threshold(float threshold) { this->upper_threshold_ = threshold; } - void set_lower_threshold(float threshold) { this->lower_threshold_ = threshold; } + template void set_upper_threshold(T upper_threshold) { this->upper_threshold_ = upper_threshold; } + template void set_lower_threshold(T lower_threshold) { this->lower_threshold_ = lower_threshold; } protected: sensor::Sensor *sensor_{nullptr}; - - float upper_threshold_; - float lower_threshold_; + TemplatableValue upper_threshold_{}; + TemplatableValue lower_threshold_{}; }; } // namespace analog_threshold diff --git a/esphome/components/analog_threshold/binary_sensor.py b/esphome/components/analog_threshold/binary_sensor.py index 775b3e6bbf..b5f87b9b5c 100644 --- a/esphome/components/analog_threshold/binary_sensor.py +++ b/esphome/components/analog_threshold/binary_sensor.py @@ -18,11 +18,11 @@ CONFIG_SCHEMA = ( { cv.Required(CONF_SENSOR_ID): cv.use_id(sensor.Sensor), cv.Required(CONF_THRESHOLD): cv.Any( - cv.float_, + cv.templatable(cv.float_), cv.Schema( { - cv.Required(CONF_UPPER): cv.float_, - cv.Required(CONF_LOWER): cv.float_, + cv.Required(CONF_UPPER): cv.templatable(cv.float_), + cv.Required(CONF_LOWER): cv.templatable(cv.float_), } ), ), @@ -39,9 +39,11 @@ async def to_code(config): sens = await cg.get_variable(config[CONF_SENSOR_ID]) cg.add(var.set_sensor(sens)) - if isinstance(config[CONF_THRESHOLD], float): - cg.add(var.set_upper_threshold(config[CONF_THRESHOLD])) - cg.add(var.set_lower_threshold(config[CONF_THRESHOLD])) + if isinstance(config[CONF_THRESHOLD], dict): + lower = await cg.templatable(config[CONF_THRESHOLD][CONF_LOWER], [], float) + upper = await cg.templatable(config[CONF_THRESHOLD][CONF_UPPER], [], float) else: - cg.add(var.set_upper_threshold(config[CONF_THRESHOLD][CONF_UPPER])) - cg.add(var.set_lower_threshold(config[CONF_THRESHOLD][CONF_LOWER])) + lower = await cg.templatable(config[CONF_THRESHOLD], [], float) + upper = lower + cg.add(var.set_upper_threshold(upper)) + cg.add(var.set_lower_threshold(lower)) diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index 27de5c873b..4b63c76fba 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -82,6 +82,19 @@ ACTIONS_SCHEMA = automation.validate_automation( ), ) +ENCRYPTION_SCHEMA = cv.Schema( + { + cv.Optional(CONF_KEY): validate_encryption_key, + } +) + + +def _encryption_schema(config): + if config is None: + config = {} + return ENCRYPTION_SCHEMA(config) + + CONFIG_SCHEMA = cv.All( cv.Schema( { @@ -95,11 +108,7 @@ CONFIG_SCHEMA = cv.All( CONF_SERVICES, group_of_exclusion=CONF_ACTIONS ): ACTIONS_SCHEMA, cv.Exclusive(CONF_ACTIONS, group_of_exclusion=CONF_ACTIONS): ACTIONS_SCHEMA, - cv.Optional(CONF_ENCRYPTION): cv.Schema( - { - cv.Required(CONF_KEY): validate_encryption_key, - } - ), + cv.Optional(CONF_ENCRYPTION): _encryption_schema, cv.Optional(CONF_ON_CLIENT_CONNECTED): automation.validate_automation( single=True ), @@ -151,9 +160,17 @@ async def to_code(config): config[CONF_ON_CLIENT_DISCONNECTED], ) - if encryption_config := config.get(CONF_ENCRYPTION): - decoded = base64.b64decode(encryption_config[CONF_KEY]) - cg.add(var.set_noise_psk(list(decoded))) + if (encryption_config := config.get(CONF_ENCRYPTION, None)) is not None: + if key := encryption_config.get(CONF_KEY): + decoded = base64.b64decode(key) + cg.add(var.set_noise_psk(list(decoded))) + else: + # No key provided, but encryption desired + # This will allow a plaintext client to provide a noise key, + # send it to the device, and then switch to noise. + # The key will be saved in flash and used for future connections + # and plaintext disabled. Only a factory reset can remove it. + cg.add_define("USE_API_PLAINTEXT") cg.add_define("USE_API_NOISE") cg.add_library("esphome/noise-c", "0.1.6") else: diff --git a/esphome/components/api/api.proto b/esphome/components/api/api.proto index d59b5e0d3e..1fdf4e1339 100644 --- a/esphome/components/api/api.proto +++ b/esphome/components/api/api.proto @@ -31,24 +31,26 @@ service APIConnection { option (needs_authentication) = false; } rpc execute_service (ExecuteServiceRequest) returns (void) {} + rpc noise_encryption_set_key (NoiseEncryptionSetKeyRequest) returns (NoiseEncryptionSetKeyResponse) {} - rpc cover_command (CoverCommandRequest) returns (void) {} - rpc fan_command (FanCommandRequest) returns (void) {} - rpc light_command (LightCommandRequest) returns (void) {} - rpc switch_command (SwitchCommandRequest) returns (void) {} + rpc button_command (ButtonCommandRequest) returns (void) {} rpc camera_image (CameraImageRequest) returns (void) {} rpc climate_command (ClimateCommandRequest) returns (void) {} - rpc number_command (NumberCommandRequest) returns (void) {} - rpc text_command (TextCommandRequest) returns (void) {} - rpc select_command (SelectCommandRequest) returns (void) {} - rpc button_command (ButtonCommandRequest) returns (void) {} - rpc lock_command (LockCommandRequest) returns (void) {} - rpc valve_command (ValveCommandRequest) returns (void) {} - rpc media_player_command (MediaPlayerCommandRequest) returns (void) {} + rpc cover_command (CoverCommandRequest) returns (void) {} rpc date_command (DateCommandRequest) returns (void) {} - rpc time_command (TimeCommandRequest) returns (void) {} rpc datetime_command (DateTimeCommandRequest) returns (void) {} + rpc fan_command (FanCommandRequest) returns (void) {} + rpc light_command (LightCommandRequest) returns (void) {} + rpc lock_command (LockCommandRequest) returns (void) {} + rpc media_player_command (MediaPlayerCommandRequest) returns (void) {} + rpc number_command (NumberCommandRequest) returns (void) {} + rpc select_command (SelectCommandRequest) returns (void) {} + rpc siren_command (SirenCommandRequest) returns (void) {} + rpc switch_command (SwitchCommandRequest) returns (void) {} + rpc text_command (TextCommandRequest) returns (void) {} + rpc time_command (TimeCommandRequest) returns (void) {} rpc update_command (UpdateCommandRequest) returns (void) {} + rpc valve_command (ValveCommandRequest) returns (void) {} rpc subscribe_bluetooth_le_advertisements(SubscribeBluetoothLEAdvertisementsRequest) returns (void) {} rpc bluetooth_device_request(BluetoothDeviceRequest) returns (void) {} @@ -60,6 +62,7 @@ service APIConnection { rpc bluetooth_gatt_notify(BluetoothGATTNotifyRequest) returns (void) {} rpc subscribe_bluetooth_connections_free(SubscribeBluetoothConnectionsFreeRequest) returns (BluetoothConnectionsFreeResponse) {} rpc unsubscribe_bluetooth_le_advertisements(UnsubscribeBluetoothLEAdvertisementsRequest) returns (void) {} + rpc bluetooth_scanner_set_mode(BluetoothScannerSetModeRequest) returns (void) {} rpc subscribe_voice_assistant(SubscribeVoiceAssistantRequest) returns (void) {} rpc voice_assistant_get_configuration(VoiceAssistantConfigurationRequest) returns (VoiceAssistantConfigurationResponse) {} @@ -230,6 +233,9 @@ message DeviceInfoResponse { // The Bluetooth mac address of the device. For example "AC:BC:32:89:0E:AA" string bluetooth_mac_address = 18; + + // Supports receiving and saving api encryption key + bool api_encryption_supported = 19; } message ListEntitiesRequest { @@ -650,10 +656,27 @@ message SubscribeLogsResponse { option (no_delay) = false; LogLevel level = 1; - string message = 3; + bytes message = 3; bool send_failed = 4; } +// ==================== NOISE ENCRYPTION ==================== +message NoiseEncryptionSetKeyRequest { + option (id) = 124; + option (source) = SOURCE_CLIENT; + option (ifdef) = "USE_API_NOISE"; + + bytes key = 1; +} + +message NoiseEncryptionSetKeyResponse { + option (id) = 125; + option (source) = SOURCE_SERVER; + option (ifdef) = "USE_API_NOISE"; + + bool success = 1; +} + // ==================== HOMEASSISTANT.SERVICE ==================== message SubscribeHomeassistantServicesRequest { option (id) = 34; @@ -889,6 +912,7 @@ message ClimateStateResponse { float target_temperature = 4; float target_temperature_low = 5; float target_temperature_high = 6; + // For older peers, equal to preset == CLIMATE_PRESET_AWAY bool unused_legacy_away = 7; ClimateAction action = 8; ClimateFanMode fan_mode = 9; @@ -914,6 +938,7 @@ message ClimateCommandRequest { float target_temperature_low = 7; bool has_target_temperature_high = 8; float target_temperature_high = 9; + // legacy, for older peers, newer ones should use CLIMATE_PRESET_AWAY in preset bool unused_has_legacy_away = 10; bool unused_legacy_away = 11; bool has_fan_mode = 12; @@ -1016,6 +1041,49 @@ message SelectCommandRequest { string state = 2; } +// ==================== SIREN ==================== +message ListEntitiesSirenResponse { + option (id) = 55; + option (source) = SOURCE_SERVER; + option (ifdef) = "USE_SIREN"; + + string object_id = 1; + fixed32 key = 2; + string name = 3; + string unique_id = 4; + + string icon = 5; + bool disabled_by_default = 6; + repeated string tones = 7; + bool supports_duration = 8; + bool supports_volume = 9; + EntityCategory entity_category = 10; +} +message SirenStateResponse { + option (id) = 56; + option (source) = SOURCE_SERVER; + option (ifdef) = "USE_SIREN"; + option (no_delay) = true; + + fixed32 key = 1; + bool state = 2; +} +message SirenCommandRequest { + option (id) = 57; + option (source) = SOURCE_CLIENT; + option (ifdef) = "USE_SIREN"; + option (no_delay) = true; + + fixed32 key = 1; + bool has_state = 2; + bool state = 3; + bool has_tone = 4; + string tone = 5; + bool has_duration = 6; + uint32 duration = 7; + bool has_volume = 8; + float volume = 9; +} // ==================== LOCK ==================== enum LockState { @@ -1185,8 +1253,8 @@ message SubscribeBluetoothLEAdvertisementsRequest { message BluetoothServiceData { string uuid = 1; - repeated uint32 legacy_data = 2 [deprecated = true]; - bytes data = 3; // Changed in proto version 1.7 + repeated uint32 legacy_data = 2 [deprecated = true]; // Removed in api version 1.7 + bytes data = 3; // Added in api version 1.7 } message BluetoothLEAdvertisementResponse { option (id) = 67; @@ -1195,7 +1263,7 @@ message BluetoothLEAdvertisementResponse { option (no_delay) = true; uint64 address = 1; - string name = 2; + bytes name = 2; sint32 rssi = 3; repeated string service_uuids = 4; @@ -1451,7 +1519,38 @@ message BluetoothDeviceClearCacheResponse { int32 error = 3; } -// ==================== PUSH TO TALK ==================== +enum BluetoothScannerState { + BLUETOOTH_SCANNER_STATE_IDLE = 0; + BLUETOOTH_SCANNER_STATE_STARTING = 1; + BLUETOOTH_SCANNER_STATE_RUNNING = 2; + BLUETOOTH_SCANNER_STATE_FAILED = 3; + BLUETOOTH_SCANNER_STATE_STOPPING = 4; + BLUETOOTH_SCANNER_STATE_STOPPED = 5; +} + +enum BluetoothScannerMode { + BLUETOOTH_SCANNER_MODE_PASSIVE = 0; + BLUETOOTH_SCANNER_MODE_ACTIVE = 1; +} + +message BluetoothScannerStateResponse { + option(id) = 126; + option(source) = SOURCE_SERVER; + option(ifdef) = "USE_BLUETOOTH_PROXY"; + + BluetoothScannerState state = 1; + BluetoothScannerMode mode = 2; +} + +message BluetoothScannerSetModeRequest { + option(id) = 127; + option(source) = SOURCE_CLIENT; + option(ifdef) = "USE_BLUETOOTH_PROXY"; + + BluetoothScannerMode mode = 1; +} + +// ==================== VOICE ASSISTANT ==================== enum VoiceAssistantSubscribeFlag { VOICE_ASSISTANT_SUBSCRIBE_NONE = 0; VOICE_ASSISTANT_SUBSCRIBE_API_AUDIO = 1; diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index 9d7b8c1780..ee0451f499 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -62,7 +62,14 @@ APIConnection::APIConnection(std::unique_ptr sock, APIServer *pa : parent_(parent), deferred_message_queue_(this), initial_state_iterator_(this), list_entities_iterator_(this) { this->proto_write_buffer_.reserve(64); -#if defined(USE_API_PLAINTEXT) +#if defined(USE_API_PLAINTEXT) && defined(USE_API_NOISE) + auto noise_ctx = parent->get_noise_ctx(); + if (noise_ctx->has_psk()) { + this->helper_ = std::unique_ptr{new APINoiseFrameHelper(std::move(sock), noise_ctx)}; + } else { + this->helper_ = std::unique_ptr{new APIPlaintextFrameHelper(std::move(sock))}; + } +#elif defined(USE_API_PLAINTEXT) this->helper_ = std::unique_ptr{new APIPlaintextFrameHelper(std::move(sock))}; #elif defined(USE_API_NOISE) this->helper_ = std::unique_ptr{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; @@ -185,15 +192,34 @@ void APIConnection::loop() { #ifdef USE_ESP32_CAMERA if (this->image_reader_.available() && this->helper_->can_write_without_blocking()) { - uint32_t to_send = std::min((size_t) 1024, this->image_reader_.available()); - auto buffer = this->create_buffer(); + // Message will use 8 more bytes than the minimum size, and typical + // MTU is 1500. Sometimes users will see as low as 1460 MTU. + // If its IPv6 the header is 40 bytes, and if its IPv4 + // the header is 20 bytes. So we have 1460 - 40 = 1420 bytes + // available for the payload. But we also need to add the size of + // the protobuf overhead, which is 8 bytes. + // + // To be safe we pick 1390 bytes as the maximum size + // to send in one go. This is the maximum size of a single packet + // that can be sent over the network. + // This is to avoid fragmentation of the packet. + uint32_t to_send = std::min((size_t) 1390, this->image_reader_.available()); + bool done = this->image_reader_.available() == to_send; + uint32_t msg_size = 0; + ProtoSize::add_fixed_field<4>(msg_size, 1, true); + // partial message size calculated manually since its a special case + // 1 for the data field, varint for the data size, and the data itself + msg_size += 1 + ProtoSize::varint(to_send) + to_send; + ProtoSize::add_bool_field(msg_size, 1, done); + + auto buffer = this->create_buffer(msg_size); // fixed32 key = 1; buffer.encode_fixed32(1, esp32_camera::global_esp32_camera->get_object_id_hash()); // bytes data = 2; buffer.encode_bytes(2, this->image_reader_.peek_data_buffer(), to_send); // bool done = 3; - bool done = this->image_reader_.available() == to_send; buffer.encode_bool(3, done); + bool success = this->send_buffer(buffer, 44); if (success) { @@ -1468,6 +1494,11 @@ BluetoothConnectionsFreeResponse APIConnection::subscribe_bluetooth_connections_ resp.limit = bluetooth_proxy::global_bluetooth_proxy->get_bluetooth_connections_limit(); return resp; } + +void APIConnection::bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequest &msg) { + bluetooth_proxy::global_bluetooth_proxy->bluetooth_scanner_set_mode( + msg.mode == enums::BluetoothScannerMode::BLUETOOTH_SCANNER_MODE_ACTIVE); +} #endif #ifdef USE_VOICE_ASSISTANT @@ -1762,12 +1793,25 @@ bool APIConnection::try_send_log_message(int level, const char *tag, const char if (this->log_subscription_ < level) return false; - // Send raw so that we don't copy too much - auto buffer = this->create_buffer(); - // LogLevel level = 1; - buffer.encode_uint32(1, static_cast(level)); - // string message = 3; - buffer.encode_string(3, line, strlen(line)); + // Pre-calculate message size to avoid reallocations + const size_t line_length = strlen(line); + uint32_t msg_size = 0; + + // Add size for level field (field ID 1, varint type) + // 1 byte for field tag + size of the level varint + msg_size += 1 + api::ProtoSize::varint(static_cast(level)); + + // Add size for string field (field ID 3, string type) + // 1 byte for field tag + size of length varint + string length + msg_size += 1 + api::ProtoSize::varint(static_cast(line_length)) + line_length; + + // Create a pre-sized buffer + auto buffer = this->create_buffer(msg_size); + + // Encode the message (SubscribeLogsResponse) + buffer.encode_uint32(1, static_cast(level)); // LogLevel level = 1 + buffer.encode_string(3, line, line_length); // string message = 3 + // SubscribeLogsResponse - 29 return this->send_buffer(buffer, 29); } @@ -1848,6 +1892,9 @@ DeviceInfoResponse APIConnection::device_info(const DeviceInfoRequest &msg) { #ifdef USE_VOICE_ASSISTANT resp.legacy_voice_assistant_version = voice_assistant::global_voice_assistant->get_legacy_version(); resp.voice_assistant_feature_flags = voice_assistant::global_voice_assistant->get_feature_flags(); +#endif +#ifdef USE_API_NOISE + resp.api_encryption_supported = true; #endif return resp; } @@ -1869,6 +1916,26 @@ void APIConnection::execute_service(const ExecuteServiceRequest &msg) { ESP_LOGV(TAG, "Could not find matching service!"); } } +#ifdef USE_API_NOISE +NoiseEncryptionSetKeyResponse APIConnection::noise_encryption_set_key(const NoiseEncryptionSetKeyRequest &msg) { + psk_t psk{}; + NoiseEncryptionSetKeyResponse resp; + if (base64_decode(msg.key, psk.data(), msg.key.size()) != psk.size()) { + ESP_LOGW(TAG, "Invalid encryption key length"); + resp.success = false; + return resp; + } + + if (!this->parent_->save_noise_psk(psk, true)) { + ESP_LOGW(TAG, "Failed to save encryption key"); + resp.success = false; + return resp; + } + + resp.success = true; + return resp; +} +#endif void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) { state_subs_at_ = 0; } diff --git a/esphome/components/api/api_connection.h b/esphome/components/api/api_connection.h index f17080a6c8..1e47418d90 100644 --- a/esphome/components/api/api_connection.h +++ b/esphome/components/api/api_connection.h @@ -221,6 +221,7 @@ class APIConnection : public APIServerConnection { void bluetooth_gatt_notify(const BluetoothGATTNotifyRequest &msg) override; BluetoothConnectionsFreeResponse subscribe_bluetooth_connections_free( const SubscribeBluetoothConnectionsFreeRequest &msg) override; + void bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequest &msg) override; #endif #ifdef USE_HOMEASSISTANT_TIME @@ -300,6 +301,9 @@ class APIConnection : public APIServerConnection { return {}; } void execute_service(const ExecuteServiceRequest &msg) override; +#ifdef USE_API_NOISE + NoiseEncryptionSetKeyResponse noise_encryption_set_key(const NoiseEncryptionSetKeyRequest &msg) override; +#endif bool is_authenticated() override { return this->connection_state_ == ConnectionState::AUTHENTICATED; } bool is_connection_setup() override { @@ -308,9 +312,10 @@ class APIConnection : public APIServerConnection { void on_fatal_error() override; void on_unauthenticated_access() override; void on_no_setup_connection() override; - ProtoWriteBuffer create_buffer() override { + ProtoWriteBuffer create_buffer(uint32_t reserve_size) override { // FIXME: ensure no recursive writes can happen this->proto_write_buffer_.clear(); + this->proto_write_buffer_.reserve(reserve_size); return {&this->proto_write_buffer_}; } bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) override; diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index 3d6bc95163..31b0732275 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -5,6 +5,7 @@ #include "esphome/core/helpers.h" #include "esphome/core/application.h" #include "proto.h" +#include "api_pb2_size.h" #include namespace esphome { @@ -72,6 +73,91 @@ const char *api_error_to_str(APIError err) { return "UNKNOWN"; } +// Common implementation for writing raw data to socket +template +APIError APIFrameHelper::write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, + std::vector &tx_buf, const std::string &info, StateEnum &state, + StateEnum failed_state) { + // This method writes data to socket or buffers it + // Returns APIError::OK if successful (or would block, but data has been buffered) + // Returns APIError::SOCKET_WRITE_FAILED if socket write failed, and sets state to failed_state + + if (iovcnt == 0) + return APIError::OK; // Nothing to do, success + + size_t total_write_len = 0; + for (int i = 0; i < iovcnt; i++) { +#ifdef HELPER_LOG_PACKETS + ESP_LOGVV(TAG, "Sending raw: %s", + format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); +#endif + total_write_len += iov[i].iov_len; + } + + if (!tx_buf.empty()) { + // try to empty tx_buf first + while (!tx_buf.empty()) { + ssize_t sent = socket->write(tx_buf.data(), tx_buf.size()); + if (is_would_block(sent)) { + break; + } else if (sent == -1) { + ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); + state = failed_state; + return APIError::SOCKET_WRITE_FAILED; // Socket write failed + } + // TODO: inefficient if multiple packets in txbuf + // replace with deque of buffers + tx_buf.erase(tx_buf.begin(), tx_buf.begin() + sent); + } + } + + if (!tx_buf.empty()) { + // tx buf not empty, can't write now because then stream would be inconsistent + // Reserve space upfront to avoid multiple reallocations + tx_buf.reserve(tx_buf.size() + total_write_len); + for (int i = 0; i < iovcnt; i++) { + tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base), + reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + } + return APIError::OK; // Success, data buffered + } + + ssize_t sent = socket->writev(iov, iovcnt); + if (is_would_block(sent)) { + // operation would block, add buffer to tx_buf + // Reserve space upfront to avoid multiple reallocations + tx_buf.reserve(tx_buf.size() + total_write_len); + for (int i = 0; i < iovcnt; i++) { + tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base), + reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + } + return APIError::OK; // Success, data buffered + } else if (sent == -1) { + // an error occurred + ESP_LOGVV(TAG, "%s: Socket write failed with errno %d", info.c_str(), errno); + state = failed_state; + return APIError::SOCKET_WRITE_FAILED; // Socket write failed + } else if ((size_t) sent != total_write_len) { + // partially sent, add end to tx_buf + size_t remaining = total_write_len - sent; + // Reserve space upfront to avoid multiple reallocations + tx_buf.reserve(tx_buf.size() + remaining); + + size_t to_consume = sent; + for (int i = 0; i < iovcnt; i++) { + if (to_consume >= iov[i].iov_len) { + to_consume -= iov[i].iov_len; + } else { + tx_buf.insert(tx_buf.end(), reinterpret_cast(iov[i].iov_base) + to_consume, + reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); + to_consume = 0; + } + } + return APIError::OK; // Success, data buffered + } + return APIError::OK; // Success, all data sent +} + #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) // uncomment to log raw packets //#define HELPER_LOG_PACKETS @@ -546,71 +632,6 @@ APIError APINoiseFrameHelper::try_send_tx_buf_() { return APIError::OK; } -/** Write the data to the socket, or buffer it a write would block - * - * @param data The data to write - * @param len The length of data - */ -APIError APINoiseFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { - if (iovcnt == 0) - return APIError::OK; - APIError aerr; - - size_t total_write_len = 0; - for (int i = 0; i < iovcnt; i++) { -#ifdef HELPER_LOG_PACKETS - ESP_LOGVV(TAG, "Sending raw: %s", - format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); -#endif - total_write_len += iov[i].iov_len; - } - - if (!tx_buf_.empty()) { - // try to empty tx_buf_ first - aerr = try_send_tx_buf_(); - if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) - return aerr; - } - - if (!tx_buf_.empty()) { - // tx buf not empty, can't write now because then stream would be inconsistent - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } - - ssize_t sent = socket_->writev(iov, iovcnt); - if (is_would_block(sent)) { - // operation would block, add buffer to tx_buf - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } else if (sent == -1) { - // an error occurred - state_ = State::FAILED; - HELPER_LOG("Socket write failed with errno %d", errno); - return APIError::SOCKET_WRITE_FAILED; - } else if ((size_t) sent != total_write_len) { - // partially sent, add end to tx_buf - size_t to_consume = sent; - for (int i = 0; i < iovcnt; i++) { - if (to_consume >= iov[i].iov_len) { - to_consume -= iov[i].iov_len; - } else { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base) + to_consume, - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - to_consume = 0; - } - } - return APIError::OK; - } - // fully sent - return APIError::OK; -} APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { uint8_t header[3]; header[0] = 0x01; // indicator @@ -744,6 +765,11 @@ void noise_rand_bytes(void *output, size_t len) { } } } + +// Explicit template instantiation for Noise +template APIError APIFrameHelper::write_raw_( + const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf_, const std::string &info, + APINoiseFrameHelper::State &state, APINoiseFrameHelper::State failed_state); #endif // USE_API_NOISE #ifdef USE_API_PLAINTEXT @@ -933,6 +959,8 @@ APIError APIPlaintextFrameHelper::write_packet(uint16_t type, const uint8_t *pay } std::vector header; + header.reserve(1 + api::ProtoSize::varint(static_cast(payload_len)) + + api::ProtoSize::varint(static_cast(type))); header.push_back(0x00); ProtoVarInt(payload_len).encode(header); ProtoVarInt(type).encode(header); @@ -966,71 +994,6 @@ APIError APIPlaintextFrameHelper::try_send_tx_buf_() { return APIError::OK; } -/** Write the data to the socket, or buffer it a write would block - * - * @param data The data to write - * @param len The length of data - */ -APIError APIPlaintextFrameHelper::write_raw_(const struct iovec *iov, int iovcnt) { - if (iovcnt == 0) - return APIError::OK; - APIError aerr; - - size_t total_write_len = 0; - for (int i = 0; i < iovcnt; i++) { -#ifdef HELPER_LOG_PACKETS - ESP_LOGVV(TAG, "Sending raw: %s", - format_hex_pretty(reinterpret_cast(iov[i].iov_base), iov[i].iov_len).c_str()); -#endif - total_write_len += iov[i].iov_len; - } - - if (!tx_buf_.empty()) { - // try to empty tx_buf_ first - aerr = try_send_tx_buf_(); - if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) - return aerr; - } - - if (!tx_buf_.empty()) { - // tx buf not empty, can't write now because then stream would be inconsistent - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } - - ssize_t sent = socket_->writev(iov, iovcnt); - if (is_would_block(sent)) { - // operation would block, add buffer to tx_buf - for (int i = 0; i < iovcnt; i++) { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base), - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - } - return APIError::OK; - } else if (sent == -1) { - // an error occurred - state_ = State::FAILED; - HELPER_LOG("Socket write failed with errno %d", errno); - return APIError::SOCKET_WRITE_FAILED; - } else if ((size_t) sent != total_write_len) { - // partially sent, add end to tx_buf - size_t to_consume = sent; - for (int i = 0; i < iovcnt; i++) { - if (to_consume >= iov[i].iov_len) { - to_consume -= iov[i].iov_len; - } else { - tx_buf_.insert(tx_buf_.end(), reinterpret_cast(iov[i].iov_base) + to_consume, - reinterpret_cast(iov[i].iov_base) + iov[i].iov_len); - to_consume = 0; - } - } - return APIError::OK; - } - // fully sent - return APIError::OK; -} APIError APIPlaintextFrameHelper::close() { state_ = State::CLOSED; @@ -1048,6 +1011,11 @@ APIError APIPlaintextFrameHelper::shutdown(int how) { } return APIError::OK; } + +// Explicit template instantiation for Plaintext +template APIError APIFrameHelper::write_raw_( + const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf_, const std::string &info, + APIPlaintextFrameHelper::State &state, APIPlaintextFrameHelper::State failed_state); #endif // USE_API_PLAINTEXT } // namespace api diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index 56d8bf1973..59f3cf7471 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -72,6 +72,12 @@ class APIFrameHelper { virtual APIError shutdown(int how) = 0; // Give this helper a name for logging virtual void set_log_info(std::string info) = 0; + + protected: + // Common implementation for writing raw data to socket + template + APIError write_raw_(const struct iovec *iov, int iovcnt, socket::Socket *socket, std::vector &tx_buf, + const std::string &info, StateEnum &state, StateEnum failed_state); }; #ifdef USE_API_NOISE @@ -103,7 +109,9 @@ class APINoiseFrameHelper : public APIFrameHelper { APIError try_read_frame_(ParsedFrame *frame); APIError try_send_tx_buf_(); APIError write_frame_(const uint8_t *data, size_t len); - APIError write_raw_(const struct iovec *iov, int iovcnt); + inline APIError write_raw_(const struct iovec *iov, int iovcnt) { + return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); + } APIError init_handshake_(); APIError check_handshake_finished_(); void send_explicit_handshake_reject_(const std::string &reason); @@ -164,7 +172,9 @@ class APIPlaintextFrameHelper : public APIFrameHelper { APIError try_read_frame_(ParsedFrame *frame); APIError try_send_tx_buf_(); - APIError write_raw_(const struct iovec *iov, int iovcnt); + inline APIError write_raw_(const struct iovec *iov, int iovcnt) { + return APIFrameHelper::write_raw_(iov, iovcnt, socket_.get(), tx_buf_, info_, state_, State::FAILED); + } std::unique_ptr socket_; diff --git a/esphome/components/api/api_noise_context.h b/esphome/components/api/api_noise_context.h index 324e69d945..fa4435e570 100644 --- a/esphome/components/api/api_noise_context.h +++ b/esphome/components/api/api_noise_context.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include #include "esphome/core/defines.h" namespace esphome { @@ -11,11 +11,20 @@ using psk_t = std::array; class APINoiseContext { public: - void set_psk(psk_t psk) { psk_ = psk; } - const psk_t &get_psk() const { return psk_; } + void set_psk(psk_t psk) { + this->psk_ = psk; + bool has_psk = false; + for (auto i : psk) { + has_psk |= i; + } + this->has_psk_ = has_psk; + } + const psk_t &get_psk() const { return this->psk_; } + bool has_psk() const { return this->has_psk_; } protected: - psk_t psk_; + psk_t psk_{}; + bool has_psk_{false}; }; #endif // USE_API_NOISE diff --git a/esphome/components/api/api_pb2.cpp b/esphome/components/api/api_pb2.cpp index 8001a74b6d..e3181b6166 100644 --- a/esphome/components/api/api_pb2.cpp +++ b/esphome/components/api/api_pb2.cpp @@ -1,6 +1,7 @@ // This file was automatically generated with a tool. -// See scripts/api_protobuf/api_protobuf.py +// See script/api_protobuf/api_protobuf.py #include "api_pb2.h" +#include "api_pb2_size.h" #include "esphome/core/log.h" #include @@ -422,6 +423,38 @@ const char *proto_enum_to_string(enums::Bluet } #endif #ifdef HAS_PROTO_MESSAGE_DUMP +template<> const char *proto_enum_to_string(enums::BluetoothScannerState value) { + switch (value) { + case enums::BLUETOOTH_SCANNER_STATE_IDLE: + return "BLUETOOTH_SCANNER_STATE_IDLE"; + case enums::BLUETOOTH_SCANNER_STATE_STARTING: + return "BLUETOOTH_SCANNER_STATE_STARTING"; + case enums::BLUETOOTH_SCANNER_STATE_RUNNING: + return "BLUETOOTH_SCANNER_STATE_RUNNING"; + case enums::BLUETOOTH_SCANNER_STATE_FAILED: + return "BLUETOOTH_SCANNER_STATE_FAILED"; + case enums::BLUETOOTH_SCANNER_STATE_STOPPING: + return "BLUETOOTH_SCANNER_STATE_STOPPING"; + case enums::BLUETOOTH_SCANNER_STATE_STOPPED: + return "BLUETOOTH_SCANNER_STATE_STOPPED"; + default: + return "UNKNOWN"; + } +} +#endif +#ifdef HAS_PROTO_MESSAGE_DUMP +template<> const char *proto_enum_to_string(enums::BluetoothScannerMode value) { + switch (value) { + case enums::BLUETOOTH_SCANNER_MODE_PASSIVE: + return "BLUETOOTH_SCANNER_MODE_PASSIVE"; + case enums::BLUETOOTH_SCANNER_MODE_ACTIVE: + return "BLUETOOTH_SCANNER_MODE_ACTIVE"; + default: + return "UNKNOWN"; + } +} +#endif +#ifdef HAS_PROTO_MESSAGE_DUMP template<> const char *proto_enum_to_string(enums::VoiceAssistantSubscribeFlag value) { switch (value) { @@ -622,6 +655,11 @@ void HelloRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(2, this->api_version_major); buffer.encode_uint32(3, this->api_version_minor); } +void HelloRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->client_info, false); + ProtoSize::add_uint32_field(total_size, 1, this->api_version_major, false); + ProtoSize::add_uint32_field(total_size, 1, this->api_version_minor, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void HelloRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -676,6 +714,12 @@ void HelloResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(3, this->server_info); buffer.encode_string(4, this->name); } +void HelloResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint32_field(total_size, 1, this->api_version_major, false); + ProtoSize::add_uint32_field(total_size, 1, this->api_version_minor, false); + ProtoSize::add_string_field(total_size, 1, this->server_info, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void HelloResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -711,6 +755,9 @@ bool ConnectRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value } } void ConnectRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->password); } +void ConnectRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->password, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ConnectRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -732,6 +779,9 @@ bool ConnectResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { } } void ConnectResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->invalid_password); } +void ConnectResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->invalid_password, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ConnectResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -743,22 +793,27 @@ void ConnectResponse::dump_to(std::string &out) const { } #endif void DisconnectRequest::encode(ProtoWriteBuffer buffer) const {} +void DisconnectRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void DisconnectRequest::dump_to(std::string &out) const { out.append("DisconnectRequest {}"); } #endif void DisconnectResponse::encode(ProtoWriteBuffer buffer) const {} +void DisconnectResponse::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void DisconnectResponse::dump_to(std::string &out) const { out.append("DisconnectResponse {}"); } #endif void PingRequest::encode(ProtoWriteBuffer buffer) const {} +void PingRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void PingRequest::dump_to(std::string &out) const { out.append("PingRequest {}"); } #endif void PingResponse::encode(ProtoWriteBuffer buffer) const {} +void PingResponse::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void PingResponse::dump_to(std::string &out) const { out.append("PingResponse {}"); } #endif void DeviceInfoRequest::encode(ProtoWriteBuffer buffer) const {} +void DeviceInfoRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void DeviceInfoRequest::dump_to(std::string &out) const { out.append("DeviceInfoRequest {}"); } #endif @@ -792,6 +847,10 @@ bool DeviceInfoResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { this->voice_assistant_feature_flags = value.as_uint32(); return true; } + case 19: { + this->api_encryption_supported = value.as_bool(); + return true; + } default: return false; } @@ -865,6 +924,28 @@ void DeviceInfoResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(17, this->voice_assistant_feature_flags); buffer.encode_string(16, this->suggested_area); buffer.encode_string(18, this->bluetooth_mac_address); + buffer.encode_bool(19, this->api_encryption_supported); +} +void DeviceInfoResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->uses_password, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->mac_address, false); + ProtoSize::add_string_field(total_size, 1, this->esphome_version, false); + ProtoSize::add_string_field(total_size, 1, this->compilation_time, false); + ProtoSize::add_string_field(total_size, 1, this->model, false); + ProtoSize::add_bool_field(total_size, 1, this->has_deep_sleep, false); + ProtoSize::add_string_field(total_size, 1, this->project_name, false); + ProtoSize::add_string_field(total_size, 1, this->project_version, false); + ProtoSize::add_uint32_field(total_size, 1, this->webserver_port, false); + ProtoSize::add_uint32_field(total_size, 1, this->legacy_bluetooth_proxy_version, false); + ProtoSize::add_uint32_field(total_size, 1, this->bluetooth_proxy_feature_flags, false); + ProtoSize::add_string_field(total_size, 1, this->manufacturer, false); + ProtoSize::add_string_field(total_size, 1, this->friendly_name, false); + ProtoSize::add_uint32_field(total_size, 1, this->legacy_voice_assistant_version, false); + ProtoSize::add_uint32_field(total_size, 2, this->voice_assistant_feature_flags, false); + ProtoSize::add_string_field(total_size, 2, this->suggested_area, false); + ProtoSize::add_string_field(total_size, 2, this->bluetooth_mac_address, false); + ProtoSize::add_bool_field(total_size, 2, this->api_encryption_supported, false); } #ifdef HAS_PROTO_MESSAGE_DUMP void DeviceInfoResponse::dump_to(std::string &out) const { @@ -946,18 +1027,25 @@ void DeviceInfoResponse::dump_to(std::string &out) const { out.append(" bluetooth_mac_address: "); out.append("'").append(this->bluetooth_mac_address).append("'"); out.append("\n"); + + out.append(" api_encryption_supported: "); + out.append(YESNO(this->api_encryption_supported)); + out.append("\n"); out.append("}"); } #endif void ListEntitiesRequest::encode(ProtoWriteBuffer buffer) const {} +void ListEntitiesRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesRequest::dump_to(std::string &out) const { out.append("ListEntitiesRequest {}"); } #endif void ListEntitiesDoneResponse::encode(ProtoWriteBuffer buffer) const {} +void ListEntitiesDoneResponse::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesDoneResponse::dump_to(std::string &out) const { out.append("ListEntitiesDoneResponse {}"); } #endif void SubscribeStatesRequest::encode(ProtoWriteBuffer buffer) const {} +void SubscribeStatesRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeStatesRequest::dump_to(std::string &out) const { out.append("SubscribeStatesRequest {}"); } #endif @@ -1026,6 +1114,17 @@ void ListEntitiesBinarySensorResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(8, this->icon); buffer.encode_enum(9, this->entity_category); } +void ListEntitiesBinarySensorResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); + ProtoSize::add_bool_field(total_size, 1, this->is_status_binary_sensor, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesBinarySensorResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1098,6 +1197,11 @@ void BinarySensorStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(2, this->state); buffer.encode_bool(3, this->missing_state); } +void BinarySensorStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BinarySensorStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1197,6 +1301,20 @@ void ListEntitiesCoverResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(11, this->entity_category); buffer.encode_bool(12, this->supports_stop); } +void ListEntitiesCoverResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_bool_field(total_size, 1, this->assumed_state, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_position, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_tilt, false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_bool_field(total_size, 1, this->supports_stop, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesCoverResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1291,6 +1409,13 @@ void CoverStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(4, this->tilt); buffer.encode_enum(5, this->current_operation); } +void CoverStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->legacy_state), false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->position != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->tilt != 0.0f, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->current_operation), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void CoverStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1374,6 +1499,16 @@ void CoverCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(7, this->tilt); buffer.encode_bool(8, this->stop); } +void CoverCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_legacy_command, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->legacy_command), false); + ProtoSize::add_bool_field(total_size, 1, this->has_position, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->position != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_tilt, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->tilt != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->stop, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void CoverCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1497,6 +1632,24 @@ void ListEntitiesFanResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(12, it, true); } } +void ListEntitiesFanResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_oscillation, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_speed, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_direction, false); + ProtoSize::add_int32_field(total_size, 1, this->supported_speed_count, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + if (!this->supported_preset_modes.empty()) { + for (const auto &it : this->supported_preset_modes) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesFanResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1610,6 +1763,15 @@ void FanStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_int32(6, this->speed_level); buffer.encode_string(7, this->preset_mode); } +void FanStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->oscillating, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->speed), false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->direction), false); + ProtoSize::add_int32_field(total_size, 1, this->speed_level, false); + ProtoSize::add_string_field(total_size, 1, this->preset_mode, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void FanStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1731,6 +1893,21 @@ void FanCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(12, this->has_preset_mode); buffer.encode_string(13, this->preset_mode); } +void FanCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_state, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->has_speed, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->speed), false); + ProtoSize::add_bool_field(total_size, 1, this->has_oscillating, false); + ProtoSize::add_bool_field(total_size, 1, this->oscillating, false); + ProtoSize::add_bool_field(total_size, 1, this->has_direction, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->direction), false); + ProtoSize::add_bool_field(total_size, 1, this->has_speed_level, false); + ProtoSize::add_int32_field(total_size, 1, this->speed_level, false); + ProtoSize::add_bool_field(total_size, 1, this->has_preset_mode, false); + ProtoSize::add_string_field(total_size, 1, this->preset_mode, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void FanCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -1890,6 +2067,31 @@ void ListEntitiesLightResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(14, this->icon); buffer.encode_enum(15, this->entity_category); } +void ListEntitiesLightResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + if (!this->supported_color_modes.empty()) { + for (const auto &it : this->supported_color_modes) { + ProtoSize::add_enum_field(total_size, 1, static_cast(it), true); + } + } + ProtoSize::add_bool_field(total_size, 1, this->legacy_supports_brightness, false); + ProtoSize::add_bool_field(total_size, 1, this->legacy_supports_rgb, false); + ProtoSize::add_bool_field(total_size, 1, this->legacy_supports_white_value, false); + ProtoSize::add_bool_field(total_size, 1, this->legacy_supports_color_temperature, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->min_mireds != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->max_mireds != 0.0f, false); + if (!this->effects.empty()) { + for (const auto &it : this->effects) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesLightResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2048,6 +2250,21 @@ void LightStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(13, this->warm_white); buffer.encode_string(9, this->effect); } +void LightStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->brightness != 0.0f, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->color_mode), false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->color_brightness != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->red != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->green != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->blue != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->white != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->color_temperature != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->cold_white != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->warm_white != 0.0f, false); + ProtoSize::add_string_field(total_size, 1, this->effect, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void LightStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2271,6 +2488,35 @@ void LightCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(18, this->has_effect); buffer.encode_string(19, this->effect); } +void LightCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_state, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->has_brightness, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->brightness != 0.0f, false); + ProtoSize::add_bool_field(total_size, 2, this->has_color_mode, false); + ProtoSize::add_enum_field(total_size, 2, static_cast(this->color_mode), false); + ProtoSize::add_bool_field(total_size, 2, this->has_color_brightness, false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->color_brightness != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_rgb, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->red != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->green != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->blue != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_white, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->white != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_color_temperature, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->color_temperature != 0.0f, false); + ProtoSize::add_bool_field(total_size, 2, this->has_cold_white, false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->cold_white != 0.0f, false); + ProtoSize::add_bool_field(total_size, 2, this->has_warm_white, false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->warm_white != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_transition_length, false); + ProtoSize::add_uint32_field(total_size, 1, this->transition_length, false); + ProtoSize::add_bool_field(total_size, 2, this->has_flash_length, false); + ProtoSize::add_uint32_field(total_size, 2, this->flash_length, false); + ProtoSize::add_bool_field(total_size, 2, this->has_effect, false); + ProtoSize::add_string_field(total_size, 2, this->effect, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void LightCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2482,6 +2728,21 @@ void ListEntitiesSensorResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(12, this->disabled_by_default); buffer.encode_enum(13, this->entity_category); } +void ListEntitiesSensorResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_string_field(total_size, 1, this->unit_of_measurement, false); + ProtoSize::add_int32_field(total_size, 1, this->accuracy_decimals, false); + ProtoSize::add_bool_field(total_size, 1, this->force_update, false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->state_class), false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->legacy_last_reset_type), false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesSensorResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2571,6 +2832,11 @@ void SensorStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(2, this->state); buffer.encode_bool(3, this->missing_state); } +void SensorStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->state != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SensorStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2656,6 +2922,17 @@ void ListEntitiesSwitchResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(8, this->entity_category); buffer.encode_string(9, this->device_class); } +void ListEntitiesSwitchResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->assumed_state, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesSwitchResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2723,6 +3000,10 @@ void SwitchStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_bool(2, this->state); } +void SwitchStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SwitchStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2762,6 +3043,10 @@ void SwitchCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_bool(2, this->state); } +void SwitchCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SwitchCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2837,6 +3122,16 @@ void ListEntitiesTextSensorResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(7, this->entity_category); buffer.encode_string(8, this->device_class); } +void ListEntitiesTextSensorResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesTextSensorResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2911,6 +3206,11 @@ void TextSensorStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(2, this->state); buffer.encode_bool(3, this->missing_state); } +void TextSensorStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void TextSensorStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2948,6 +3248,10 @@ void SubscribeLogsRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(1, this->level); buffer.encode_bool(2, this->dump_config); } +void SubscribeLogsRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_enum_field(total_size, 1, static_cast(this->level), false); + ProtoSize::add_bool_field(total_size, 1, this->dump_config, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeLogsRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -2991,6 +3295,11 @@ void SubscribeLogsResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(3, this->message); buffer.encode_bool(4, this->send_failed); } +void SubscribeLogsResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_enum_field(total_size, 1, static_cast(this->level), false); + ProtoSize::add_string_field(total_size, 1, this->message, false); + ProtoSize::add_bool_field(total_size, 1, this->send_failed, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeLogsResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3009,7 +3318,56 @@ void SubscribeLogsResponse::dump_to(std::string &out) const { out.append("}"); } #endif +bool NoiseEncryptionSetKeyRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 1: { + this->key = value.as_string(); + return true; + } + default: + return false; + } +} +void NoiseEncryptionSetKeyRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->key); } +void NoiseEncryptionSetKeyRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->key, false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void NoiseEncryptionSetKeyRequest::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("NoiseEncryptionSetKeyRequest {\n"); + out.append(" key: "); + out.append("'").append(this->key).append("'"); + out.append("\n"); + out.append("}"); +} +#endif +bool NoiseEncryptionSetKeyResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: { + this->success = value.as_bool(); + return true; + } + default: + return false; + } +} +void NoiseEncryptionSetKeyResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->success); } +void NoiseEncryptionSetKeyResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->success, false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void NoiseEncryptionSetKeyResponse::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("NoiseEncryptionSetKeyResponse {\n"); + out.append(" success: "); + out.append(YESNO(this->success)); + out.append("\n"); + out.append("}"); +} +#endif void SubscribeHomeassistantServicesRequest::encode(ProtoWriteBuffer buffer) const {} +void SubscribeHomeassistantServicesRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeHomeassistantServicesRequest::dump_to(std::string &out) const { out.append("SubscribeHomeassistantServicesRequest {}"); @@ -3033,6 +3391,10 @@ void HomeassistantServiceMap::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->key); buffer.encode_string(2, this->value); } +void HomeassistantServiceMap::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->key, false); + ProtoSize::add_string_field(total_size, 1, this->value, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void HomeassistantServiceMap::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3092,6 +3454,13 @@ void HomeassistantServiceResponse::encode(ProtoWriteBuffer buffer) const { } buffer.encode_bool(5, this->is_event); } +void HomeassistantServiceResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->service, false); + ProtoSize::add_repeated_message(total_size, 1, this->data); + ProtoSize::add_repeated_message(total_size, 1, this->data_template); + ProtoSize::add_repeated_message(total_size, 1, this->variables); + ProtoSize::add_bool_field(total_size, 1, this->is_event, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void HomeassistantServiceResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3125,6 +3494,7 @@ void HomeassistantServiceResponse::dump_to(std::string &out) const { } #endif void SubscribeHomeAssistantStatesRequest::encode(ProtoWriteBuffer buffer) const {} +void SubscribeHomeAssistantStatesRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeHomeAssistantStatesRequest::dump_to(std::string &out) const { out.append("SubscribeHomeAssistantStatesRequest {}"); @@ -3159,6 +3529,11 @@ void SubscribeHomeAssistantStateResponse::encode(ProtoWriteBuffer buffer) const buffer.encode_string(2, this->attribute); buffer.encode_bool(3, this->once); } +void SubscribeHomeAssistantStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->entity_id, false); + ProtoSize::add_string_field(total_size, 1, this->attribute, false); + ProtoSize::add_bool_field(total_size, 1, this->once, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeHomeAssistantStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3200,6 +3575,11 @@ void HomeAssistantStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(2, this->state); buffer.encode_string(3, this->attribute); } +void HomeAssistantStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->entity_id, false); + ProtoSize::add_string_field(total_size, 1, this->state, false); + ProtoSize::add_string_field(total_size, 1, this->attribute, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void HomeAssistantStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3219,6 +3599,7 @@ void HomeAssistantStateResponse::dump_to(std::string &out) const { } #endif void GetTimeRequest::encode(ProtoWriteBuffer buffer) const {} +void GetTimeRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void GetTimeRequest::dump_to(std::string &out) const { out.append("GetTimeRequest {}"); } #endif @@ -3233,6 +3614,9 @@ bool GetTimeResponse::decode_32bit(uint32_t field_id, Proto32Bit value) { } } void GetTimeResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->epoch_seconds); } +void GetTimeResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->epoch_seconds != 0, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void GetTimeResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3268,6 +3652,10 @@ void ListEntitiesServicesArgument::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->name); buffer.encode_enum(2, this->type); } +void ListEntitiesServicesArgument::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->type), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesServicesArgument::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3313,6 +3701,11 @@ void ListEntitiesServicesResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(3, it, true); } } +void ListEntitiesServicesResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_repeated_message(total_size, 1, this->args); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesServicesResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3407,6 +3800,33 @@ void ExecuteServiceArgument::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(9, it, true); } } +void ExecuteServiceArgument::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->bool_, false); + ProtoSize::add_int32_field(total_size, 1, this->legacy_int, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->float_ != 0.0f, false); + ProtoSize::add_string_field(total_size, 1, this->string_, false); + ProtoSize::add_sint32_field(total_size, 1, this->int_, false); + if (!this->bool_array.empty()) { + for (const auto it : this->bool_array) { + ProtoSize::add_bool_field(total_size, 1, it, true); + } + } + if (!this->int_array.empty()) { + for (const auto &it : this->int_array) { + ProtoSize::add_sint32_field(total_size, 1, it, true); + } + } + if (!this->float_array.empty()) { + for (const auto &it : this->float_array) { + ProtoSize::add_fixed_field<4>(total_size, 1, it != 0.0f, true); + } + } + if (!this->string_array.empty()) { + for (const auto &it : this->string_array) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } +} #ifdef HAS_PROTO_MESSAGE_DUMP void ExecuteServiceArgument::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3488,6 +3908,10 @@ void ExecuteServiceRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(2, it, true); } } +void ExecuteServiceRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_repeated_message(total_size, 1, this->args); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ExecuteServiceRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3560,6 +3984,15 @@ void ListEntitiesCameraResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(6, this->icon); buffer.encode_enum(7, this->entity_category); } +void ListEntitiesCameraResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesCameraResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3630,6 +4063,11 @@ void CameraImageResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(2, this->data); buffer.encode_bool(3, this->done); } +void CameraImageResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->data, false); + ProtoSize::add_bool_field(total_size, 1, this->done, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void CameraImageResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3667,6 +4105,10 @@ void CameraImageRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->single); buffer.encode_bool(2, this->stream); } +void CameraImageRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->single, false); + ProtoSize::add_bool_field(total_size, 1, this->stream, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void CameraImageRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -3838,6 +4280,57 @@ void ListEntitiesClimateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(24, this->visual_min_humidity); buffer.encode_float(25, this->visual_max_humidity); } +void ListEntitiesClimateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_current_temperature, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_two_point_target_temperature, false); + if (!this->supported_modes.empty()) { + for (const auto &it : this->supported_modes) { + ProtoSize::add_enum_field(total_size, 1, static_cast(it), true); + } + } + ProtoSize::add_fixed_field<4>(total_size, 1, this->visual_min_temperature != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->visual_max_temperature != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->visual_target_temperature_step != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->legacy_supports_away, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_action, false); + if (!this->supported_fan_modes.empty()) { + for (const auto &it : this->supported_fan_modes) { + ProtoSize::add_enum_field(total_size, 1, static_cast(it), true); + } + } + if (!this->supported_swing_modes.empty()) { + for (const auto &it : this->supported_swing_modes) { + ProtoSize::add_enum_field(total_size, 1, static_cast(it), true); + } + } + if (!this->supported_custom_fan_modes.empty()) { + for (const auto &it : this->supported_custom_fan_modes) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } + if (!this->supported_presets.empty()) { + for (const auto &it : this->supported_presets) { + ProtoSize::add_enum_field(total_size, 2, static_cast(it), true); + } + } + if (!this->supported_custom_presets.empty()) { + for (const auto &it : this->supported_custom_presets) { + ProtoSize::add_string_field(total_size, 2, it, true); + } + } + ProtoSize::add_bool_field(total_size, 2, this->disabled_by_default, false); + ProtoSize::add_string_field(total_size, 2, this->icon, false); + ProtoSize::add_enum_field(total_size, 2, static_cast(this->entity_category), false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->visual_current_temperature_step != 0.0f, false); + ProtoSize::add_bool_field(total_size, 2, this->supports_current_humidity, false); + ProtoSize::add_bool_field(total_size, 2, this->supports_target_humidity, false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->visual_min_humidity != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->visual_max_humidity != 0.0f, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesClimateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4058,6 +4551,23 @@ void ClimateStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(14, this->current_humidity); buffer.encode_float(15, this->target_humidity); } +void ClimateStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->mode), false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->current_temperature != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_temperature != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_temperature_low != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_temperature_high != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->unused_legacy_away, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->action), false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->fan_mode), false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->swing_mode), false); + ProtoSize::add_string_field(total_size, 1, this->custom_fan_mode, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->preset), false); + ProtoSize::add_string_field(total_size, 1, this->custom_preset, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->current_humidity != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_humidity != 0.0f, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ClimateStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4266,6 +4776,31 @@ void ClimateCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(22, this->has_target_humidity); buffer.encode_float(23, this->target_humidity); } +void ClimateCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_mode, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->mode), false); + ProtoSize::add_bool_field(total_size, 1, this->has_target_temperature, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_temperature != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_target_temperature_low, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_temperature_low != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_target_temperature_high, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->target_temperature_high != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->unused_has_legacy_away, false); + ProtoSize::add_bool_field(total_size, 1, this->unused_legacy_away, false); + ProtoSize::add_bool_field(total_size, 1, this->has_fan_mode, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->fan_mode), false); + ProtoSize::add_bool_field(total_size, 1, this->has_swing_mode, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->swing_mode), false); + ProtoSize::add_bool_field(total_size, 2, this->has_custom_fan_mode, false); + ProtoSize::add_string_field(total_size, 2, this->custom_fan_mode, false); + ProtoSize::add_bool_field(total_size, 2, this->has_preset, false); + ProtoSize::add_enum_field(total_size, 2, static_cast(this->preset), false); + ProtoSize::add_bool_field(total_size, 2, this->has_custom_preset, false); + ProtoSize::add_string_field(total_size, 2, this->custom_preset, false); + ProtoSize::add_bool_field(total_size, 2, this->has_target_humidity, false); + ProtoSize::add_fixed_field<4>(total_size, 2, this->target_humidity != 0.0f, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ClimateCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4454,6 +4989,21 @@ void ListEntitiesNumberResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(12, this->mode); buffer.encode_string(13, this->device_class); } +void ListEntitiesNumberResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->min_value != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->max_value != 0.0f, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->step != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->unit_of_measurement, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->mode), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesNumberResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4545,6 +5095,11 @@ void NumberStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(2, this->state); buffer.encode_bool(3, this->missing_state); } +void NumberStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->state != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void NumberStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4583,6 +5138,10 @@ void NumberCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_float(2, this->state); } +void NumberCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->state != 0.0f, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void NumberCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4661,6 +5220,20 @@ void ListEntitiesSelectResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(7, this->disabled_by_default); buffer.encode_enum(8, this->entity_category); } +void ListEntitiesSelectResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + if (!this->options.empty()) { + for (const auto &it : this->options) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesSelectResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4737,6 +5310,11 @@ void SelectStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(2, this->state); buffer.encode_bool(3, this->missing_state); } +void SelectStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SelectStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4780,6 +5358,10 @@ void SelectCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_string(2, this->state); } +void SelectCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SelectCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4795,6 +5377,307 @@ void SelectCommandRequest::dump_to(std::string &out) const { out.append("}"); } #endif +bool ListEntitiesSirenResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 6: { + this->disabled_by_default = value.as_bool(); + return true; + } + case 8: { + this->supports_duration = value.as_bool(); + return true; + } + case 9: { + this->supports_volume = value.as_bool(); + return true; + } + case 10: { + this->entity_category = value.as_enum(); + return true; + } + default: + return false; + } +} +bool ListEntitiesSirenResponse::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 1: { + this->object_id = value.as_string(); + return true; + } + case 3: { + this->name = value.as_string(); + return true; + } + case 4: { + this->unique_id = value.as_string(); + return true; + } + case 5: { + this->icon = value.as_string(); + return true; + } + case 7: { + this->tones.push_back(value.as_string()); + return true; + } + default: + return false; + } +} +bool ListEntitiesSirenResponse::decode_32bit(uint32_t field_id, Proto32Bit value) { + switch (field_id) { + case 2: { + this->key = value.as_fixed32(); + return true; + } + default: + return false; + } +} +void ListEntitiesSirenResponse::encode(ProtoWriteBuffer buffer) const { + buffer.encode_string(1, this->object_id); + buffer.encode_fixed32(2, this->key); + buffer.encode_string(3, this->name); + buffer.encode_string(4, this->unique_id); + buffer.encode_string(5, this->icon); + buffer.encode_bool(6, this->disabled_by_default); + for (auto &it : this->tones) { + buffer.encode_string(7, it, true); + } + buffer.encode_bool(8, this->supports_duration); + buffer.encode_bool(9, this->supports_volume); + buffer.encode_enum(10, this->entity_category); +} +void ListEntitiesSirenResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + if (!this->tones.empty()) { + for (const auto &it : this->tones) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } + ProtoSize::add_bool_field(total_size, 1, this->supports_duration, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_volume, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void ListEntitiesSirenResponse::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("ListEntitiesSirenResponse {\n"); + out.append(" object_id: "); + out.append("'").append(this->object_id).append("'"); + out.append("\n"); + + out.append(" key: "); + sprintf(buffer, "%" PRIu32, this->key); + out.append(buffer); + out.append("\n"); + + out.append(" name: "); + out.append("'").append(this->name).append("'"); + out.append("\n"); + + out.append(" unique_id: "); + out.append("'").append(this->unique_id).append("'"); + out.append("\n"); + + out.append(" icon: "); + out.append("'").append(this->icon).append("'"); + out.append("\n"); + + out.append(" disabled_by_default: "); + out.append(YESNO(this->disabled_by_default)); + out.append("\n"); + + for (const auto &it : this->tones) { + out.append(" tones: "); + out.append("'").append(it).append("'"); + out.append("\n"); + } + + out.append(" supports_duration: "); + out.append(YESNO(this->supports_duration)); + out.append("\n"); + + out.append(" supports_volume: "); + out.append(YESNO(this->supports_volume)); + out.append("\n"); + + out.append(" entity_category: "); + out.append(proto_enum_to_string(this->entity_category)); + out.append("\n"); + out.append("}"); +} +#endif +bool SirenStateResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 2: { + this->state = value.as_bool(); + return true; + } + default: + return false; + } +} +bool SirenStateResponse::decode_32bit(uint32_t field_id, Proto32Bit value) { + switch (field_id) { + case 1: { + this->key = value.as_fixed32(); + return true; + } + default: + return false; + } +} +void SirenStateResponse::encode(ProtoWriteBuffer buffer) const { + buffer.encode_fixed32(1, this->key); + buffer.encode_bool(2, this->state); +} +void SirenStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void SirenStateResponse::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("SirenStateResponse {\n"); + out.append(" key: "); + sprintf(buffer, "%" PRIu32, this->key); + out.append(buffer); + out.append("\n"); + + out.append(" state: "); + out.append(YESNO(this->state)); + out.append("\n"); + out.append("}"); +} +#endif +bool SirenCommandRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 2: { + this->has_state = value.as_bool(); + return true; + } + case 3: { + this->state = value.as_bool(); + return true; + } + case 4: { + this->has_tone = value.as_bool(); + return true; + } + case 6: { + this->has_duration = value.as_bool(); + return true; + } + case 7: { + this->duration = value.as_uint32(); + return true; + } + case 8: { + this->has_volume = value.as_bool(); + return true; + } + default: + return false; + } +} +bool SirenCommandRequest::decode_length(uint32_t field_id, ProtoLengthDelimited value) { + switch (field_id) { + case 5: { + this->tone = value.as_string(); + return true; + } + default: + return false; + } +} +bool SirenCommandRequest::decode_32bit(uint32_t field_id, Proto32Bit value) { + switch (field_id) { + case 1: { + this->key = value.as_fixed32(); + return true; + } + case 9: { + this->volume = value.as_float(); + return true; + } + default: + return false; + } +} +void SirenCommandRequest::encode(ProtoWriteBuffer buffer) const { + buffer.encode_fixed32(1, this->key); + buffer.encode_bool(2, this->has_state); + buffer.encode_bool(3, this->state); + buffer.encode_bool(4, this->has_tone); + buffer.encode_string(5, this->tone); + buffer.encode_bool(6, this->has_duration); + buffer.encode_uint32(7, this->duration); + buffer.encode_bool(8, this->has_volume); + buffer.encode_float(9, this->volume); +} +void SirenCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_state, false); + ProtoSize::add_bool_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->has_tone, false); + ProtoSize::add_string_field(total_size, 1, this->tone, false); + ProtoSize::add_bool_field(total_size, 1, this->has_duration, false); + ProtoSize::add_uint32_field(total_size, 1, this->duration, false); + ProtoSize::add_bool_field(total_size, 1, this->has_volume, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->volume != 0.0f, false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void SirenCommandRequest::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("SirenCommandRequest {\n"); + out.append(" key: "); + sprintf(buffer, "%" PRIu32, this->key); + out.append(buffer); + out.append("\n"); + + out.append(" has_state: "); + out.append(YESNO(this->has_state)); + out.append("\n"); + + out.append(" state: "); + out.append(YESNO(this->state)); + out.append("\n"); + + out.append(" has_tone: "); + out.append(YESNO(this->has_tone)); + out.append("\n"); + + out.append(" tone: "); + out.append("'").append(this->tone).append("'"); + out.append("\n"); + + out.append(" has_duration: "); + out.append(YESNO(this->has_duration)); + out.append("\n"); + + out.append(" duration: "); + sprintf(buffer, "%" PRIu32, this->duration); + out.append(buffer); + out.append("\n"); + + out.append(" has_volume: "); + out.append(YESNO(this->has_volume)); + out.append("\n"); + + out.append(" volume: "); + sprintf(buffer, "%g", this->volume); + out.append(buffer); + out.append("\n"); + out.append("}"); +} +#endif bool ListEntitiesLockResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { switch (field_id) { case 6: { @@ -4870,6 +5753,19 @@ void ListEntitiesLockResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(10, this->requires_code); buffer.encode_string(11, this->code_format); } +void ListEntitiesLockResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_bool_field(total_size, 1, this->assumed_state, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_open, false); + ProtoSize::add_bool_field(total_size, 1, this->requires_code, false); + ProtoSize::add_string_field(total_size, 1, this->code_format, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesLockResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -4945,6 +5841,10 @@ void LockStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_enum(2, this->state); } +void LockStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->state), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void LockStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5000,6 +5900,12 @@ void LockCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(3, this->has_code); buffer.encode_string(4, this->code); } +void LockCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->command), false); + ProtoSize::add_bool_field(total_size, 1, this->has_code, false); + ProtoSize::add_string_field(total_size, 1, this->code, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void LockCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5083,6 +5989,16 @@ void ListEntitiesButtonResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(7, this->entity_category); buffer.encode_string(8, this->device_class); } +void ListEntitiesButtonResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesButtonResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5133,6 +6049,9 @@ bool ButtonCommandRequest::decode_32bit(uint32_t field_id, Proto32Bit value) { } } void ButtonCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); } +void ButtonCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ButtonCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5183,6 +6102,13 @@ void MediaPlayerSupportedFormat::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(4, this->purpose); buffer.encode_uint32(5, this->sample_bytes); } +void MediaPlayerSupportedFormat::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->format, false); + ProtoSize::add_uint32_field(total_size, 1, this->sample_rate, false); + ProtoSize::add_uint32_field(total_size, 1, this->num_channels, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->purpose), false); + ProtoSize::add_uint32_field(total_size, 1, this->sample_bytes, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void MediaPlayerSupportedFormat::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5279,6 +6205,17 @@ void ListEntitiesMediaPlayerResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(9, it, true); } } +void ListEntitiesMediaPlayerResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_bool_field(total_size, 1, this->supports_pause, false); + ProtoSize::add_repeated_message(total_size, 1, this->supported_formats); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesMediaPlayerResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5358,6 +6295,12 @@ void MediaPlayerStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(3, this->volume); buffer.encode_bool(4, this->muted); } +void MediaPlayerStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->state), false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->volume != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->muted, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void MediaPlayerStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5447,6 +6390,17 @@ void MediaPlayerCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(8, this->has_announcement); buffer.encode_bool(9, this->announcement); } +void MediaPlayerCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_command, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->command), false); + ProtoSize::add_bool_field(total_size, 1, this->has_volume, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->volume != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->has_media_url, false); + ProtoSize::add_string_field(total_size, 1, this->media_url, false); + ProtoSize::add_bool_field(total_size, 1, this->has_announcement, false); + ProtoSize::add_bool_field(total_size, 1, this->announcement, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void MediaPlayerCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5504,6 +6458,9 @@ bool SubscribeBluetoothLEAdvertisementsRequest::decode_varint(uint32_t field_id, void SubscribeBluetoothLEAdvertisementsRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(1, this->flags); } +void SubscribeBluetoothLEAdvertisementsRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint32_field(total_size, 1, this->flags, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeBluetoothLEAdvertisementsRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5546,6 +6503,15 @@ void BluetoothServiceData::encode(ProtoWriteBuffer buffer) const { } buffer.encode_string(3, this->data); } +void BluetoothServiceData::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->uuid, false); + if (!this->legacy_data.empty()) { + for (const auto &it : this->legacy_data) { + ProtoSize::add_uint32_field(total_size, 1, it, true); + } + } + ProtoSize::add_string_field(total_size, 1, this->data, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothServiceData::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5622,6 +6588,19 @@ void BluetoothLEAdvertisementResponse::encode(ProtoWriteBuffer buffer) const { } buffer.encode_uint32(7, this->address_type); } +void BluetoothLEAdvertisementResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_sint32_field(total_size, 1, this->rssi, false); + if (!this->service_uuids.empty()) { + for (const auto &it : this->service_uuids) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } + ProtoSize::add_repeated_message(total_size, 1, this->service_data); + ProtoSize::add_repeated_message(total_size, 1, this->manufacturer_data); + ProtoSize::add_uint32_field(total_size, 1, this->address_type, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothLEAdvertisementResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5699,6 +6678,12 @@ void BluetoothLERawAdvertisement::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(3, this->address_type); buffer.encode_string(4, this->data); } +void BluetoothLERawAdvertisement::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_sint32_field(total_size, 1, this->rssi, false); + ProtoSize::add_uint32_field(total_size, 1, this->address_type, false); + ProtoSize::add_string_field(total_size, 1, this->data, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothLERawAdvertisement::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5739,6 +6724,9 @@ void BluetoothLERawAdvertisementsResponse::encode(ProtoWriteBuffer buffer) const buffer.encode_message(1, it, true); } } +void BluetoothLERawAdvertisementsResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_repeated_message(total_size, 1, this->advertisements); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothLERawAdvertisementsResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5779,6 +6767,12 @@ void BluetoothDeviceRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(3, this->has_address_type); buffer.encode_uint32(4, this->address_type); } +void BluetoothDeviceRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->request_type), false); + ProtoSize::add_bool_field(total_size, 1, this->has_address_type, false); + ProtoSize::add_uint32_field(total_size, 1, this->address_type, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothDeviceRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5831,6 +6825,12 @@ void BluetoothDeviceConnectionResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(3, this->mtu); buffer.encode_int32(4, this->error); } +void BluetoothDeviceConnectionResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_bool_field(total_size, 1, this->connected, false); + ProtoSize::add_uint32_field(total_size, 1, this->mtu, false); + ProtoSize::add_int32_field(total_size, 1, this->error, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothDeviceConnectionResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5867,6 +6867,9 @@ bool BluetoothGATTGetServicesRequest::decode_varint(uint32_t field_id, ProtoVarI } } void BluetoothGATTGetServicesRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(1, this->address); } +void BluetoothGATTGetServicesRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTGetServicesRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5898,6 +6901,14 @@ void BluetoothGATTDescriptor::encode(ProtoWriteBuffer buffer) const { } buffer.encode_uint32(2, this->handle); } +void BluetoothGATTDescriptor::calculate_size(uint32_t &total_size) const { + if (!this->uuid.empty()) { + for (const auto &it : this->uuid) { + ProtoSize::add_uint64_field(total_size, 1, it, true); + } + } + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTDescriptor::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -5954,6 +6965,16 @@ void BluetoothGATTCharacteristic::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(4, it, true); } } +void BluetoothGATTCharacteristic::calculate_size(uint32_t &total_size) const { + if (!this->uuid.empty()) { + for (const auto &it : this->uuid) { + ProtoSize::add_uint64_field(total_size, 1, it, true); + } + } + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_uint32_field(total_size, 1, this->properties, false); + ProtoSize::add_repeated_message(total_size, 1, this->descriptors); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTCharacteristic::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6016,6 +7037,15 @@ void BluetoothGATTService::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(3, it, true); } } +void BluetoothGATTService::calculate_size(uint32_t &total_size) const { + if (!this->uuid.empty()) { + for (const auto &it : this->uuid) { + ProtoSize::add_uint64_field(total_size, 1, it, true); + } + } + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_repeated_message(total_size, 1, this->characteristics); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTService::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6066,6 +7096,10 @@ void BluetoothGATTGetServicesResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(2, it, true); } } +void BluetoothGATTGetServicesResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_repeated_message(total_size, 1, this->services); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTGetServicesResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6096,6 +7130,9 @@ bool BluetoothGATTGetServicesDoneResponse::decode_varint(uint32_t field_id, Prot void BluetoothGATTGetServicesDoneResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(1, this->address); } +void BluetoothGATTGetServicesDoneResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTGetServicesDoneResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6125,6 +7162,10 @@ void BluetoothGATTReadRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(1, this->address); buffer.encode_uint32(2, this->handle); } +void BluetoothGATTReadRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTReadRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6170,6 +7211,11 @@ void BluetoothGATTReadResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(2, this->handle); buffer.encode_string(3, this->data); } +void BluetoothGATTReadResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_string_field(total_size, 1, this->data, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTReadResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6224,6 +7270,12 @@ void BluetoothGATTWriteRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(3, this->response); buffer.encode_string(4, this->data); } +void BluetoothGATTWriteRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_bool_field(total_size, 1, this->response, false); + ProtoSize::add_string_field(total_size, 1, this->data, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTWriteRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6266,6 +7318,10 @@ void BluetoothGATTReadDescriptorRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(1, this->address); buffer.encode_uint32(2, this->handle); } +void BluetoothGATTReadDescriptorRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTReadDescriptorRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6311,6 +7367,11 @@ void BluetoothGATTWriteDescriptorRequest::encode(ProtoWriteBuffer buffer) const buffer.encode_uint32(2, this->handle); buffer.encode_string(3, this->data); } +void BluetoothGATTWriteDescriptorRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_string_field(total_size, 1, this->data, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTWriteDescriptorRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6354,6 +7415,11 @@ void BluetoothGATTNotifyRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(2, this->handle); buffer.encode_bool(3, this->enable); } +void BluetoothGATTNotifyRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_bool_field(total_size, 1, this->enable, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTNotifyRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6403,6 +7469,11 @@ void BluetoothGATTNotifyDataResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(2, this->handle); buffer.encode_string(3, this->data); } +void BluetoothGATTNotifyDataResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_string_field(total_size, 1, this->data, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTNotifyDataResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6424,6 +7495,7 @@ void BluetoothGATTNotifyDataResponse::dump_to(std::string &out) const { } #endif void SubscribeBluetoothConnectionsFreeRequest::encode(ProtoWriteBuffer buffer) const {} +void SubscribeBluetoothConnectionsFreeRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeBluetoothConnectionsFreeRequest::dump_to(std::string &out) const { out.append("SubscribeBluetoothConnectionsFreeRequest {}"); @@ -6454,6 +7526,15 @@ void BluetoothConnectionsFreeResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(3, it, true); } } +void BluetoothConnectionsFreeResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint32_field(total_size, 1, this->free, false); + ProtoSize::add_uint32_field(total_size, 1, this->limit, false); + if (!this->allocated.empty()) { + for (const auto &it : this->allocated) { + ProtoSize::add_uint64_field(total_size, 1, it, true); + } + } +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothConnectionsFreeResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6500,6 +7581,11 @@ void BluetoothGATTErrorResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(2, this->handle); buffer.encode_int32(3, this->error); } +void BluetoothGATTErrorResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); + ProtoSize::add_int32_field(total_size, 1, this->error, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTErrorResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6539,6 +7625,10 @@ void BluetoothGATTWriteResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(1, this->address); buffer.encode_uint32(2, this->handle); } +void BluetoothGATTWriteResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTWriteResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6573,6 +7663,10 @@ void BluetoothGATTNotifyResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint64(1, this->address); buffer.encode_uint32(2, this->handle); } +void BluetoothGATTNotifyResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_uint32_field(total_size, 1, this->handle, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothGATTNotifyResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6612,6 +7706,11 @@ void BluetoothDevicePairingResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(2, this->paired); buffer.encode_int32(3, this->error); } +void BluetoothDevicePairingResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_bool_field(total_size, 1, this->paired, false); + ProtoSize::add_int32_field(total_size, 1, this->error, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothDevicePairingResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6655,6 +7754,11 @@ void BluetoothDeviceUnpairingResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(2, this->success); buffer.encode_int32(3, this->error); } +void BluetoothDeviceUnpairingResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_bool_field(total_size, 1, this->success, false); + ProtoSize::add_int32_field(total_size, 1, this->error, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothDeviceUnpairingResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6676,6 +7780,7 @@ void BluetoothDeviceUnpairingResponse::dump_to(std::string &out) const { } #endif void UnsubscribeBluetoothLEAdvertisementsRequest::encode(ProtoWriteBuffer buffer) const {} +void UnsubscribeBluetoothLEAdvertisementsRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void UnsubscribeBluetoothLEAdvertisementsRequest::dump_to(std::string &out) const { out.append("UnsubscribeBluetoothLEAdvertisementsRequest {}"); @@ -6704,6 +7809,11 @@ void BluetoothDeviceClearCacheResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(2, this->success); buffer.encode_int32(3, this->error); } +void BluetoothDeviceClearCacheResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint64_field(total_size, 1, this->address, false); + ProtoSize::add_bool_field(total_size, 1, this->success, false); + ProtoSize::add_int32_field(total_size, 1, this->error, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void BluetoothDeviceClearCacheResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6724,6 +7834,68 @@ void BluetoothDeviceClearCacheResponse::dump_to(std::string &out) const { out.append("}"); } #endif +bool BluetoothScannerStateResponse::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: { + this->state = value.as_enum(); + return true; + } + case 2: { + this->mode = value.as_enum(); + return true; + } + default: + return false; + } +} +void BluetoothScannerStateResponse::encode(ProtoWriteBuffer buffer) const { + buffer.encode_enum(1, this->state); + buffer.encode_enum(2, this->mode); +} +void BluetoothScannerStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_enum_field(total_size, 1, static_cast(this->state), false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->mode), false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void BluetoothScannerStateResponse::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("BluetoothScannerStateResponse {\n"); + out.append(" state: "); + out.append(proto_enum_to_string(this->state)); + out.append("\n"); + + out.append(" mode: "); + out.append(proto_enum_to_string(this->mode)); + out.append("\n"); + out.append("}"); +} +#endif +bool BluetoothScannerSetModeRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { + switch (field_id) { + case 1: { + this->mode = value.as_enum(); + return true; + } + default: + return false; + } +} +void BluetoothScannerSetModeRequest::encode(ProtoWriteBuffer buffer) const { + buffer.encode_enum(1, this->mode); +} +void BluetoothScannerSetModeRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_enum_field(total_size, 1, static_cast(this->mode), false); +} +#ifdef HAS_PROTO_MESSAGE_DUMP +void BluetoothScannerSetModeRequest::dump_to(std::string &out) const { + __attribute__((unused)) char buffer[64]; + out.append("BluetoothScannerSetModeRequest {\n"); + out.append(" mode: "); + out.append(proto_enum_to_string(this->mode)); + out.append("\n"); + out.append("}"); +} +#endif bool SubscribeVoiceAssistantRequest::decode_varint(uint32_t field_id, ProtoVarInt value) { switch (field_id) { case 1: { @@ -6742,6 +7914,10 @@ void SubscribeVoiceAssistantRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->subscribe); buffer.encode_uint32(2, this->flags); } +void SubscribeVoiceAssistantRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->subscribe, false); + ProtoSize::add_uint32_field(total_size, 1, this->flags, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void SubscribeVoiceAssistantRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6786,6 +7962,11 @@ void VoiceAssistantAudioSettings::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(2, this->auto_gain); buffer.encode_float(3, this->volume_multiplier); } +void VoiceAssistantAudioSettings::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint32_field(total_size, 1, this->noise_suppression_level, false); + ProtoSize::add_uint32_field(total_size, 1, this->auto_gain, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->volume_multiplier != 0.0f, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantAudioSettings::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6846,6 +8027,13 @@ void VoiceAssistantRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(4, this->audio_settings); buffer.encode_string(5, this->wake_word_phrase); } +void VoiceAssistantRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->start, false); + ProtoSize::add_string_field(total_size, 1, this->conversation_id, false); + ProtoSize::add_uint32_field(total_size, 1, this->flags, false); + ProtoSize::add_message_object(total_size, 1, this->audio_settings, false); + ProtoSize::add_string_field(total_size, 1, this->wake_word_phrase, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6891,6 +8079,10 @@ void VoiceAssistantResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(1, this->port); buffer.encode_bool(2, this->error); } +void VoiceAssistantResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_uint32_field(total_size, 1, this->port, false); + ProtoSize::add_bool_field(total_size, 1, this->error, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6924,6 +8116,10 @@ void VoiceAssistantEventData::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->name); buffer.encode_string(2, this->value); } +void VoiceAssistantEventData::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->value, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantEventData::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -6964,6 +8160,10 @@ void VoiceAssistantEventResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_message(2, it, true); } } +void VoiceAssistantEventResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_enum_field(total_size, 1, static_cast(this->event_type), false); + ProtoSize::add_repeated_message(total_size, 1, this->data); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantEventResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7004,6 +8204,10 @@ void VoiceAssistantAudio::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, this->data); buffer.encode_bool(2, this->end); } +void VoiceAssistantAudio::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->data, false); + ProtoSize::add_bool_field(total_size, 1, this->end, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantAudio::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7062,6 +8266,14 @@ void VoiceAssistantTimerEventResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(5, this->seconds_left); buffer.encode_bool(6, this->is_active); } +void VoiceAssistantTimerEventResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_enum_field(total_size, 1, static_cast(this->event_type), false); + ProtoSize::add_string_field(total_size, 1, this->timer_id, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_uint32_field(total_size, 1, this->total_seconds, false); + ProtoSize::add_uint32_field(total_size, 1, this->seconds_left, false); + ProtoSize::add_bool_field(total_size, 1, this->is_active, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantTimerEventResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7128,6 +8340,12 @@ void VoiceAssistantAnnounceRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(3, this->preannounce_media_id); buffer.encode_bool(4, this->start_conversation); } +void VoiceAssistantAnnounceRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->media_id, false); + ProtoSize::add_string_field(total_size, 1, this->text, false); + ProtoSize::add_string_field(total_size, 1, this->preannounce_media_id, false); + ProtoSize::add_bool_field(total_size, 1, this->start_conversation, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantAnnounceRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7161,6 +8379,9 @@ bool VoiceAssistantAnnounceFinished::decode_varint(uint32_t field_id, ProtoVarIn } } void VoiceAssistantAnnounceFinished::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(1, this->success); } +void VoiceAssistantAnnounceFinished::calculate_size(uint32_t &total_size) const { + ProtoSize::add_bool_field(total_size, 1, this->success, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantAnnounceFinished::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7196,6 +8417,15 @@ void VoiceAssistantWakeWord::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(3, it, true); } } +void VoiceAssistantWakeWord::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->id, false); + ProtoSize::add_string_field(total_size, 1, this->wake_word, false); + if (!this->trained_languages.empty()) { + for (const auto &it : this->trained_languages) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantWakeWord::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7217,6 +8447,7 @@ void VoiceAssistantWakeWord::dump_to(std::string &out) const { } #endif void VoiceAssistantConfigurationRequest::encode(ProtoWriteBuffer buffer) const {} +void VoiceAssistantConfigurationRequest::calculate_size(uint32_t &total_size) const {} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantConfigurationRequest::dump_to(std::string &out) const { out.append("VoiceAssistantConfigurationRequest {}"); @@ -7255,6 +8486,15 @@ void VoiceAssistantConfigurationResponse::encode(ProtoWriteBuffer buffer) const } buffer.encode_uint32(3, this->max_active_wake_words); } +void VoiceAssistantConfigurationResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_repeated_message(total_size, 1, this->available_wake_words); + if (!this->active_wake_words.empty()) { + for (const auto &it : this->active_wake_words) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } + ProtoSize::add_uint32_field(total_size, 1, this->max_active_wake_words, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantConfigurationResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7293,6 +8533,13 @@ void VoiceAssistantSetConfiguration::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(1, it, true); } } +void VoiceAssistantSetConfiguration::calculate_size(uint32_t &total_size) const { + if (!this->active_wake_words.empty()) { + for (const auto &it : this->active_wake_words) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } +} #ifdef HAS_PROTO_MESSAGE_DUMP void VoiceAssistantSetConfiguration::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7375,6 +8622,18 @@ void ListEntitiesAlarmControlPanelResponse::encode(ProtoWriteBuffer buffer) cons buffer.encode_bool(9, this->requires_code); buffer.encode_bool(10, this->requires_code_to_arm); } +void ListEntitiesAlarmControlPanelResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_uint32_field(total_size, 1, this->supported_features, false); + ProtoSize::add_bool_field(total_size, 1, this->requires_code, false); + ProtoSize::add_bool_field(total_size, 1, this->requires_code_to_arm, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesAlarmControlPanelResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7447,6 +8706,10 @@ void AlarmControlPanelStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_enum(2, this->state); } +void AlarmControlPanelStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->state), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void AlarmControlPanelStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7497,6 +8760,11 @@ void AlarmControlPanelCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(2, this->command); buffer.encode_string(3, this->code); } +void AlarmControlPanelCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->command), false); + ProtoSize::add_string_field(total_size, 1, this->code, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void AlarmControlPanelCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7591,6 +8859,19 @@ void ListEntitiesTextResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(10, this->pattern); buffer.encode_enum(11, this->mode); } +void ListEntitiesTextResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_uint32_field(total_size, 1, this->min_length, false); + ProtoSize::add_uint32_field(total_size, 1, this->max_length, false); + ProtoSize::add_string_field(total_size, 1, this->pattern, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->mode), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesTextResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7679,6 +8960,11 @@ void TextStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(2, this->state); buffer.encode_bool(3, this->missing_state); } +void TextStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->state, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void TextStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7722,6 +9008,10 @@ void TextCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_string(2, this->state); } +void TextCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->state, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void TextCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7792,6 +9082,15 @@ void ListEntitiesDateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(6, this->disabled_by_default); buffer.encode_enum(7, this->entity_category); } +void ListEntitiesDateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesDateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7866,6 +9165,13 @@ void DateStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(4, this->month); buffer.encode_uint32(5, this->day); } +void DateStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); + ProtoSize::add_uint32_field(total_size, 1, this->year, false); + ProtoSize::add_uint32_field(total_size, 1, this->month, false); + ProtoSize::add_uint32_field(total_size, 1, this->day, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void DateStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -7930,6 +9236,12 @@ void DateCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(3, this->month); buffer.encode_uint32(4, this->day); } +void DateCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_uint32_field(total_size, 1, this->year, false); + ProtoSize::add_uint32_field(total_size, 1, this->month, false); + ProtoSize::add_uint32_field(total_size, 1, this->day, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void DateCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8011,6 +9323,15 @@ void ListEntitiesTimeResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(6, this->disabled_by_default); buffer.encode_enum(7, this->entity_category); } +void ListEntitiesTimeResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesTimeResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8085,6 +9406,13 @@ void TimeStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(4, this->minute); buffer.encode_uint32(5, this->second); } +void TimeStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); + ProtoSize::add_uint32_field(total_size, 1, this->hour, false); + ProtoSize::add_uint32_field(total_size, 1, this->minute, false); + ProtoSize::add_uint32_field(total_size, 1, this->second, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void TimeStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8149,6 +9477,12 @@ void TimeCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_uint32(3, this->minute); buffer.encode_uint32(4, this->second); } +void TimeCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_uint32_field(total_size, 1, this->hour, false); + ProtoSize::add_uint32_field(total_size, 1, this->minute, false); + ProtoSize::add_uint32_field(total_size, 1, this->second, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void TimeCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8242,6 +9576,21 @@ void ListEntitiesEventResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(9, it, true); } } +void ListEntitiesEventResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); + if (!this->event_types.empty()) { + for (const auto &it : this->event_types) { + ProtoSize::add_string_field(total_size, 1, it, true); + } + } +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesEventResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8311,6 +9660,10 @@ void EventResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_string(2, this->event_type); } +void EventResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->event_type, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void EventResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8401,6 +9754,19 @@ void ListEntitiesValveResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(10, this->supports_position); buffer.encode_bool(11, this->supports_stop); } +void ListEntitiesValveResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); + ProtoSize::add_bool_field(total_size, 1, this->assumed_state, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_position, false); + ProtoSize::add_bool_field(total_size, 1, this->supports_stop, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesValveResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8481,6 +9847,11 @@ void ValveStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(2, this->position); buffer.encode_enum(3, this->current_operation); } +void ValveStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->position != 0.0f, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->current_operation), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ValveStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8535,6 +9906,12 @@ void ValveCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_float(3, this->position); buffer.encode_bool(4, this->stop); } +void ValveCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->has_position, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->position != 0.0f, false); + ProtoSize::add_bool_field(total_size, 1, this->stop, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ValveCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8614,6 +9991,15 @@ void ListEntitiesDateTimeResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(6, this->disabled_by_default); buffer.encode_enum(7, this->entity_category); } +void ListEntitiesDateTimeResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesDateTimeResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8678,6 +10064,11 @@ void DateTimeStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_bool(2, this->missing_state); buffer.encode_fixed32(3, this->epoch_seconds); } +void DateTimeStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->epoch_seconds != 0, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void DateTimeStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8716,6 +10107,10 @@ void DateTimeCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_fixed32(2, this->epoch_seconds); } +void DateTimeCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->epoch_seconds != 0, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void DateTimeCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8792,6 +10187,16 @@ void ListEntitiesUpdateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_enum(7, this->entity_category); buffer.encode_string(8, this->device_class); } +void ListEntitiesUpdateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_string_field(total_size, 1, this->object_id, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_string_field(total_size, 1, this->name, false); + ProtoSize::add_string_field(total_size, 1, this->unique_id, false); + ProtoSize::add_string_field(total_size, 1, this->icon, false); + ProtoSize::add_bool_field(total_size, 1, this->disabled_by_default, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->entity_category), false); + ProtoSize::add_string_field(total_size, 1, this->device_class, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void ListEntitiesUpdateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8901,6 +10306,18 @@ void UpdateStateResponse::encode(ProtoWriteBuffer buffer) const { buffer.encode_string(9, this->release_summary); buffer.encode_string(10, this->release_url); } +void UpdateStateResponse::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_bool_field(total_size, 1, this->missing_state, false); + ProtoSize::add_bool_field(total_size, 1, this->in_progress, false); + ProtoSize::add_bool_field(total_size, 1, this->has_progress, false); + ProtoSize::add_fixed_field<4>(total_size, 1, this->progress != 0.0f, false); + ProtoSize::add_string_field(total_size, 1, this->current_version, false); + ProtoSize::add_string_field(total_size, 1, this->latest_version, false); + ProtoSize::add_string_field(total_size, 1, this->title, false); + ProtoSize::add_string_field(total_size, 1, this->release_summary, false); + ProtoSize::add_string_field(total_size, 1, this->release_url, false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void UpdateStateResponse::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; @@ -8973,6 +10390,10 @@ void UpdateCommandRequest::encode(ProtoWriteBuffer buffer) const { buffer.encode_fixed32(1, this->key); buffer.encode_enum(2, this->command); } +void UpdateCommandRequest::calculate_size(uint32_t &total_size) const { + ProtoSize::add_fixed_field<4>(total_size, 1, this->key != 0, false); + ProtoSize::add_enum_field(total_size, 1, static_cast(this->command), false); +} #ifdef HAS_PROTO_MESSAGE_DUMP void UpdateCommandRequest::dump_to(std::string &out) const { __attribute__((unused)) char buffer[64]; diff --git a/esphome/components/api/api_pb2.h b/esphome/components/api/api_pb2.h index 455e3ff6cf..c0927ebdc0 100644 --- a/esphome/components/api/api_pb2.h +++ b/esphome/components/api/api_pb2.h @@ -1,8 +1,9 @@ // This file was automatically generated with a tool. -// See scripts/api_protobuf/api_protobuf.py +// See script/api_protobuf/api_protobuf.py #pragma once #include "proto.h" +#include "api_pb2_size.h" namespace esphome { namespace api { @@ -169,6 +170,18 @@ enum BluetoothDeviceRequestType : uint32_t { BLUETOOTH_DEVICE_REQUEST_TYPE_CONNECT_V3_WITHOUT_CACHE = 5, BLUETOOTH_DEVICE_REQUEST_TYPE_CLEAR_CACHE = 6, }; +enum BluetoothScannerState : uint32_t { + BLUETOOTH_SCANNER_STATE_IDLE = 0, + BLUETOOTH_SCANNER_STATE_STARTING = 1, + BLUETOOTH_SCANNER_STATE_RUNNING = 2, + BLUETOOTH_SCANNER_STATE_FAILED = 3, + BLUETOOTH_SCANNER_STATE_STOPPING = 4, + BLUETOOTH_SCANNER_STATE_STOPPED = 5, +}; +enum BluetoothScannerMode : uint32_t { + BLUETOOTH_SCANNER_MODE_PASSIVE = 0, + BLUETOOTH_SCANNER_MODE_ACTIVE = 1, +}; enum VoiceAssistantSubscribeFlag : uint32_t { VOICE_ASSISTANT_SUBSCRIBE_NONE = 0, VOICE_ASSISTANT_SUBSCRIBE_API_AUDIO = 1, @@ -245,6 +258,7 @@ class HelloRequest : public ProtoMessage { uint32_t api_version_major{0}; uint32_t api_version_minor{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -260,6 +274,7 @@ class HelloResponse : public ProtoMessage { std::string server_info{}; std::string name{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -272,6 +287,7 @@ class ConnectRequest : public ProtoMessage { public: std::string password{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -283,6 +299,7 @@ class ConnectResponse : public ProtoMessage { public: bool invalid_password{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -293,6 +310,7 @@ class ConnectResponse : public ProtoMessage { class DisconnectRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -302,6 +320,7 @@ class DisconnectRequest : public ProtoMessage { class DisconnectResponse : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -311,6 +330,7 @@ class DisconnectResponse : public ProtoMessage { class PingRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -320,6 +340,7 @@ class PingRequest : public ProtoMessage { class PingResponse : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -329,6 +350,7 @@ class PingResponse : public ProtoMessage { class DeviceInfoRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -355,7 +377,9 @@ class DeviceInfoResponse : public ProtoMessage { uint32_t voice_assistant_feature_flags{0}; std::string suggested_area{}; std::string bluetooth_mac_address{}; + bool api_encryption_supported{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -367,6 +391,7 @@ class DeviceInfoResponse : public ProtoMessage { class ListEntitiesRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -376,6 +401,7 @@ class ListEntitiesRequest : public ProtoMessage { class ListEntitiesDoneResponse : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -385,6 +411,7 @@ class ListEntitiesDoneResponse : public ProtoMessage { class SubscribeStatesRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -403,6 +430,7 @@ class ListEntitiesBinarySensorResponse : public ProtoMessage { std::string icon{}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -418,6 +446,7 @@ class BinarySensorStateResponse : public ProtoMessage { bool state{false}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -441,6 +470,7 @@ class ListEntitiesCoverResponse : public ProtoMessage { enums::EntityCategory entity_category{}; bool supports_stop{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -458,6 +488,7 @@ class CoverStateResponse : public ProtoMessage { float tilt{0.0f}; enums::CoverOperation current_operation{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -477,6 +508,7 @@ class CoverCommandRequest : public ProtoMessage { float tilt{0.0f}; bool stop{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -500,6 +532,7 @@ class ListEntitiesFanResponse : public ProtoMessage { enums::EntityCategory entity_category{}; std::vector supported_preset_modes{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -519,6 +552,7 @@ class FanStateResponse : public ProtoMessage { int32_t speed_level{0}; std::string preset_mode{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -544,6 +578,7 @@ class FanCommandRequest : public ProtoMessage { bool has_preset_mode{false}; std::string preset_mode{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -571,6 +606,7 @@ class ListEntitiesLightResponse : public ProtoMessage { std::string icon{}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -596,6 +632,7 @@ class LightStateResponse : public ProtoMessage { float warm_white{0.0f}; std::string effect{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -635,6 +672,7 @@ class LightCommandRequest : public ProtoMessage { bool has_effect{false}; std::string effect{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -660,6 +698,7 @@ class ListEntitiesSensorResponse : public ProtoMessage { bool disabled_by_default{false}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -675,6 +714,7 @@ class SensorStateResponse : public ProtoMessage { float state{0.0f}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -695,6 +735,7 @@ class ListEntitiesSwitchResponse : public ProtoMessage { enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -709,6 +750,7 @@ class SwitchStateResponse : public ProtoMessage { uint32_t key{0}; bool state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -722,6 +764,7 @@ class SwitchCommandRequest : public ProtoMessage { uint32_t key{0}; bool state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -741,6 +784,7 @@ class ListEntitiesTextSensorResponse : public ProtoMessage { enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -756,6 +800,7 @@ class TextSensorStateResponse : public ProtoMessage { std::string state{}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -770,6 +815,7 @@ class SubscribeLogsRequest : public ProtoMessage { enums::LogLevel level{}; bool dump_config{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -783,6 +829,7 @@ class SubscribeLogsResponse : public ProtoMessage { std::string message{}; bool send_failed{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -791,9 +838,34 @@ class SubscribeLogsResponse : public ProtoMessage { bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; bool decode_varint(uint32_t field_id, ProtoVarInt value) override; }; +class NoiseEncryptionSetKeyRequest : public ProtoMessage { + public: + std::string key{}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; +}; +class NoiseEncryptionSetKeyResponse : public ProtoMessage { + public: + bool success{false}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; class SubscribeHomeassistantServicesRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -805,6 +877,7 @@ class HomeassistantServiceMap : public ProtoMessage { std::string key{}; std::string value{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -820,6 +893,7 @@ class HomeassistantServiceResponse : public ProtoMessage { std::vector variables{}; bool is_event{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -831,6 +905,7 @@ class HomeassistantServiceResponse : public ProtoMessage { class SubscribeHomeAssistantStatesRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -843,6 +918,7 @@ class SubscribeHomeAssistantStateResponse : public ProtoMessage { std::string attribute{}; bool once{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -857,6 +933,7 @@ class HomeAssistantStateResponse : public ProtoMessage { std::string state{}; std::string attribute{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -867,6 +944,7 @@ class HomeAssistantStateResponse : public ProtoMessage { class GetTimeRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -877,6 +955,7 @@ class GetTimeResponse : public ProtoMessage { public: uint32_t epoch_seconds{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -889,6 +968,7 @@ class ListEntitiesServicesArgument : public ProtoMessage { std::string name{}; enums::ServiceArgType type{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -903,6 +983,7 @@ class ListEntitiesServicesResponse : public ProtoMessage { uint32_t key{0}; std::vector args{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -923,6 +1004,7 @@ class ExecuteServiceArgument : public ProtoMessage { std::vector float_array{}; std::vector string_array{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -937,6 +1019,7 @@ class ExecuteServiceRequest : public ProtoMessage { uint32_t key{0}; std::vector args{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -955,6 +1038,7 @@ class ListEntitiesCameraResponse : public ProtoMessage { std::string icon{}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -970,6 +1054,7 @@ class CameraImageResponse : public ProtoMessage { std::string data{}; bool done{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -984,6 +1069,7 @@ class CameraImageRequest : public ProtoMessage { bool single{false}; bool stream{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1019,6 +1105,7 @@ class ListEntitiesClimateResponse : public ProtoMessage { float visual_min_humidity{0.0f}; float visual_max_humidity{0.0f}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1046,6 +1133,7 @@ class ClimateStateResponse : public ProtoMessage { float current_humidity{0.0f}; float target_humidity{0.0f}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1081,6 +1169,7 @@ class ClimateCommandRequest : public ProtoMessage { bool has_target_humidity{false}; float target_humidity{0.0f}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1106,6 +1195,7 @@ class ListEntitiesNumberResponse : public ProtoMessage { enums::NumberMode mode{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1121,6 +1211,7 @@ class NumberStateResponse : public ProtoMessage { float state{0.0f}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1134,6 +1225,7 @@ class NumberCommandRequest : public ProtoMessage { uint32_t key{0}; float state{0.0f}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1152,6 +1244,7 @@ class ListEntitiesSelectResponse : public ProtoMessage { bool disabled_by_default{false}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1167,6 +1260,7 @@ class SelectStateResponse : public ProtoMessage { std::string state{}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1181,6 +1275,7 @@ class SelectCommandRequest : public ProtoMessage { uint32_t key{0}; std::string state{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1189,6 +1284,65 @@ class SelectCommandRequest : public ProtoMessage { bool decode_32bit(uint32_t field_id, Proto32Bit value) override; bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; }; +class ListEntitiesSirenResponse : public ProtoMessage { + public: + std::string object_id{}; + uint32_t key{0}; + std::string name{}; + std::string unique_id{}; + std::string icon{}; + bool disabled_by_default{false}; + std::vector tones{}; + bool supports_duration{false}; + bool supports_volume{false}; + enums::EntityCategory entity_category{}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_32bit(uint32_t field_id, Proto32Bit value) override; + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +class SirenStateResponse : public ProtoMessage { + public: + uint32_t key{0}; + bool state{false}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_32bit(uint32_t field_id, Proto32Bit value) override; + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +class SirenCommandRequest : public ProtoMessage { + public: + uint32_t key{0}; + bool has_state{false}; + bool state{false}; + bool has_tone{false}; + std::string tone{}; + bool has_duration{false}; + uint32_t duration{0}; + bool has_volume{false}; + float volume{0.0f}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_32bit(uint32_t field_id, Proto32Bit value) override; + bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override; + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; class ListEntitiesLockResponse : public ProtoMessage { public: std::string object_id{}; @@ -1203,6 +1357,7 @@ class ListEntitiesLockResponse : public ProtoMessage { bool requires_code{false}; std::string code_format{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1217,6 +1372,7 @@ class LockStateResponse : public ProtoMessage { uint32_t key{0}; enums::LockState state{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1232,6 +1388,7 @@ class LockCommandRequest : public ProtoMessage { bool has_code{false}; std::string code{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1252,6 +1409,7 @@ class ListEntitiesButtonResponse : public ProtoMessage { enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1265,6 +1423,7 @@ class ButtonCommandRequest : public ProtoMessage { public: uint32_t key{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1280,6 +1439,7 @@ class MediaPlayerSupportedFormat : public ProtoMessage { enums::MediaPlayerFormatPurpose purpose{}; uint32_t sample_bytes{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1300,6 +1460,7 @@ class ListEntitiesMediaPlayerResponse : public ProtoMessage { bool supports_pause{false}; std::vector supported_formats{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1316,6 +1477,7 @@ class MediaPlayerStateResponse : public ProtoMessage { float volume{0.0f}; bool muted{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1336,6 +1498,7 @@ class MediaPlayerCommandRequest : public ProtoMessage { bool has_announcement{false}; bool announcement{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1349,6 +1512,7 @@ class SubscribeBluetoothLEAdvertisementsRequest : public ProtoMessage { public: uint32_t flags{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1362,6 +1526,7 @@ class BluetoothServiceData : public ProtoMessage { std::vector legacy_data{}; std::string data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1380,6 +1545,7 @@ class BluetoothLEAdvertisementResponse : public ProtoMessage { std::vector manufacturer_data{}; uint32_t address_type{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1395,6 +1561,7 @@ class BluetoothLERawAdvertisement : public ProtoMessage { uint32_t address_type{0}; std::string data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1407,6 +1574,7 @@ class BluetoothLERawAdvertisementsResponse : public ProtoMessage { public: std::vector advertisements{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1421,6 +1589,7 @@ class BluetoothDeviceRequest : public ProtoMessage { bool has_address_type{false}; uint32_t address_type{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1435,6 +1604,7 @@ class BluetoothDeviceConnectionResponse : public ProtoMessage { uint32_t mtu{0}; int32_t error{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1446,6 +1616,7 @@ class BluetoothGATTGetServicesRequest : public ProtoMessage { public: uint64_t address{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1458,6 +1629,7 @@ class BluetoothGATTDescriptor : public ProtoMessage { std::vector uuid{}; uint32_t handle{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1472,6 +1644,7 @@ class BluetoothGATTCharacteristic : public ProtoMessage { uint32_t properties{0}; std::vector descriptors{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1486,6 +1659,7 @@ class BluetoothGATTService : public ProtoMessage { uint32_t handle{0}; std::vector characteristics{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1499,6 +1673,7 @@ class BluetoothGATTGetServicesResponse : public ProtoMessage { uint64_t address{0}; std::vector services{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1511,6 +1686,7 @@ class BluetoothGATTGetServicesDoneResponse : public ProtoMessage { public: uint64_t address{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1523,6 +1699,7 @@ class BluetoothGATTReadRequest : public ProtoMessage { uint64_t address{0}; uint32_t handle{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1536,6 +1713,7 @@ class BluetoothGATTReadResponse : public ProtoMessage { uint32_t handle{0}; std::string data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1551,6 +1729,7 @@ class BluetoothGATTWriteRequest : public ProtoMessage { bool response{false}; std::string data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1564,6 +1743,7 @@ class BluetoothGATTReadDescriptorRequest : public ProtoMessage { uint64_t address{0}; uint32_t handle{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1577,6 +1757,7 @@ class BluetoothGATTWriteDescriptorRequest : public ProtoMessage { uint32_t handle{0}; std::string data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1591,6 +1772,7 @@ class BluetoothGATTNotifyRequest : public ProtoMessage { uint32_t handle{0}; bool enable{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1604,6 +1786,7 @@ class BluetoothGATTNotifyDataResponse : public ProtoMessage { uint32_t handle{0}; std::string data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1615,6 +1798,7 @@ class BluetoothGATTNotifyDataResponse : public ProtoMessage { class SubscribeBluetoothConnectionsFreeRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1627,6 +1811,7 @@ class BluetoothConnectionsFreeResponse : public ProtoMessage { uint32_t limit{0}; std::vector allocated{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1640,6 +1825,7 @@ class BluetoothGATTErrorResponse : public ProtoMessage { uint32_t handle{0}; int32_t error{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1652,6 +1838,7 @@ class BluetoothGATTWriteResponse : public ProtoMessage { uint64_t address{0}; uint32_t handle{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1664,6 +1851,7 @@ class BluetoothGATTNotifyResponse : public ProtoMessage { uint64_t address{0}; uint32_t handle{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1677,6 +1865,7 @@ class BluetoothDevicePairingResponse : public ProtoMessage { bool paired{false}; int32_t error{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1690,6 +1879,7 @@ class BluetoothDeviceUnpairingResponse : public ProtoMessage { bool success{false}; int32_t error{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1700,6 +1890,7 @@ class BluetoothDeviceUnpairingResponse : public ProtoMessage { class UnsubscribeBluetoothLEAdvertisementsRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1712,6 +1903,32 @@ class BluetoothDeviceClearCacheResponse : public ProtoMessage { bool success{false}; int32_t error{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +class BluetoothScannerStateResponse : public ProtoMessage { + public: + enums::BluetoothScannerState state{}; + enums::BluetoothScannerMode mode{}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; +#ifdef HAS_PROTO_MESSAGE_DUMP + void dump_to(std::string &out) const override; +#endif + + protected: + bool decode_varint(uint32_t field_id, ProtoVarInt value) override; +}; +class BluetoothScannerSetModeRequest : public ProtoMessage { + public: + enums::BluetoothScannerMode mode{}; + void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1724,6 +1941,7 @@ class SubscribeVoiceAssistantRequest : public ProtoMessage { bool subscribe{false}; uint32_t flags{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1737,6 +1955,7 @@ class VoiceAssistantAudioSettings : public ProtoMessage { uint32_t auto_gain{0}; float volume_multiplier{0.0f}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1753,6 +1972,7 @@ class VoiceAssistantRequest : public ProtoMessage { VoiceAssistantAudioSettings audio_settings{}; std::string wake_word_phrase{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1766,6 +1986,7 @@ class VoiceAssistantResponse : public ProtoMessage { uint32_t port{0}; bool error{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1778,6 +1999,7 @@ class VoiceAssistantEventData : public ProtoMessage { std::string name{}; std::string value{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1790,6 +2012,7 @@ class VoiceAssistantEventResponse : public ProtoMessage { enums::VoiceAssistantEvent event_type{}; std::vector data{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1803,6 +2026,7 @@ class VoiceAssistantAudio : public ProtoMessage { std::string data{}; bool end{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1820,6 +2044,7 @@ class VoiceAssistantTimerEventResponse : public ProtoMessage { uint32_t seconds_left{0}; bool is_active{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1835,6 +2060,7 @@ class VoiceAssistantAnnounceRequest : public ProtoMessage { std::string preannounce_media_id{}; bool start_conversation{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1847,6 +2073,7 @@ class VoiceAssistantAnnounceFinished : public ProtoMessage { public: bool success{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1860,6 +2087,7 @@ class VoiceAssistantWakeWord : public ProtoMessage { std::string wake_word{}; std::vector trained_languages{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1870,6 +2098,7 @@ class VoiceAssistantWakeWord : public ProtoMessage { class VoiceAssistantConfigurationRequest : public ProtoMessage { public: void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1882,6 +2111,7 @@ class VoiceAssistantConfigurationResponse : public ProtoMessage { std::vector active_wake_words{}; uint32_t max_active_wake_words{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1894,6 +2124,7 @@ class VoiceAssistantSetConfiguration : public ProtoMessage { public: std::vector active_wake_words{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1914,6 +2145,7 @@ class ListEntitiesAlarmControlPanelResponse : public ProtoMessage { bool requires_code{false}; bool requires_code_to_arm{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1928,6 +2160,7 @@ class AlarmControlPanelStateResponse : public ProtoMessage { uint32_t key{0}; enums::AlarmControlPanelState state{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1942,6 +2175,7 @@ class AlarmControlPanelCommandRequest : public ProtoMessage { enums::AlarmControlPanelStateCommand command{}; std::string code{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1965,6 +2199,7 @@ class ListEntitiesTextResponse : public ProtoMessage { std::string pattern{}; enums::TextMode mode{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1980,6 +2215,7 @@ class TextStateResponse : public ProtoMessage { std::string state{}; bool missing_state{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -1994,6 +2230,7 @@ class TextCommandRequest : public ProtoMessage { uint32_t key{0}; std::string state{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2012,6 +2249,7 @@ class ListEntitiesDateResponse : public ProtoMessage { bool disabled_by_default{false}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2029,6 +2267,7 @@ class DateStateResponse : public ProtoMessage { uint32_t month{0}; uint32_t day{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2044,6 +2283,7 @@ class DateCommandRequest : public ProtoMessage { uint32_t month{0}; uint32_t day{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2062,6 +2302,7 @@ class ListEntitiesTimeResponse : public ProtoMessage { bool disabled_by_default{false}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2079,6 +2320,7 @@ class TimeStateResponse : public ProtoMessage { uint32_t minute{0}; uint32_t second{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2094,6 +2336,7 @@ class TimeCommandRequest : public ProtoMessage { uint32_t minute{0}; uint32_t second{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2114,6 +2357,7 @@ class ListEntitiesEventResponse : public ProtoMessage { std::string device_class{}; std::vector event_types{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2128,6 +2372,7 @@ class EventResponse : public ProtoMessage { uint32_t key{0}; std::string event_type{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2150,6 +2395,7 @@ class ListEntitiesValveResponse : public ProtoMessage { bool supports_position{false}; bool supports_stop{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2165,6 +2411,7 @@ class ValveStateResponse : public ProtoMessage { float position{0.0f}; enums::ValveOperation current_operation{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2180,6 +2427,7 @@ class ValveCommandRequest : public ProtoMessage { float position{0.0f}; bool stop{false}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2198,6 +2446,7 @@ class ListEntitiesDateTimeResponse : public ProtoMessage { bool disabled_by_default{false}; enums::EntityCategory entity_category{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2213,6 +2462,7 @@ class DateTimeStateResponse : public ProtoMessage { bool missing_state{false}; uint32_t epoch_seconds{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2226,6 +2476,7 @@ class DateTimeCommandRequest : public ProtoMessage { uint32_t key{0}; uint32_t epoch_seconds{0}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2244,6 +2495,7 @@ class ListEntitiesUpdateResponse : public ProtoMessage { enums::EntityCategory entity_category{}; std::string device_class{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2266,6 +2518,7 @@ class UpdateStateResponse : public ProtoMessage { std::string release_summary{}; std::string release_url{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif @@ -2280,6 +2533,7 @@ class UpdateCommandRequest : public ProtoMessage { uint32_t key{0}; enums::UpdateCommand command{}; void encode(ProtoWriteBuffer buffer) const override; + void calculate_size(uint32_t &total_size) const override; #ifdef HAS_PROTO_MESSAGE_DUMP void dump_to(std::string &out) const override; #endif diff --git a/esphome/components/api/api_pb2_service.cpp b/esphome/components/api/api_pb2_service.cpp index 6e11d7169d..5a701aeafa 100644 --- a/esphome/components/api/api_pb2_service.cpp +++ b/esphome/components/api/api_pb2_service.cpp @@ -1,5 +1,5 @@ // This file was automatically generated with a tool. -// See scripts/api_protobuf/api_protobuf.py +// See script/api_protobuf/api_protobuf.py #include "api_pb2_service.h" #include "esphome/core/log.h" @@ -179,6 +179,16 @@ bool APIServerConnectionBase::send_text_sensor_state_response(const TextSensorSt bool APIServerConnectionBase::send_subscribe_logs_response(const SubscribeLogsResponse &msg) { return this->send_message_(msg, 29); } +#ifdef USE_API_NOISE +#endif +#ifdef USE_API_NOISE +bool APIServerConnectionBase::send_noise_encryption_set_key_response(const NoiseEncryptionSetKeyResponse &msg) { +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "send_noise_encryption_set_key_response: %s", msg.dump().c_str()); +#endif + return this->send_message_(msg, 125); +} +#endif bool APIServerConnectionBase::send_homeassistant_service_response(const HomeassistantServiceResponse &msg) { #ifdef HAS_PROTO_MESSAGE_DUMP ESP_LOGVV(TAG, "send_homeassistant_service_response: %s", msg.dump().c_str()); @@ -282,6 +292,24 @@ bool APIServerConnectionBase::send_select_state_response(const SelectStateRespon #endif #ifdef USE_SELECT #endif +#ifdef USE_SIREN +bool APIServerConnectionBase::send_list_entities_siren_response(const ListEntitiesSirenResponse &msg) { +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "send_list_entities_siren_response: %s", msg.dump().c_str()); +#endif + return this->send_message_(msg, 55); +} +#endif +#ifdef USE_SIREN +bool APIServerConnectionBase::send_siren_state_response(const SirenStateResponse &msg) { +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "send_siren_state_response: %s", msg.dump().c_str()); +#endif + return this->send_message_(msg, 56); +} +#endif +#ifdef USE_SIREN +#endif #ifdef USE_LOCK bool APIServerConnectionBase::send_list_entities_lock_response(const ListEntitiesLockResponse &msg) { #ifdef HAS_PROTO_MESSAGE_DUMP @@ -462,6 +490,16 @@ bool APIServerConnectionBase::send_bluetooth_device_clear_cache_response(const B return this->send_message_(msg, 88); } #endif +#ifdef USE_BLUETOOTH_PROXY +bool APIServerConnectionBase::send_bluetooth_scanner_state_response(const BluetoothScannerStateResponse &msg) { +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "send_bluetooth_scanner_state_response: %s", msg.dump().c_str()); +#endif + return this->send_message_(msg, 126); +} +#endif +#ifdef USE_BLUETOOTH_PROXY +#endif #ifdef USE_VOICE_ASSISTANT #endif #ifdef USE_VOICE_ASSISTANT @@ -883,6 +921,17 @@ bool APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, ESP_LOGVV(TAG, "on_select_command_request: %s", msg.dump().c_str()); #endif this->on_select_command_request(msg); +#endif + break; + } + case 57: { +#ifdef USE_SIREN + SirenCommandRequest msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_siren_command_request: %s", msg.dump().c_str()); +#endif + this->on_siren_command_request(msg); #endif break; } @@ -1191,6 +1240,28 @@ bool APIServerConnectionBase::read_message(uint32_t msg_size, uint32_t msg_type, ESP_LOGVV(TAG, "on_voice_assistant_set_configuration: %s", msg.dump().c_str()); #endif this->on_voice_assistant_set_configuration(msg); +#endif + break; + } + case 124: { +#ifdef USE_API_NOISE + NoiseEncryptionSetKeyRequest msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_noise_encryption_set_key_request: %s", msg.dump().c_str()); +#endif + this->on_noise_encryption_set_key_request(msg); +#endif + break; + } + case 127: { +#ifdef USE_BLUETOOTH_PROXY + BluetoothScannerSetModeRequest msg; + msg.decode(msg_data, msg_size); +#ifdef HAS_PROTO_MESSAGE_DUMP + ESP_LOGVV(TAG, "on_bluetooth_scanner_set_mode_request: %s", msg.dump().c_str()); +#endif + this->on_bluetooth_scanner_set_mode_request(msg); #endif break; } @@ -1311,8 +1382,8 @@ void APIServerConnection::on_execute_service_request(const ExecuteServiceRequest } this->execute_service(msg); } -#ifdef USE_COVER -void APIServerConnection::on_cover_command_request(const CoverCommandRequest &msg) { +#ifdef USE_API_NOISE +void APIServerConnection::on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &msg) { if (!this->is_connection_setup()) { this->on_no_setup_connection(); return; @@ -1321,11 +1392,14 @@ void APIServerConnection::on_cover_command_request(const CoverCommandRequest &ms this->on_unauthenticated_access(); return; } - this->cover_command(msg); + NoiseEncryptionSetKeyResponse ret = this->noise_encryption_set_key(msg); + if (!this->send_noise_encryption_set_key_response(ret)) { + this->on_fatal_error(); + } } #endif -#ifdef USE_FAN -void APIServerConnection::on_fan_command_request(const FanCommandRequest &msg) { +#ifdef USE_BUTTON +void APIServerConnection::on_button_command_request(const ButtonCommandRequest &msg) { if (!this->is_connection_setup()) { this->on_no_setup_connection(); return; @@ -1334,33 +1408,7 @@ void APIServerConnection::on_fan_command_request(const FanCommandRequest &msg) { this->on_unauthenticated_access(); return; } - this->fan_command(msg); -} -#endif -#ifdef USE_LIGHT -void APIServerConnection::on_light_command_request(const LightCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->light_command(msg); -} -#endif -#ifdef USE_SWITCH -void APIServerConnection::on_switch_command_request(const SwitchCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->switch_command(msg); + this->button_command(msg); } #endif #ifdef USE_ESP32_CAMERA @@ -1389,8 +1437,8 @@ void APIServerConnection::on_climate_command_request(const ClimateCommandRequest this->climate_command(msg); } #endif -#ifdef USE_NUMBER -void APIServerConnection::on_number_command_request(const NumberCommandRequest &msg) { +#ifdef USE_COVER +void APIServerConnection::on_cover_command_request(const CoverCommandRequest &msg) { if (!this->is_connection_setup()) { this->on_no_setup_connection(); return; @@ -1399,85 +1447,7 @@ void APIServerConnection::on_number_command_request(const NumberCommandRequest & this->on_unauthenticated_access(); return; } - this->number_command(msg); -} -#endif -#ifdef USE_TEXT -void APIServerConnection::on_text_command_request(const TextCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->text_command(msg); -} -#endif -#ifdef USE_SELECT -void APIServerConnection::on_select_command_request(const SelectCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->select_command(msg); -} -#endif -#ifdef USE_BUTTON -void APIServerConnection::on_button_command_request(const ButtonCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->button_command(msg); -} -#endif -#ifdef USE_LOCK -void APIServerConnection::on_lock_command_request(const LockCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->lock_command(msg); -} -#endif -#ifdef USE_VALVE -void APIServerConnection::on_valve_command_request(const ValveCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->valve_command(msg); -} -#endif -#ifdef USE_MEDIA_PLAYER -void APIServerConnection::on_media_player_command_request(const MediaPlayerCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->media_player_command(msg); + this->cover_command(msg); } #endif #ifdef USE_DATETIME_DATE @@ -1493,19 +1463,6 @@ void APIServerConnection::on_date_command_request(const DateCommandRequest &msg) this->date_command(msg); } #endif -#ifdef USE_DATETIME_TIME -void APIServerConnection::on_time_command_request(const TimeCommandRequest &msg) { - if (!this->is_connection_setup()) { - this->on_no_setup_connection(); - return; - } - if (!this->is_authenticated()) { - this->on_unauthenticated_access(); - return; - } - this->time_command(msg); -} -#endif #ifdef USE_DATETIME_DATETIME void APIServerConnection::on_date_time_command_request(const DateTimeCommandRequest &msg) { if (!this->is_connection_setup()) { @@ -1519,6 +1476,136 @@ void APIServerConnection::on_date_time_command_request(const DateTimeCommandRequ this->datetime_command(msg); } #endif +#ifdef USE_FAN +void APIServerConnection::on_fan_command_request(const FanCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->fan_command(msg); +} +#endif +#ifdef USE_LIGHT +void APIServerConnection::on_light_command_request(const LightCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->light_command(msg); +} +#endif +#ifdef USE_LOCK +void APIServerConnection::on_lock_command_request(const LockCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->lock_command(msg); +} +#endif +#ifdef USE_MEDIA_PLAYER +void APIServerConnection::on_media_player_command_request(const MediaPlayerCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->media_player_command(msg); +} +#endif +#ifdef USE_NUMBER +void APIServerConnection::on_number_command_request(const NumberCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->number_command(msg); +} +#endif +#ifdef USE_SELECT +void APIServerConnection::on_select_command_request(const SelectCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->select_command(msg); +} +#endif +#ifdef USE_SIREN +void APIServerConnection::on_siren_command_request(const SirenCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->siren_command(msg); +} +#endif +#ifdef USE_SWITCH +void APIServerConnection::on_switch_command_request(const SwitchCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->switch_command(msg); +} +#endif +#ifdef USE_TEXT +void APIServerConnection::on_text_command_request(const TextCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->text_command(msg); +} +#endif +#ifdef USE_DATETIME_TIME +void APIServerConnection::on_time_command_request(const TimeCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->time_command(msg); +} +#endif #ifdef USE_UPDATE void APIServerConnection::on_update_command_request(const UpdateCommandRequest &msg) { if (!this->is_connection_setup()) { @@ -1532,6 +1619,19 @@ void APIServerConnection::on_update_command_request(const UpdateCommandRequest & this->update_command(msg); } #endif +#ifdef USE_VALVE +void APIServerConnection::on_valve_command_request(const ValveCommandRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->valve_command(msg); +} +#endif #ifdef USE_BLUETOOTH_PROXY void APIServerConnection::on_subscribe_bluetooth_le_advertisements_request( const SubscribeBluetoothLEAdvertisementsRequest &msg) { @@ -1668,6 +1768,19 @@ void APIServerConnection::on_unsubscribe_bluetooth_le_advertisements_request( this->unsubscribe_bluetooth_le_advertisements(msg); } #endif +#ifdef USE_BLUETOOTH_PROXY +void APIServerConnection::on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &msg) { + if (!this->is_connection_setup()) { + this->on_no_setup_connection(); + return; + } + if (!this->is_authenticated()) { + this->on_unauthenticated_access(); + return; + } + this->bluetooth_scanner_set_mode(msg); +} +#endif #ifdef USE_VOICE_ASSISTANT void APIServerConnection::on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &msg) { if (!this->is_connection_setup()) { diff --git a/esphome/components/api/api_pb2_service.h b/esphome/components/api/api_pb2_service.h index 51b94bf530..8ee5c0fcf1 100644 --- a/esphome/components/api/api_pb2_service.h +++ b/esphome/components/api/api_pb2_service.h @@ -1,5 +1,5 @@ // This file was automatically generated with a tool. -// See scripts/api_protobuf/api_protobuf.py +// See script/api_protobuf/api_protobuf.py #pragma once #include "api_pb2.h" @@ -83,6 +83,12 @@ class APIServerConnectionBase : public ProtoService { #endif virtual void on_subscribe_logs_request(const SubscribeLogsRequest &value){}; bool send_subscribe_logs_response(const SubscribeLogsResponse &msg); +#ifdef USE_API_NOISE + virtual void on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &value){}; +#endif +#ifdef USE_API_NOISE + bool send_noise_encryption_set_key_response(const NoiseEncryptionSetKeyResponse &msg); +#endif virtual void on_subscribe_homeassistant_services_request(const SubscribeHomeassistantServicesRequest &value){}; bool send_homeassistant_service_response(const HomeassistantServiceResponse &msg); virtual void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &value){}; @@ -130,6 +136,15 @@ class APIServerConnectionBase : public ProtoService { #ifdef USE_SELECT virtual void on_select_command_request(const SelectCommandRequest &value){}; #endif +#ifdef USE_SIREN + bool send_list_entities_siren_response(const ListEntitiesSirenResponse &msg); +#endif +#ifdef USE_SIREN + bool send_siren_state_response(const SirenStateResponse &msg); +#endif +#ifdef USE_SIREN + virtual void on_siren_command_request(const SirenCommandRequest &value){}; +#endif #ifdef USE_LOCK bool send_list_entities_lock_response(const ListEntitiesLockResponse &msg); #endif @@ -228,6 +243,12 @@ class APIServerConnectionBase : public ProtoService { #ifdef USE_BLUETOOTH_PROXY bool send_bluetooth_device_clear_cache_response(const BluetoothDeviceClearCacheResponse &msg); #endif +#ifdef USE_BLUETOOTH_PROXY + bool send_bluetooth_scanner_state_response(const BluetoothScannerStateResponse &msg); +#endif +#ifdef USE_BLUETOOTH_PROXY + virtual void on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &value){}; +#endif #ifdef USE_VOICE_ASSISTANT virtual void on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &value){}; #endif @@ -349,17 +370,11 @@ class APIServerConnection : public APIServerConnectionBase { virtual void subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) = 0; virtual GetTimeResponse get_time(const GetTimeRequest &msg) = 0; virtual void execute_service(const ExecuteServiceRequest &msg) = 0; -#ifdef USE_COVER - virtual void cover_command(const CoverCommandRequest &msg) = 0; +#ifdef USE_API_NOISE + virtual NoiseEncryptionSetKeyResponse noise_encryption_set_key(const NoiseEncryptionSetKeyRequest &msg) = 0; #endif -#ifdef USE_FAN - virtual void fan_command(const FanCommandRequest &msg) = 0; -#endif -#ifdef USE_LIGHT - virtual void light_command(const LightCommandRequest &msg) = 0; -#endif -#ifdef USE_SWITCH - virtual void switch_command(const SwitchCommandRequest &msg) = 0; +#ifdef USE_BUTTON + virtual void button_command(const ButtonCommandRequest &msg) = 0; #endif #ifdef USE_ESP32_CAMERA virtual void camera_image(const CameraImageRequest &msg) = 0; @@ -367,39 +382,51 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_CLIMATE virtual void climate_command(const ClimateCommandRequest &msg) = 0; #endif -#ifdef USE_NUMBER - virtual void number_command(const NumberCommandRequest &msg) = 0; -#endif -#ifdef USE_TEXT - virtual void text_command(const TextCommandRequest &msg) = 0; -#endif -#ifdef USE_SELECT - virtual void select_command(const SelectCommandRequest &msg) = 0; -#endif -#ifdef USE_BUTTON - virtual void button_command(const ButtonCommandRequest &msg) = 0; -#endif -#ifdef USE_LOCK - virtual void lock_command(const LockCommandRequest &msg) = 0; -#endif -#ifdef USE_VALVE - virtual void valve_command(const ValveCommandRequest &msg) = 0; -#endif -#ifdef USE_MEDIA_PLAYER - virtual void media_player_command(const MediaPlayerCommandRequest &msg) = 0; +#ifdef USE_COVER + virtual void cover_command(const CoverCommandRequest &msg) = 0; #endif #ifdef USE_DATETIME_DATE virtual void date_command(const DateCommandRequest &msg) = 0; #endif -#ifdef USE_DATETIME_TIME - virtual void time_command(const TimeCommandRequest &msg) = 0; -#endif #ifdef USE_DATETIME_DATETIME virtual void datetime_command(const DateTimeCommandRequest &msg) = 0; #endif +#ifdef USE_FAN + virtual void fan_command(const FanCommandRequest &msg) = 0; +#endif +#ifdef USE_LIGHT + virtual void light_command(const LightCommandRequest &msg) = 0; +#endif +#ifdef USE_LOCK + virtual void lock_command(const LockCommandRequest &msg) = 0; +#endif +#ifdef USE_MEDIA_PLAYER + virtual void media_player_command(const MediaPlayerCommandRequest &msg) = 0; +#endif +#ifdef USE_NUMBER + virtual void number_command(const NumberCommandRequest &msg) = 0; +#endif +#ifdef USE_SELECT + virtual void select_command(const SelectCommandRequest &msg) = 0; +#endif +#ifdef USE_SIREN + virtual void siren_command(const SirenCommandRequest &msg) = 0; +#endif +#ifdef USE_SWITCH + virtual void switch_command(const SwitchCommandRequest &msg) = 0; +#endif +#ifdef USE_TEXT + virtual void text_command(const TextCommandRequest &msg) = 0; +#endif +#ifdef USE_DATETIME_TIME + virtual void time_command(const TimeCommandRequest &msg) = 0; +#endif #ifdef USE_UPDATE virtual void update_command(const UpdateCommandRequest &msg) = 0; #endif +#ifdef USE_VALVE + virtual void valve_command(const ValveCommandRequest &msg) = 0; +#endif #ifdef USE_BLUETOOTH_PROXY virtual void subscribe_bluetooth_le_advertisements(const SubscribeBluetoothLEAdvertisementsRequest &msg) = 0; #endif @@ -431,6 +458,9 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_BLUETOOTH_PROXY virtual void unsubscribe_bluetooth_le_advertisements(const UnsubscribeBluetoothLEAdvertisementsRequest &msg) = 0; #endif +#ifdef USE_BLUETOOTH_PROXY + virtual void bluetooth_scanner_set_mode(const BluetoothScannerSetModeRequest &msg) = 0; +#endif #ifdef USE_VOICE_ASSISTANT virtual void subscribe_voice_assistant(const SubscribeVoiceAssistantRequest &msg) = 0; #endif @@ -457,17 +487,11 @@ class APIServerConnection : public APIServerConnectionBase { void on_subscribe_home_assistant_states_request(const SubscribeHomeAssistantStatesRequest &msg) override; void on_get_time_request(const GetTimeRequest &msg) override; void on_execute_service_request(const ExecuteServiceRequest &msg) override; -#ifdef USE_COVER - void on_cover_command_request(const CoverCommandRequest &msg) override; +#ifdef USE_API_NOISE + void on_noise_encryption_set_key_request(const NoiseEncryptionSetKeyRequest &msg) override; #endif -#ifdef USE_FAN - void on_fan_command_request(const FanCommandRequest &msg) override; -#endif -#ifdef USE_LIGHT - void on_light_command_request(const LightCommandRequest &msg) override; -#endif -#ifdef USE_SWITCH - void on_switch_command_request(const SwitchCommandRequest &msg) override; +#ifdef USE_BUTTON + void on_button_command_request(const ButtonCommandRequest &msg) override; #endif #ifdef USE_ESP32_CAMERA void on_camera_image_request(const CameraImageRequest &msg) override; @@ -475,39 +499,51 @@ class APIServerConnection : public APIServerConnectionBase { #ifdef USE_CLIMATE void on_climate_command_request(const ClimateCommandRequest &msg) override; #endif -#ifdef USE_NUMBER - void on_number_command_request(const NumberCommandRequest &msg) override; -#endif -#ifdef USE_TEXT - void on_text_command_request(const TextCommandRequest &msg) override; -#endif -#ifdef USE_SELECT - void on_select_command_request(const SelectCommandRequest &msg) override; -#endif -#ifdef USE_BUTTON - void on_button_command_request(const ButtonCommandRequest &msg) override; -#endif -#ifdef USE_LOCK - void on_lock_command_request(const LockCommandRequest &msg) override; -#endif -#ifdef USE_VALVE - void on_valve_command_request(const ValveCommandRequest &msg) override; -#endif -#ifdef USE_MEDIA_PLAYER - void on_media_player_command_request(const MediaPlayerCommandRequest &msg) override; +#ifdef USE_COVER + void on_cover_command_request(const CoverCommandRequest &msg) override; #endif #ifdef USE_DATETIME_DATE void on_date_command_request(const DateCommandRequest &msg) override; #endif -#ifdef USE_DATETIME_TIME - void on_time_command_request(const TimeCommandRequest &msg) override; -#endif #ifdef USE_DATETIME_DATETIME void on_date_time_command_request(const DateTimeCommandRequest &msg) override; #endif +#ifdef USE_FAN + void on_fan_command_request(const FanCommandRequest &msg) override; +#endif +#ifdef USE_LIGHT + void on_light_command_request(const LightCommandRequest &msg) override; +#endif +#ifdef USE_LOCK + void on_lock_command_request(const LockCommandRequest &msg) override; +#endif +#ifdef USE_MEDIA_PLAYER + void on_media_player_command_request(const MediaPlayerCommandRequest &msg) override; +#endif +#ifdef USE_NUMBER + void on_number_command_request(const NumberCommandRequest &msg) override; +#endif +#ifdef USE_SELECT + void on_select_command_request(const SelectCommandRequest &msg) override; +#endif +#ifdef USE_SIREN + void on_siren_command_request(const SirenCommandRequest &msg) override; +#endif +#ifdef USE_SWITCH + void on_switch_command_request(const SwitchCommandRequest &msg) override; +#endif +#ifdef USE_TEXT + void on_text_command_request(const TextCommandRequest &msg) override; +#endif +#ifdef USE_DATETIME_TIME + void on_time_command_request(const TimeCommandRequest &msg) override; +#endif #ifdef USE_UPDATE void on_update_command_request(const UpdateCommandRequest &msg) override; #endif +#ifdef USE_VALVE + void on_valve_command_request(const ValveCommandRequest &msg) override; +#endif #ifdef USE_BLUETOOTH_PROXY void on_subscribe_bluetooth_le_advertisements_request(const SubscribeBluetoothLEAdvertisementsRequest &msg) override; #endif @@ -539,6 +575,9 @@ class APIServerConnection : public APIServerConnectionBase { void on_unsubscribe_bluetooth_le_advertisements_request( const UnsubscribeBluetoothLEAdvertisementsRequest &msg) override; #endif +#ifdef USE_BLUETOOTH_PROXY + void on_bluetooth_scanner_set_mode_request(const BluetoothScannerSetModeRequest &msg) override; +#endif #ifdef USE_VOICE_ASSISTANT void on_subscribe_voice_assistant_request(const SubscribeVoiceAssistantRequest &msg) override; #endif diff --git a/esphome/components/api/api_pb2_size.h b/esphome/components/api/api_pb2_size.h new file mode 100644 index 0000000000..e591a7350f --- /dev/null +++ b/esphome/components/api/api_pb2_size.h @@ -0,0 +1,361 @@ +#pragma once + +#include "proto.h" +#include +#include + +namespace esphome { +namespace api { + +class ProtoSize { + public: + /** + * @brief ProtoSize class for Protocol Buffer serialization size calculation + * + * This class provides static methods to calculate the exact byte counts needed + * for encoding various Protocol Buffer field types. All methods are designed to be + * efficient for the common case where many fields have default values. + * + * Implements Protocol Buffer encoding size calculation according to: + * https://protobuf.dev/programming-guides/encoding/ + * + * Key features: + * - Early-return optimization for zero/default values + * - Direct total_size updates to avoid unnecessary additions + * - Specialized handling for different field types according to protobuf spec + * - Templated helpers for repeated fields and messages + */ + + /** + * @brief Calculates the size in bytes needed to encode a uint32_t value as a varint + * + * @param value The uint32_t value to calculate size for + * @return The number of bytes needed to encode the value + */ + static inline uint32_t varint(uint32_t value) { + // Optimized varint size calculation using leading zeros + // Each 7 bits requires one byte in the varint encoding + if (value < 128) + return 1; // 7 bits, common case for small values + + // For larger values, count bytes needed based on the position of the highest bit set + if (value < 16384) { + return 2; // 14 bits + } else if (value < 2097152) { + return 3; // 21 bits + } else if (value < 268435456) { + return 4; // 28 bits + } else { + return 5; // 32 bits (maximum for uint32_t) + } + } + + /** + * @brief Calculates the size in bytes needed to encode a uint64_t value as a varint + * + * @param value The uint64_t value to calculate size for + * @return The number of bytes needed to encode the value + */ + static inline uint32_t varint(uint64_t value) { + // Handle common case of values fitting in uint32_t (vast majority of use cases) + if (value <= UINT32_MAX) { + return varint(static_cast(value)); + } + + // For larger values, determine size based on highest bit position + if (value < (1ULL << 35)) { + return 5; // 35 bits + } else if (value < (1ULL << 42)) { + return 6; // 42 bits + } else if (value < (1ULL << 49)) { + return 7; // 49 bits + } else if (value < (1ULL << 56)) { + return 8; // 56 bits + } else if (value < (1ULL << 63)) { + return 9; // 63 bits + } else { + return 10; // 64 bits (maximum for uint64_t) + } + } + + /** + * @brief Calculates the size in bytes needed to encode an int32_t value as a varint + * + * Special handling is needed for negative values, which are sign-extended to 64 bits + * in Protocol Buffers, resulting in a 10-byte varint. + * + * @param value The int32_t value to calculate size for + * @return The number of bytes needed to encode the value + */ + static inline uint32_t varint(int32_t value) { + // Negative values are sign-extended to 64 bits in protocol buffers, + // which always results in a 10-byte varint for negative int32 + if (value < 0) { + return 10; // Negative int32 is always 10 bytes long + } + // For non-negative values, use the uint32_t implementation + return varint(static_cast(value)); + } + + /** + * @brief Calculates the size in bytes needed to encode an int64_t value as a varint + * + * @param value The int64_t value to calculate size for + * @return The number of bytes needed to encode the value + */ + static inline uint32_t varint(int64_t value) { + // For int64_t, we convert to uint64_t and calculate the size + // This works because the bit pattern determines the encoding size, + // and we've handled negative int32 values as a special case above + return varint(static_cast(value)); + } + + /** + * @brief Calculates the size in bytes needed to encode a field ID and wire type + * + * @param field_id The field identifier + * @param type The wire type value (from the WireType enum in the protobuf spec) + * @return The number of bytes needed to encode the field ID and wire type + */ + static inline uint32_t field(uint32_t field_id, uint32_t type) { + uint32_t tag = (field_id << 3) | (type & 0b111); + return varint(tag); + } + + /** + * @brief Common parameters for all add_*_field methods + * + * All add_*_field methods follow these common patterns: + * + * @param total_size Reference to the total message size to update + * @param field_id_size Pre-calculated size of the field ID in bytes + * @param value The value to calculate size for (type varies) + * @param force Whether to calculate size even if the value is default/zero/empty + * + * Each method follows this implementation pattern: + * 1. Skip calculation if value is default (0, false, empty) and not forced + * 2. Calculate the size based on the field's encoding rules + * 3. Add the field_id_size + calculated value size to total_size + */ + + /** + * @brief Calculates and adds the size of an int32 field to the total message size + */ + static inline void add_int32_field(uint32_t &total_size, uint32_t field_id_size, int32_t value, bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // Calculate and directly add to total_size + if (value < 0) { + // Negative values are encoded as 10-byte varints in protobuf + total_size += field_id_size + 10; + } else { + // For non-negative values, use the standard varint size + total_size += field_id_size + varint(static_cast(value)); + } + } + + /** + * @brief Calculates and adds the size of a uint32 field to the total message size + */ + static inline void add_uint32_field(uint32_t &total_size, uint32_t field_id_size, uint32_t value, + bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // Calculate and directly add to total_size + total_size += field_id_size + varint(value); + } + + /** + * @brief Calculates and adds the size of a boolean field to the total message size + */ + static inline void add_bool_field(uint32_t &total_size, uint32_t field_id_size, bool value, bool force = false) { + // Skip calculation if value is false and not forced + if (!value && !force) { + return; // No need to update total_size + } + + // Boolean fields always use 1 byte when true + total_size += field_id_size + 1; + } + + /** + * @brief Calculates and adds the size of a fixed field to the total message size + * + * Fixed fields always take exactly N bytes (4 for fixed32/float, 8 for fixed64/double). + * + * @tparam NumBytes The number of bytes for this fixed field (4 or 8) + * @param is_nonzero Whether the value is non-zero + */ + template + static inline void add_fixed_field(uint32_t &total_size, uint32_t field_id_size, bool is_nonzero, + bool force = false) { + // Skip calculation if value is zero and not forced + if (!is_nonzero && !force) { + return; // No need to update total_size + } + + // Fixed fields always take exactly NumBytes + total_size += field_id_size + NumBytes; + } + + /** + * @brief Calculates and adds the size of an enum field to the total message size + * + * Enum fields are encoded as uint32 varints. + */ + static inline void add_enum_field(uint32_t &total_size, uint32_t field_id_size, uint32_t value, bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // Enums are encoded as uint32 + total_size += field_id_size + varint(value); + } + + /** + * @brief Calculates and adds the size of a sint32 field to the total message size + * + * Sint32 fields use ZigZag encoding, which is more efficient for negative values. + */ + static inline void add_sint32_field(uint32_t &total_size, uint32_t field_id_size, int32_t value, bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // ZigZag encoding for sint32: (n << 1) ^ (n >> 31) + uint32_t zigzag = (static_cast(value) << 1) ^ (static_cast(value >> 31)); + total_size += field_id_size + varint(zigzag); + } + + /** + * @brief Calculates and adds the size of an int64 field to the total message size + */ + static inline void add_int64_field(uint32_t &total_size, uint32_t field_id_size, int64_t value, bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // Calculate and directly add to total_size + total_size += field_id_size + varint(value); + } + + /** + * @brief Calculates and adds the size of a uint64 field to the total message size + */ + static inline void add_uint64_field(uint32_t &total_size, uint32_t field_id_size, uint64_t value, + bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // Calculate and directly add to total_size + total_size += field_id_size + varint(value); + } + + /** + * @brief Calculates and adds the size of a sint64 field to the total message size + * + * Sint64 fields use ZigZag encoding, which is more efficient for negative values. + */ + static inline void add_sint64_field(uint32_t &total_size, uint32_t field_id_size, int64_t value, bool force = false) { + // Skip calculation if value is zero and not forced + if (value == 0 && !force) { + return; // No need to update total_size + } + + // ZigZag encoding for sint64: (n << 1) ^ (n >> 63) + uint64_t zigzag = (static_cast(value) << 1) ^ (static_cast(value >> 63)); + total_size += field_id_size + varint(zigzag); + } + + /** + * @brief Calculates and adds the size of a string/bytes field to the total message size + */ + static inline void add_string_field(uint32_t &total_size, uint32_t field_id_size, const std::string &str, + bool force = false) { + // Skip calculation if string is empty and not forced + if (str.empty() && !force) { + return; // No need to update total_size + } + + // Calculate and directly add to total_size + const uint32_t str_size = static_cast(str.size()); + total_size += field_id_size + varint(str_size) + str_size; + } + + /** + * @brief Calculates and adds the size of a nested message field to the total message size + * + * This helper function directly updates the total_size reference if the nested size + * is greater than zero or force is true. + * + * @param nested_size The pre-calculated size of the nested message + */ + static inline void add_message_field(uint32_t &total_size, uint32_t field_id_size, uint32_t nested_size, + bool force = false) { + // Skip calculation if nested message is empty and not forced + if (nested_size == 0 && !force) { + return; // No need to update total_size + } + + // Calculate and directly add to total_size + // Field ID + length varint + nested message content + total_size += field_id_size + varint(nested_size) + nested_size; + } + + /** + * @brief Calculates and adds the size of a nested message field to the total message size + * + * This templated version directly takes a message object, calculates its size internally, + * and updates the total_size reference. This eliminates the need for a temporary variable + * at the call site. + * + * @tparam MessageType The type of the nested message (inferred from parameter) + * @param message The nested message object + */ + template + static inline void add_message_object(uint32_t &total_size, uint32_t field_id_size, const MessageType &message, + bool force = false) { + uint32_t nested_size = 0; + message.calculate_size(nested_size); + + // Use the base implementation with the calculated nested_size + add_message_field(total_size, field_id_size, nested_size, force); + } + + /** + * @brief Calculates and adds the sizes of all messages in a repeated field to the total message size + * + * This helper processes a vector of message objects, calculating the size for each message + * and adding it to the total size. + * + * @tparam MessageType The type of the nested messages in the vector + * @param messages Vector of message objects + */ + template + static inline void add_repeated_message(uint32_t &total_size, uint32_t field_id_size, + const std::vector &messages) { + // Skip if the vector is empty + if (messages.empty()) { + return; + } + + // For repeated fields, always use force=true + for (const auto &message : messages) { + add_message_object(total_size, field_id_size, message, true); + } + } +}; + +} // namespace api +} // namespace esphome diff --git a/esphome/components/api/api_server.cpp b/esphome/components/api/api_server.cpp index 7b21a174a0..b987b44705 100644 --- a/esphome/components/api/api_server.cpp +++ b/esphome/components/api/api_server.cpp @@ -22,22 +22,40 @@ namespace api { static const char *const TAG = "api"; // APIServer +APIServer *global_api_server = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) + +APIServer::APIServer() { global_api_server = this; } + void APIServer::setup() { ESP_LOGCONFIG(TAG, "Setting up Home Assistant API server..."); this->setup_controller(); - socket_ = socket::socket_ip(SOCK_STREAM, 0); - if (socket_ == nullptr) { - ESP_LOGW(TAG, "Could not create socket."); + +#ifdef USE_API_NOISE + uint32_t hash = 88491486UL; + + this->noise_pref_ = global_preferences->make_preference(hash, true); + + SavedNoisePsk noise_pref_saved{}; + if (this->noise_pref_.load(&noise_pref_saved)) { + ESP_LOGD(TAG, "Loaded saved Noise PSK"); + + this->set_noise_psk(noise_pref_saved.psk); + } +#endif + + this->socket_ = socket::socket_ip(SOCK_STREAM, 0); + if (this->socket_ == nullptr) { + ESP_LOGW(TAG, "Could not create socket"); this->mark_failed(); return; } int enable = 1; - int err = socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); + int err = this->socket_->setsockopt(SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int)); if (err != 0) { ESP_LOGW(TAG, "Socket unable to set reuseaddr: errno %d", err); // we can still continue } - err = socket_->setblocking(false); + err = this->socket_->setblocking(false); if (err != 0) { ESP_LOGW(TAG, "Socket unable to set nonblocking mode: errno %d", err); this->mark_failed(); @@ -53,14 +71,14 @@ void APIServer::setup() { return; } - err = socket_->bind((struct sockaddr *) &server, sl); + err = this->socket_->bind((struct sockaddr *) &server, sl); if (err != 0) { ESP_LOGW(TAG, "Socket unable to bind: errno %d", errno); this->mark_failed(); return; } - err = socket_->listen(4); + err = this->socket_->listen(4); if (err != 0) { ESP_LOGW(TAG, "Socket unable to listen: errno %d", errno); this->mark_failed(); @@ -92,34 +110,45 @@ void APIServer::setup() { } #endif } + void APIServer::loop() { // Accept new clients while (true) { struct sockaddr_storage source_addr; socklen_t addr_len = sizeof(source_addr); - auto sock = socket_->accept((struct sockaddr *) &source_addr, &addr_len); + auto sock = this->socket_->accept((struct sockaddr *) &source_addr, &addr_len); if (!sock) break; ESP_LOGD(TAG, "Accepted %s", sock->getpeername().c_str()); auto *conn = new APIConnection(std::move(sock), this); - clients_.emplace_back(conn); + this->clients_.emplace_back(conn); conn->start(); } - // Partition clients into remove and active - auto new_end = std::partition(this->clients_.begin(), this->clients_.end(), - [](const std::unique_ptr &conn) { return !conn->remove_; }); - // print disconnection messages - for (auto it = new_end; it != this->clients_.end(); ++it) { - this->client_disconnected_trigger_->trigger((*it)->client_info_, (*it)->client_peername_); - ESP_LOGV(TAG, "Removing connection to %s", (*it)->client_info_.c_str()); - } - // resize vector - this->clients_.erase(new_end, this->clients_.end()); + // Process clients and remove disconnected ones in a single pass + if (!this->clients_.empty()) { + size_t client_index = 0; + while (client_index < this->clients_.size()) { + auto &client = this->clients_[client_index]; - for (auto &client : this->clients_) { - client->loop(); + if (client->remove_) { + // Handle disconnection + this->client_disconnected_trigger_->trigger(client->client_info_, client->client_peername_); + ESP_LOGV(TAG, "Removing connection to %s", client->client_info_.c_str()); + + // Swap with the last element and pop (avoids expensive vector shifts) + if (client_index < this->clients_.size() - 1) { + std::swap(this->clients_[client_index], this->clients_.back()); + } + this->clients_.pop_back(); + // Don't increment client_index since we need to process the swapped element + } else { + // Process active client + client->loop(); + client_index++; // Move to next client + } + } } if (this->reboot_timeout_ != 0) { @@ -136,16 +165,22 @@ void APIServer::loop() { } } } + void APIServer::dump_config() { ESP_LOGCONFIG(TAG, "API Server:"); ESP_LOGCONFIG(TAG, " Address: %s:%u", network::get_use_address().c_str(), this->port_); #ifdef USE_API_NOISE - ESP_LOGCONFIG(TAG, " Using noise encryption: YES"); + ESP_LOGCONFIG(TAG, " Using noise encryption: %s", YESNO(this->noise_ctx_->has_psk())); + if (!this->noise_ctx_->has_psk()) { + ESP_LOGCONFIG(TAG, " Supports noise encryption: YES"); + } #else ESP_LOGCONFIG(TAG, " Using noise encryption: NO"); #endif } + bool APIServer::uses_password() const { return !this->password_.empty(); } + bool APIServer::check_password(const std::string &password) const { // depend only on input password length const char *a = this->password_.c_str(); @@ -174,7 +209,9 @@ bool APIServer::check_password(const std::string &password) const { return result == 0; } + void APIServer::handle_disconnect(APIConnection *conn) {} + #ifdef USE_BINARY_SENSOR void APIServer::on_binary_sensor_update(binary_sensor::BinarySensor *obj, bool state) { if (obj->is_internal()) @@ -342,57 +379,6 @@ void APIServer::on_update(update::UpdateEntity *obj) { } #endif -float APIServer::get_setup_priority() const { return setup_priority::AFTER_WIFI; } -void APIServer::set_port(uint16_t port) { this->port_ = port; } -APIServer *global_api_server = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) - -void APIServer::set_password(const std::string &password) { this->password_ = password; } -void APIServer::send_homeassistant_service_call(const HomeassistantServiceResponse &call) { - for (auto &client : this->clients_) { - client->send_homeassistant_service_call(call); - } -} - -APIServer::APIServer() { global_api_server = this; } -void APIServer::subscribe_home_assistant_state(std::string entity_id, optional attribute, - std::function f) { - this->state_subs_.push_back(HomeAssistantStateSubscription{ - .entity_id = std::move(entity_id), - .attribute = std::move(attribute), - .callback = std::move(f), - .once = false, - }); -} -void APIServer::get_home_assistant_state(std::string entity_id, optional attribute, - std::function f) { - this->state_subs_.push_back(HomeAssistantStateSubscription{ - .entity_id = std::move(entity_id), - .attribute = std::move(attribute), - .callback = std::move(f), - .once = true, - }); -}; -const std::vector &APIServer::get_state_subs() const { - return this->state_subs_; -} -uint16_t APIServer::get_port() const { return this->port_; } -void APIServer::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; } -#ifdef USE_HOMEASSISTANT_TIME -void APIServer::request_time() { - for (auto &client : this->clients_) { - if (!client->remove_ && client->is_authenticated()) - client->send_time_request(); - } -} -#endif -bool APIServer::is_connected() const { return !this->clients_.empty(); } -void APIServer::on_shutdown() { - for (auto &c : this->clients_) { - c->send_disconnect_request(DisconnectRequest()); - } - delay(10); -} - #ifdef USE_ALARM_CONTROL_PANEL void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlPanel *obj) { if (obj->is_internal()) @@ -402,6 +388,96 @@ void APIServer::on_alarm_control_panel_update(alarm_control_panel::AlarmControlP } #endif +float APIServer::get_setup_priority() const { return setup_priority::AFTER_WIFI; } + +void APIServer::set_port(uint16_t port) { this->port_ = port; } + +void APIServer::set_password(const std::string &password) { this->password_ = password; } + +void APIServer::send_homeassistant_service_call(const HomeassistantServiceResponse &call) { + for (auto &client : this->clients_) { + client->send_homeassistant_service_call(call); + } +} + +void APIServer::subscribe_home_assistant_state(std::string entity_id, optional attribute, + std::function f) { + this->state_subs_.push_back(HomeAssistantStateSubscription{ + .entity_id = std::move(entity_id), + .attribute = std::move(attribute), + .callback = std::move(f), + .once = false, + }); +} + +void APIServer::get_home_assistant_state(std::string entity_id, optional attribute, + std::function f) { + this->state_subs_.push_back(HomeAssistantStateSubscription{ + .entity_id = std::move(entity_id), + .attribute = std::move(attribute), + .callback = std::move(f), + .once = true, + }); +}; + +const std::vector &APIServer::get_state_subs() const { + return this->state_subs_; +} + +uint16_t APIServer::get_port() const { return this->port_; } + +void APIServer::set_reboot_timeout(uint32_t reboot_timeout) { this->reboot_timeout_ = reboot_timeout; } + +#ifdef USE_API_NOISE +bool APIServer::save_noise_psk(psk_t psk, bool make_active) { + auto &old_psk = this->noise_ctx_->get_psk(); + if (std::equal(old_psk.begin(), old_psk.end(), psk.begin())) { + ESP_LOGW(TAG, "New PSK matches old"); + return true; + } + + SavedNoisePsk new_saved_psk{psk}; + if (!this->noise_pref_.save(&new_saved_psk)) { + ESP_LOGW(TAG, "Failed to save Noise PSK"); + return false; + } + // ensure it's written immediately + if (!global_preferences->sync()) { + ESP_LOGW(TAG, "Failed to sync preferences"); + return false; + } + ESP_LOGD(TAG, "Noise PSK saved"); + if (make_active) { + this->set_timeout(100, [this, psk]() { + ESP_LOGW(TAG, "Disconnecting all clients to reset connections"); + this->set_noise_psk(psk); + for (auto &c : this->clients_) { + c->send_disconnect_request(DisconnectRequest()); + } + }); + } + return true; +} +#endif + +#ifdef USE_HOMEASSISTANT_TIME +void APIServer::request_time() { + for (auto &client : this->clients_) { + if (!client->remove_ && client->is_authenticated()) + client->send_time_request(); + } +} +#endif + +bool APIServer::is_connected() const { return !this->clients_.empty(); } + +void APIServer::on_shutdown() { + for (auto &c : this->clients_) { + c->send_disconnect_request(DisconnectRequest()); + } + delay(10); +} + } // namespace api } // namespace esphome #endif diff --git a/esphome/components/api/api_server.h b/esphome/components/api/api_server.h index 42e0b1048a..a6645b96ce 100644 --- a/esphome/components/api/api_server.h +++ b/esphome/components/api/api_server.h @@ -19,6 +19,12 @@ namespace esphome { namespace api { +#ifdef USE_API_NOISE +struct SavedNoisePsk { + psk_t psk; +} PACKED; // NOLINT +#endif + class APIServer : public Component, public Controller { public: APIServer(); @@ -35,6 +41,7 @@ class APIServer : public Component, public Controller { void set_reboot_timeout(uint32_t reboot_timeout); #ifdef USE_API_NOISE + bool save_noise_psk(psk_t psk, bool make_active = true); void set_noise_psk(psk_t psk) { noise_ctx_->set_psk(psk); } std::shared_ptr get_noise_ctx() { return noise_ctx_; } #endif // USE_API_NOISE @@ -142,6 +149,7 @@ class APIServer : public Component, public Controller { #ifdef USE_API_NOISE std::shared_ptr noise_ctx_ = std::make_shared(); + ESPPreferenceObject noise_pref_; #endif // USE_API_NOISE }; diff --git a/esphome/components/api/proto.h b/esphome/components/api/proto.h index ccc6c0d52c..b8ee6b7920 100644 --- a/esphome/components/api/proto.h +++ b/esphome/components/api/proto.h @@ -149,6 +149,18 @@ class ProtoWriteBuffer { void write(uint8_t value) { this->buffer_->push_back(value); } void encode_varint_raw(ProtoVarInt value) { value.encode(*this->buffer_); } void encode_varint_raw(uint32_t value) { this->encode_varint_raw(ProtoVarInt(value)); } + /** + * Encode a field key (tag/wire type combination). + * + * @param field_id Field number (tag) in the protobuf message + * @param type Wire type value: + * - 0: Varint (int32, int64, uint32, uint64, sint32, sint64, bool, enum) + * - 1: 64-bit (fixed64, sfixed64, double) + * - 2: Length-delimited (string, bytes, embedded messages, packed repeated fields) + * - 5: 32-bit (fixed32, sfixed32, float) + * + * Following https://protobuf.dev/programming-guides/encoding/#structure + */ void encode_field_raw(uint32_t field_id, uint32_t type) { uint32_t val = (field_id << 3) | (type & 0b111); this->encode_varint_raw(val); @@ -157,7 +169,7 @@ class ProtoWriteBuffer { if (len == 0 && !force) return; - this->encode_field_raw(field_id, 2); + this->encode_field_raw(field_id, 2); // type 2: Length-delimited string this->encode_varint_raw(len); auto *data = reinterpret_cast(string); this->buffer_->insert(this->buffer_->end(), data, data + len); @@ -171,26 +183,26 @@ class ProtoWriteBuffer { void encode_uint32(uint32_t field_id, uint32_t value, bool force = false) { if (value == 0 && !force) return; - this->encode_field_raw(field_id, 0); + this->encode_field_raw(field_id, 0); // type 0: Varint - uint32 this->encode_varint_raw(value); } void encode_uint64(uint32_t field_id, uint64_t value, bool force = false) { if (value == 0 && !force) return; - this->encode_field_raw(field_id, 0); + this->encode_field_raw(field_id, 0); // type 0: Varint - uint64 this->encode_varint_raw(ProtoVarInt(value)); } void encode_bool(uint32_t field_id, bool value, bool force = false) { if (!value && !force) return; - this->encode_field_raw(field_id, 0); + this->encode_field_raw(field_id, 0); // type 0: Varint - bool this->write(0x01); } void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) { if (value == 0 && !force) return; - this->encode_field_raw(field_id, 5); + this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32 this->write((value >> 0) & 0xFF); this->write((value >> 8) & 0xFF); this->write((value >> 16) & 0xFF); @@ -200,7 +212,7 @@ class ProtoWriteBuffer { if (value == 0 && !force) return; - this->encode_field_raw(field_id, 5); + this->encode_field_raw(field_id, 1); // type 1: 64-bit fixed64 this->write((value >> 0) & 0xFF); this->write((value >> 8) & 0xFF); this->write((value >> 16) & 0xFF); @@ -254,7 +266,7 @@ class ProtoWriteBuffer { this->encode_uint64(field_id, uvalue, force); } template void encode_message(uint32_t field_id, const C &value, bool force = false) { - this->encode_field_raw(field_id, 2); + this->encode_field_raw(field_id, 2); // type 2: Length-delimited message size_t begin = this->buffer_->size(); value.encode(*this); @@ -276,6 +288,7 @@ class ProtoMessage { virtual ~ProtoMessage() = default; virtual void encode(ProtoWriteBuffer buffer) const = 0; void decode(const uint8_t *buffer, size_t length); + virtual void calculate_size(uint32_t &total_size) const = 0; #ifdef HAS_PROTO_MESSAGE_DUMP std::string dump() const; virtual void dump_to(std::string &out) const = 0; @@ -298,13 +311,29 @@ class ProtoService { virtual void on_fatal_error() = 0; virtual void on_unauthenticated_access() = 0; virtual void on_no_setup_connection() = 0; - virtual ProtoWriteBuffer create_buffer() = 0; + /** + * Create a buffer with a reserved size. + * @param reserve_size The number of bytes to pre-allocate in the buffer. This is a hint + * to optimize memory usage and avoid reallocations during encoding. + * Implementations should aim to allocate at least this size. + * @return A ProtoWriteBuffer object with the reserved size. + */ + virtual ProtoWriteBuffer create_buffer(uint32_t reserve_size) = 0; virtual bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) = 0; virtual bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) = 0; + // Optimized method that pre-allocates buffer based on message size template bool send_message_(const C &msg, uint32_t message_type) { - auto buffer = this->create_buffer(); + uint32_t msg_size = 0; + msg.calculate_size(msg_size); + + // Create a pre-sized buffer + auto buffer = this->create_buffer(msg_size); + + // Encode message into the buffer msg.encode(buffer); + + // Send the buffer return this->send_buffer(buffer, message_type); } }; diff --git a/esphome/components/as7341/as7341.h b/esphome/components/as7341/as7341.h index e517e1d2bf..aed7996cef 100644 --- a/esphome/components/as7341/as7341.h +++ b/esphome/components/as7341/as7341.h @@ -7,7 +7,7 @@ namespace esphome { namespace as7341 { -static const uint8_t AS7341_CHIP_ID = 0X09; +static const uint8_t AS7341_CHIP_ID = 0x09; static const uint8_t AS7341_CONFIG = 0x70; static const uint8_t AS7341_LED = 0x74; diff --git a/esphome/components/atm90e32/__init__.py b/esphome/components/atm90e32/__init__.py index 8ce95be489..766807872b 100644 --- a/esphome/components/atm90e32/__init__.py +++ b/esphome/components/atm90e32/__init__.py @@ -3,5 +3,6 @@ import esphome.codegen as cg CODEOWNERS = ["@circuitsetup", "@descipher"] atm90e32_ns = cg.esphome_ns.namespace("atm90e32") +ATM90E32Component = atm90e32_ns.class_("ATM90E32Component", cg.Component) CONF_ATM90E32_ID = "atm90e32_id" diff --git a/esphome/components/atm90e32/atm90e32.cpp b/esphome/components/atm90e32/atm90e32.cpp index 43647b1855..f4f177587c 100644 --- a/esphome/components/atm90e32/atm90e32.cpp +++ b/esphome/components/atm90e32/atm90e32.cpp @@ -1,7 +1,7 @@ #include "atm90e32.h" -#include "atm90e32_reg.h" -#include "esphome/core/log.h" #include +#include +#include "esphome/core/log.h" namespace esphome { namespace atm90e32 { @@ -11,115 +11,84 @@ void ATM90E32Component::loop() { if (this->get_publish_interval_flag_()) { this->set_publish_interval_flag_(false); for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].voltage_sensor_ != nullptr) { + if (this->phase_[phase].voltage_sensor_ != nullptr) this->phase_[phase].voltage_ = this->get_phase_voltage_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].current_sensor_ != nullptr) { + + if (this->phase_[phase].current_sensor_ != nullptr) this->phase_[phase].current_ = this->get_phase_current_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].power_sensor_ != nullptr) { + + if (this->phase_[phase].power_sensor_ != nullptr) this->phase_[phase].active_power_ = this->get_phase_active_power_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].power_factor_sensor_ != nullptr) { + + if (this->phase_[phase].power_factor_sensor_ != nullptr) this->phase_[phase].power_factor_ = this->get_phase_power_factor_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].reactive_power_sensor_ != nullptr) { + + if (this->phase_[phase].reactive_power_sensor_ != nullptr) this->phase_[phase].reactive_power_ = this->get_phase_reactive_power_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].forward_active_energy_sensor_ != nullptr) { + + if (this->phase_[phase].apparent_power_sensor_ != nullptr) + this->phase_[phase].apparent_power_ = this->get_phase_apparent_power_(phase); + + if (this->phase_[phase].forward_active_energy_sensor_ != nullptr) this->phase_[phase].forward_active_energy_ = this->get_phase_forward_active_energy_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].reverse_active_energy_sensor_ != nullptr) { + + if (this->phase_[phase].reverse_active_energy_sensor_ != nullptr) this->phase_[phase].reverse_active_energy_ = this->get_phase_reverse_active_energy_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].phase_angle_sensor_ != nullptr) { + + if (this->phase_[phase].phase_angle_sensor_ != nullptr) this->phase_[phase].phase_angle_ = this->get_phase_angle_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].harmonic_active_power_sensor_ != nullptr) { + + if (this->phase_[phase].harmonic_active_power_sensor_ != nullptr) this->phase_[phase].harmonic_active_power_ = this->get_phase_harmonic_active_power_(phase); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].peak_current_sensor_ != nullptr) { + + if (this->phase_[phase].peak_current_sensor_ != nullptr) this->phase_[phase].peak_current_ = this->get_phase_peak_current_(phase); - } - } - // After the local store in collected we can publish them trusting they are withing +-1 haardware sampling - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].voltage_sensor_ != nullptr) { + + // After the local store is collected we can publish them trusting they are within +-1 hardware sampling + if (this->phase_[phase].voltage_sensor_ != nullptr) this->phase_[phase].voltage_sensor_->publish_state(this->get_local_phase_voltage_(phase)); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].current_sensor_ != nullptr) { + + if (this->phase_[phase].current_sensor_ != nullptr) this->phase_[phase].current_sensor_->publish_state(this->get_local_phase_current_(phase)); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].power_sensor_ != nullptr) { + + if (this->phase_[phase].power_sensor_ != nullptr) this->phase_[phase].power_sensor_->publish_state(this->get_local_phase_active_power_(phase)); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].power_factor_sensor_ != nullptr) { + + if (this->phase_[phase].power_factor_sensor_ != nullptr) this->phase_[phase].power_factor_sensor_->publish_state(this->get_local_phase_power_factor_(phase)); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].reactive_power_sensor_ != nullptr) { + + if (this->phase_[phase].reactive_power_sensor_ != nullptr) this->phase_[phase].reactive_power_sensor_->publish_state(this->get_local_phase_reactive_power_(phase)); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { + + if (this->phase_[phase].apparent_power_sensor_ != nullptr) + this->phase_[phase].apparent_power_sensor_->publish_state(this->get_local_phase_apparent_power_(phase)); + if (this->phase_[phase].forward_active_energy_sensor_ != nullptr) { this->phase_[phase].forward_active_energy_sensor_->publish_state( this->get_local_phase_forward_active_energy_(phase)); } - } - for (uint8_t phase = 0; phase < 3; phase++) { + if (this->phase_[phase].reverse_active_energy_sensor_ != nullptr) { this->phase_[phase].reverse_active_energy_sensor_->publish_state( this->get_local_phase_reverse_active_energy_(phase)); } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].phase_angle_sensor_ != nullptr) { + + if (this->phase_[phase].phase_angle_sensor_ != nullptr) this->phase_[phase].phase_angle_sensor_->publish_state(this->get_local_phase_angle_(phase)); - } - } - for (uint8_t phase = 0; phase < 3; phase++) { + if (this->phase_[phase].harmonic_active_power_sensor_ != nullptr) { this->phase_[phase].harmonic_active_power_sensor_->publish_state( this->get_local_phase_harmonic_active_power_(phase)); } - } - for (uint8_t phase = 0; phase < 3; phase++) { - if (this->phase_[phase].peak_current_sensor_ != nullptr) { + + if (this->phase_[phase].peak_current_sensor_ != nullptr) this->phase_[phase].peak_current_sensor_->publish_state(this->get_local_phase_peak_current_(phase)); - } } - if (this->freq_sensor_ != nullptr) { + if (this->freq_sensor_ != nullptr) this->freq_sensor_->publish_state(this->get_frequency_()); - } - if (this->chip_temperature_sensor_ != nullptr) { + + if (this->chip_temperature_sensor_ != nullptr) this->chip_temperature_sensor_->publish_state(this->get_chip_temperature_()); - } } } @@ -130,82 +99,30 @@ void ATM90E32Component::update() { } this->set_publish_interval_flag_(true); this->status_clear_warning(); -} -void ATM90E32Component::restore_calibrations_() { - if (enable_offset_calibration_) { - this->pref_.load(&this->offset_phase_); - } -}; - -void ATM90E32Component::run_offset_calibrations() { - // Run the calibrations and - // Setup voltage and current calibration offsets for PHASE A - this->offset_phase_[PHASEA].voltage_offset_ = calibrate_voltage_offset_phase(PHASEA); - this->phase_[PHASEA].voltage_offset_ = this->offset_phase_[PHASEA].voltage_offset_; - this->write16_(ATM90E32_REGISTER_UOFFSETA, this->phase_[PHASEA].voltage_offset_); // C Voltage offset - this->offset_phase_[PHASEA].current_offset_ = calibrate_current_offset_phase(PHASEA); - this->phase_[PHASEA].current_offset_ = this->offset_phase_[PHASEA].current_offset_; - this->write16_(ATM90E32_REGISTER_IOFFSETA, this->phase_[PHASEA].current_offset_); // C Current offset - // Setup voltage and current calibration offsets for PHASE B - this->offset_phase_[PHASEB].voltage_offset_ = calibrate_voltage_offset_phase(PHASEB); - this->phase_[PHASEB].voltage_offset_ = this->offset_phase_[PHASEB].voltage_offset_; - this->write16_(ATM90E32_REGISTER_UOFFSETB, this->phase_[PHASEB].voltage_offset_); // C Voltage offset - this->offset_phase_[PHASEB].current_offset_ = calibrate_current_offset_phase(PHASEB); - this->phase_[PHASEB].current_offset_ = this->offset_phase_[PHASEB].current_offset_; - this->write16_(ATM90E32_REGISTER_IOFFSETB, this->phase_[PHASEB].current_offset_); // C Current offset - // Setup voltage and current calibration offsets for PHASE C - this->offset_phase_[PHASEC].voltage_offset_ = calibrate_voltage_offset_phase(PHASEC); - this->phase_[PHASEC].voltage_offset_ = this->offset_phase_[PHASEC].voltage_offset_; - this->write16_(ATM90E32_REGISTER_UOFFSETC, this->phase_[PHASEC].voltage_offset_); // C Voltage offset - this->offset_phase_[PHASEC].current_offset_ = calibrate_current_offset_phase(PHASEC); - this->phase_[PHASEC].current_offset_ = this->offset_phase_[PHASEC].current_offset_; - this->write16_(ATM90E32_REGISTER_IOFFSETC, this->phase_[PHASEC].current_offset_); // C Current offset - this->pref_.save(&this->offset_phase_); - ESP_LOGI(TAG, "PhaseA Vo=%5d PhaseB Vo=%5d PhaseC Vo=%5d", this->offset_phase_[PHASEA].voltage_offset_, - this->offset_phase_[PHASEB].voltage_offset_, this->offset_phase_[PHASEC].voltage_offset_); - ESP_LOGI(TAG, "PhaseA Io=%5d PhaseB Io=%5d PhaseC Io=%5d", this->offset_phase_[PHASEA].current_offset_, - this->offset_phase_[PHASEB].current_offset_, this->offset_phase_[PHASEC].current_offset_); -} - -void ATM90E32Component::clear_offset_calibrations() { - // Clear the calibrations and - this->offset_phase_[PHASEA].voltage_offset_ = 0; - this->phase_[PHASEA].voltage_offset_ = this->offset_phase_[PHASEA].voltage_offset_; - this->write16_(ATM90E32_REGISTER_UOFFSETA, this->phase_[PHASEA].voltage_offset_); // C Voltage offset - this->offset_phase_[PHASEA].current_offset_ = 0; - this->phase_[PHASEA].current_offset_ = this->offset_phase_[PHASEA].current_offset_; - this->write16_(ATM90E32_REGISTER_IOFFSETA, this->phase_[PHASEA].current_offset_); // C Current offset - this->offset_phase_[PHASEB].voltage_offset_ = 0; - this->phase_[PHASEB].voltage_offset_ = this->offset_phase_[PHASEB].voltage_offset_; - this->write16_(ATM90E32_REGISTER_UOFFSETB, this->phase_[PHASEB].voltage_offset_); // C Voltage offset - this->offset_phase_[PHASEB].current_offset_ = 0; - this->phase_[PHASEB].current_offset_ = this->offset_phase_[PHASEB].current_offset_; - this->write16_(ATM90E32_REGISTER_IOFFSETB, this->phase_[PHASEB].current_offset_); // C Current offset - this->offset_phase_[PHASEC].voltage_offset_ = 0; - this->phase_[PHASEC].voltage_offset_ = this->offset_phase_[PHASEC].voltage_offset_; - this->write16_(ATM90E32_REGISTER_UOFFSETC, this->phase_[PHASEC].voltage_offset_); // C Voltage offset - this->offset_phase_[PHASEC].current_offset_ = 0; - this->phase_[PHASEC].current_offset_ = this->offset_phase_[PHASEC].current_offset_; - this->write16_(ATM90E32_REGISTER_IOFFSETC, this->phase_[PHASEC].current_offset_); // C Current offset - this->pref_.save(&this->offset_phase_); - ESP_LOGI(TAG, "PhaseA Vo=%5d PhaseB Vo=%5d PhaseC Vo=%5d", this->offset_phase_[PHASEA].voltage_offset_, - this->offset_phase_[PHASEB].voltage_offset_, this->offset_phase_[PHASEC].voltage_offset_); - ESP_LOGI(TAG, "PhaseA Io=%5d PhaseB Io=%5d PhaseC Io=%5d", this->offset_phase_[PHASEA].current_offset_, - this->offset_phase_[PHASEB].current_offset_, this->offset_phase_[PHASEC].current_offset_); +#ifdef USE_TEXT_SENSOR + this->check_phase_status(); + this->check_over_current(); + this->check_freq_status(); +#endif } void ATM90E32Component::setup() { ESP_LOGCONFIG(TAG, "Setting up ATM90E32 Component..."); this->spi_setup(); - if (this->enable_offset_calibration_) { - uint32_t hash = fnv1_hash(App.get_friendly_name()); - this->pref_ = global_preferences->make_preference(hash, true); - this->restore_calibrations_(); - } + uint16_t mmode0 = 0x87; // 3P4W 50Hz + uint16_t high_thresh = 0; + uint16_t low_thresh = 0; + if (line_freq_ == 60) { mmode0 |= 1 << 12; // sets 12th bit to 1, 60Hz + // for freq threshold registers + high_thresh = 6300; // 63.00 Hz + low_thresh = 5700; // 57.00 Hz + } else { + high_thresh = 5300; // 53.00 Hz + low_thresh = 4700; // 47.00 Hz } if (current_phases_ == 2) { @@ -216,34 +133,84 @@ void ATM90E32Component::setup() { this->write16_(ATM90E32_REGISTER_SOFTRESET, 0x789A); // Perform soft reset delay(6); // Wait for the minimum 5ms + 1ms this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x55AA); // enable register config access - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != 0x55AA) { + if (!this->validate_spi_read_(0x55AA, "setup()")) { ESP_LOGW(TAG, "Could not initialize ATM90E32 IC, check SPI settings"); this->mark_failed(); return; } this->write16_(ATM90E32_REGISTER_METEREN, 0x0001); // Enable Metering - this->write16_(ATM90E32_REGISTER_SAGPEAKDETCFG, 0xFF3F); // Peak Detector time ms (15:8), Sag Period ms (7:0) + this->write16_(ATM90E32_REGISTER_SAGPEAKDETCFG, 0xFF3F); // Peak Detector time (15:8) 255ms, Sag Period (7:0) 63ms this->write16_(ATM90E32_REGISTER_PLCONSTH, 0x0861); // PL Constant MSB (default) = 140625000 this->write16_(ATM90E32_REGISTER_PLCONSTL, 0xC468); // PL Constant LSB (default) - this->write16_(ATM90E32_REGISTER_ZXCONFIG, 0xD654); // ZX2, ZX1, ZX0 pin config + this->write16_(ATM90E32_REGISTER_ZXCONFIG, 0xD654); // Zero crossing (ZX2, ZX1, ZX0) pin config this->write16_(ATM90E32_REGISTER_MMODE0, mmode0); // Mode Config (frequency set in main program) this->write16_(ATM90E32_REGISTER_MMODE1, pga_gain_); // PGA Gain Configuration for Current Channels + this->write16_(ATM90E32_REGISTER_FREQHITH, high_thresh); // Frequency high threshold + this->write16_(ATM90E32_REGISTER_FREQLOTH, low_thresh); // Frequency low threshold this->write16_(ATM90E32_REGISTER_PSTARTTH, 0x1D4C); // All Active Startup Power Threshold - 0.02A/0.00032 = 7500 this->write16_(ATM90E32_REGISTER_QSTARTTH, 0x1D4C); // All Reactive Startup Power Threshold - 50% this->write16_(ATM90E32_REGISTER_SSTARTTH, 0x1D4C); // All Reactive Startup Power Threshold - 50% this->write16_(ATM90E32_REGISTER_PPHASETH, 0x02EE); // Each Phase Active Phase Threshold - 0.002A/0.00032 = 750 this->write16_(ATM90E32_REGISTER_QPHASETH, 0x02EE); // Each phase Reactive Phase Threshold - 10% - // Setup voltage and current gain for PHASE A - this->write16_(ATM90E32_REGISTER_UGAINA, this->phase_[PHASEA].voltage_gain_); // A Voltage rms gain - this->write16_(ATM90E32_REGISTER_IGAINA, this->phase_[PHASEA].ct_gain_); // A line current gain - // Setup voltage and current gain for PHASE B - this->write16_(ATM90E32_REGISTER_UGAINB, this->phase_[PHASEB].voltage_gain_); // B Voltage rms gain - this->write16_(ATM90E32_REGISTER_IGAINB, this->phase_[PHASEB].ct_gain_); // B line current gain - // Setup voltage and current gain for PHASE C - this->write16_(ATM90E32_REGISTER_UGAINC, this->phase_[PHASEC].voltage_gain_); // C Voltage rms gain - this->write16_(ATM90E32_REGISTER_IGAINC, this->phase_[PHASEC].ct_gain_); // C line current gain - this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x0000); // end configuration + + if (this->enable_offset_calibration_) { + // Initialize flash storage for offset calibrations + uint32_t o_hash = fnv1_hash(std::string("_offset_calibration_") + this->cs_->dump_summary()); + this->offset_pref_ = global_preferences->make_preference(o_hash, true); + this->restore_offset_calibrations_(); + + // Initialize flash storage for power offset calibrations + uint32_t po_hash = fnv1_hash(std::string("_power_offset_calibration_") + this->cs_->dump_summary()); + this->power_offset_pref_ = global_preferences->make_preference(po_hash, true); + this->restore_power_offset_calibrations_(); + } else { + ESP_LOGI(TAG, "[CALIBRATION] Power & Voltage/Current offset calibration is disabled. Using config file values."); + for (uint8_t phase = 0; phase < 3; ++phase) { + this->write16_(this->voltage_offset_registers[phase], + static_cast(this->offset_phase_[phase].voltage_offset_)); + this->write16_(this->current_offset_registers[phase], + static_cast(this->offset_phase_[phase].current_offset_)); + this->write16_(this->power_offset_registers[phase], + static_cast(this->power_offset_phase_[phase].active_power_offset)); + this->write16_(this->reactive_power_offset_registers[phase], + static_cast(this->power_offset_phase_[phase].reactive_power_offset)); + } + } + + if (this->enable_gain_calibration_) { + // Initialize flash storage for gain calibration + uint32_t g_hash = fnv1_hash(std::string("_gain_calibration_") + this->cs_->dump_summary()); + this->gain_calibration_pref_ = global_preferences->make_preference(g_hash, true); + this->restore_gain_calibrations_(); + + if (this->using_saved_calibrations_) { + ESP_LOGI(TAG, "[CALIBRATION] Successfully restored gain calibration from memory."); + } else { + for (uint8_t phase = 0; phase < 3; ++phase) { + this->write16_(voltage_gain_registers[phase], this->phase_[phase].voltage_gain_); + this->write16_(current_gain_registers[phase], this->phase_[phase].ct_gain_); + } + } + } else { + ESP_LOGI(TAG, "[CALIBRATION] Gain calibration is disabled. Using config file values."); + + for (uint8_t phase = 0; phase < 3; ++phase) { + this->write16_(voltage_gain_registers[phase], this->phase_[phase].voltage_gain_); + this->write16_(current_gain_registers[phase], this->phase_[phase].ct_gain_); + } + } + + // Sag threshold (78%) + uint16_t sagth = calculate_voltage_threshold(line_freq_, this->phase_[0].voltage_gain_, 0.78f); + // Overvoltage threshold (122%) + uint16_t ovth = calculate_voltage_threshold(line_freq_, this->phase_[0].voltage_gain_, 1.22f); + + // Write to registers + this->write16_(ATM90E32_REGISTER_SAGTH, sagth); + this->write16_(ATM90E32_REGISTER_OVTH, ovth); + + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x0000); // end configuration } void ATM90E32Component::dump_config() { @@ -257,6 +224,7 @@ void ATM90E32Component::dump_config() { LOG_SENSOR(" ", "Current A", this->phase_[PHASEA].current_sensor_); LOG_SENSOR(" ", "Power A", this->phase_[PHASEA].power_sensor_); LOG_SENSOR(" ", "Reactive Power A", this->phase_[PHASEA].reactive_power_sensor_); + LOG_SENSOR(" ", "Apparent Power A", this->phase_[PHASEA].apparent_power_sensor_); LOG_SENSOR(" ", "PF A", this->phase_[PHASEA].power_factor_sensor_); LOG_SENSOR(" ", "Active Forward Energy A", this->phase_[PHASEA].forward_active_energy_sensor_); LOG_SENSOR(" ", "Active Reverse Energy A", this->phase_[PHASEA].reverse_active_energy_sensor_); @@ -267,22 +235,24 @@ void ATM90E32Component::dump_config() { LOG_SENSOR(" ", "Current B", this->phase_[PHASEB].current_sensor_); LOG_SENSOR(" ", "Power B", this->phase_[PHASEB].power_sensor_); LOG_SENSOR(" ", "Reactive Power B", this->phase_[PHASEB].reactive_power_sensor_); + LOG_SENSOR(" ", "Apparent Power B", this->phase_[PHASEB].apparent_power_sensor_); LOG_SENSOR(" ", "PF B", this->phase_[PHASEB].power_factor_sensor_); LOG_SENSOR(" ", "Active Forward Energy B", this->phase_[PHASEB].forward_active_energy_sensor_); LOG_SENSOR(" ", "Active Reverse Energy B", this->phase_[PHASEB].reverse_active_energy_sensor_); - LOG_SENSOR(" ", "Harmonic Power A", this->phase_[PHASEB].harmonic_active_power_sensor_); - LOG_SENSOR(" ", "Phase Angle A", this->phase_[PHASEB].phase_angle_sensor_); - LOG_SENSOR(" ", "Peak Current A", this->phase_[PHASEB].peak_current_sensor_); + LOG_SENSOR(" ", "Harmonic Power B", this->phase_[PHASEB].harmonic_active_power_sensor_); + LOG_SENSOR(" ", "Phase Angle B", this->phase_[PHASEB].phase_angle_sensor_); + LOG_SENSOR(" ", "Peak Current B", this->phase_[PHASEB].peak_current_sensor_); LOG_SENSOR(" ", "Voltage C", this->phase_[PHASEC].voltage_sensor_); LOG_SENSOR(" ", "Current C", this->phase_[PHASEC].current_sensor_); LOG_SENSOR(" ", "Power C", this->phase_[PHASEC].power_sensor_); LOG_SENSOR(" ", "Reactive Power C", this->phase_[PHASEC].reactive_power_sensor_); + LOG_SENSOR(" ", "Apparent Power C", this->phase_[PHASEC].apparent_power_sensor_); LOG_SENSOR(" ", "PF C", this->phase_[PHASEC].power_factor_sensor_); LOG_SENSOR(" ", "Active Forward Energy C", this->phase_[PHASEC].forward_active_energy_sensor_); LOG_SENSOR(" ", "Active Reverse Energy C", this->phase_[PHASEC].reverse_active_energy_sensor_); - LOG_SENSOR(" ", "Harmonic Power A", this->phase_[PHASEC].harmonic_active_power_sensor_); - LOG_SENSOR(" ", "Phase Angle A", this->phase_[PHASEC].phase_angle_sensor_); - LOG_SENSOR(" ", "Peak Current A", this->phase_[PHASEC].peak_current_sensor_); + LOG_SENSOR(" ", "Harmonic Power C", this->phase_[PHASEC].harmonic_active_power_sensor_); + LOG_SENSOR(" ", "Phase Angle C", this->phase_[PHASEC].phase_angle_sensor_); + LOG_SENSOR(" ", "Peak Current C", this->phase_[PHASEC].peak_current_sensor_); LOG_SENSOR(" ", "Frequency", this->freq_sensor_); LOG_SENSOR(" ", "Chip Temp", this->chip_temperature_sensor_); } @@ -298,7 +268,7 @@ uint16_t ATM90E32Component::read16_(uint16_t a_register) { uint8_t data[2]; uint16_t output; this->enable(); - delay_microseconds_safe(10); + delay_microseconds_safe(1); // min delay between CS low and first SCK is 200ns - 1ms is plenty this->write_byte(addrh); this->write_byte(addrl); this->read_array(data, 2); @@ -328,8 +298,7 @@ void ATM90E32Component::write16_(uint16_t a_register, uint16_t val) { this->write_byte16(a_register); this->write_byte16(val); this->disable(); - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != val) - ESP_LOGW(TAG, "SPI write error 0x%04X val 0x%04X", a_register, val); + this->validate_spi_read_(val, "write16()"); } float ATM90E32Component::get_local_phase_voltage_(uint8_t phase) { return this->phase_[phase].voltage_; } @@ -340,6 +309,8 @@ float ATM90E32Component::get_local_phase_active_power_(uint8_t phase) { return t float ATM90E32Component::get_local_phase_reactive_power_(uint8_t phase) { return this->phase_[phase].reactive_power_; } +float ATM90E32Component::get_local_phase_apparent_power_(uint8_t phase) { return this->phase_[phase].apparent_power_; } + float ATM90E32Component::get_local_phase_power_factor_(uint8_t phase) { return this->phase_[phase].power_factor_; } float ATM90E32Component::get_local_phase_forward_active_energy_(uint8_t phase) { @@ -360,8 +331,7 @@ float ATM90E32Component::get_local_phase_peak_current_(uint8_t phase) { return t float ATM90E32Component::get_phase_voltage_(uint8_t phase) { const uint16_t voltage = this->read16_(ATM90E32_REGISTER_URMS + phase); - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != voltage) - ESP_LOGW(TAG, "SPI URMS voltage register read error."); + this->validate_spi_read_(voltage, "get_phase_voltage()"); return (float) voltage / 100; } @@ -371,8 +341,7 @@ float ATM90E32Component::get_phase_voltage_avg_(uint8_t phase) { uint16_t voltage = 0; for (uint8_t i = 0; i < reads; i++) { voltage = this->read16_(ATM90E32_REGISTER_URMS + phase); - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != voltage) - ESP_LOGW(TAG, "SPI URMS voltage register read error."); + this->validate_spi_read_(voltage, "get_phase_voltage_avg_()"); accumulation += voltage; } voltage = accumulation / reads; @@ -386,8 +355,7 @@ float ATM90E32Component::get_phase_current_avg_(uint8_t phase) { uint16_t current = 0; for (uint8_t i = 0; i < reads; i++) { current = this->read16_(ATM90E32_REGISTER_IRMS + phase); - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != current) - ESP_LOGW(TAG, "SPI IRMS current register read error."); + this->validate_spi_read_(current, "get_phase_current_avg_()"); accumulation += current; } current = accumulation / reads; @@ -397,8 +365,7 @@ float ATM90E32Component::get_phase_current_avg_(uint8_t phase) { float ATM90E32Component::get_phase_current_(uint8_t phase) { const uint16_t current = this->read16_(ATM90E32_REGISTER_IRMS + phase); - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != current) - ESP_LOGW(TAG, "SPI IRMS current register read error."); + this->validate_spi_read_(current, "get_phase_current_()"); return (float) current / 1000; } @@ -412,11 +379,15 @@ float ATM90E32Component::get_phase_reactive_power_(uint8_t phase) { return val * 0.00032f; } +float ATM90E32Component::get_phase_apparent_power_(uint8_t phase) { + const int val = this->read32_(ATM90E32_REGISTER_SMEAN + phase, ATM90E32_REGISTER_SMEANLSB + phase); + return val * 0.00032f; +} + float ATM90E32Component::get_phase_power_factor_(uint8_t phase) { - const int16_t powerfactor = this->read16_(ATM90E32_REGISTER_PFMEAN + phase); - if (this->read16_(ATM90E32_REGISTER_LASTSPIDATA) != powerfactor) - ESP_LOGW(TAG, "SPI power factor read error."); - return (float) powerfactor / 1000; + uint16_t powerfactor = this->read16_(ATM90E32_REGISTER_PFMEAN + phase); // unsigned to compare to lastspidata + this->validate_spi_read_(powerfactor, "get_phase_power_factor_()"); + return (float) ((int16_t) powerfactor) / 1000; // make it signed again } float ATM90E32Component::get_phase_forward_active_energy_(uint8_t phase) { @@ -426,17 +397,19 @@ float ATM90E32Component::get_phase_forward_active_energy_(uint8_t phase) { } else { this->phase_[phase].cumulative_forward_active_energy_ = val; } - return ((float) this->phase_[phase].cumulative_forward_active_energy_ * 10 / 3200); + // 0.01CF resolution = 0.003125 Wh per count + return ((float) this->phase_[phase].cumulative_forward_active_energy_ * (10.0f / 3200.0f)); } float ATM90E32Component::get_phase_reverse_active_energy_(uint8_t phase) { - const uint16_t val = this->read16_(ATM90E32_REGISTER_ANENERGY); + const uint16_t val = this->read16_(ATM90E32_REGISTER_ANENERGY + phase); if (UINT32_MAX - this->phase_[phase].cumulative_reverse_active_energy_ > val) { this->phase_[phase].cumulative_reverse_active_energy_ += val; } else { this->phase_[phase].cumulative_reverse_active_energy_ = val; } - return ((float) this->phase_[phase].cumulative_reverse_active_energy_ * 10 / 3200); + // 0.01CF resolution = 0.003125 Wh per count + return ((float) this->phase_[phase].cumulative_reverse_active_energy_ * (10.0f / 3200.0f)); } float ATM90E32Component::get_phase_harmonic_active_power_(uint8_t phase) { @@ -446,15 +419,15 @@ float ATM90E32Component::get_phase_harmonic_active_power_(uint8_t phase) { float ATM90E32Component::get_phase_angle_(uint8_t phase) { uint16_t val = this->read16_(ATM90E32_REGISTER_PANGLE + phase) / 10.0; - return (float) (val > 180) ? val - 360.0 : val; + return (val > 180) ? (float) (val - 360.0f) : (float) val; } float ATM90E32Component::get_phase_peak_current_(uint8_t phase) { int16_t val = (float) this->read16_(ATM90E32_REGISTER_IPEAK + phase); if (!this->peak_current_signed_) - val = abs(val); + val = std::abs(val); // phase register * phase current gain value / 1000 * 2^13 - return (float) (val * this->phase_[phase].ct_gain_ / 8192000.0); + return (val * this->phase_[phase].ct_gain_ / 8192000.0); } float ATM90E32Component::get_frequency_() { @@ -467,29 +440,433 @@ float ATM90E32Component::get_chip_temperature_() { return (float) ctemp; } -uint16_t ATM90E32Component::calibrate_voltage_offset_phase(uint8_t phase) { - const uint8_t num_reads = 5; - uint64_t total_value = 0; - for (int i = 0; i < num_reads; ++i) { - const uint32_t measurement_value = read32_(ATM90E32_REGISTER_URMS + phase, ATM90E32_REGISTER_URMSLSB + phase); - total_value += measurement_value; +void ATM90E32Component::run_gain_calibrations() { + if (!this->enable_gain_calibration_) { + ESP_LOGW(TAG, "[CALIBRATION] Gain calibration is disabled! Enable it first with enable_gain_calibration: true"); + return; } - const uint32_t average_value = total_value / num_reads; - const uint32_t shifted_value = average_value >> 7; - const uint32_t voltage_offset = ~shifted_value + 1; - return voltage_offset & 0xFFFF; // Take the lower 16 bits + + float ref_voltages[3] = { + this->get_reference_voltage(0), + this->get_reference_voltage(1), + this->get_reference_voltage(2), + }; + float ref_currents[3] = {this->get_reference_current(0), this->get_reference_current(1), + this->get_reference_current(2)}; + + ESP_LOGI(TAG, "[CALIBRATION] "); + ESP_LOGI(TAG, "[CALIBRATION] ========================= Gain Calibration ========================="); + ESP_LOGI(TAG, "[CALIBRATION] ---------------------------------------------------------------------"); + ESP_LOGI(TAG, + "[CALIBRATION] | Phase | V_meas (V) | I_meas (A) | V_ref | I_ref | V_gain (old→new) | I_gain (old→new) |"); + ESP_LOGI(TAG, "[CALIBRATION] ---------------------------------------------------------------------"); + + for (uint8_t phase = 0; phase < 3; phase++) { + float measured_voltage = this->get_phase_voltage_avg_(phase); + float measured_current = this->get_phase_current_avg_(phase); + + float ref_voltage = ref_voltages[phase]; + float ref_current = ref_currents[phase]; + + uint16_t current_voltage_gain = this->read16_(voltage_gain_registers[phase]); + uint16_t current_current_gain = this->read16_(current_gain_registers[phase]); + + bool did_voltage = false; + bool did_current = false; + + // Voltage calibration + if (ref_voltage <= 0.0f) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Skipping voltage calibration: reference voltage is 0.", + phase_labels[phase]); + } else if (measured_voltage == 0.0f) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Skipping voltage calibration: measured voltage is 0.", + phase_labels[phase]); + } else { + uint32_t new_voltage_gain = static_cast((ref_voltage / measured_voltage) * current_voltage_gain); + if (new_voltage_gain == 0) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Voltage gain would be 0. Check reference and measured voltage.", + phase_labels[phase]); + } else { + if (new_voltage_gain >= 65535) { + ESP_LOGW( + TAG, + "[CALIBRATION] Phase %s - Voltage gain exceeds 65535. You may need a higher output voltage transformer.", + phase_labels[phase]); + new_voltage_gain = 65535; + } + this->gain_phase_[phase].voltage_gain = static_cast(new_voltage_gain); + did_voltage = true; + } + } + + // Current calibration + if (ref_current == 0.0f) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Skipping current calibration: reference current is 0.", + phase_labels[phase]); + } else if (measured_current == 0.0f) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Skipping current calibration: measured current is 0.", + phase_labels[phase]); + } else { + uint32_t new_current_gain = static_cast((ref_current / measured_current) * current_current_gain); + if (new_current_gain == 0) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Current gain would be 0. Check reference and measured current.", + phase_labels[phase]); + } else { + if (new_current_gain >= 65535) { + ESP_LOGW(TAG, "[CALIBRATION] Phase %s - Current gain exceeds 65535. You may need to turn up pga gain.", + phase_labels[phase]); + new_current_gain = 65535; + } + this->gain_phase_[phase].current_gain = static_cast(new_current_gain); + did_current = true; + } + } + + // Final row output + ESP_LOGI(TAG, "[CALIBRATION] | %c | %9.2f | %9.4f | %5.2f | %6.4f | %5u → %-5u | %5u → %-5u |", + 'A' + phase, measured_voltage, measured_current, ref_voltage, ref_current, current_voltage_gain, + did_voltage ? this->gain_phase_[phase].voltage_gain : current_voltage_gain, current_current_gain, + did_current ? this->gain_phase_[phase].current_gain : current_current_gain); + } + + ESP_LOGI(TAG, "[CALIBRATION] =====================================================================\n"); + + this->save_gain_calibration_to_memory_(); + this->write_gains_to_registers_(); + this->verify_gain_writes_(); } -uint16_t ATM90E32Component::calibrate_current_offset_phase(uint8_t phase) { +void ATM90E32Component::save_gain_calibration_to_memory_() { + bool success = this->gain_calibration_pref_.save(&this->gain_phase_); + if (success) { + this->using_saved_calibrations_ = true; + ESP_LOGI(TAG, "[CALIBRATION] Gain calibration saved to memory."); + } else { + this->using_saved_calibrations_ = false; + ESP_LOGE(TAG, "[CALIBRATION] Failed to save gain calibration to memory!"); + } +} + +void ATM90E32Component::run_offset_calibrations() { + if (!this->enable_offset_calibration_) { + ESP_LOGW(TAG, "[CALIBRATION] Offset calibration is disabled! Enable it first with enable_offset_calibration: true"); + return; + } + + for (uint8_t phase = 0; phase < 3; phase++) { + int16_t voltage_offset = calibrate_offset(phase, true); + int16_t current_offset = calibrate_offset(phase, false); + + this->write_offsets_to_registers_(phase, voltage_offset, current_offset); + + ESP_LOGI(TAG, "[CALIBRATION] Phase %c - offset_voltage: %d, offset_current: %d", 'A' + phase, voltage_offset, + current_offset); + } + + this->offset_pref_.save(&this->offset_phase_); // Save to flash +} + +void ATM90E32Component::run_power_offset_calibrations() { + if (!this->enable_offset_calibration_) { + ESP_LOGW( + TAG, + "[CALIBRATION] Offset power calibration is disabled! Enable it first with enable_offset_calibration: true"); + return; + } + + for (uint8_t phase = 0; phase < 3; ++phase) { + int16_t active_offset = calibrate_power_offset(phase, false); + int16_t reactive_offset = calibrate_power_offset(phase, true); + + this->write_power_offsets_to_registers_(phase, active_offset, reactive_offset); + + ESP_LOGI(TAG, "[CALIBRATION] Phase %c - offset_active_power: %d, offset_reactive_power: %d", 'A' + phase, + active_offset, reactive_offset); + } + + this->power_offset_pref_.save(&this->power_offset_phase_); // Save to flash +} + +void ATM90E32Component::write_gains_to_registers_() { + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x55AA); + + for (int phase = 0; phase < 3; phase++) { + this->write16_(voltage_gain_registers[phase], this->gain_phase_[phase].voltage_gain); + this->write16_(current_gain_registers[phase], this->gain_phase_[phase].current_gain); + } + + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x0000); +} + +void ATM90E32Component::write_offsets_to_registers_(uint8_t phase, int16_t voltage_offset, int16_t current_offset) { + // Save to runtime + this->offset_phase_[phase].voltage_offset_ = voltage_offset; + this->phase_[phase].voltage_offset_ = voltage_offset; + + // Save to flash-storable struct + this->offset_phase_[phase].current_offset_ = current_offset; + this->phase_[phase].current_offset_ = current_offset; + + // Write to registers + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x55AA); + this->write16_(voltage_offset_registers[phase], static_cast(voltage_offset)); + this->write16_(current_offset_registers[phase], static_cast(current_offset)); + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x0000); +} + +void ATM90E32Component::write_power_offsets_to_registers_(uint8_t phase, int16_t p_offset, int16_t q_offset) { + // Save to runtime + this->phase_[phase].active_power_offset_ = p_offset; + this->phase_[phase].reactive_power_offset_ = q_offset; + + // Save to flash-storable struct + this->power_offset_phase_[phase].active_power_offset = p_offset; + this->power_offset_phase_[phase].reactive_power_offset = q_offset; + + // Write to registers + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x55AA); + this->write16_(this->power_offset_registers[phase], static_cast(p_offset)); + this->write16_(this->reactive_power_offset_registers[phase], static_cast(q_offset)); + this->write16_(ATM90E32_REGISTER_CFGREGACCEN, 0x0000); +} + +void ATM90E32Component::restore_gain_calibrations_() { + if (this->gain_calibration_pref_.load(&this->gain_phase_)) { + ESP_LOGI(TAG, "[CALIBRATION] Restoring saved gain calibrations to registers:"); + + for (uint8_t phase = 0; phase < 3; phase++) { + uint16_t v_gain = this->gain_phase_[phase].voltage_gain; + uint16_t i_gain = this->gain_phase_[phase].current_gain; + ESP_LOGI(TAG, "[CALIBRATION] Phase %c - Voltage Gain: %u, Current Gain: %u", 'A' + phase, v_gain, i_gain); + } + + this->write_gains_to_registers_(); + + if (this->verify_gain_writes_()) { + this->using_saved_calibrations_ = true; + ESP_LOGI(TAG, "[CALIBRATION] Gain calibration loaded and verified successfully."); + } else { + this->using_saved_calibrations_ = false; + ESP_LOGE(TAG, "[CALIBRATION] Gain verification failed! Calibration may not be applied correctly."); + } + } else { + this->using_saved_calibrations_ = false; + ESP_LOGW(TAG, "[CALIBRATION] No stored gain calibrations found. Using config file values."); + } +} + +void ATM90E32Component::restore_offset_calibrations_() { + if (this->offset_pref_.load(&this->offset_phase_)) { + ESP_LOGI(TAG, "[CALIBRATION] Successfully restored offset calibration from memory."); + + for (uint8_t phase = 0; phase < 3; phase++) { + auto &offset = this->offset_phase_[phase]; + write_offsets_to_registers_(phase, offset.voltage_offset_, offset.current_offset_); + ESP_LOGI(TAG, "[CALIBRATION] Phase %c - offset_voltage:: %d, offset_current: %d", 'A' + phase, + offset.voltage_offset_, offset.current_offset_); + } + } else { + ESP_LOGW(TAG, "[CALIBRATION] No stored offset calibrations found. Using default values."); + } +} + +void ATM90E32Component::restore_power_offset_calibrations_() { + if (this->power_offset_pref_.load(&this->power_offset_phase_)) { + ESP_LOGI(TAG, "[CALIBRATION] Successfully restored power offset calibration from memory."); + + for (uint8_t phase = 0; phase < 3; ++phase) { + auto &offset = this->power_offset_phase_[phase]; + write_power_offsets_to_registers_(phase, offset.active_power_offset, offset.reactive_power_offset); + ESP_LOGI(TAG, "[CALIBRATION] Phase %c - offset_active_power: %d, offset_reactive_power: %d", 'A' + phase, + offset.active_power_offset, offset.reactive_power_offset); + } + } else { + ESP_LOGW(TAG, "[CALIBRATION] No stored power offsets found. Using default values."); + } +} + +void ATM90E32Component::clear_gain_calibrations() { + ESP_LOGI(TAG, "[CALIBRATION] Clearing stored gain calibrations and restoring config-defined values..."); + + for (int phase = 0; phase < 3; phase++) { + gain_phase_[phase].voltage_gain = this->phase_[phase].voltage_gain_; + gain_phase_[phase].current_gain = this->phase_[phase].ct_gain_; + } + + bool success = this->gain_calibration_pref_.save(&this->gain_phase_); + this->using_saved_calibrations_ = false; + + if (success) { + ESP_LOGI(TAG, "[CALIBRATION] Gain calibrations cleared. Config values restored:"); + for (int phase = 0; phase < 3; phase++) { + ESP_LOGI(TAG, "[CALIBRATION] Phase %c - Voltage Gain: %u, Current Gain: %u", 'A' + phase, + gain_phase_[phase].voltage_gain, gain_phase_[phase].current_gain); + } + } else { + ESP_LOGE(TAG, "[CALIBRATION] Failed to clear gain calibrations!"); + } + + this->write_gains_to_registers_(); // Apply them to the chip immediately +} + +void ATM90E32Component::clear_offset_calibrations() { + for (uint8_t phase = 0; phase < 3; phase++) { + this->write_offsets_to_registers_(phase, 0, 0); + } + + this->offset_pref_.save(&this->offset_phase_); // Save cleared values to flash memory + + ESP_LOGI(TAG, "[CALIBRATION] Offsets cleared."); +} + +void ATM90E32Component::clear_power_offset_calibrations() { + for (uint8_t phase = 0; phase < 3; phase++) { + this->write_power_offsets_to_registers_(phase, 0, 0); + } + + this->power_offset_pref_.save(&this->power_offset_phase_); + + ESP_LOGI(TAG, "[CALIBRATION] Power offsets cleared."); +} + +int16_t ATM90E32Component::calibrate_offset(uint8_t phase, bool voltage) { const uint8_t num_reads = 5; uint64_t total_value = 0; - for (int i = 0; i < num_reads; ++i) { - const uint32_t measurement_value = read32_(ATM90E32_REGISTER_IRMS + phase, ATM90E32_REGISTER_IRMSLSB + phase); - total_value += measurement_value; + + for (uint8_t i = 0; i < num_reads; ++i) { + uint32_t reading = voltage ? this->read32_(ATM90E32_REGISTER_URMS + phase, ATM90E32_REGISTER_URMSLSB + phase) + : this->read32_(ATM90E32_REGISTER_IRMS + phase, ATM90E32_REGISTER_IRMSLSB + phase); + total_value += reading; } + const uint32_t average_value = total_value / num_reads; - const uint32_t current_offset = ~average_value + 1; - return current_offset & 0xFFFF; // Take the lower 16 bits + const uint32_t shifted = average_value >> 7; + const uint32_t offset = ~shifted + 1; + return static_cast(offset); // Takes lower 16 bits +} + +int16_t ATM90E32Component::calibrate_power_offset(uint8_t phase, bool reactive) { + const uint8_t num_reads = 5; + uint64_t total_value = 0; + + for (uint8_t i = 0; i < num_reads; ++i) { + uint32_t reading = reactive ? this->read32_(ATM90E32_REGISTER_QMEAN + phase, ATM90E32_REGISTER_QMEANLSB + phase) + : this->read32_(ATM90E32_REGISTER_PMEAN + phase, ATM90E32_REGISTER_PMEANLSB + phase); + total_value += reading; + } + + const uint32_t average_value = total_value / num_reads; + const uint32_t power_offset = ~average_value + 1; + return static_cast(power_offset); // Takes the lower 16 bits +} + +bool ATM90E32Component::verify_gain_writes_() { + bool success = true; + for (uint8_t phase = 0; phase < 3; phase++) { + uint16_t read_voltage = this->read16_(voltage_gain_registers[phase]); + uint16_t read_current = this->read16_(current_gain_registers[phase]); + + if (read_voltage != this->gain_phase_[phase].voltage_gain || + read_current != this->gain_phase_[phase].current_gain) { + ESP_LOGE(TAG, "[CALIBRATION] Mismatch detected for Phase %s!", phase_labels[phase]); + success = false; + } + } + return success; // Return true if all writes were successful, false otherwise +} + +#ifdef USE_TEXT_SENSOR +void ATM90E32Component::check_phase_status() { + uint16_t state0 = this->read16_(ATM90E32_REGISTER_EMMSTATE0); + uint16_t state1 = this->read16_(ATM90E32_REGISTER_EMMSTATE1); + + for (int phase = 0; phase < 3; phase++) { + std::string status; + + if (state0 & over_voltage_flags[phase]) + status += "Over Voltage; "; + if (state1 & voltage_sag_flags[phase]) + status += "Voltage Sag; "; + if (state1 & phase_loss_flags[phase]) + status += "Phase Loss; "; + + auto *sensor = this->phase_status_text_sensor_[phase]; + const char *phase_name = sensor ? sensor->get_name().c_str() : "Unknown Phase"; + if (!status.empty()) { + status.pop_back(); // remove space + status.pop_back(); // remove semicolon + ESP_LOGW(TAG, "%s: %s", phase_name, status.c_str()); + if (sensor != nullptr) + sensor->publish_state(status); + } else { + if (sensor != nullptr) + sensor->publish_state("Okay"); + } + } +} + +void ATM90E32Component::check_freq_status() { + uint16_t state1 = this->read16_(ATM90E32_REGISTER_EMMSTATE1); + + std::string freq_status; + + if (state1 & ATM90E32_STATUS_S1_FREQHIST) { + freq_status = "HIGH"; + } else if (state1 & ATM90E32_STATUS_S1_FREQLOST) { + freq_status = "LOW"; + } else { + freq_status = "Normal"; + } + ESP_LOGW(TAG, "Frequency status: %s", freq_status.c_str()); + + if (this->freq_status_text_sensor_ != nullptr) { + this->freq_status_text_sensor_->publish_state(freq_status); + } +} + +void ATM90E32Component::check_over_current() { + constexpr float max_current_threshold = 65.53f; + + for (uint8_t phase = 0; phase < 3; phase++) { + float current_val = + this->phase_[phase].current_sensor_ != nullptr ? this->phase_[phase].current_sensor_->state : 0.0f; + + if (current_val > max_current_threshold) { + ESP_LOGW(TAG, "Over current detected on Phase %c: %.2f A", 'A' + phase, current_val); + ESP_LOGW(TAG, "You may need to half your gain_ct: value & multiply the current and power values by 2"); + if (this->phase_status_text_sensor_[phase] != nullptr) { + this->phase_status_text_sensor_[phase]->publish_state("Over Current; "); + } + } + } +} +#endif + +uint16_t ATM90E32Component::calculate_voltage_threshold(int line_freq, uint16_t ugain, float multiplier) { + // this assumes that 60Hz electrical systems use 120V mains, + // which is usually, but not always the case + float nominal_voltage = (line_freq == 60) ? 120.0f : 220.0f; + float target_voltage = nominal_voltage * multiplier; + + float peak_01v = target_voltage * 100.0f * std::sqrt(2.0f); // convert RMS → peak, scale to 0.01V + float divider = (2.0f * ugain) / 32768.0f; + + float threshold = peak_01v / divider; + + return static_cast(threshold); +} + +bool ATM90E32Component::validate_spi_read_(uint16_t expected, const char *context) { + uint16_t last = this->read16_(ATM90E32_REGISTER_LASTSPIDATA); + if (last != expected) { + if (context != nullptr) { + ESP_LOGW(TAG, "[%s] SPI read mismatch: expected 0x%04X, got 0x%04X", context, expected, last); + } else { + ESP_LOGW(TAG, "SPI read mismatch: expected 0x%04X, got 0x%04X", expected, last); + } + return false; + } + return true; } } // namespace atm90e32 diff --git a/esphome/components/atm90e32/atm90e32.h b/esphome/components/atm90e32/atm90e32.h index 35c61d1e05..0703c40ae0 100644 --- a/esphome/components/atm90e32/atm90e32.h +++ b/esphome/components/atm90e32/atm90e32.h @@ -1,5 +1,6 @@ #pragma once +#include #include "atm90e32_reg.h" #include "esphome/components/sensor/sensor.h" #include "esphome/components/spi/spi.h" @@ -18,6 +19,26 @@ class ATM90E32Component : public PollingComponent, static const uint8_t PHASEA = 0; static const uint8_t PHASEB = 1; static const uint8_t PHASEC = 2; + const char *phase_labels[3] = {"A", "B", "C"}; + // these registers are not sucessive, so we can't just do 'base + phase' + const uint16_t voltage_gain_registers[3] = {ATM90E32_REGISTER_UGAINA, ATM90E32_REGISTER_UGAINB, + ATM90E32_REGISTER_UGAINC}; + const uint16_t current_gain_registers[3] = {ATM90E32_REGISTER_IGAINA, ATM90E32_REGISTER_IGAINB, + ATM90E32_REGISTER_IGAINC}; + const uint16_t voltage_offset_registers[3] = {ATM90E32_REGISTER_UOFFSETA, ATM90E32_REGISTER_UOFFSETB, + ATM90E32_REGISTER_UOFFSETC}; + const uint16_t current_offset_registers[3] = {ATM90E32_REGISTER_IOFFSETA, ATM90E32_REGISTER_IOFFSETB, + ATM90E32_REGISTER_IOFFSETC}; + const uint16_t power_offset_registers[3] = {ATM90E32_REGISTER_POFFSETA, ATM90E32_REGISTER_POFFSETB, + ATM90E32_REGISTER_POFFSETC}; + const uint16_t reactive_power_offset_registers[3] = {ATM90E32_REGISTER_QOFFSETA, ATM90E32_REGISTER_QOFFSETB, + ATM90E32_REGISTER_QOFFSETC}; + const uint16_t over_voltage_flags[3] = {ATM90E32_STATUS_S0_OVPHASEAST, ATM90E32_STATUS_S0_OVPHASEBST, + ATM90E32_STATUS_S0_OVPHASECST}; + const uint16_t voltage_sag_flags[3] = {ATM90E32_STATUS_S1_SAGPHASEAST, ATM90E32_STATUS_S1_SAGPHASEBST, + ATM90E32_STATUS_S1_SAGPHASECST}; + const uint16_t phase_loss_flags[3] = {ATM90E32_STATUS_S1_PHASELOSSAST, ATM90E32_STATUS_S1_PHASELOSSBST, + ATM90E32_STATUS_S1_PHASELOSSCST}; void loop() override; void setup() override; void dump_config() override; @@ -42,6 +63,14 @@ class ATM90E32Component : public PollingComponent, void set_peak_current_sensor(int phase, sensor::Sensor *obj) { this->phase_[phase].peak_current_sensor_ = obj; } void set_volt_gain(int phase, uint16_t gain) { this->phase_[phase].voltage_gain_ = gain; } void set_ct_gain(int phase, uint16_t gain) { this->phase_[phase].ct_gain_ = gain; } + void set_voltage_offset(uint8_t phase, int16_t offset) { this->offset_phase_[phase].voltage_offset_ = offset; } + void set_current_offset(uint8_t phase, int16_t offset) { this->offset_phase_[phase].current_offset_ = offset; } + void set_active_power_offset(uint8_t phase, int16_t offset) { + this->power_offset_phase_[phase].active_power_offset = offset; + } + void set_reactive_power_offset(uint8_t phase, int16_t offset) { + this->power_offset_phase_[phase].reactive_power_offset = offset; + } void set_freq_sensor(sensor::Sensor *freq_sensor) { freq_sensor_ = freq_sensor; } void set_peak_current_signed(bool flag) { peak_current_signed_ = flag; } void set_chip_temperature_sensor(sensor::Sensor *chip_temperature_sensor) { @@ -51,53 +80,104 @@ class ATM90E32Component : public PollingComponent, void set_current_phases(int phases) { current_phases_ = phases; } void set_pga_gain(uint16_t gain) { pga_gain_ = gain; } void run_offset_calibrations(); + void run_power_offset_calibrations(); void clear_offset_calibrations(); + void clear_power_offset_calibrations(); + void clear_gain_calibrations(); void set_enable_offset_calibration(bool flag) { enable_offset_calibration_ = flag; } - uint16_t calibrate_voltage_offset_phase(uint8_t /*phase*/); - uint16_t calibrate_current_offset_phase(uint8_t /*phase*/); + void set_enable_gain_calibration(bool flag) { enable_gain_calibration_ = flag; } + int16_t calibrate_offset(uint8_t phase, bool voltage); + int16_t calibrate_power_offset(uint8_t phase, bool reactive); + void run_gain_calibrations(); +#ifdef USE_NUMBER + void set_reference_voltage(uint8_t phase, number::Number *ref_voltage) { ref_voltages_[phase] = ref_voltage; } + void set_reference_current(uint8_t phase, number::Number *ref_current) { ref_currents_[phase] = ref_current; } +#endif + float get_reference_voltage(uint8_t phase) { +#ifdef USE_NUMBER + return (phase >= 0 && phase < 3 && ref_voltages_[phase]) ? ref_voltages_[phase]->state : 120.0; // Default voltage +#else + return 120.0; // Default voltage +#endif + } + float get_reference_current(uint8_t phase) { +#ifdef USE_NUMBER + return (phase >= 0 && phase < 3 && ref_currents_[phase]) ? ref_currents_[phase]->state : 5.0f; // Default current +#else + return 5.0f; // Default current +#endif + } + bool using_saved_calibrations_ = false; // Track if stored calibrations are being used +#ifdef USE_TEXT_SENSOR + void check_phase_status(); + void check_freq_status(); + void check_over_current(); + void set_phase_status_text_sensor(uint8_t phase, text_sensor::TextSensor *sensor) { + this->phase_status_text_sensor_[phase] = sensor; + } + void set_freq_status_text_sensor(text_sensor::TextSensor *sensor) { this->freq_status_text_sensor_ = sensor; } +#endif + uint16_t calculate_voltage_threshold(int line_freq, uint16_t ugain, float multiplier); int32_t last_periodic_millis = millis(); protected: +#ifdef USE_NUMBER + number::Number *ref_voltages_[3]{nullptr, nullptr, nullptr}; + number::Number *ref_currents_[3]{nullptr, nullptr, nullptr}; +#endif uint16_t read16_(uint16_t a_register); int read32_(uint16_t addr_h, uint16_t addr_l); void write16_(uint16_t a_register, uint16_t val); - float get_local_phase_voltage_(uint8_t /*phase*/); - float get_local_phase_current_(uint8_t /*phase*/); - float get_local_phase_active_power_(uint8_t /*phase*/); - float get_local_phase_reactive_power_(uint8_t /*phase*/); - float get_local_phase_power_factor_(uint8_t /*phase*/); - float get_local_phase_forward_active_energy_(uint8_t /*phase*/); - float get_local_phase_reverse_active_energy_(uint8_t /*phase*/); - float get_local_phase_angle_(uint8_t /*phase*/); - float get_local_phase_harmonic_active_power_(uint8_t /*phase*/); - float get_local_phase_peak_current_(uint8_t /*phase*/); - float get_phase_voltage_(uint8_t /*phase*/); - float get_phase_voltage_avg_(uint8_t /*phase*/); - float get_phase_current_(uint8_t /*phase*/); - float get_phase_current_avg_(uint8_t /*phase*/); - float get_phase_active_power_(uint8_t /*phase*/); - float get_phase_reactive_power_(uint8_t /*phase*/); - float get_phase_power_factor_(uint8_t /*phase*/); - float get_phase_forward_active_energy_(uint8_t /*phase*/); - float get_phase_reverse_active_energy_(uint8_t /*phase*/); - float get_phase_angle_(uint8_t /*phase*/); - float get_phase_harmonic_active_power_(uint8_t /*phase*/); - float get_phase_peak_current_(uint8_t /*phase*/); + float get_local_phase_voltage_(uint8_t phase); + float get_local_phase_current_(uint8_t phase); + float get_local_phase_active_power_(uint8_t phase); + float get_local_phase_reactive_power_(uint8_t phase); + float get_local_phase_apparent_power_(uint8_t phase); + float get_local_phase_power_factor_(uint8_t phase); + float get_local_phase_forward_active_energy_(uint8_t phase); + float get_local_phase_reverse_active_energy_(uint8_t phase); + float get_local_phase_angle_(uint8_t phase); + float get_local_phase_harmonic_active_power_(uint8_t phase); + float get_local_phase_peak_current_(uint8_t phase); + float get_phase_voltage_(uint8_t phase); + float get_phase_voltage_avg_(uint8_t phase); + float get_phase_current_(uint8_t phase); + float get_phase_current_avg_(uint8_t phase); + float get_phase_active_power_(uint8_t phase); + float get_phase_reactive_power_(uint8_t phase); + float get_phase_apparent_power_(uint8_t phase); + float get_phase_power_factor_(uint8_t phase); + float get_phase_forward_active_energy_(uint8_t phase); + float get_phase_reverse_active_energy_(uint8_t phase); + float get_phase_angle_(uint8_t phase); + float get_phase_harmonic_active_power_(uint8_t phase); + float get_phase_peak_current_(uint8_t phase); float get_frequency_(); float get_chip_temperature_(); bool get_publish_interval_flag_() { return publish_interval_flag_; }; void set_publish_interval_flag_(bool flag) { publish_interval_flag_ = flag; }; - void restore_calibrations_(); + void restore_offset_calibrations_(); + void restore_power_offset_calibrations_(); + void restore_gain_calibrations_(); + void save_gain_calibration_to_memory_(); + void write_offsets_to_registers_(uint8_t phase, int16_t voltage_offset, int16_t current_offset); + void write_power_offsets_to_registers_(uint8_t phase, int16_t p_offset, int16_t q_offset); + void write_gains_to_registers_(); + bool verify_gain_writes_(); + bool validate_spi_read_(uint16_t expected, const char *context = nullptr); struct ATM90E32Phase { uint16_t voltage_gain_{0}; uint16_t ct_gain_{0}; - uint16_t voltage_offset_{0}; - uint16_t current_offset_{0}; + int16_t voltage_offset_{0}; + int16_t current_offset_{0}; + int16_t active_power_offset_{0}; + int16_t reactive_power_offset_{0}; float voltage_{0}; float current_{0}; float active_power_{0}; float reactive_power_{0}; + float apparent_power_{0}; float power_factor_{0}; float forward_active_energy_{0}; float reverse_active_energy_{0}; @@ -119,14 +199,30 @@ class ATM90E32Component : public PollingComponent, uint32_t cumulative_reverse_active_energy_{0}; } phase_[3]; - struct Calibration { - uint16_t voltage_offset_{0}; - uint16_t current_offset_{0}; + struct OffsetCalibration { + int16_t voltage_offset_{0}; + int16_t current_offset_{0}; } offset_phase_[3]; - ESPPreferenceObject pref_; + struct PowerOffsetCalibration { + int16_t active_power_offset{0}; + int16_t reactive_power_offset{0}; + } power_offset_phase_[3]; + + struct GainCalibration { + uint16_t voltage_gain{1}; + uint16_t current_gain{1}; + } gain_phase_[3]; + + ESPPreferenceObject offset_pref_; + ESPPreferenceObject power_offset_pref_; + ESPPreferenceObject gain_calibration_pref_; sensor::Sensor *freq_sensor_{nullptr}; +#ifdef USE_TEXT_SENSOR + text_sensor::TextSensor *phase_status_text_sensor_[3]{nullptr}; + text_sensor::TextSensor *freq_status_text_sensor_{nullptr}; +#endif sensor::Sensor *chip_temperature_sensor_{nullptr}; uint16_t pga_gain_{0x15}; int line_freq_{60}; @@ -134,6 +230,7 @@ class ATM90E32Component : public PollingComponent, bool publish_interval_flag_{false}; bool peak_current_signed_{false}; bool enable_offset_calibration_{false}; + bool enable_gain_calibration_{false}; }; } // namespace atm90e32 diff --git a/esphome/components/atm90e32/atm90e32_reg.h b/esphome/components/atm90e32/atm90e32_reg.h index 954fb42e79..86c2d64569 100644 --- a/esphome/components/atm90e32/atm90e32_reg.h +++ b/esphome/components/atm90e32/atm90e32_reg.h @@ -176,16 +176,17 @@ static const uint16_t ATM90E32_REGISTER_ANENERGYCH = 0xAF; // C Reverse Harm. E /* POWER & P.F. REGISTERS */ static const uint16_t ATM90E32_REGISTER_PMEANT = 0xB0; // Total Mean Power (P) -static const uint16_t ATM90E32_REGISTER_PMEAN = 0xB1; // Mean Power Reg Base (P) +static const uint16_t ATM90E32_REGISTER_PMEAN = 0xB1; // Active Power Reg Base (P) static const uint16_t ATM90E32_REGISTER_PMEANA = 0xB1; // A Mean Power (P) static const uint16_t ATM90E32_REGISTER_PMEANB = 0xB2; // B Mean Power (P) static const uint16_t ATM90E32_REGISTER_PMEANC = 0xB3; // C Mean Power (P) static const uint16_t ATM90E32_REGISTER_QMEANT = 0xB4; // Total Mean Power (Q) -static const uint16_t ATM90E32_REGISTER_QMEAN = 0xB5; // Mean Power Reg Base (Q) +static const uint16_t ATM90E32_REGISTER_QMEAN = 0xB5; // Reactive Power Reg Base (Q) static const uint16_t ATM90E32_REGISTER_QMEANA = 0xB5; // A Mean Power (Q) static const uint16_t ATM90E32_REGISTER_QMEANB = 0xB6; // B Mean Power (Q) static const uint16_t ATM90E32_REGISTER_QMEANC = 0xB7; // C Mean Power (Q) static const uint16_t ATM90E32_REGISTER_SMEANT = 0xB8; // Total Mean Power (S) +static const uint16_t ATM90E32_REGISTER_SMEAN = 0xB9; // Apparent Mean Power Base (S) static const uint16_t ATM90E32_REGISTER_SMEANA = 0xB9; // A Mean Power (S) static const uint16_t ATM90E32_REGISTER_SMEANB = 0xBA; // B Mean Power (S) static const uint16_t ATM90E32_REGISTER_SMEANC = 0xBB; // C Mean Power (S) @@ -206,6 +207,7 @@ static const uint16_t ATM90E32_REGISTER_QMEANALSB = 0xC5; // Lower Word (A Rea static const uint16_t ATM90E32_REGISTER_QMEANBLSB = 0xC6; // Lower Word (B React. Power) static const uint16_t ATM90E32_REGISTER_QMEANCLSB = 0xC7; // Lower Word (C React. Power) static const uint16_t ATM90E32_REGISTER_SAMEANTLSB = 0xC8; // Lower Word (Tot. App. Power) +static const uint16_t ATM90E32_REGISTER_SMEANLSB = 0xC9; // Lower Word Reg Base (Apparent Power) static const uint16_t ATM90E32_REGISTER_SMEANALSB = 0xC9; // Lower Word (A App. Power) static const uint16_t ATM90E32_REGISTER_SMEANBLSB = 0xCA; // Lower Word (B App. Power) static const uint16_t ATM90E32_REGISTER_SMEANCLSB = 0xCB; // Lower Word (C App. Power) diff --git a/esphome/components/atm90e32/button/__init__.py b/esphome/components/atm90e32/button/__init__.py index 931346b386..19f62ccfbd 100644 --- a/esphome/components/atm90e32/button/__init__.py +++ b/esphome/components/atm90e32/button/__init__.py @@ -1,43 +1,95 @@ import esphome.codegen as cg from esphome.components import button import esphome.config_validation as cv -from esphome.const import CONF_ID, ENTITY_CATEGORY_CONFIG, ICON_CHIP, ICON_SCALE +from esphome.const import CONF_ID, ENTITY_CATEGORY_CONFIG, ICON_SCALE from .. import atm90e32_ns from ..sensor import ATM90E32Component +CONF_RUN_GAIN_CALIBRATION = "run_gain_calibration" +CONF_CLEAR_GAIN_CALIBRATION = "clear_gain_calibration" CONF_RUN_OFFSET_CALIBRATION = "run_offset_calibration" CONF_CLEAR_OFFSET_CALIBRATION = "clear_offset_calibration" +CONF_RUN_POWER_OFFSET_CALIBRATION = "run_power_offset_calibration" +CONF_CLEAR_POWER_OFFSET_CALIBRATION = "clear_power_offset_calibration" -ATM90E32CalibrationButton = atm90e32_ns.class_( - "ATM90E32CalibrationButton", - button.Button, +ATM90E32GainCalibrationButton = atm90e32_ns.class_( + "ATM90E32GainCalibrationButton", button.Button ) -ATM90E32ClearCalibrationButton = atm90e32_ns.class_( - "ATM90E32ClearCalibrationButton", - button.Button, +ATM90E32ClearGainCalibrationButton = atm90e32_ns.class_( + "ATM90E32ClearGainCalibrationButton", button.Button +) +ATM90E32OffsetCalibrationButton = atm90e32_ns.class_( + "ATM90E32OffsetCalibrationButton", button.Button +) +ATM90E32ClearOffsetCalibrationButton = atm90e32_ns.class_( + "ATM90E32ClearOffsetCalibrationButton", button.Button +) +ATM90E32PowerOffsetCalibrationButton = atm90e32_ns.class_( + "ATM90E32PowerOffsetCalibrationButton", button.Button +) +ATM90E32ClearPowerOffsetCalibrationButton = atm90e32_ns.class_( + "ATM90E32ClearPowerOffsetCalibrationButton", button.Button ) CONFIG_SCHEMA = { cv.GenerateID(CONF_ID): cv.use_id(ATM90E32Component), + cv.Optional(CONF_RUN_GAIN_CALIBRATION): button.button_schema( + ATM90E32GainCalibrationButton, + entity_category=ENTITY_CATEGORY_CONFIG, + icon="mdi:scale-balance", + ), + cv.Optional(CONF_CLEAR_GAIN_CALIBRATION): button.button_schema( + ATM90E32ClearGainCalibrationButton, + entity_category=ENTITY_CATEGORY_CONFIG, + icon="mdi:delete", + ), cv.Optional(CONF_RUN_OFFSET_CALIBRATION): button.button_schema( - ATM90E32CalibrationButton, + ATM90E32OffsetCalibrationButton, entity_category=ENTITY_CATEGORY_CONFIG, icon=ICON_SCALE, ), cv.Optional(CONF_CLEAR_OFFSET_CALIBRATION): button.button_schema( - ATM90E32ClearCalibrationButton, + ATM90E32ClearOffsetCalibrationButton, entity_category=ENTITY_CATEGORY_CONFIG, - icon=ICON_CHIP, + icon="mdi:delete", + ), + cv.Optional(CONF_RUN_POWER_OFFSET_CALIBRATION): button.button_schema( + ATM90E32PowerOffsetCalibrationButton, + entity_category=ENTITY_CATEGORY_CONFIG, + icon=ICON_SCALE, + ), + cv.Optional(CONF_CLEAR_POWER_OFFSET_CALIBRATION): button.button_schema( + ATM90E32ClearPowerOffsetCalibrationButton, + entity_category=ENTITY_CATEGORY_CONFIG, + icon="mdi:delete", ), } async def to_code(config): parent = await cg.get_variable(config[CONF_ID]) + + if run_gain := config.get(CONF_RUN_GAIN_CALIBRATION): + b = await button.new_button(run_gain) + await cg.register_parented(b, parent) + + if clear_gain := config.get(CONF_CLEAR_GAIN_CALIBRATION): + b = await button.new_button(clear_gain) + await cg.register_parented(b, parent) + if run_offset := config.get(CONF_RUN_OFFSET_CALIBRATION): b = await button.new_button(run_offset) await cg.register_parented(b, parent) + if clear_offset := config.get(CONF_CLEAR_OFFSET_CALIBRATION): b = await button.new_button(clear_offset) await cg.register_parented(b, parent) + + if run_power := config.get(CONF_RUN_POWER_OFFSET_CALIBRATION): + b = await button.new_button(run_power) + await cg.register_parented(b, parent) + + if clear_power := config.get(CONF_CLEAR_POWER_OFFSET_CALIBRATION): + b = await button.new_button(clear_power) + await cg.register_parented(b, parent) diff --git a/esphome/components/atm90e32/button/atm90e32_button.cpp b/esphome/components/atm90e32/button/atm90e32_button.cpp index 00715b61dd..a89f071997 100644 --- a/esphome/components/atm90e32/button/atm90e32_button.cpp +++ b/esphome/components/atm90e32/button/atm90e32_button.cpp @@ -1,4 +1,5 @@ #include "atm90e32_button.h" +#include "esphome/core/component.h" #include "esphome/core/log.h" namespace esphome { @@ -6,15 +7,73 @@ namespace atm90e32 { static const char *const TAG = "atm90e32.button"; -void ATM90E32CalibrationButton::press_action() { - ESP_LOGI(TAG, "Running offset calibrations, Note: CTs and ACVs must be 0 during this process..."); +void ATM90E32GainCalibrationButton::press_action() { + if (this->parent_ == nullptr) { + ESP_LOGW(TAG, "[CALIBRATION] No meters assigned to Gain Calibration button [%s]", this->get_name().c_str()); + return; + } + + ESP_LOGI(TAG, "%s", this->get_name().c_str()); + ESP_LOGI(TAG, + "[CALIBRATION] Use gain_ct: & gain_voltage: under each phase_x: in your config file to save these values"); + this->parent_->run_gain_calibrations(); +} + +void ATM90E32ClearGainCalibrationButton::press_action() { + if (this->parent_ == nullptr) { + ESP_LOGW(TAG, "[CALIBRATION] No meters assigned to Clear Gain button [%s]", this->get_name().c_str()); + return; + } + + ESP_LOGI(TAG, "%s", this->get_name().c_str()); + this->parent_->clear_gain_calibrations(); +} + +void ATM90E32OffsetCalibrationButton::press_action() { + if (this->parent_ == nullptr) { + ESP_LOGW(TAG, "[CALIBRATION] No meters assigned to Offset Calibration button [%s]", this->get_name().c_str()); + return; + } + + ESP_LOGI(TAG, "%s", this->get_name().c_str()); + ESP_LOGI(TAG, "[CALIBRATION] **NOTE: CTs and ACVs must be 0 during this process. USB power only**"); + ESP_LOGI(TAG, "[CALIBRATION] Use offset_voltage: & offset_current: under each phase_x: in your config file to save " + "these values"); this->parent_->run_offset_calibrations(); } -void ATM90E32ClearCalibrationButton::press_action() { - ESP_LOGI(TAG, "Offset calibrations cleared."); +void ATM90E32ClearOffsetCalibrationButton::press_action() { + if (this->parent_ == nullptr) { + ESP_LOGW(TAG, "[CALIBRATION] No meters assigned to Clear Offset button [%s]", this->get_name().c_str()); + return; + } + + ESP_LOGI(TAG, "%s", this->get_name().c_str()); this->parent_->clear_offset_calibrations(); } +void ATM90E32PowerOffsetCalibrationButton::press_action() { + if (this->parent_ == nullptr) { + ESP_LOGW(TAG, "[CALIBRATION] No meters assigned to Power Calibration button [%s]", this->get_name().c_str()); + return; + } + + ESP_LOGI(TAG, "%s", this->get_name().c_str()); + ESP_LOGI(TAG, "[CALIBRATION] **NOTE: CTs must be 0 during this process. Voltage reference should be present**"); + ESP_LOGI(TAG, "[CALIBRATION] Use offset_active_power: & offset_reactive_power: under each phase_x: in your config " + "file to save these values"); + this->parent_->run_power_offset_calibrations(); +} + +void ATM90E32ClearPowerOffsetCalibrationButton::press_action() { + if (this->parent_ == nullptr) { + ESP_LOGW(TAG, "[CALIBRATION] No meters assigned to Clear Power button [%s]", this->get_name().c_str()); + return; + } + + ESP_LOGI(TAG, "%s", this->get_name().c_str()); + this->parent_->clear_power_offset_calibrations(); +} + } // namespace atm90e32 } // namespace esphome diff --git a/esphome/components/atm90e32/button/atm90e32_button.h b/esphome/components/atm90e32/button/atm90e32_button.h index 0617099457..2449581531 100644 --- a/esphome/components/atm90e32/button/atm90e32_button.h +++ b/esphome/components/atm90e32/button/atm90e32_button.h @@ -7,17 +7,49 @@ namespace esphome { namespace atm90e32 { -class ATM90E32CalibrationButton : public button::Button, public Parented { +class ATM90E32GainCalibrationButton : public button::Button, public Parented { public: - ATM90E32CalibrationButton() = default; + ATM90E32GainCalibrationButton() = default; protected: void press_action() override; }; -class ATM90E32ClearCalibrationButton : public button::Button, public Parented { +class ATM90E32ClearGainCalibrationButton : public button::Button, public Parented { public: - ATM90E32ClearCalibrationButton() = default; + ATM90E32ClearGainCalibrationButton() = default; + + protected: + void press_action() override; +}; + +class ATM90E32OffsetCalibrationButton : public button::Button, public Parented { + public: + ATM90E32OffsetCalibrationButton() = default; + + protected: + void press_action() override; +}; + +class ATM90E32ClearOffsetCalibrationButton : public button::Button, public Parented { + public: + ATM90E32ClearOffsetCalibrationButton() = default; + + protected: + void press_action() override; +}; + +class ATM90E32PowerOffsetCalibrationButton : public button::Button, public Parented { + public: + ATM90E32PowerOffsetCalibrationButton() = default; + + protected: + void press_action() override; +}; + +class ATM90E32ClearPowerOffsetCalibrationButton : public button::Button, public Parented { + public: + ATM90E32ClearPowerOffsetCalibrationButton() = default; protected: void press_action() override; diff --git a/esphome/components/atm90e32/number/__init__.py b/esphome/components/atm90e32/number/__init__.py new file mode 100644 index 0000000000..848680b875 --- /dev/null +++ b/esphome/components/atm90e32/number/__init__.py @@ -0,0 +1,130 @@ +import esphome.codegen as cg +from esphome.components import number +import esphome.config_validation as cv +from esphome.const import ( + CONF_ID, + CONF_MAX_VALUE, + CONF_MIN_VALUE, + CONF_MODE, + CONF_PHASE_A, + CONF_PHASE_B, + CONF_PHASE_C, + CONF_REFERENCE_VOLTAGE, + CONF_STEP, + ENTITY_CATEGORY_CONFIG, + UNIT_AMPERE, + UNIT_VOLT, +) + +from .. import atm90e32_ns +from ..sensor import ATM90E32Component + +ATM90E32Number = atm90e32_ns.class_( + "ATM90E32Number", number.Number, cg.Parented.template(ATM90E32Component) +) + +CONF_REFERENCE_CURRENT = "reference_current" +PHASE_KEYS = [CONF_PHASE_A, CONF_PHASE_B, CONF_PHASE_C] + + +REFERENCE_VOLTAGE_PHASE_SCHEMA = cv.All( + cv.Schema( + { + cv.Optional(CONF_MODE, default="box"): cv.string, + cv.Optional(CONF_MIN_VALUE, default=100.0): cv.float_, + cv.Optional(CONF_MAX_VALUE, default=260.0): cv.float_, + cv.Optional(CONF_STEP, default=0.1): cv.float_, + } + ).extend( + number.number_schema( + class_=ATM90E32Number, + unit_of_measurement=UNIT_VOLT, + entity_category=ENTITY_CATEGORY_CONFIG, + icon="mdi:power-plug", + ) + ) +) + + +REFERENCE_CURRENT_PHASE_SCHEMA = cv.All( + cv.Schema( + { + cv.Optional(CONF_MODE, default="box"): cv.string, + cv.Optional(CONF_MIN_VALUE, default=1.0): cv.float_, + cv.Optional(CONF_MAX_VALUE, default=200.0): cv.float_, + cv.Optional(CONF_STEP, default=0.1): cv.float_, + } + ).extend( + number.number_schema( + class_=ATM90E32Number, + unit_of_measurement=UNIT_AMPERE, + entity_category=ENTITY_CATEGORY_CONFIG, + icon="mdi:home-lightning-bolt", + ) + ) +) + + +REFERENCE_VOLTAGE_SCHEMA = cv.Schema( + { + cv.Optional(CONF_PHASE_A): REFERENCE_VOLTAGE_PHASE_SCHEMA, + cv.Optional(CONF_PHASE_B): REFERENCE_VOLTAGE_PHASE_SCHEMA, + cv.Optional(CONF_PHASE_C): REFERENCE_VOLTAGE_PHASE_SCHEMA, + } +) + +REFERENCE_CURRENT_SCHEMA = cv.Schema( + { + cv.Optional(CONF_PHASE_A): REFERENCE_CURRENT_PHASE_SCHEMA, + cv.Optional(CONF_PHASE_B): REFERENCE_CURRENT_PHASE_SCHEMA, + cv.Optional(CONF_PHASE_C): REFERENCE_CURRENT_PHASE_SCHEMA, + } +) + +CONFIG_SCHEMA = cv.Schema( + { + cv.GenerateID(CONF_ID): cv.use_id(ATM90E32Component), + cv.Optional(CONF_REFERENCE_VOLTAGE): REFERENCE_VOLTAGE_SCHEMA, + cv.Optional(CONF_REFERENCE_CURRENT): REFERENCE_CURRENT_SCHEMA, + } +) + + +async def to_code(config): + parent = await cg.get_variable(config[CONF_ID]) + + if voltage_cfg := config.get(CONF_REFERENCE_VOLTAGE): + voltage_objs = [None, None, None] + + for i, key in enumerate(PHASE_KEYS): + if validated := voltage_cfg.get(key): + obj = await number.new_number( + validated, + min_value=validated["min_value"], + max_value=validated["max_value"], + step=validated["step"], + ) + await cg.register_parented(obj, parent) + voltage_objs[i] = obj + + # Inherit from A → B/C if only A defined + if voltage_objs[0] is not None: + for i in range(3): + if voltage_objs[i] is None: + voltage_objs[i] = voltage_objs[0] + + for i, obj in enumerate(voltage_objs): + if obj is not None: + cg.add(parent.set_reference_voltage(i, obj)) + + if current_cfg := config.get(CONF_REFERENCE_CURRENT): + for i, key in enumerate(PHASE_KEYS): + if validated := current_cfg.get(key): + obj = await number.new_number( + validated, + min_value=validated["min_value"], + max_value=validated["max_value"], + step=validated["step"], + ) + await cg.register_parented(obj, parent) + cg.add(parent.set_reference_current(i, obj)) diff --git a/esphome/components/atm90e32/number/atm90e32_number.h b/esphome/components/atm90e32/number/atm90e32_number.h new file mode 100644 index 0000000000..9b6129b26d --- /dev/null +++ b/esphome/components/atm90e32/number/atm90e32_number.h @@ -0,0 +1,16 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/components/atm90e32/atm90e32.h" +#include "esphome/components/number/number.h" + +namespace esphome { +namespace atm90e32 { + +class ATM90E32Number : public number::Number, public Parented { + public: + void control(float value) override { this->publish_state(value); } +}; + +} // namespace atm90e32 +} // namespace esphome diff --git a/esphome/components/atm90e32/sensor.py b/esphome/components/atm90e32/sensor.py index 0dc3bfdc4f..7cdbd69f56 100644 --- a/esphome/components/atm90e32/sensor.py +++ b/esphome/components/atm90e32/sensor.py @@ -33,6 +33,7 @@ from esphome.const import ( UNIT_DEGREES, UNIT_HERTZ, UNIT_VOLT, + UNIT_VOLT_AMPS, UNIT_VOLT_AMPS_REACTIVE, UNIT_WATT, UNIT_WATT_HOURS, @@ -45,10 +46,17 @@ CONF_GAIN_PGA = "gain_pga" CONF_CURRENT_PHASES = "current_phases" CONF_GAIN_VOLTAGE = "gain_voltage" CONF_GAIN_CT = "gain_ct" +CONF_OFFSET_VOLTAGE = "offset_voltage" +CONF_OFFSET_CURRENT = "offset_current" +CONF_OFFSET_ACTIVE_POWER = "offset_active_power" +CONF_OFFSET_REACTIVE_POWER = "offset_reactive_power" CONF_HARMONIC_POWER = "harmonic_power" CONF_PEAK_CURRENT = "peak_current" CONF_PEAK_CURRENT_SIGNED = "peak_current_signed" CONF_ENABLE_OFFSET_CALIBRATION = "enable_offset_calibration" +CONF_ENABLE_GAIN_CALIBRATION = "enable_gain_calibration" +CONF_PHASE_STATUS = "phase_status" +CONF_FREQUENCY_STATUS = "frequency_status" UNIT_DEG = "degrees" LINE_FREQS = { "50HZ": 50, @@ -92,10 +100,11 @@ ATM90E32_PHASE_SCHEMA = cv.Schema( unit_of_measurement=UNIT_VOLT_AMPS_REACTIVE, icon=ICON_LIGHTBULB, accuracy_decimals=2, + device_class=DEVICE_CLASS_POWER, state_class=STATE_CLASS_MEASUREMENT, ), cv.Optional(CONF_APPARENT_POWER): sensor.sensor_schema( - unit_of_measurement=UNIT_WATT, + unit_of_measurement=UNIT_VOLT_AMPS, accuracy_decimals=2, device_class=DEVICE_CLASS_POWER, state_class=STATE_CLASS_MEASUREMENT, @@ -137,6 +146,10 @@ ATM90E32_PHASE_SCHEMA = cv.Schema( ), cv.Optional(CONF_GAIN_VOLTAGE, default=7305): cv.uint16_t, cv.Optional(CONF_GAIN_CT, default=27961): cv.uint16_t, + cv.Optional(CONF_OFFSET_VOLTAGE, default=0): cv.int_, + cv.Optional(CONF_OFFSET_CURRENT, default=0): cv.int_, + cv.Optional(CONF_OFFSET_ACTIVE_POWER, default=0): cv.int_, + cv.Optional(CONF_OFFSET_REACTIVE_POWER, default=0): cv.int_, } ) @@ -164,9 +177,10 @@ CONFIG_SCHEMA = ( cv.Optional(CONF_CURRENT_PHASES, default="3"): cv.enum( CURRENT_PHASES, upper=True ), - cv.Optional(CONF_GAIN_PGA, default="2X"): cv.enum(PGA_GAINS, upper=True), + cv.Optional(CONF_GAIN_PGA, default="1X"): cv.enum(PGA_GAINS, upper=True), cv.Optional(CONF_PEAK_CURRENT_SIGNED, default=False): cv.boolean, cv.Optional(CONF_ENABLE_OFFSET_CALIBRATION, default=False): cv.boolean, + cv.Optional(CONF_ENABLE_GAIN_CALIBRATION, default=False): cv.boolean, } ) .extend(cv.polling_component_schema("60s")) @@ -185,6 +199,10 @@ async def to_code(config): conf = config[phase] cg.add(var.set_volt_gain(i, conf[CONF_GAIN_VOLTAGE])) cg.add(var.set_ct_gain(i, conf[CONF_GAIN_CT])) + cg.add(var.set_voltage_offset(i, conf[CONF_OFFSET_VOLTAGE])) + cg.add(var.set_current_offset(i, conf[CONF_OFFSET_CURRENT])) + cg.add(var.set_active_power_offset(i, conf[CONF_OFFSET_ACTIVE_POWER])) + cg.add(var.set_reactive_power_offset(i, conf[CONF_OFFSET_REACTIVE_POWER])) if voltage_config := conf.get(CONF_VOLTAGE): sens = await sensor.new_sensor(voltage_config) cg.add(var.set_voltage_sensor(i, sens)) @@ -218,16 +236,15 @@ async def to_code(config): if peak_current_config := conf.get(CONF_PEAK_CURRENT): sens = await sensor.new_sensor(peak_current_config) cg.add(var.set_peak_current_sensor(i, sens)) - if frequency_config := config.get(CONF_FREQUENCY): sens = await sensor.new_sensor(frequency_config) cg.add(var.set_freq_sensor(sens)) if chip_temperature_config := config.get(CONF_CHIP_TEMPERATURE): sens = await sensor.new_sensor(chip_temperature_config) cg.add(var.set_chip_temperature_sensor(sens)) - cg.add(var.set_line_freq(config[CONF_LINE_FREQUENCY])) cg.add(var.set_current_phases(config[CONF_CURRENT_PHASES])) cg.add(var.set_pga_gain(config[CONF_GAIN_PGA])) cg.add(var.set_peak_current_signed(config[CONF_PEAK_CURRENT_SIGNED])) cg.add(var.set_enable_offset_calibration(config[CONF_ENABLE_OFFSET_CALIBRATION])) + cg.add(var.set_enable_gain_calibration(config[CONF_ENABLE_GAIN_CALIBRATION])) diff --git a/esphome/components/atm90e32/text_sensor/__init__.py b/esphome/components/atm90e32/text_sensor/__init__.py new file mode 100644 index 0000000000..ab96f6c207 --- /dev/null +++ b/esphome/components/atm90e32/text_sensor/__init__.py @@ -0,0 +1,48 @@ +import esphome.codegen as cg +from esphome.components import text_sensor +import esphome.config_validation as cv +from esphome.const import CONF_ID, CONF_PHASE_A, CONF_PHASE_B, CONF_PHASE_C + +from ..sensor import ATM90E32Component + +CONF_PHASE_STATUS = "phase_status" +CONF_FREQUENCY_STATUS = "frequency_status" +PHASE_KEYS = [CONF_PHASE_A, CONF_PHASE_B, CONF_PHASE_C] + +PHASE_STATUS_SCHEMA = cv.Schema( + { + cv.Optional(CONF_PHASE_A): text_sensor.text_sensor_schema( + icon="mdi:flash-alert" + ), + cv.Optional(CONF_PHASE_B): text_sensor.text_sensor_schema( + icon="mdi:flash-alert" + ), + cv.Optional(CONF_PHASE_C): text_sensor.text_sensor_schema( + icon="mdi:flash-alert" + ), + } +) + +CONFIG_SCHEMA = cv.Schema( + { + cv.GenerateID(): cv.use_id(ATM90E32Component), + cv.Optional(CONF_PHASE_STATUS): PHASE_STATUS_SCHEMA, + cv.Optional(CONF_FREQUENCY_STATUS): text_sensor.text_sensor_schema( + icon="mdi:lightbulb-alert" + ), + } +) + + +async def to_code(config): + parent = await cg.get_variable(config[CONF_ID]) + + if phase_cfg := config.get(CONF_PHASE_STATUS): + for i, key in enumerate(PHASE_KEYS): + if sub_phase_cfg := phase_cfg.get(key): + sens = await text_sensor.new_text_sensor(sub_phase_cfg) + cg.add(parent.set_phase_status_text_sensor(i, sens)) + + if freq_status_config := config.get(CONF_FREQUENCY_STATUS): + sens = await text_sensor.new_text_sensor(freq_status_config) + cg.add(parent.set_freq_status_text_sensor(sens)) diff --git a/esphome/components/audio/__init__.py b/esphome/components/audio/__init__.py index f8ec8cbd85..f657cb5da3 100644 --- a/esphome/components/audio/__init__.py +++ b/esphome/components/audio/__init__.py @@ -37,29 +37,32 @@ AUDIO_COMPONENT_SCHEMA = cv.Schema( ) -_UNDEF = object() - - def set_stream_limits( - min_bits_per_sample: int = _UNDEF, - max_bits_per_sample: int = _UNDEF, - min_channels: int = _UNDEF, - max_channels: int = _UNDEF, - min_sample_rate: int = _UNDEF, - max_sample_rate: int = _UNDEF, + min_bits_per_sample: int = cv.UNDEFINED, + max_bits_per_sample: int = cv.UNDEFINED, + min_channels: int = cv.UNDEFINED, + max_channels: int = cv.UNDEFINED, + min_sample_rate: int = cv.UNDEFINED, + max_sample_rate: int = cv.UNDEFINED, ): + """Sets the limits for the audio stream that audio component can handle + + When the component sinks audio (e.g., a speaker), these indicate the limits to the audio it can receive. + When the component sources audio (e.g., a microphone), these indicate the limits to the audio it can send. + """ + def set_limits_in_config(config): - if min_bits_per_sample is not _UNDEF: + if min_bits_per_sample is not cv.UNDEFINED: config[CONF_MIN_BITS_PER_SAMPLE] = min_bits_per_sample - if max_bits_per_sample is not _UNDEF: + if max_bits_per_sample is not cv.UNDEFINED: config[CONF_MAX_BITS_PER_SAMPLE] = max_bits_per_sample - if min_channels is not _UNDEF: + if min_channels is not cv.UNDEFINED: config[CONF_MIN_CHANNELS] = min_channels - if max_channels is not _UNDEF: + if max_channels is not cv.UNDEFINED: config[CONF_MAX_CHANNELS] = max_channels - if min_sample_rate is not _UNDEF: + if min_sample_rate is not cv.UNDEFINED: config[CONF_MIN_SAMPLE_RATE] = min_sample_rate - if max_sample_rate is not _UNDEF: + if max_sample_rate is not cv.UNDEFINED: config[CONF_MAX_SAMPLE_RATE] = max_sample_rate return set_limits_in_config @@ -69,43 +72,87 @@ def final_validate_audio_schema( name: str, *, audio_device: str, - bits_per_sample: int, - channels: int, - sample_rate: int, + bits_per_sample: int = cv.UNDEFINED, + channels: int = cv.UNDEFINED, + sample_rate: int = cv.UNDEFINED, + enabled_channels: list[int] = cv.UNDEFINED, + audio_device_issue: bool = False, ): + """Validates audio compatibility when passed between different components. + + The component derived from ``AUDIO_COMPONENT_SCHEMA`` should call ``set_stream_limits`` in a validator to specify its compatible settings + + - If audio_device_issue is True, then the error message indicates the user should adjust the AUDIO_COMPONENT_SCHEMA derived component's configuration to match the values passed to this function + - If audio_device_issue is False, then the error message indicates the user should adjust the configuration of the component calling this function, as it falls out of the valid stream limits + + Args: + name (str): Friendly name of the component calling this function with an audio component to validate + audio_device (str): The configuration parameter name that contains the ID of an AUDIO_COMPONENT_SCHEMA derived component to validate against + bits_per_sample (int, optional): The desired bits per sample + channels (int, optional): The desired number of channels + sample_rate (int, optional): The desired sample rate + enabled_channels (list[int], optional): The desired enabled channels + audio_device_issue (bool, optional): Format the error message to indicate the problem is in the configuration for the ``audio_device`` component. Defaults to False. + """ + def validate_audio_compatiblity(audio_config): audio_schema = {} - try: - cv.int_range( - min=audio_config.get(CONF_MIN_BITS_PER_SAMPLE), - max=audio_config.get(CONF_MAX_BITS_PER_SAMPLE), - )(bits_per_sample) - except cv.Invalid as exc: - raise cv.Invalid( - f"Invalid configuration for the {name} component. The {CONF_BITS_PER_SAMPLE} {str(exc)}" - ) from exc + if bits_per_sample is not cv.UNDEFINED: + try: + cv.int_range( + min=audio_config.get(CONF_MIN_BITS_PER_SAMPLE), + max=audio_config.get(CONF_MAX_BITS_PER_SAMPLE), + )(bits_per_sample) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires {bits_per_sample} bits per sample." + else: + error_string = f"Invalid configuration for the {name} component. The {CONF_BITS_PER_SAMPLE} {str(exc)}" + raise cv.Invalid(error_string) from exc - try: - cv.int_range( - min=audio_config.get(CONF_MIN_CHANNELS), - max=audio_config.get(CONF_MAX_CHANNELS), - )(channels) - except cv.Invalid as exc: - raise cv.Invalid( - f"Invalid configuration for the {name} component. The {CONF_NUM_CHANNELS} {str(exc)}" - ) from exc + if channels is not cv.UNDEFINED: + try: + cv.int_range( + min=audio_config.get(CONF_MIN_CHANNELS), + max=audio_config.get(CONF_MAX_CHANNELS), + )(channels) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires {channels} channels." + else: + error_string = f"Invalid configuration for the {name} component. The {CONF_NUM_CHANNELS} {str(exc)}" + raise cv.Invalid(error_string) from exc - try: - cv.int_range( - min=audio_config.get(CONF_MIN_SAMPLE_RATE), - max=audio_config.get(CONF_MAX_SAMPLE_RATE), - )(sample_rate) - return cv.Schema(audio_schema, extra=cv.ALLOW_EXTRA)(audio_config) - except cv.Invalid as exc: - raise cv.Invalid( - f"Invalid configuration for the {name} component. The {CONF_SAMPLE_RATE} {str(exc)}" - ) from exc + if sample_rate is not cv.UNDEFINED: + try: + cv.int_range( + min=audio_config.get(CONF_MIN_SAMPLE_RATE), + max=audio_config.get(CONF_MAX_SAMPLE_RATE), + )(sample_rate) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires a {sample_rate} sample rate." + else: + error_string = f"Invalid configuration for the {name} component. The {CONF_SAMPLE_RATE} {str(exc)}" + raise cv.Invalid(error_string) from exc + + if enabled_channels is not cv.UNDEFINED: + for channel in enabled_channels: + try: + # Channels are 0-indexed + cv.int_range( + min=0, + max=audio_config.get(CONF_MAX_CHANNELS) - 1, + )(channel) + except cv.Invalid as exc: + if audio_device_issue: + error_string = f"Invalid configuration for the specified {audio_device}. The {name} component requires channel {channel}." + else: + error_string = f"Invalid configuration for the {name} component. Enabled channel {channel} {str(exc)}" + raise cv.Invalid(error_string) from exc + + return cv.Schema(audio_schema, extra=cv.ALLOW_EXTRA)(audio_config) return cv.Schema( { @@ -118,4 +165,4 @@ def final_validate_audio_schema( async def to_code(config): - cg.add_library("esphome/esp-audio-libs", "1.1.3") + cg.add_library("esphome/esp-audio-libs", "1.1.4") diff --git a/esphome/components/audio/audio.h b/esphome/components/audio/audio.h index 6f0f1aaa46..95c31872e3 100644 --- a/esphome/components/audio/audio.h +++ b/esphome/components/audio/audio.h @@ -135,5 +135,53 @@ const char *audio_file_type_to_string(AudioFileType file_type); void scale_audio_samples(const int16_t *audio_samples, int16_t *output_buffer, int16_t scale_factor, size_t samples_to_scale); +/// @brief Unpacks a quantized audio sample into a Q31 fixed-point number. +/// @param data Pointer to uint8_t array containing the audio sample +/// @param bytes_per_sample The number of bytes per sample +/// @return Q31 sample +inline int32_t unpack_audio_sample_to_q31(const uint8_t *data, size_t bytes_per_sample) { + int32_t sample = 0; + if (bytes_per_sample == 1) { + sample |= data[0] << 24; + } else if (bytes_per_sample == 2) { + sample |= data[0] << 16; + sample |= data[1] << 24; + } else if (bytes_per_sample == 3) { + sample |= data[0] << 8; + sample |= data[1] << 16; + sample |= data[2] << 24; + } else if (bytes_per_sample == 4) { + sample |= data[0]; + sample |= data[1] << 8; + sample |= data[2] << 16; + sample |= data[3] << 24; + } + + return sample; +} + +/// @brief Packs a Q31 fixed-point number as an audio sample with the specified number of bytes per sample. +/// Packs the most significant bits - no dithering is applied. +/// @param sample Q31 fixed-point number to pack +/// @param data Pointer to data array to store +/// @param bytes_per_sample The audio data's bytes per sample +inline void pack_q31_as_audio_sample(int32_t sample, uint8_t *data, size_t bytes_per_sample) { + if (bytes_per_sample == 1) { + data[0] = static_cast(sample >> 24); + } else if (bytes_per_sample == 2) { + data[0] = static_cast(sample >> 16); + data[1] = static_cast(sample >> 24); + } else if (bytes_per_sample == 3) { + data[0] = static_cast(sample >> 8); + data[1] = static_cast(sample >> 16); + data[2] = static_cast(sample >> 24); + } else if (bytes_per_sample == 4) { + data[0] = static_cast(sample); + data[1] = static_cast(sample >> 8); + data[2] = static_cast(sample >> 16); + data[3] = static_cast(sample >> 24); + } +} + } // namespace audio } // namespace esphome diff --git a/esphome/components/audio/audio_decoder.cpp b/esphome/components/audio/audio_decoder.cpp index 60489d7d78..c74b028c4b 100644 --- a/esphome/components/audio/audio_decoder.cpp +++ b/esphome/components/audio/audio_decoder.cpp @@ -171,7 +171,7 @@ AudioDecoderState AudioDecoder::decode(bool stop_gracefully) { bytes_available_before_processing = this->input_transfer_buffer_->available(); - if ((this->potentially_failed_count_ > 10) && (bytes_read == 0)) { + if ((this->potentially_failed_count_ > 0) && (bytes_read == 0)) { // Failed to decode in last attempt and there is no new data if ((this->input_transfer_buffer_->free() == 0) && first_loop_iteration) { diff --git a/esphome/components/audio/audio_resampler.cpp b/esphome/components/audio/audio_resampler.cpp index a7621225a1..20d246f1e0 100644 --- a/esphome/components/audio/audio_resampler.cpp +++ b/esphome/components/audio/audio_resampler.cpp @@ -4,6 +4,8 @@ #include "esphome/core/hal.h" +#include + namespace esphome { namespace audio { diff --git a/esphome/components/audio/audio_resampler.h b/esphome/components/audio/audio_resampler.h index 7f4e987b4c..082ade3371 100644 --- a/esphome/components/audio/audio_resampler.h +++ b/esphome/components/audio/audio_resampler.h @@ -6,6 +6,7 @@ #include "audio_transfer_buffer.h" #include "esphome/core/defines.h" +#include "esphome/core/helpers.h" #include "esphome/core/ring_buffer.h" #ifdef USE_SPEAKER diff --git a/esphome/components/binary_sensor/__init__.py b/esphome/components/binary_sensor/__init__.py index d947c2aba6..448323da5a 100644 --- a/esphome/components/binary_sensor/__init__.py +++ b/esphome/components/binary_sensor/__init__.py @@ -386,7 +386,7 @@ def validate_click_timing(value): return value -BINARY_SENSOR_SCHEMA = ( +_BINARY_SENSOR_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMPONENT_SCHEMA) .extend( @@ -458,19 +458,17 @@ BINARY_SENSOR_SCHEMA = ( ) ) -_UNDEF = object() - def binary_sensor_schema( - class_: MockObjClass = _UNDEF, + class_: MockObjClass = cv.UNDEFINED, *, - icon: str = _UNDEF, - entity_category: str = _UNDEF, - device_class: str = _UNDEF, + icon: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, ) -> cv.Schema: schema = {} - if class_ is not _UNDEF: + if class_ is not cv.UNDEFINED: # Not cv.optional schema[cv.GenerateID()] = cv.declare_id(class_) @@ -479,10 +477,15 @@ def binary_sensor_schema( (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), (CONF_DEVICE_CLASS, device_class, validate_device_class), ]: - if default is not _UNDEF: + if default is not cv.UNDEFINED: schema[cv.Optional(key, default=default)] = validator - return BINARY_SENSOR_SCHEMA.extend(schema) + return _BINARY_SENSOR_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +BINARY_SENSOR_SCHEMA = binary_sensor_schema() +BINARY_SENSOR_SCHEMA.add_extra(cv.deprecated_schema_constant("binary_sensor")) async def setup_binary_sensor_core_(var, config): diff --git a/esphome/components/binary_sensor/binary_sensor.cpp b/esphome/components/binary_sensor/binary_sensor.cpp index 20604a0b7e..30fbe4f0b4 100644 --- a/esphome/components/binary_sensor/binary_sensor.cpp +++ b/esphome/components/binary_sensor/binary_sensor.cpp @@ -15,21 +15,17 @@ void BinarySensor::publish_state(bool state) { if (!this->publish_dedup_.next(state)) return; if (this->filter_list_ == nullptr) { - this->send_state_internal(state, false); + this->send_state_internal(state); } else { - this->filter_list_->input(state, false); + this->filter_list_->input(state); } } void BinarySensor::publish_initial_state(bool state) { - if (!this->publish_dedup_.next(state)) - return; - if (this->filter_list_ == nullptr) { - this->send_state_internal(state, true); - } else { - this->filter_list_->input(state, true); - } + this->has_state_ = false; + this->publish_state(state); } -void BinarySensor::send_state_internal(bool state, bool is_initial) { +void BinarySensor::send_state_internal(bool state) { + bool is_initial = !this->has_state_; if (is_initial) { ESP_LOGD(TAG, "'%s': Sending initial state %s", this->get_name().c_str(), ONOFF(state)); } else { diff --git a/esphome/components/binary_sensor/binary_sensor.h b/esphome/components/binary_sensor/binary_sensor.h index 57cae9e2f5..9ba7aeeeff 100644 --- a/esphome/components/binary_sensor/binary_sensor.h +++ b/esphome/components/binary_sensor/binary_sensor.h @@ -67,7 +67,7 @@ class BinarySensor : public EntityBase, public EntityBase_DeviceClass { // ========== INTERNAL METHODS ========== // (In most use cases you won't need these) - void send_state_internal(bool state, bool is_initial); + void send_state_internal(bool state); /// Return whether this binary sensor has outputted a state. virtual bool has_state() const; diff --git a/esphome/components/binary_sensor/filter.cpp b/esphome/components/binary_sensor/filter.cpp index 8f94b108ac..fd6cc31008 100644 --- a/esphome/components/binary_sensor/filter.cpp +++ b/esphome/components/binary_sensor/filter.cpp @@ -9,37 +9,37 @@ namespace binary_sensor { static const char *const TAG = "sensor.filter"; -void Filter::output(bool value, bool is_initial) { +void Filter::output(bool value) { if (!this->dedup_.next(value)) return; if (this->next_ == nullptr) { - this->parent_->send_state_internal(value, is_initial); + this->parent_->send_state_internal(value); } else { - this->next_->input(value, is_initial); + this->next_->input(value); } } -void Filter::input(bool value, bool is_initial) { - auto b = this->new_value(value, is_initial); +void Filter::input(bool value) { + auto b = this->new_value(value); if (b.has_value()) { - this->output(*b, is_initial); + this->output(*b); } } -optional DelayedOnOffFilter::new_value(bool value, bool is_initial) { +optional DelayedOnOffFilter::new_value(bool value) { if (value) { - this->set_timeout("ON_OFF", this->on_delay_.value(), [this, is_initial]() { this->output(true, is_initial); }); + this->set_timeout("ON_OFF", this->on_delay_.value(), [this]() { this->output(true); }); } else { - this->set_timeout("ON_OFF", this->off_delay_.value(), [this, is_initial]() { this->output(false, is_initial); }); + this->set_timeout("ON_OFF", this->off_delay_.value(), [this]() { this->output(false); }); } return {}; } float DelayedOnOffFilter::get_setup_priority() const { return setup_priority::HARDWARE; } -optional DelayedOnFilter::new_value(bool value, bool is_initial) { +optional DelayedOnFilter::new_value(bool value) { if (value) { - this->set_timeout("ON", this->delay_.value(), [this, is_initial]() { this->output(true, is_initial); }); + this->set_timeout("ON", this->delay_.value(), [this]() { this->output(true); }); return {}; } else { this->cancel_timeout("ON"); @@ -49,9 +49,9 @@ optional DelayedOnFilter::new_value(bool value, bool is_initial) { float DelayedOnFilter::get_setup_priority() const { return setup_priority::HARDWARE; } -optional DelayedOffFilter::new_value(bool value, bool is_initial) { +optional DelayedOffFilter::new_value(bool value) { if (!value) { - this->set_timeout("OFF", this->delay_.value(), [this, is_initial]() { this->output(false, is_initial); }); + this->set_timeout("OFF", this->delay_.value(), [this]() { this->output(false); }); return {}; } else { this->cancel_timeout("OFF"); @@ -61,11 +61,11 @@ optional DelayedOffFilter::new_value(bool value, bool is_initial) { float DelayedOffFilter::get_setup_priority() const { return setup_priority::HARDWARE; } -optional InvertFilter::new_value(bool value, bool is_initial) { return !value; } +optional InvertFilter::new_value(bool value) { return !value; } AutorepeatFilter::AutorepeatFilter(std::vector timings) : timings_(std::move(timings)) {} -optional AutorepeatFilter::new_value(bool value, bool is_initial) { +optional AutorepeatFilter::new_value(bool value) { if (value) { // Ignore if already running if (this->active_timing_ != 0) @@ -101,7 +101,7 @@ void AutorepeatFilter::next_timing_() { void AutorepeatFilter::next_value_(bool val) { const AutorepeatFilterTiming &timing = this->timings_[this->active_timing_ - 2]; - this->output(val, false); // This is at least the second one so not initial + this->output(val); this->set_timeout("ON_OFF", val ? timing.time_on : timing.time_off, [this, val]() { this->next_value_(!val); }); } @@ -109,18 +109,18 @@ float AutorepeatFilter::get_setup_priority() const { return setup_priority::HARD LambdaFilter::LambdaFilter(std::function(bool)> f) : f_(std::move(f)) {} -optional LambdaFilter::new_value(bool value, bool is_initial) { return this->f_(value); } +optional LambdaFilter::new_value(bool value) { return this->f_(value); } -optional SettleFilter::new_value(bool value, bool is_initial) { +optional SettleFilter::new_value(bool value) { if (!this->steady_) { - this->set_timeout("SETTLE", this->delay_.value(), [this, value, is_initial]() { + this->set_timeout("SETTLE", this->delay_.value(), [this, value]() { this->steady_ = true; - this->output(value, is_initial); + this->output(value); }); return {}; } else { this->steady_ = false; - this->output(value, is_initial); + this->output(value); this->set_timeout("SETTLE", this->delay_.value(), [this]() { this->steady_ = true; }); return value; } diff --git a/esphome/components/binary_sensor/filter.h b/esphome/components/binary_sensor/filter.h index f7342db2fb..65838da49d 100644 --- a/esphome/components/binary_sensor/filter.h +++ b/esphome/components/binary_sensor/filter.h @@ -14,11 +14,11 @@ class BinarySensor; class Filter { public: - virtual optional new_value(bool value, bool is_initial) = 0; + virtual optional new_value(bool value) = 0; - void input(bool value, bool is_initial); + void input(bool value); - void output(bool value, bool is_initial); + void output(bool value); protected: friend BinarySensor; @@ -30,7 +30,7 @@ class Filter { class DelayedOnOffFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -44,7 +44,7 @@ class DelayedOnOffFilter : public Filter, public Component { class DelayedOnFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -56,7 +56,7 @@ class DelayedOnFilter : public Filter, public Component { class DelayedOffFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -68,7 +68,7 @@ class DelayedOffFilter : public Filter, public Component { class InvertFilter : public Filter { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; }; struct AutorepeatFilterTiming { @@ -86,7 +86,7 @@ class AutorepeatFilter : public Filter, public Component { public: explicit AutorepeatFilter(std::vector timings); - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; @@ -102,7 +102,7 @@ class LambdaFilter : public Filter { public: explicit LambdaFilter(std::function(bool)> f); - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; protected: std::function(bool)> f_; @@ -110,7 +110,7 @@ class LambdaFilter : public Filter { class SettleFilter : public Filter, public Component { public: - optional new_value(bool value, bool is_initial) override; + optional new_value(bool value) override; float get_setup_priority() const override; diff --git a/esphome/components/bl0906/constants.h b/esphome/components/bl0906/constants.h index 546916aa3c..a174e54bb2 100644 --- a/esphome/components/bl0906/constants.h +++ b/esphome/components/bl0906/constants.h @@ -45,7 +45,7 @@ static const uint8_t BL0906_WRITE_COMMAND = 0xCA; static const uint8_t BL0906_V_RMS = 0x16; // Total power -static const uint8_t BL0906_WATT_SUM = 0X2C; +static const uint8_t BL0906_WATT_SUM = 0x2C; // Current1~6 static const uint8_t BL0906_I_1_RMS = 0x0D; // current_1 @@ -56,29 +56,29 @@ static const uint8_t BL0906_I_5_RMS = 0x13; static const uint8_t BL0906_I_6_RMS = 0x14; // current_6 // Power1~6 -static const uint8_t BL0906_WATT_1 = 0X23; // power_1 -static const uint8_t BL0906_WATT_2 = 0X24; -static const uint8_t BL0906_WATT_3 = 0X25; -static const uint8_t BL0906_WATT_4 = 0X26; -static const uint8_t BL0906_WATT_5 = 0X29; -static const uint8_t BL0906_WATT_6 = 0X2A; // power_6 +static const uint8_t BL0906_WATT_1 = 0x23; // power_1 +static const uint8_t BL0906_WATT_2 = 0x24; +static const uint8_t BL0906_WATT_3 = 0x25; +static const uint8_t BL0906_WATT_4 = 0x26; +static const uint8_t BL0906_WATT_5 = 0x29; +static const uint8_t BL0906_WATT_6 = 0x2A; // power_6 // Active pulse count, unsigned -static const uint8_t BL0906_CF_1_CNT = 0X30; // Channel_1 -static const uint8_t BL0906_CF_2_CNT = 0X31; -static const uint8_t BL0906_CF_3_CNT = 0X32; -static const uint8_t BL0906_CF_4_CNT = 0X33; -static const uint8_t BL0906_CF_5_CNT = 0X36; -static const uint8_t BL0906_CF_6_CNT = 0X37; // Channel_6 +static const uint8_t BL0906_CF_1_CNT = 0x30; // Channel_1 +static const uint8_t BL0906_CF_2_CNT = 0x31; +static const uint8_t BL0906_CF_3_CNT = 0x32; +static const uint8_t BL0906_CF_4_CNT = 0x33; +static const uint8_t BL0906_CF_5_CNT = 0x36; +static const uint8_t BL0906_CF_6_CNT = 0x37; // Channel_6 // Total active pulse count, unsigned -static const uint8_t BL0906_CF_SUM_CNT = 0X39; +static const uint8_t BL0906_CF_SUM_CNT = 0x39; // Voltage frequency cycle -static const uint8_t BL0906_FREQUENCY = 0X4E; +static const uint8_t BL0906_FREQUENCY = 0x4E; // Internal temperature -static const uint8_t BL0906_TEMPERATURE = 0X5E; +static const uint8_t BL0906_TEMPERATURE = 0x5E; // Calibration register // RMS gain adjustment register diff --git a/esphome/components/ble_client/text_sensor/__init__.py b/esphome/components/ble_client/text_sensor/__init__.py index afa60f6c0c..a6b8956f93 100644 --- a/esphome/components/ble_client/text_sensor/__init__.py +++ b/esphome/components/ble_client/text_sensor/__init__.py @@ -4,7 +4,6 @@ from esphome.components import ble_client, esp32_ble_tracker, text_sensor import esphome.config_validation as cv from esphome.const import ( CONF_CHARACTERISTIC_UUID, - CONF_ID, CONF_NOTIFY, CONF_SERVICE_UUID, CONF_TRIGGER_ID, @@ -32,9 +31,9 @@ BLETextSensorNotifyTrigger = ble_client_ns.class_( ) CONFIG_SCHEMA = cv.All( - text_sensor.TEXT_SENSOR_SCHEMA.extend( + text_sensor.text_sensor_schema(BLETextSensor) + .extend( { - cv.GenerateID(): cv.declare_id(BLETextSensor), cv.Required(CONF_SERVICE_UUID): esp32_ble_tracker.bt_uuid, cv.Required(CONF_CHARACTERISTIC_UUID): esp32_ble_tracker.bt_uuid, cv.Optional(CONF_DESCRIPTOR_UUID): esp32_ble_tracker.bt_uuid, @@ -54,7 +53,7 @@ CONFIG_SCHEMA = cv.All( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await text_sensor.new_text_sensor(config) if len(config[CONF_SERVICE_UUID]) == len(esp32_ble_tracker.bt_uuid16_format): cg.add( var.set_service_uuid16(esp32_ble_tracker.as_hex(config[CONF_SERVICE_UUID])) @@ -101,7 +100,6 @@ async def to_code(config): await cg.register_component(var, config) await ble_client.register_ble_node(var, config) cg.add(var.set_enable_notify(config[CONF_NOTIFY])) - await text_sensor.register_text_sensor(var, config) for conf in config.get(CONF_ON_NOTIFY, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await ble_client.register_ble_node(trigger, config) diff --git a/esphome/components/bluetooth_proxy/bluetooth_connection.cpp b/esphome/components/bluetooth_proxy/bluetooth_connection.cpp index b63f7ccde9..3c5c2bd438 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_connection.cpp +++ b/esphome/components/bluetooth_proxy/bluetooth_connection.cpp @@ -73,9 +73,8 @@ bool BluetoothConnection::gattc_event_handler(esp_gattc_cb_event_t event, esp_ga resp.address = this->address_; resp.handle = param->read.handle; resp.data.reserve(param->read.value_len); - for (uint16_t i = 0; i < param->read.value_len; i++) { - resp.data.push_back(param->read.value[i]); - } + // Use bulk insert instead of individual push_backs + resp.data.insert(resp.data.end(), param->read.value, param->read.value + param->read.value_len); this->proxy_->get_api_connection()->send_bluetooth_gatt_read_response(resp); break; } @@ -127,9 +126,8 @@ bool BluetoothConnection::gattc_event_handler(esp_gattc_cb_event_t event, esp_ga resp.address = this->address_; resp.handle = param->notify.handle; resp.data.reserve(param->notify.value_len); - for (uint16_t i = 0; i < param->notify.value_len; i++) { - resp.data.push_back(param->notify.value[i]); - } + // Use bulk insert instead of individual push_backs + resp.data.insert(resp.data.end(), param->notify.value, param->notify.value + param->notify.value_len); this->proxy_->get_api_connection()->send_bluetooth_gatt_notify_data_response(resp); break; } diff --git a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp index 03213432cd..9c8bd4009f 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp +++ b/esphome/components/bluetooth_proxy/bluetooth_proxy.cpp @@ -25,6 +25,22 @@ std::vector get_128bit_uuid_vec(esp_bt_uuid_t uuid_source) { BluetoothProxy::BluetoothProxy() { global_bluetooth_proxy = this; } +void BluetoothProxy::setup() { + this->parent_->add_scanner_state_callback([this](esp32_ble_tracker::ScannerState state) { + if (this->api_connection_ != nullptr) { + this->send_bluetooth_scanner_state_(state); + } + }); +} + +void BluetoothProxy::send_bluetooth_scanner_state_(esp32_ble_tracker::ScannerState state) { + api::BluetoothScannerStateResponse resp; + resp.state = static_cast(state); + resp.mode = this->parent_->get_scan_active() ? api::enums::BluetoothScannerMode::BLUETOOTH_SCANNER_MODE_ACTIVE + : api::enums::BluetoothScannerMode::BLUETOOTH_SCANNER_MODE_PASSIVE; + this->api_connection_->send_bluetooth_scanner_state_response(resp); +} + bool BluetoothProxy::parse_device(const esp32_ble_tracker::ESPBTDevice &device) { if (!api::global_api_server->is_connected() || this->api_connection_ == nullptr || this->raw_advertisements_) return false; @@ -40,6 +56,9 @@ bool BluetoothProxy::parse_devices(esp_ble_gap_cb_param_t::ble_scan_result_evt_p return false; api::BluetoothLERawAdvertisementsResponse resp; + // Pre-allocate the advertisements vector to avoid reallocations + resp.advertisements.reserve(count); + for (size_t i = 0; i < count; i++) { auto &result = advertisements[i]; api::BluetoothLERawAdvertisement adv; @@ -49,9 +68,8 @@ bool BluetoothProxy::parse_devices(esp_ble_gap_cb_param_t::ble_scan_result_evt_p uint8_t length = result.adv_data_len + result.scan_rsp_len; adv.data.reserve(length); - for (uint16_t i = 0; i < length; i++) { - adv.data.push_back(result.ble_adv[i]); - } + // Use a bulk insert instead of individual push_backs + adv.data.insert(adv.data.end(), &result.ble_adv[0], &result.ble_adv[length]); resp.advertisements.push_back(std::move(adv)); @@ -69,21 +87,34 @@ void BluetoothProxy::send_api_packet_(const esp32_ble_tracker::ESPBTDevice &devi if (!device.get_name().empty()) resp.name = device.get_name(); resp.rssi = device.get_rssi(); - for (auto uuid : device.get_service_uuids()) { + + // Pre-allocate vectors based on known sizes + auto service_uuids = device.get_service_uuids(); + resp.service_uuids.reserve(service_uuids.size()); + for (auto uuid : service_uuids) { resp.service_uuids.push_back(uuid.to_string()); } - for (auto &data : device.get_service_datas()) { + + // Pre-allocate service data vector + auto service_datas = device.get_service_datas(); + resp.service_data.reserve(service_datas.size()); + for (auto &data : service_datas) { api::BluetoothServiceData service_data; service_data.uuid = data.uuid.to_string(); service_data.data.assign(data.data.begin(), data.data.end()); resp.service_data.push_back(std::move(service_data)); } - for (auto &data : device.get_manufacturer_datas()) { + + // Pre-allocate manufacturer data vector + auto manufacturer_datas = device.get_manufacturer_datas(); + resp.manufacturer_data.reserve(manufacturer_datas.size()); + for (auto &data : manufacturer_datas) { api::BluetoothServiceData manufacturer_data; manufacturer_data.uuid = data.uuid.to_string(); manufacturer_data.data.assign(data.data.begin(), data.data.end()); resp.manufacturer_data.push_back(std::move(manufacturer_data)); } + this->api_connection_->send_bluetooth_le_advertisement(resp); } @@ -145,11 +176,27 @@ void BluetoothProxy::loop() { } api::BluetoothGATTGetServicesResponse resp; resp.address = connection->get_address(); + resp.services.reserve(1); // Always one service per response in this implementation api::BluetoothGATTService service_resp; service_resp.uuid = get_128bit_uuid_vec(service_result.uuid); service_resp.handle = service_result.start_handle; uint16_t char_offset = 0; esp_gattc_char_elem_t char_result; + // Get the number of characteristics directly with one call + uint16_t total_char_count = 0; + esp_gatt_status_t char_count_status = esp_ble_gattc_get_attr_count( + connection->get_gattc_if(), connection->get_conn_id(), ESP_GATT_DB_CHARACTERISTIC, + service_result.start_handle, service_result.end_handle, 0, &total_char_count); + + if (char_count_status == ESP_GATT_OK && total_char_count > 0) { + // Only reserve if we successfully got a count + service_resp.characteristics.reserve(total_char_count); + } else if (char_count_status != ESP_GATT_OK) { + ESP_LOGW(TAG, "[%d] [%s] Error getting characteristic count, status=%d", connection->get_connection_index(), + connection->address_str().c_str(), char_count_status); + } + + // Now process characteristics while (true) { // characteristics uint16_t char_count = 1; esp_gatt_status_t char_status = esp_ble_gattc_get_all_char( @@ -171,6 +218,23 @@ void BluetoothProxy::loop() { characteristic_resp.handle = char_result.char_handle; characteristic_resp.properties = char_result.properties; char_offset++; + + // Get the number of descriptors directly with one call + uint16_t total_desc_count = 0; + esp_gatt_status_t desc_count_status = + esp_ble_gattc_get_attr_count(connection->get_gattc_if(), connection->get_conn_id(), ESP_GATT_DB_DESCRIPTOR, + char_result.char_handle, service_result.end_handle, 0, &total_desc_count); + + if (desc_count_status == ESP_GATT_OK && total_desc_count > 0) { + // Only reserve if we successfully got a count + characteristic_resp.descriptors.reserve(total_desc_count); + } else if (desc_count_status != ESP_GATT_OK) { + ESP_LOGW(TAG, "[%d] [%s] Error getting descriptor count for char handle %d, status=%d", + connection->get_connection_index(), connection->address_str().c_str(), char_result.char_handle, + desc_count_status); + } + + // Now process descriptors uint16_t desc_offset = 0; esp_gattc_descr_elem_t desc_result; while (true) { // descriptors @@ -453,6 +517,8 @@ void BluetoothProxy::subscribe_api_connection(api::APIConnection *api_connection this->api_connection_ = api_connection; this->raw_advertisements_ = flags & BluetoothProxySubscriptionFlag::SUBSCRIPTION_RAW_ADVERTISEMENTS; this->parent_->recalculate_advertisement_parser_types(); + + this->send_bluetooth_scanner_state_(this->parent_->get_scanner_state()); } void BluetoothProxy::unsubscribe_api_connection(api::APIConnection *api_connection) { @@ -525,6 +591,17 @@ void BluetoothProxy::send_device_unpairing(uint64_t address, bool success, esp_e this->api_connection_->send_bluetooth_device_unpairing_response(call); } +void BluetoothProxy::bluetooth_scanner_set_mode(bool active) { + if (this->parent_->get_scan_active() == active) { + return; + } + ESP_LOGD(TAG, "Setting scanner mode to %s", active ? "active" : "passive"); + this->parent_->set_scan_active(active); + this->parent_->stop_scan(); + this->parent_->set_scan_continuous( + true); // Set this to true to automatically start scanning again when it has cleaned up. +} + BluetoothProxy *global_bluetooth_proxy = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) } // namespace bluetooth_proxy diff --git a/esphome/components/bluetooth_proxy/bluetooth_proxy.h b/esphome/components/bluetooth_proxy/bluetooth_proxy.h index e0345ff248..de24165fe8 100644 --- a/esphome/components/bluetooth_proxy/bluetooth_proxy.h +++ b/esphome/components/bluetooth_proxy/bluetooth_proxy.h @@ -41,6 +41,7 @@ enum BluetoothProxyFeature : uint32_t { FEATURE_PAIRING = 1 << 3, FEATURE_CACHE_CLEARING = 1 << 4, FEATURE_RAW_ADVERTISEMENTS = 1 << 5, + FEATURE_STATE_AND_MODE = 1 << 6, }; enum BluetoothProxySubscriptionFlag : uint32_t { @@ -53,6 +54,7 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com bool parse_device(const esp32_ble_tracker::ESPBTDevice &device) override; bool parse_devices(esp_ble_gap_cb_param_t::ble_scan_result_evt_param *advertisements, size_t count) override; void dump_config() override; + void setup() override; void loop() override; esp32_ble_tracker::AdvertisementParserType get_advertisement_parser_type() override; @@ -84,6 +86,8 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com void send_device_unpairing(uint64_t address, bool success, esp_err_t error = ESP_OK); void send_device_clear_cache(uint64_t address, bool success, esp_err_t error = ESP_OK); + void bluetooth_scanner_set_mode(bool active); + static void uint64_to_bd_addr(uint64_t address, esp_bd_addr_t bd_addr) { bd_addr[0] = (address >> 40) & 0xff; bd_addr[1] = (address >> 32) & 0xff; @@ -107,6 +111,7 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com uint32_t flags = 0; flags |= BluetoothProxyFeature::FEATURE_PASSIVE_SCAN; flags |= BluetoothProxyFeature::FEATURE_RAW_ADVERTISEMENTS; + flags |= BluetoothProxyFeature::FEATURE_STATE_AND_MODE; if (this->active_) { flags |= BluetoothProxyFeature::FEATURE_ACTIVE_CONNECTIONS; flags |= BluetoothProxyFeature::FEATURE_REMOTE_CACHING; @@ -124,6 +129,7 @@ class BluetoothProxy : public esp32_ble_tracker::ESPBTDeviceListener, public Com protected: void send_api_packet_(const esp32_ble_tracker::ESPBTDevice &device); + void send_bluetooth_scanner_state_(esp32_ble_tracker::ScannerState state); BluetoothConnection *get_connection_(uint64_t address, bool reserve); diff --git a/esphome/components/button/__init__.py b/esphome/components/button/__init__.py index 366d0edf7d..b68334dd98 100644 --- a/esphome/components/button/__init__.py +++ b/esphome/components/button/__init__.py @@ -44,7 +44,7 @@ ButtonPressTrigger = button_ns.class_( validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_") -BUTTON_SCHEMA = ( +_BUTTON_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( @@ -60,15 +60,13 @@ BUTTON_SCHEMA = ( ) ) -_UNDEF = object() - def button_schema( class_: MockObjClass, *, - icon: str = _UNDEF, - entity_category: str = _UNDEF, - device_class: str = _UNDEF, + icon: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, ) -> cv.Schema: schema = {cv.GenerateID(): cv.declare_id(class_)} @@ -77,10 +75,15 @@ def button_schema( (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), (CONF_DEVICE_CLASS, device_class, validate_device_class), ]: - if default is not _UNDEF: + if default is not cv.UNDEFINED: schema[cv.Optional(key, default=default)] = validator - return BUTTON_SCHEMA.extend(schema) + return _BUTTON_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +BUTTON_SCHEMA = button_schema(Button) +BUTTON_SCHEMA.add_extra(cv.deprecated_schema_constant("button")) async def setup_button_core_(var, config): diff --git a/esphome/components/canbus/canbus.cpp b/esphome/components/canbus/canbus.cpp index 696cfff2b7..3b86f209cd 100644 --- a/esphome/components/canbus/canbus.cpp +++ b/esphome/components/canbus/canbus.cpp @@ -86,6 +86,9 @@ void Canbus::loop() { data.push_back(can_message.data[i]); } + this->callback_manager_(can_message.can_id, can_message.use_extended_id, can_message.remote_transmission_request, + data); + // fire all triggers for (auto *trigger : this->triggers_) { if ((trigger->can_id_ == (can_message.can_id & trigger->can_id_mask_)) && diff --git a/esphome/components/canbus/canbus.h b/esphome/components/canbus/canbus.h index 1e5214fef4..7319bfb4ad 100644 --- a/esphome/components/canbus/canbus.h +++ b/esphome/components/canbus/canbus.h @@ -81,6 +81,20 @@ class Canbus : public Component { void set_bitrate(CanSpeed bit_rate) { this->bit_rate_ = bit_rate; } void add_trigger(CanbusTrigger *trigger); + /** + * Add a callback to be called when a CAN message is received. All received messages + * are passed to the callback without filtering. + * + * The callback function receives: + * - can_id of the received data + * - extended_id True if the can_id is an extended id + * - rtr If this is a remote transmission request + * - data The message data + */ + void add_callback( + std::function &data)> callback) { + this->callback_manager_.add(std::move(callback)); + } protected: template friend class CanbusSendAction; @@ -88,6 +102,8 @@ class Canbus : public Component { uint32_t can_id_; bool use_extended_id_; CanSpeed bit_rate_; + CallbackManager &data)> + callback_manager_{}; virtual bool setup_internal(); virtual Error send_message(struct CanFrame *frame); diff --git a/esphome/components/climate/__init__.py b/esphome/components/climate/__init__.py index 445507c620..7007dc13af 100644 --- a/esphome/components/climate/__init__.py +++ b/esphome/components/climate/__init__.py @@ -11,9 +11,11 @@ from esphome.const import ( CONF_CURRENT_TEMPERATURE_STATE_TOPIC, CONF_CUSTOM_FAN_MODE, CONF_CUSTOM_PRESET, + CONF_ENTITY_CATEGORY, CONF_FAN_MODE, CONF_FAN_MODE_COMMAND_TOPIC, CONF_FAN_MODE_STATE_TOPIC, + CONF_ICON, CONF_ID, CONF_MAX_TEMPERATURE, CONF_MIN_TEMPERATURE, @@ -46,6 +48,7 @@ from esphome.const import ( CONF_WEB_SERVER, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity IS_PLATFORM_COMPONENT = True @@ -151,12 +154,11 @@ ControlTrigger = climate_ns.class_( "ControlTrigger", automation.Trigger.template(ClimateCall.operator("ref")) ) -CLIMATE_SCHEMA = ( +_CLIMATE_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( { - cv.GenerateID(): cv.declare_id(Climate), cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTClimateComponent), cv.Optional(CONF_VISUAL, default={}): cv.Schema( { @@ -245,6 +247,31 @@ CLIMATE_SCHEMA = ( ) +def climate_schema( + class_: MockObjClass, + *, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, +) -> cv.Schema: + schema = { + cv.GenerateID(): cv.declare_id(class_), + } + + for key, default, validator in [ + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _CLIMATE_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +CLIMATE_SCHEMA = climate_schema(Climate) +CLIMATE_SCHEMA.add_extra(cv.deprecated_schema_constant("climate")) + + async def setup_climate_core_(var, config): await setup_entity(var, config) @@ -419,6 +446,12 @@ async def register_climate(var, config): await setup_climate_core_(var, config) +async def new_climate(config, *args): + var = cg.new_Pvariable(config[CONF_ID], *args) + await register_climate(var, config) + return var + + CLIMATE_CONTROL_ACTION_SCHEMA = cv.Schema( { cv.Required(CONF_ID): cv.use_id(Climate), diff --git a/esphome/components/climate/climate_mode.h b/esphome/components/climate/climate_mode.h index c5245812c7..80efb4c048 100644 --- a/esphome/components/climate/climate_mode.h +++ b/esphome/components/climate/climate_mode.h @@ -20,7 +20,7 @@ enum ClimateMode : uint8_t { CLIMATE_MODE_FAN_ONLY = 4, /// The climate device is set to dry/humidity mode CLIMATE_MODE_DRY = 5, - /** The climate device is adjusting the temperatre dynamically. + /** The climate device is adjusting the temperature dynamically. * For example, the target temperature can be adjusted based on a schedule, or learned behavior. * The target temperature can't be adjusted when in this mode. */ diff --git a/esphome/components/climate/climate_traits.h b/esphome/components/climate/climate_traits.h index 58d7b586d7..c3a0dfca8f 100644 --- a/esphome/components/climate/climate_traits.h +++ b/esphome/components/climate/climate_traits.h @@ -40,24 +40,24 @@ namespace climate { */ class ClimateTraits { public: - bool get_supports_current_temperature() const { return supports_current_temperature_; } + bool get_supports_current_temperature() const { return this->supports_current_temperature_; } void set_supports_current_temperature(bool supports_current_temperature) { - supports_current_temperature_ = supports_current_temperature; + this->supports_current_temperature_ = supports_current_temperature; } - bool get_supports_current_humidity() const { return supports_current_humidity_; } + bool get_supports_current_humidity() const { return this->supports_current_humidity_; } void set_supports_current_humidity(bool supports_current_humidity) { - supports_current_humidity_ = supports_current_humidity; + this->supports_current_humidity_ = supports_current_humidity; } - bool get_supports_two_point_target_temperature() const { return supports_two_point_target_temperature_; } + bool get_supports_two_point_target_temperature() const { return this->supports_two_point_target_temperature_; } void set_supports_two_point_target_temperature(bool supports_two_point_target_temperature) { - supports_two_point_target_temperature_ = supports_two_point_target_temperature; + this->supports_two_point_target_temperature_ = supports_two_point_target_temperature; } - bool get_supports_target_humidity() const { return supports_target_humidity_; } + bool get_supports_target_humidity() const { return this->supports_target_humidity_; } void set_supports_target_humidity(bool supports_target_humidity) { - supports_target_humidity_ = supports_target_humidity; + this->supports_target_humidity_ = supports_target_humidity; } - void set_supported_modes(std::set modes) { supported_modes_ = std::move(modes); } - void add_supported_mode(ClimateMode mode) { supported_modes_.insert(mode); } + void set_supported_modes(std::set modes) { this->supported_modes_ = std::move(modes); } + void add_supported_mode(ClimateMode mode) { this->supported_modes_.insert(mode); } ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20") void set_supports_auto_mode(bool supports_auto_mode) { set_mode_support_(CLIMATE_MODE_AUTO, supports_auto_mode); } ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20") @@ -72,15 +72,15 @@ class ClimateTraits { } ESPDEPRECATED("This method is deprecated, use set_supported_modes() instead", "v1.20") void set_supports_dry_mode(bool supports_dry_mode) { set_mode_support_(CLIMATE_MODE_DRY, supports_dry_mode); } - bool supports_mode(ClimateMode mode) const { return supported_modes_.count(mode); } - const std::set &get_supported_modes() const { return supported_modes_; } + bool supports_mode(ClimateMode mode) const { return this->supported_modes_.count(mode); } + const std::set &get_supported_modes() const { return this->supported_modes_; } - void set_supports_action(bool supports_action) { supports_action_ = supports_action; } - bool get_supports_action() const { return supports_action_; } + void set_supports_action(bool supports_action) { this->supports_action_ = supports_action; } + bool get_supports_action() const { return this->supports_action_; } - void set_supported_fan_modes(std::set modes) { supported_fan_modes_ = std::move(modes); } - void add_supported_fan_mode(ClimateFanMode mode) { supported_fan_modes_.insert(mode); } - void add_supported_custom_fan_mode(const std::string &mode) { supported_custom_fan_modes_.insert(mode); } + void set_supported_fan_modes(std::set modes) { this->supported_fan_modes_ = std::move(modes); } + void add_supported_fan_mode(ClimateFanMode mode) { this->supported_fan_modes_.insert(mode); } + void add_supported_custom_fan_mode(const std::string &mode) { this->supported_custom_fan_modes_.insert(mode); } ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20") void set_supports_fan_mode_on(bool supported) { set_fan_mode_support_(CLIMATE_FAN_ON, supported); } ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20") @@ -99,35 +99,37 @@ class ClimateTraits { void set_supports_fan_mode_focus(bool supported) { set_fan_mode_support_(CLIMATE_FAN_FOCUS, supported); } ESPDEPRECATED("This method is deprecated, use set_supported_fan_modes() instead", "v1.20") void set_supports_fan_mode_diffuse(bool supported) { set_fan_mode_support_(CLIMATE_FAN_DIFFUSE, supported); } - bool supports_fan_mode(ClimateFanMode fan_mode) const { return supported_fan_modes_.count(fan_mode); } - bool get_supports_fan_modes() const { return !supported_fan_modes_.empty() || !supported_custom_fan_modes_.empty(); } - const std::set &get_supported_fan_modes() const { return supported_fan_modes_; } + bool supports_fan_mode(ClimateFanMode fan_mode) const { return this->supported_fan_modes_.count(fan_mode); } + bool get_supports_fan_modes() const { + return !this->supported_fan_modes_.empty() || !this->supported_custom_fan_modes_.empty(); + } + const std::set &get_supported_fan_modes() const { return this->supported_fan_modes_; } void set_supported_custom_fan_modes(std::set supported_custom_fan_modes) { - supported_custom_fan_modes_ = std::move(supported_custom_fan_modes); + this->supported_custom_fan_modes_ = std::move(supported_custom_fan_modes); } - const std::set &get_supported_custom_fan_modes() const { return supported_custom_fan_modes_; } + const std::set &get_supported_custom_fan_modes() const { return this->supported_custom_fan_modes_; } bool supports_custom_fan_mode(const std::string &custom_fan_mode) const { - return supported_custom_fan_modes_.count(custom_fan_mode); + return this->supported_custom_fan_modes_.count(custom_fan_mode); } - void set_supported_presets(std::set presets) { supported_presets_ = std::move(presets); } - void add_supported_preset(ClimatePreset preset) { supported_presets_.insert(preset); } - void add_supported_custom_preset(const std::string &preset) { supported_custom_presets_.insert(preset); } - bool supports_preset(ClimatePreset preset) const { return supported_presets_.count(preset); } - bool get_supports_presets() const { return !supported_presets_.empty(); } - const std::set &get_supported_presets() const { return supported_presets_; } + void set_supported_presets(std::set presets) { this->supported_presets_ = std::move(presets); } + void add_supported_preset(ClimatePreset preset) { this->supported_presets_.insert(preset); } + void add_supported_custom_preset(const std::string &preset) { this->supported_custom_presets_.insert(preset); } + bool supports_preset(ClimatePreset preset) const { return this->supported_presets_.count(preset); } + bool get_supports_presets() const { return !this->supported_presets_.empty(); } + const std::set &get_supported_presets() const { return this->supported_presets_; } void set_supported_custom_presets(std::set supported_custom_presets) { - supported_custom_presets_ = std::move(supported_custom_presets); + this->supported_custom_presets_ = std::move(supported_custom_presets); } - const std::set &get_supported_custom_presets() const { return supported_custom_presets_; } + const std::set &get_supported_custom_presets() const { return this->supported_custom_presets_; } bool supports_custom_preset(const std::string &custom_preset) const { - return supported_custom_presets_.count(custom_preset); + return this->supported_custom_presets_.count(custom_preset); } - void set_supported_swing_modes(std::set modes) { supported_swing_modes_ = std::move(modes); } - void add_supported_swing_mode(ClimateSwingMode mode) { supported_swing_modes_.insert(mode); } + void set_supported_swing_modes(std::set modes) { this->supported_swing_modes_ = std::move(modes); } + void add_supported_swing_mode(ClimateSwingMode mode) { this->supported_swing_modes_.insert(mode); } ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20") void set_supports_swing_mode_off(bool supported) { set_swing_mode_support_(CLIMATE_SWING_OFF, supported); } ESPDEPRECATED("This method is deprecated, use set_supported_swing_modes() instead", "v1.20") @@ -138,54 +140,58 @@ class ClimateTraits { void set_supports_swing_mode_horizontal(bool supported) { set_swing_mode_support_(CLIMATE_SWING_HORIZONTAL, supported); } - bool supports_swing_mode(ClimateSwingMode swing_mode) const { return supported_swing_modes_.count(swing_mode); } - bool get_supports_swing_modes() const { return !supported_swing_modes_.empty(); } - const std::set &get_supported_swing_modes() const { return supported_swing_modes_; } + bool supports_swing_mode(ClimateSwingMode swing_mode) const { return this->supported_swing_modes_.count(swing_mode); } + bool get_supports_swing_modes() const { return !this->supported_swing_modes_.empty(); } + const std::set &get_supported_swing_modes() const { return this->supported_swing_modes_; } - float get_visual_min_temperature() const { return visual_min_temperature_; } - void set_visual_min_temperature(float visual_min_temperature) { visual_min_temperature_ = visual_min_temperature; } - float get_visual_max_temperature() const { return visual_max_temperature_; } - void set_visual_max_temperature(float visual_max_temperature) { visual_max_temperature_ = visual_max_temperature; } - float get_visual_target_temperature_step() const { return visual_target_temperature_step_; } - float get_visual_current_temperature_step() const { return visual_current_temperature_step_; } + float get_visual_min_temperature() const { return this->visual_min_temperature_; } + void set_visual_min_temperature(float visual_min_temperature) { + this->visual_min_temperature_ = visual_min_temperature; + } + float get_visual_max_temperature() const { return this->visual_max_temperature_; } + void set_visual_max_temperature(float visual_max_temperature) { + this->visual_max_temperature_ = visual_max_temperature; + } + float get_visual_target_temperature_step() const { return this->visual_target_temperature_step_; } + float get_visual_current_temperature_step() const { return this->visual_current_temperature_step_; } void set_visual_target_temperature_step(float temperature_step) { - visual_target_temperature_step_ = temperature_step; + this->visual_target_temperature_step_ = temperature_step; } void set_visual_current_temperature_step(float temperature_step) { - visual_current_temperature_step_ = temperature_step; + this->visual_current_temperature_step_ = temperature_step; } void set_visual_temperature_step(float temperature_step) { - visual_target_temperature_step_ = temperature_step; - visual_current_temperature_step_ = temperature_step; + this->visual_target_temperature_step_ = temperature_step; + this->visual_current_temperature_step_ = temperature_step; } int8_t get_target_temperature_accuracy_decimals() const; int8_t get_current_temperature_accuracy_decimals() const; - float get_visual_min_humidity() const { return visual_min_humidity_; } - void set_visual_min_humidity(float visual_min_humidity) { visual_min_humidity_ = visual_min_humidity; } - float get_visual_max_humidity() const { return visual_max_humidity_; } - void set_visual_max_humidity(float visual_max_humidity) { visual_max_humidity_ = visual_max_humidity; } + float get_visual_min_humidity() const { return this->visual_min_humidity_; } + void set_visual_min_humidity(float visual_min_humidity) { this->visual_min_humidity_ = visual_min_humidity; } + float get_visual_max_humidity() const { return this->visual_max_humidity_; } + void set_visual_max_humidity(float visual_max_humidity) { this->visual_max_humidity_ = visual_max_humidity; } protected: void set_mode_support_(climate::ClimateMode mode, bool supported) { if (supported) { - supported_modes_.insert(mode); + this->supported_modes_.insert(mode); } else { - supported_modes_.erase(mode); + this->supported_modes_.erase(mode); } } void set_fan_mode_support_(climate::ClimateFanMode mode, bool supported) { if (supported) { - supported_fan_modes_.insert(mode); + this->supported_fan_modes_.insert(mode); } else { - supported_fan_modes_.erase(mode); + this->supported_fan_modes_.erase(mode); } } void set_swing_mode_support_(climate::ClimateSwingMode mode, bool supported) { if (supported) { - supported_swing_modes_.insert(mode); + this->supported_swing_modes_.insert(mode); } else { - supported_swing_modes_.erase(mode); + this->supported_swing_modes_.erase(mode); } } diff --git a/esphome/components/climate_ir_lg/climate_ir_lg.cpp b/esphome/components/climate_ir_lg/climate_ir_lg.cpp index c65f24ebc0..7fe0646230 100644 --- a/esphome/components/climate_ir_lg/climate_ir_lg.cpp +++ b/esphome/components/climate_ir_lg/climate_ir_lg.cpp @@ -32,7 +32,7 @@ const uint32_t FAN_MAX = 0x40; // Temperature const uint8_t TEMP_RANGE = TEMP_MAX - TEMP_MIN + 1; -const uint32_t TEMP_MASK = 0XF00; +const uint32_t TEMP_MASK = 0xF00; const uint32_t TEMP_SHIFT = 8; const uint16_t BITS = 28; @@ -43,11 +43,11 @@ void LgIrClimate::transmit_state() { // ESP_LOGD(TAG, "climate_lg_ir mode_before_ code: 0x%02X", modeBefore_); // Set command - if (send_swing_cmd_) { - send_swing_cmd_ = false; + if (this->send_swing_cmd_) { + this->send_swing_cmd_ = false; remote_state |= COMMAND_SWING; } else { - bool climate_is_off = (mode_before_ == climate::CLIMATE_MODE_OFF); + bool climate_is_off = (this->mode_before_ == climate::CLIMATE_MODE_OFF); switch (this->mode) { case climate::CLIMATE_MODE_COOL: remote_state |= climate_is_off ? COMMAND_ON_COOL : COMMAND_COOL; @@ -71,7 +71,7 @@ void LgIrClimate::transmit_state() { } } - mode_before_ = this->mode; + this->mode_before_ = this->mode; ESP_LOGD(TAG, "climate_lg_ir mode code: 0x%02X", this->mode); @@ -102,7 +102,7 @@ void LgIrClimate::transmit_state() { remote_state |= ((temp - 15) << TEMP_SHIFT); } - transmit_(remote_state); + this->transmit_(remote_state); this->publish_state(); } @@ -187,7 +187,7 @@ bool LgIrClimate::on_receive(remote_base::RemoteReceiveData data) { } void LgIrClimate::transmit_(uint32_t value) { - calc_checksum_(value); + this->calc_checksum_(value); ESP_LOGD(TAG, "Sending climate_lg_ir code: 0x%02" PRIX32, value); auto transmit = this->transmitter_->transmit(); diff --git a/esphome/components/climate_ir_lg/climate_ir_lg.h b/esphome/components/climate_ir_lg/climate_ir_lg.h index 7ee041b86f..00fc99ae73 100644 --- a/esphome/components/climate_ir_lg/climate_ir_lg.h +++ b/esphome/components/climate_ir_lg/climate_ir_lg.h @@ -21,7 +21,7 @@ class LgIrClimate : public climate_ir::ClimateIR { /// Override control to change settings of the climate device. void control(const climate::ClimateCall &call) override { - send_swing_cmd_ = call.get_swing_mode().has_value(); + this->send_swing_cmd_ = call.get_swing_mode().has_value(); // swing resets after unit powered off if (call.get_mode().has_value() && *call.get_mode() == climate::CLIMATE_MODE_OFF) this->swing_mode = climate::CLIMATE_SWING_OFF; diff --git a/esphome/components/color/__init__.py b/esphome/components/color/__init__.py index c3381cfd70..c39c5924af 100644 --- a/esphome/components/color/__init__.py +++ b/esphome/components/color/__init__.py @@ -3,6 +3,8 @@ from esphome.const import CONF_BLUE, CONF_GREEN, CONF_ID, CONF_RED, CONF_WHITE ColorStruct = cg.esphome_ns.struct("Color") +INSTANCE_TYPE = ColorStruct + MULTI_CONF = True CONF_RED_INT = "red_int" diff --git a/esphome/components/const/__init__.py b/esphome/components/const/__init__.py new file mode 100644 index 0000000000..6af357f23b --- /dev/null +++ b/esphome/components/const/__init__.py @@ -0,0 +1,5 @@ +"""Constants used by esphome components.""" + +CODEOWNERS = ["@esphome/core"] + +CONF_DRAW_ROUNDING = "draw_rounding" diff --git a/esphome/components/copy/cover/__init__.py b/esphome/components/copy/cover/__init__.py index 7db9034d02..ff5bef5668 100644 --- a/esphome/components/copy/cover/__init__.py +++ b/esphome/components/copy/cover/__init__.py @@ -5,7 +5,6 @@ from esphome.const import ( CONF_DEVICE_CLASS, CONF_ENTITY_CATEGORY, CONF_ICON, - CONF_ID, CONF_SOURCE_ID, ) from esphome.core.entity_helpers import inherit_property_from @@ -15,12 +14,15 @@ from .. import copy_ns CopyCover = copy_ns.class_("CopyCover", cover.Cover, cg.Component) -CONFIG_SCHEMA = cover.COVER_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(CopyCover), - cv.Required(CONF_SOURCE_ID): cv.use_id(cover.Cover), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + cover.cover_schema(CopyCover) + .extend( + { + cv.Required(CONF_SOURCE_ID): cv.use_id(cover.Cover), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) FINAL_VALIDATE_SCHEMA = cv.All( inherit_property_from(CONF_ICON, CONF_SOURCE_ID), @@ -30,8 +32,7 @@ FINAL_VALIDATE_SCHEMA = cv.All( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) - await cover.register_cover(var, config) + var = await cover.new_cover(config) await cg.register_component(var, config) source = await cg.get_variable(config[CONF_SOURCE_ID]) diff --git a/esphome/components/copy/lock/__init__.py b/esphome/components/copy/lock/__init__.py index ddedea64c0..46bc08273e 100644 --- a/esphome/components/copy/lock/__init__.py +++ b/esphome/components/copy/lock/__init__.py @@ -1,7 +1,7 @@ import esphome.codegen as cg from esphome.components import lock import esphome.config_validation as cv -from esphome.const import CONF_ENTITY_CATEGORY, CONF_ICON, CONF_ID, CONF_SOURCE_ID +from esphome.const import CONF_ENTITY_CATEGORY, CONF_ICON, CONF_SOURCE_ID from esphome.core.entity_helpers import inherit_property_from from .. import copy_ns @@ -9,12 +9,15 @@ from .. import copy_ns CopyLock = copy_ns.class_("CopyLock", lock.Lock, cg.Component) -CONFIG_SCHEMA = lock.LOCK_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(CopyLock), - cv.Required(CONF_SOURCE_ID): cv.use_id(lock.Lock), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + lock.lock_schema(CopyLock) + .extend( + { + cv.Required(CONF_SOURCE_ID): cv.use_id(lock.Lock), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) FINAL_VALIDATE_SCHEMA = cv.All( inherit_property_from(CONF_ICON, CONF_SOURCE_ID), @@ -23,8 +26,7 @@ FINAL_VALIDATE_SCHEMA = cv.All( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) - await lock.register_lock(var, config) + var = await lock.new_lock(config) await cg.register_component(var, config) source = await cg.get_variable(config[CONF_SOURCE_ID]) diff --git a/esphome/components/copy/text/__init__.py b/esphome/components/copy/text/__init__.py index aa39225bc2..f1ca404b7b 100644 --- a/esphome/components/copy/text/__init__.py +++ b/esphome/components/copy/text/__init__.py @@ -9,12 +9,15 @@ from .. import copy_ns CopyText = copy_ns.class_("CopyText", text.Text, cg.Component) -CONFIG_SCHEMA = text.TEXT_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(CopyText), - cv.Required(CONF_SOURCE_ID): cv.use_id(text.Text), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + text.text_schema(CopyText) + .extend( + { + cv.Required(CONF_SOURCE_ID): cv.use_id(text.Text), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) FINAL_VALIDATE_SCHEMA = cv.All( inherit_property_from(CONF_ICON, CONF_SOURCE_ID), diff --git a/esphome/components/cover/__init__.py b/esphome/components/cover/__init__.py index e7e3ac3bb0..13f117c3f0 100644 --- a/esphome/components/cover/__init__.py +++ b/esphome/components/cover/__init__.py @@ -5,6 +5,8 @@ from esphome.components import mqtt, web_server import esphome.config_validation as cv from esphome.const import ( CONF_DEVICE_CLASS, + CONF_ENTITY_CATEGORY, + CONF_ICON, CONF_ID, CONF_MQTT_ID, CONF_ON_OPEN, @@ -31,6 +33,7 @@ from esphome.const import ( DEVICE_CLASS_WINDOW, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity IS_PLATFORM_COMPONENT = True @@ -89,12 +92,11 @@ CoverClosedTrigger = cover_ns.class_( CONF_ON_CLOSED = "on_closed" -COVER_SCHEMA = ( +_COVER_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( { - cv.GenerateID(): cv.declare_id(Cover), cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTCoverComponent), cv.Optional(CONF_DEVICE_CLASS): cv.one_of(*DEVICE_CLASSES, lower=True), cv.Optional(CONF_POSITION_COMMAND_TOPIC): cv.All( @@ -124,6 +126,33 @@ COVER_SCHEMA = ( ) +def cover_schema( + class_: MockObjClass, + *, + device_class: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, +) -> cv.Schema: + schema = { + cv.GenerateID(): cv.declare_id(class_), + } + + for key, default, validator in [ + (CONF_DEVICE_CLASS, device_class, cv.one_of(*DEVICE_CLASSES, lower=True)), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _COVER_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +COVER_SCHEMA = cover_schema(Cover) +COVER_SCHEMA.add_extra(cv.deprecated_schema_constant("cover")) + + async def setup_cover_core_(var, config): await setup_entity(var, config) @@ -163,6 +192,12 @@ async def register_cover(var, config): await setup_cover_core_(var, config) +async def new_cover(config, *args): + var = cg.new_Pvariable(config[CONF_ID], *args) + await register_cover(var, config) + return var + + COVER_ACTION_SCHEMA = maybe_simple_id( { cv.Required(CONF_ID): cv.use_id(Cover), diff --git a/esphome/components/cst226/binary_sensor/__init__.py b/esphome/components/cst226/binary_sensor/__init__.py new file mode 100644 index 0000000000..d95f0d2b4d --- /dev/null +++ b/esphome/components/cst226/binary_sensor/__init__.py @@ -0,0 +1,28 @@ +import esphome.codegen as cg +from esphome.components import binary_sensor +import esphome.config_validation as cv + +from .. import cst226_ns +from ..touchscreen import CST226ButtonListener, CST226Touchscreen + +CONF_CST226_ID = "cst226_id" + +CST226Button = cst226_ns.class_( + "CST226Button", + binary_sensor.BinarySensor, + cg.Component, + CST226ButtonListener, + cg.Parented.template(CST226Touchscreen), +) + +CONFIG_SCHEMA = binary_sensor.binary_sensor_schema(CST226Button).extend( + { + cv.GenerateID(CONF_CST226_ID): cv.use_id(CST226Touchscreen), + } +) + + +async def to_code(config): + var = await binary_sensor.new_binary_sensor(config) + await cg.register_component(var, config) + await cg.register_parented(var, config[CONF_CST226_ID]) diff --git a/esphome/components/cst226/binary_sensor/cs226_button.h b/esphome/components/cst226/binary_sensor/cs226_button.h new file mode 100644 index 0000000000..6d409df04f --- /dev/null +++ b/esphome/components/cst226/binary_sensor/cs226_button.h @@ -0,0 +1,22 @@ +#pragma once + +#include "esphome/components/binary_sensor/binary_sensor.h" +#include "../touchscreen/cst226_touchscreen.h" +#include "esphome/core/helpers.h" + +namespace esphome { +namespace cst226 { + +class CST226Button : public binary_sensor::BinarySensor, + public Component, + public CST226ButtonListener, + public Parented { + public: + void setup() override; + void dump_config() override; + + void update_button(bool state) override; +}; + +} // namespace cst226 +} // namespace esphome diff --git a/esphome/components/cst226/binary_sensor/cstt6_button.cpp b/esphome/components/cst226/binary_sensor/cstt6_button.cpp new file mode 100644 index 0000000000..c481ce5d57 --- /dev/null +++ b/esphome/components/cst226/binary_sensor/cstt6_button.cpp @@ -0,0 +1,19 @@ +#include "cs226_button.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace cst226 { + +static const char *const TAG = "CST226.binary_sensor"; + +void CST226Button::setup() { + this->parent_->register_button_listener(this); + this->publish_initial_state(false); +} + +void CST226Button::dump_config() { LOG_BINARY_SENSOR("", "CST226 Button", this); } + +void CST226Button::update_button(bool state) { this->publish_state(state); } + +} // namespace cst226 +} // namespace esphome diff --git a/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp b/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp index a25859fe17..fa8cd9b057 100644 --- a/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp +++ b/esphome/components/cst226/touchscreen/cst226_touchscreen.cpp @@ -3,8 +3,10 @@ namespace esphome { namespace cst226 { +static const char *const TAG = "cst226.touchscreen"; + void CST226Touchscreen::setup() { - esph_log_config(TAG, "Setting up CST226 Touchscreen..."); + ESP_LOGCONFIG(TAG, "Setting up CST226 Touchscreen..."); if (this->reset_pin_ != nullptr) { this->reset_pin_->setup(); this->reset_pin_->digital_write(true); @@ -26,6 +28,11 @@ void CST226Touchscreen::update_touches() { return; } this->status_clear_warning(); + if (data[0] == 0x83 && data[1] == 0x17 && data[5] == 0x80) { + this->update_button_state_(true); + return; + } + this->update_button_state_(false); if (data[6] != 0xAB || data[0] == 0xAB || data[5] == 0x80) { this->skip_update_ = true; return; @@ -43,13 +50,21 @@ void CST226Touchscreen::update_touches() { int16_t y = (data[index + 2] << 4) | (data[index + 3] & 0x0F); int16_t z = data[index + 4]; this->add_raw_touch_position_(id, x, y, z); - esph_log_v(TAG, "Read touch %d: %d/%d", id, x, y); + ESP_LOGV(TAG, "Read touch %d: %d/%d", id, x, y); index += 5; if (i == 0) index += 2; } } +bool CST226Touchscreen::read16_(uint16_t addr, uint8_t *data, size_t len) { + if (this->read_register16(addr, data, len) != i2c::ERROR_OK) { + ESP_LOGE(TAG, "Read data from 0x%04X failed", addr); + this->mark_failed(); + return false; + } + return true; +} void CST226Touchscreen::continue_setup_() { uint8_t buffer[8]; if (this->interrupt_pin_ != nullptr) { @@ -58,7 +73,7 @@ void CST226Touchscreen::continue_setup_() { } buffer[0] = 0xD1; if (this->write_register16(0xD1, buffer, 1) != i2c::ERROR_OK) { - esph_log_e(TAG, "Write byte to 0xD1 failed"); + ESP_LOGE(TAG, "Write byte to 0xD1 failed"); this->mark_failed(); return; } @@ -66,7 +81,7 @@ void CST226Touchscreen::continue_setup_() { if (this->read16_(0xD204, buffer, 4)) { uint16_t chip_id = buffer[2] + (buffer[3] << 8); uint16_t project_id = buffer[0] + (buffer[1] << 8); - esph_log_config(TAG, "Chip ID %X, project ID %x", chip_id, project_id); + ESP_LOGCONFIG(TAG, "Chip ID %X, project ID %x", chip_id, project_id); } if (this->x_raw_max_ == 0 || this->y_raw_max_ == 0) { if (this->read16_(0xD1F8, buffer, 4)) { @@ -80,7 +95,14 @@ void CST226Touchscreen::continue_setup_() { } } this->setup_complete_ = true; - esph_log_config(TAG, "CST226 Touchscreen setup complete"); + ESP_LOGCONFIG(TAG, "CST226 Touchscreen setup complete"); +} +void CST226Touchscreen::update_button_state_(bool state) { + if (this->button_touched_ == state) + return; + this->button_touched_ = state; + for (auto *listener : this->button_listeners_) + listener->update_button(state); } void CST226Touchscreen::dump_config() { diff --git a/esphome/components/cst226/touchscreen/cst226_touchscreen.h b/esphome/components/cst226/touchscreen/cst226_touchscreen.h index 9f518e5068..c744e51fec 100644 --- a/esphome/components/cst226/touchscreen/cst226_touchscreen.h +++ b/esphome/components/cst226/touchscreen/cst226_touchscreen.h @@ -9,10 +9,13 @@ namespace esphome { namespace cst226 { -static const char *const TAG = "cst226.touchscreen"; - static const uint8_t CST226_REG_STATUS = 0x00; +class CST226ButtonListener { + public: + virtual void update_button(bool state) = 0; +}; + class CST226Touchscreen : public touchscreen::Touchscreen, public i2c::I2CDevice { public: void setup() override; @@ -22,22 +25,19 @@ class CST226Touchscreen : public touchscreen::Touchscreen, public i2c::I2CDevice void set_interrupt_pin(InternalGPIOPin *pin) { this->interrupt_pin_ = pin; } void set_reset_pin(GPIOPin *pin) { this->reset_pin_ = pin; } bool can_proceed() override { return this->setup_complete_ || this->is_failed(); } + void register_button_listener(CST226ButtonListener *listener) { this->button_listeners_.push_back(listener); } protected: - bool read16_(uint16_t addr, uint8_t *data, size_t len) { - if (this->read_register16(addr, data, len) != i2c::ERROR_OK) { - esph_log_e(TAG, "Read data from 0x%04X failed", addr); - this->mark_failed(); - return false; - } - return true; - } + bool read16_(uint16_t addr, uint8_t *data, size_t len); void continue_setup_(); + void update_button_state_(bool state); InternalGPIOPin *interrupt_pin_{}; GPIOPin *reset_pin_{}; uint8_t chip_id_{}; bool setup_complete_{}; + std::vector button_listeners_; + bool button_touched_{}; }; } // namespace cst226 diff --git a/esphome/components/current_based/cover.py b/esphome/components/current_based/cover.py index 75f083ef14..99952adb12 100644 --- a/esphome/components/current_based/cover.py +++ b/esphome/components/current_based/cover.py @@ -5,7 +5,6 @@ import esphome.config_validation as cv from esphome.const import ( CONF_CLOSE_ACTION, CONF_CLOSE_DURATION, - CONF_ID, CONF_MAX_DURATION, CONF_OPEN_ACTION, CONF_OPEN_DURATION, @@ -30,45 +29,47 @@ CurrentBasedCover = current_based_ns.class_( "CurrentBasedCover", cover.Cover, cg.Component ) -CONFIG_SCHEMA = cover.COVER_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(CurrentBasedCover), - cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_SENSOR): cv.use_id(sensor.Sensor), - cv.Required(CONF_OPEN_MOVING_CURRENT_THRESHOLD): cv.float_range( - min=0, min_included=False - ), - cv.Optional(CONF_OPEN_OBSTACLE_CURRENT_THRESHOLD): cv.float_range( - min=0, min_included=False - ), - cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, - cv.Required(CONF_CLOSE_SENSOR): cv.use_id(sensor.Sensor), - cv.Required(CONF_CLOSE_MOVING_CURRENT_THRESHOLD): cv.float_range( - min=0, min_included=False - ), - cv.Optional(CONF_CLOSE_OBSTACLE_CURRENT_THRESHOLD): cv.float_range( - min=0, min_included=False - ), - cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_OBSTACLE_ROLLBACK, default="10%"): cv.percentage, - cv.Optional(CONF_MAX_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_MALFUNCTION_DETECTION, default=True): cv.boolean, - cv.Optional(CONF_MALFUNCTION_ACTION): automation.validate_automation( - single=True - ), - cv.Optional( - CONF_START_SENSING_DELAY, default="500ms" - ): cv.positive_time_period_milliseconds, - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + cover.cover_schema(CurrentBasedCover) + .extend( + { + cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_SENSOR): cv.use_id(sensor.Sensor), + cv.Required(CONF_OPEN_MOVING_CURRENT_THRESHOLD): cv.float_range( + min=0, min_included=False + ), + cv.Optional(CONF_OPEN_OBSTACLE_CURRENT_THRESHOLD): cv.float_range( + min=0, min_included=False + ), + cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, + cv.Required(CONF_CLOSE_SENSOR): cv.use_id(sensor.Sensor), + cv.Required(CONF_CLOSE_MOVING_CURRENT_THRESHOLD): cv.float_range( + min=0, min_included=False + ), + cv.Optional(CONF_CLOSE_OBSTACLE_CURRENT_THRESHOLD): cv.float_range( + min=0, min_included=False + ), + cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_OBSTACLE_ROLLBACK, default="10%"): cv.percentage, + cv.Optional(CONF_MAX_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_MALFUNCTION_DETECTION, default=True): cv.boolean, + cv.Optional(CONF_MALFUNCTION_ACTION): automation.validate_automation( + single=True + ), + cv.Optional( + CONF_START_SENSING_DELAY, default="500ms" + ): cv.positive_time_period_milliseconds, + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) await automation.build_automation( var.get_stop_trigger(), [], config[CONF_STOP_ACTION] diff --git a/esphome/components/daikin/daikin.cpp b/esphome/components/daikin/daikin.cpp index bb8587fbeb..359c63aeca 100644 --- a/esphome/components/daikin/daikin.cpp +++ b/esphome/components/daikin/daikin.cpp @@ -65,7 +65,7 @@ void DaikinClimate::transmit_state() { transmit.perform(); } -uint8_t DaikinClimate::operation_mode_() { +uint8_t DaikinClimate::operation_mode_() const { uint8_t operating_mode = DAIKIN_MODE_ON; switch (this->mode) { case climate::CLIMATE_MODE_COOL: @@ -92,9 +92,12 @@ uint8_t DaikinClimate::operation_mode_() { return operating_mode; } -uint16_t DaikinClimate::fan_speed_() { +uint16_t DaikinClimate::fan_speed_() const { uint16_t fan_speed; switch (this->fan_mode.value()) { + case climate::CLIMATE_FAN_QUIET: + fan_speed = DAIKIN_FAN_SILENT << 8; + break; case climate::CLIMATE_FAN_LOW: fan_speed = DAIKIN_FAN_1 << 8; break; @@ -126,12 +129,11 @@ uint16_t DaikinClimate::fan_speed_() { return fan_speed; } -uint8_t DaikinClimate::temperature_() { +uint8_t DaikinClimate::temperature_() const { // Force special temperatures depending on the mode switch (this->mode) { case climate::CLIMATE_MODE_FAN_ONLY: return 0x32; - case climate::CLIMATE_MODE_HEAT_COOL: case climate::CLIMATE_MODE_DRY: return 0xc0; default: @@ -148,19 +150,25 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { if (frame[DAIKIN_STATE_FRAME_SIZE - 1] != checksum) return false; uint8_t mode = frame[5]; + // Temperature is given in degrees celcius * 2 + // only update for states that use the temperature + uint8_t temperature = frame[6]; if (mode & DAIKIN_MODE_ON) { switch (mode & 0xF0) { case DAIKIN_MODE_COOL: this->mode = climate::CLIMATE_MODE_COOL; + this->target_temperature = static_cast(temperature * 0.5f); break; case DAIKIN_MODE_DRY: this->mode = climate::CLIMATE_MODE_DRY; break; case DAIKIN_MODE_HEAT: this->mode = climate::CLIMATE_MODE_HEAT; + this->target_temperature = static_cast(temperature * 0.5f); break; case DAIKIN_MODE_AUTO: this->mode = climate::CLIMATE_MODE_HEAT_COOL; + this->target_temperature = static_cast(temperature * 0.5f); break; case DAIKIN_MODE_FAN: this->mode = climate::CLIMATE_MODE_FAN_ONLY; @@ -169,10 +177,6 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { } else { this->mode = climate::CLIMATE_MODE_OFF; } - uint8_t temperature = frame[6]; - if (!(temperature & 0xC0)) { - this->target_temperature = temperature >> 1; - } uint8_t fan_mode = frame[8]; uint8_t swing_mode = frame[9]; if (fan_mode & 0xF && swing_mode & 0xF) { @@ -187,7 +191,6 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { switch (fan_mode & 0xF0) { case DAIKIN_FAN_1: case DAIKIN_FAN_2: - case DAIKIN_FAN_SILENT: this->fan_mode = climate::CLIMATE_FAN_LOW; break; case DAIKIN_FAN_3: @@ -200,6 +203,9 @@ bool DaikinClimate::parse_state_frame_(const uint8_t frame[]) { case DAIKIN_FAN_AUTO: this->fan_mode = climate::CLIMATE_FAN_AUTO; break; + case DAIKIN_FAN_SILENT: + this->fan_mode = climate::CLIMATE_FAN_QUIET; + break; } this->publish_state(); return true; diff --git a/esphome/components/daikin/daikin.h b/esphome/components/daikin/daikin.h index b4ac309de9..159292cb55 100644 --- a/esphome/components/daikin/daikin.h +++ b/esphome/components/daikin/daikin.h @@ -44,17 +44,17 @@ class DaikinClimate : public climate_ir::ClimateIR { public: DaikinClimate() : climate_ir::ClimateIR(DAIKIN_TEMP_MIN, DAIKIN_TEMP_MAX, 1.0f, true, true, - {climate::CLIMATE_FAN_AUTO, climate::CLIMATE_FAN_LOW, climate::CLIMATE_FAN_MEDIUM, - climate::CLIMATE_FAN_HIGH}, + {climate::CLIMATE_FAN_QUIET, climate::CLIMATE_FAN_AUTO, climate::CLIMATE_FAN_LOW, + climate::CLIMATE_FAN_MEDIUM, climate::CLIMATE_FAN_HIGH}, {climate::CLIMATE_SWING_OFF, climate::CLIMATE_SWING_VERTICAL, climate::CLIMATE_SWING_HORIZONTAL, climate::CLIMATE_SWING_BOTH}) {} protected: // Transmit via IR the state of this climate controller. void transmit_state() override; - uint8_t operation_mode_(); - uint16_t fan_speed_(); - uint8_t temperature_(); + uint8_t operation_mode_() const; + uint16_t fan_speed_() const; + uint8_t temperature_() const; // Handle received IR Buffer bool on_receive(remote_base::RemoteReceiveData data) override; bool parse_state_frame_(const uint8_t frame[]); diff --git a/esphome/components/dallas_temp/dallas_temp.cpp b/esphome/components/dallas_temp/dallas_temp.cpp index ae567d6a76..46db22d97f 100644 --- a/esphome/components/dallas_temp/dallas_temp.cpp +++ b/esphome/components/dallas_temp/dallas_temp.cpp @@ -56,21 +56,13 @@ void DallasTemperatureSensor::update() { }); } -void IRAM_ATTR DallasTemperatureSensor::read_scratch_pad_int_() { - for (uint8_t &i : this->scratch_pad_) { - i = this->bus_->read8(); - } -} - bool DallasTemperatureSensor::read_scratch_pad_() { - bool success; - { - InterruptLock lock; - success = this->send_command_(DALLAS_COMMAND_READ_SCRATCH_PAD); - if (success) - this->read_scratch_pad_int_(); - } - if (!success) { + bool success = this->send_command_(DALLAS_COMMAND_READ_SCRATCH_PAD); + if (success) { + for (uint8_t &i : this->scratch_pad_) { + i = this->bus_->read8(); + } + } else { ESP_LOGW(TAG, "'%s' - reading scratch pad failed bus reset", this->get_name().c_str()); this->status_set_warning("bus reset failed"); } @@ -113,17 +105,14 @@ void DallasTemperatureSensor::setup() { return; this->scratch_pad_[4] = res; - { - InterruptLock lock; - if (this->send_command_(DALLAS_COMMAND_WRITE_SCRATCH_PAD)) { - this->bus_->write8(this->scratch_pad_[2]); // high alarm temp - this->bus_->write8(this->scratch_pad_[3]); // low alarm temp - this->bus_->write8(this->scratch_pad_[4]); // resolution - } - - // write value to EEPROM - this->send_command_(DALLAS_COMMAND_COPY_SCRATCH_PAD); + if (this->send_command_(DALLAS_COMMAND_WRITE_SCRATCH_PAD)) { + this->bus_->write8(this->scratch_pad_[2]); // high alarm temp + this->bus_->write8(this->scratch_pad_[3]); // low alarm temp + this->bus_->write8(this->scratch_pad_[4]); // resolution } + + // write value to EEPROM + this->send_command_(DALLAS_COMMAND_COPY_SCRATCH_PAD); } bool DallasTemperatureSensor::check_scratch_pad_() { @@ -138,6 +127,10 @@ bool DallasTemperatureSensor::check_scratch_pad_() { if (!chksum_validity) { ESP_LOGW(TAG, "'%s' - Scratch pad checksum invalid!", this->get_name().c_str()); this->status_set_warning("scratch pad checksum invalid"); + ESP_LOGD(TAG, "Scratch pad: %02X.%02X.%02X.%02X.%02X.%02X.%02X.%02X.%02X (%02X)", this->scratch_pad_[0], + this->scratch_pad_[1], this->scratch_pad_[2], this->scratch_pad_[3], this->scratch_pad_[4], + this->scratch_pad_[5], this->scratch_pad_[6], this->scratch_pad_[7], this->scratch_pad_[8], + crc8(this->scratch_pad_, 8)); } return chksum_validity; } diff --git a/esphome/components/dallas_temp/dallas_temp.h b/esphome/components/dallas_temp/dallas_temp.h index 604c9d0cd7..1bd2865095 100644 --- a/esphome/components/dallas_temp/dallas_temp.h +++ b/esphome/components/dallas_temp/dallas_temp.h @@ -23,7 +23,6 @@ class DallasTemperatureSensor : public PollingComponent, public sensor::Sensor, /// Get the number of milliseconds we have to wait for the conversion phase. uint16_t millis_to_wait_for_conversion_() const; bool read_scratch_pad_(); - void read_scratch_pad_int_(); bool check_scratch_pad_(); float get_temp_c_(); }; diff --git a/esphome/components/debug/debug_component.cpp b/esphome/components/debug/debug_component.cpp index 7d25bf5472..5bcc676247 100644 --- a/esphome/components/debug/debug_component.cpp +++ b/esphome/components/debug/debug_component.cpp @@ -1,6 +1,7 @@ #include "debug_component.h" #include +#include "esphome/core/application.h" #include "esphome/core/log.h" #include "esphome/core/hal.h" #include "esphome/core/helpers.h" @@ -25,6 +26,7 @@ void DebugComponent::dump_config() { #ifdef USE_SENSOR LOG_SENSOR(" ", "Free space on heap", this->free_sensor_); LOG_SENSOR(" ", "Largest free heap block", this->block_sensor_); + LOG_SENSOR(" ", "CPU frequency", this->cpu_frequency_sensor_); #if defined(USE_ESP8266) && USE_ARDUINO_VERSION_CODE >= VERSION_CODE(2, 5, 2) LOG_SENSOR(" ", "Heap fragmentation", this->fragmentation_sensor_); #endif // defined(USE_ESP8266) && USE_ARDUINO_VERSION_CODE >= VERSION_CODE(2, 5, 2) @@ -86,6 +88,9 @@ void DebugComponent::update() { this->loop_time_sensor_->publish_state(this->max_loop_time_); this->max_loop_time_ = 0; } + if (this->cpu_frequency_sensor_ != nullptr) { + this->cpu_frequency_sensor_->publish_state(arch_get_cpu_freq_hz()); + } #endif // USE_SENSOR update_platform_(); diff --git a/esphome/components/debug/debug_component.h b/esphome/components/debug/debug_component.h index 608addb4a3..a55cc7bf44 100644 --- a/esphome/components/debug/debug_component.h +++ b/esphome/components/debug/debug_component.h @@ -34,8 +34,12 @@ class DebugComponent : public PollingComponent { #endif void set_loop_time_sensor(sensor::Sensor *loop_time_sensor) { loop_time_sensor_ = loop_time_sensor; } #ifdef USE_ESP32 + void on_shutdown() override; void set_psram_sensor(sensor::Sensor *psram_sensor) { this->psram_sensor_ = psram_sensor; } #endif // USE_ESP32 + void set_cpu_frequency_sensor(sensor::Sensor *cpu_frequency_sensor) { + this->cpu_frequency_sensor_ = cpu_frequency_sensor; + } #endif // USE_SENSOR protected: uint32_t free_heap_{}; @@ -53,6 +57,7 @@ class DebugComponent : public PollingComponent { #ifdef USE_ESP32 sensor::Sensor *psram_sensor_{nullptr}; #endif // USE_ESP32 + sensor::Sensor *cpu_frequency_sensor_{nullptr}; #endif // USE_SENSOR #ifdef USE_ESP32 @@ -75,6 +80,7 @@ class DebugComponent : public PollingComponent { #endif // USE_TEXT_SENSOR std::string get_reset_reason_(); + std::string get_wakeup_cause_(); uint32_t get_free_heap_(); void get_device_info_(std::string &device_info); void update_platform_(); diff --git a/esphome/components/debug/debug_esp32.cpp b/esphome/components/debug/debug_esp32.cpp index caa9f8d743..999cb927b3 100644 --- a/esphome/components/debug/debug_esp32.cpp +++ b/esphome/components/debug/debug_esp32.cpp @@ -1,25 +1,18 @@ #include "debug_component.h" + #ifdef USE_ESP32 +#include "esphome/core/application.h" #include "esphome/core/log.h" +#include "esphome/core/hal.h" +#include #include #include #include #include -#if defined(USE_ESP32_VARIANT_ESP32) -#include -#elif defined(USE_ESP32_VARIANT_ESP32C3) -#include -#elif defined(USE_ESP32_VARIANT_ESP32C6) -#include -#elif defined(USE_ESP32_VARIANT_ESP32S2) -#include -#elif defined(USE_ESP32_VARIANT_ESP32S3) -#include -#elif defined(USE_ESP32_VARIANT_ESP32H2) -#include -#endif +#include + #ifdef USE_ARDUINO #include #endif @@ -29,6 +22,90 @@ namespace debug { static const char *const TAG = "debug"; +// index by values returned by esp_reset_reason + +static const char *const RESET_REASONS[] = { + "unknown source", + "power-on event", + "external pin", + "software via esp_restart", + "exception/panic", + "interrupt watchdog", + "task watchdog", + "other watchdogs", + "exiting deep sleep mode", + "brownout", + "SDIO", + "USB peripheral", + "JTAG", + "efuse error", + "power glitch detected", + "CPU lock up", +}; + +static const char *const REBOOT_KEY = "reboot_source"; +static const size_t REBOOT_MAX_LEN = 24; + +// on shutdown, store the source of the reboot request +void DebugComponent::on_shutdown() { + auto *component = App.get_current_component(); + char buffer[REBOOT_MAX_LEN]{}; + auto pref = global_preferences->make_preference(REBOOT_MAX_LEN, fnv1_hash(REBOOT_KEY + App.get_name())); + if (component != nullptr) { + strncpy(buffer, component->get_component_source(), REBOOT_MAX_LEN - 1); + } + ESP_LOGD(TAG, "Storing reboot source: %s", buffer); + pref.save(&buffer); + global_preferences->sync(); +} + +std::string DebugComponent::get_reset_reason_() { + std::string reset_reason; + unsigned reason = esp_reset_reason(); + if (reason < sizeof(RESET_REASONS) / sizeof(RESET_REASONS[0])) { + reset_reason = RESET_REASONS[reason]; + if (reason == ESP_RST_SW) { + auto pref = global_preferences->make_preference(REBOOT_MAX_LEN, fnv1_hash(REBOOT_KEY + App.get_name())); + char buffer[REBOOT_MAX_LEN]{}; + if (pref.load(&buffer)) { + reset_reason = "Reboot request from " + std::string(buffer); + } + } + } else { + reset_reason = "unknown source"; + } + ESP_LOGD(TAG, "Reset Reason: %s", reset_reason.c_str()); + return reset_reason; +} + +static const char *const WAKEUP_CAUSES[] = { + "undefined", + "undefined", + "external signal using RTC_IO", + "external signal using RTC_CNTL", + "timer", + "touchpad", + "ULP program", + "GPIO", + "UART", + "WIFI", + "COCPU int", + "COCPU crash", + "BT", +}; + +std::string DebugComponent::get_wakeup_cause_() { + const char *wake_reason; + unsigned reason = esp_sleep_get_wakeup_cause(); + if (reason < sizeof(WAKEUP_CAUSES) / sizeof(WAKEUP_CAUSES[0])) { + wake_reason = WAKEUP_CAUSES[reason]; + } else { + wake_reason = "unknown source"; + } + ESP_LOGD(TAG, "Wakeup Reason: %s", wake_reason); + return wake_reason; +} + void DebugComponent::log_partition_info_() { ESP_LOGCONFIG(TAG, "Partition table:"); ESP_LOGCONFIG(TAG, " %-12s %-4s %-8s %-10s %-10s", "Name", "Type", "Subtype", "Address", "Size"); @@ -42,171 +119,16 @@ void DebugComponent::log_partition_info_() { esp_partition_iterator_release(it); } -std::string DebugComponent::get_reset_reason_() { - std::string reset_reason; - switch (esp_reset_reason()) { - case ESP_RST_POWERON: - reset_reason = "Reset due to power-on event"; - break; - case ESP_RST_EXT: - reset_reason = "Reset by external pin"; - break; - case ESP_RST_SW: - reset_reason = "Software reset via esp_restart"; - break; - case ESP_RST_PANIC: - reset_reason = "Software reset due to exception/panic"; - break; - case ESP_RST_INT_WDT: - reset_reason = "Reset (software or hardware) due to interrupt watchdog"; - break; - case ESP_RST_TASK_WDT: - reset_reason = "Reset due to task watchdog"; - break; - case ESP_RST_WDT: - reset_reason = "Reset due to other watchdogs"; - break; - case ESP_RST_DEEPSLEEP: - reset_reason = "Reset after exiting deep sleep mode"; - break; - case ESP_RST_BROWNOUT: - reset_reason = "Brownout reset (software or hardware)"; - break; - case ESP_RST_SDIO: - reset_reason = "Reset over SDIO"; - break; -#ifdef USE_ESP32_VARIANT_ESP32 -#if (ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 1, 4)) - case ESP_RST_USB: - reset_reason = "Reset by USB peripheral"; - break; - case ESP_RST_JTAG: - reset_reason = "Reset by JTAG"; - break; - case ESP_RST_EFUSE: - reset_reason = "Reset due to efuse error"; - break; - case ESP_RST_PWR_GLITCH: - reset_reason = "Reset due to power glitch detected"; - break; - case ESP_RST_CPU_LOCKUP: - reset_reason = "Reset due to CPU lock up (double exception)"; - break; -#endif // ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 1, 4) -#endif // USE_ESP32_VARIANT_ESP32 - default: // Includes ESP_RST_UNKNOWN - switch (rtc_get_reset_reason(0)) { - case POWERON_RESET: - reset_reason = "Power On Reset"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case SW_RESET: -#elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case RTC_SW_SYS_RESET: -#endif - reset_reason = "Software Reset Digital Core"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case OWDT_RESET: - reset_reason = "Watch Dog Reset Digital Core"; - break; -#endif - case DEEPSLEEP_RESET: - reset_reason = "Deep Sleep Reset Digital Core"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case SDIO_RESET: - reset_reason = "SLC Module Reset Digital Core"; - break; -#endif - case TG0WDT_SYS_RESET: - reset_reason = "Timer Group 0 Watch Dog Reset Digital Core"; - break; - case TG1WDT_SYS_RESET: - reset_reason = "Timer Group 1 Watch Dog Reset Digital Core"; - break; - case RTCWDT_SYS_RESET: - reset_reason = "RTC Watch Dog Reset Digital Core"; - break; -#if !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) - case INTRUSION_RESET: - reset_reason = "Intrusion Reset CPU"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32) - case TGWDT_CPU_RESET: - reset_reason = "Timer Group Reset CPU"; - break; -#elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case TG0WDT_CPU_RESET: - reset_reason = "Timer Group 0 Reset CPU"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32) - case SW_CPU_RESET: -#elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case RTC_SW_CPU_RESET: -#endif - reset_reason = "Software Reset CPU"; - break; - case RTCWDT_CPU_RESET: - reset_reason = "RTC Watch Dog Reset CPU"; - break; -#if defined(USE_ESP32_VARIANT_ESP32) - case EXT_CPU_RESET: - reset_reason = "External CPU Reset"; - break; -#endif - case RTCWDT_BROWN_OUT_RESET: - reset_reason = "Voltage Unstable Reset"; - break; - case RTCWDT_RTC_RESET: - reset_reason = "RTC Watch Dog Reset Digital Core And RTC Module"; - break; -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || \ - defined(USE_ESP32_VARIANT_ESP32C6) - case TG1WDT_CPU_RESET: - reset_reason = "Timer Group 1 Reset CPU"; - break; - case SUPER_WDT_RESET: - reset_reason = "Super Watchdog Reset Digital Core And RTC Module"; - break; - case EFUSE_RESET: - reset_reason = "eFuse Reset Digital Core"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) - case GLITCH_RTC_RESET: - reset_reason = "Glitch Reset Digital Core And RTC Module"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C6) - case USB_UART_CHIP_RESET: - reset_reason = "USB UART Reset Digital Core"; - break; - case USB_JTAG_CHIP_RESET: - reset_reason = "USB JTAG Reset Digital Core"; - break; -#endif -#if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32S3) - case POWER_GLITCH_RESET: - reset_reason = "Power Glitch Reset Digital Core And RTC Module"; - break; -#endif - default: - reset_reason = "Unknown Reset Reason"; - } - break; - } - ESP_LOGD(TAG, "Reset Reason: %s", reset_reason.c_str()); - return reset_reason; -} - uint32_t DebugComponent::get_free_heap_() { return heap_caps_get_free_size(MALLOC_CAP_INTERNAL); } +static const std::map CHIP_FEATURES = { + {CHIP_FEATURE_BLE, "BLE"}, + {CHIP_FEATURE_BT, "BT"}, + {CHIP_FEATURE_EMB_FLASH, "EMB Flash"}, + {CHIP_FEATURE_EMB_PSRAM, "EMB PSRAM"}, + {CHIP_FEATURE_WIFI_BGN, "2.4GHz WiFi"}, +}; + void DebugComponent::get_device_info_(std::string &device_info) { #if defined(USE_ARDUINO) const char *flash_mode; @@ -242,44 +164,16 @@ void DebugComponent::get_device_info_(std::string &device_info) { esp_chip_info_t info; esp_chip_info(&info); - const char *model; -#if defined(USE_ESP32_VARIANT_ESP32) - model = "ESP32"; -#elif defined(USE_ESP32_VARIANT_ESP32C3) - model = "ESP32-C3"; -#elif defined(USE_ESP32_VARIANT_ESP32C6) - model = "ESP32-C6"; -#elif defined(USE_ESP32_VARIANT_ESP32S2) - model = "ESP32-S2"; -#elif defined(USE_ESP32_VARIANT_ESP32S3) - model = "ESP32-S3"; -#elif defined(USE_ESP32_VARIANT_ESP32H2) - model = "ESP32-H2"; -#else - model = "UNKNOWN"; -#endif + const char *model = ESPHOME_VARIANT; std::string features; - if (info.features & CHIP_FEATURE_EMB_FLASH) { - features += "EMB_FLASH,"; - info.features &= ~CHIP_FEATURE_EMB_FLASH; + for (auto feature : CHIP_FEATURES) { + if (info.features & feature.first) { + features += feature.second; + features += ", "; + info.features &= ~feature.first; + } } - if (info.features & CHIP_FEATURE_WIFI_BGN) { - features += "WIFI_BGN,"; - info.features &= ~CHIP_FEATURE_WIFI_BGN; - } - if (info.features & CHIP_FEATURE_BLE) { - features += "BLE,"; - info.features &= ~CHIP_FEATURE_BLE; - } - if (info.features & CHIP_FEATURE_BT) { - features += "BT,"; - info.features &= ~CHIP_FEATURE_BT; - } - if (info.features & CHIP_FEATURE_EMB_PSRAM) { - features += "EMB_PSRAM,"; - info.features &= ~CHIP_FEATURE_EMB_PSRAM; - } - if (info.features) + if (info.features != 0) features += "Other:" + format_hex(info.features); ESP_LOGD(TAG, "Chip: Model=%s, Features=%s Cores=%u, Revision=%u", model, features.c_str(), info.cores, info.revision); @@ -289,6 +183,8 @@ void DebugComponent::get_device_info_(std::string &device_info) { device_info += features; device_info += " Cores:" + to_string(info.cores); device_info += " Revision:" + to_string(info.revision); + device_info += str_sprintf("|CPU Frequency: %" PRIu32 " MHz", arch_get_cpu_freq_hz() / 1000000); + ESP_LOGD(TAG, "CPU Frequency: %" PRIu32 " MHz", arch_get_cpu_freq_hz() / 1000000); // Framework detection device_info += "|Framework: "; @@ -315,48 +211,7 @@ void DebugComponent::get_device_info_(std::string &device_info) { device_info += "|Reset: "; device_info += get_reset_reason_(); - const char *wakeup_reason; - switch (rtc_get_wakeup_cause()) { - case NO_SLEEP: - wakeup_reason = "No Sleep"; - break; - case EXT_EVENT0_TRIG: - wakeup_reason = "External Event 0"; - break; - case EXT_EVENT1_TRIG: - wakeup_reason = "External Event 1"; - break; - case GPIO_TRIG: - wakeup_reason = "GPIO"; - break; - case TIMER_EXPIRE: - wakeup_reason = "Wakeup Timer"; - break; - case SDIO_TRIG: - wakeup_reason = "SDIO"; - break; - case MAC_TRIG: - wakeup_reason = "MAC"; - break; - case UART0_TRIG: - wakeup_reason = "UART0"; - break; - case UART1_TRIG: - wakeup_reason = "UART1"; - break; - case TOUCH_TRIG: - wakeup_reason = "Touch"; - break; - case SAR_TRIG: - wakeup_reason = "SAR"; - break; - case BT_TRIG: - wakeup_reason = "BT"; - break; - default: - wakeup_reason = "Unknown"; - } - ESP_LOGD(TAG, "Wakeup Reason: %s", wakeup_reason); + std::string wakeup_reason = this->get_wakeup_cause_(); device_info += "|Wakeup: "; device_info += wakeup_reason; } diff --git a/esphome/components/debug/sensor.py b/esphome/components/debug/sensor.py index 0a23658907..4669095d5d 100644 --- a/esphome/components/debug/sensor.py +++ b/esphome/components/debug/sensor.py @@ -1,5 +1,6 @@ import esphome.codegen as cg from esphome.components import sensor +from esphome.components.esp32 import CONF_CPU_FREQUENCY import esphome.config_validation as cv from esphome.const import ( CONF_BLOCK, @@ -10,6 +11,7 @@ from esphome.const import ( ICON_COUNTER, ICON_TIMER, UNIT_BYTES, + UNIT_HERTZ, UNIT_MILLISECOND, UNIT_PERCENT, ) @@ -60,6 +62,14 @@ CONFIG_SCHEMA = { entity_category=ENTITY_CATEGORY_DIAGNOSTIC, ), ), + cv.Optional(CONF_CPU_FREQUENCY): cv.All( + sensor.sensor_schema( + unit_of_measurement=UNIT_HERTZ, + icon="mdi:speedometer", + accuracy_decimals=0, + entity_category=ENTITY_CATEGORY_DIAGNOSTIC, + ), + ), } @@ -85,3 +95,7 @@ async def to_code(config): if psram_conf := config.get(CONF_PSRAM): sens = await sensor.new_sensor(psram_conf) cg.add(debug_component.set_psram_sensor(sens)) + + if cpu_freq_conf := config.get(CONF_CPU_FREQUENCY): + sens = await sensor.new_sensor(cpu_freq_conf) + cg.add(debug_component.set_cpu_frequency_sensor(sens)) diff --git a/esphome/components/deep_sleep/deep_sleep_esp32.cpp b/esphome/components/deep_sleep/deep_sleep_esp32.cpp index d647140865..4582d695f6 100644 --- a/esphome/components/deep_sleep/deep_sleep_esp32.cpp +++ b/esphome/components/deep_sleep/deep_sleep_esp32.cpp @@ -31,9 +31,12 @@ void DeepSleepComponent::set_wakeup_pin_mode(WakeupPinMode wakeup_pin_mode) { #if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) void DeepSleepComponent::set_ext1_wakeup(Ext1Wakeup ext1_wakeup) { this->ext1_wakeup_ = ext1_wakeup; } +#if !defined(USE_ESP32_VARIANT_ESP32H2) void DeepSleepComponent::set_touch_wakeup(bool touch_wakeup) { this->touch_wakeup_ = touch_wakeup; } #endif +#endif + void DeepSleepComponent::set_run_duration(WakeupCauseToRunDuration wakeup_cause_to_run_duration) { wakeup_cause_to_run_duration_ = wakeup_cause_to_run_duration; } @@ -65,7 +68,7 @@ bool DeepSleepComponent::prepare_to_sleep_() { } void DeepSleepComponent::deep_sleep_() { -#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) +#if !defined(USE_ESP32_VARIANT_ESP32C3) && !defined(USE_ESP32_VARIANT_ESP32C6) && !defined(USE_ESP32_VARIANT_ESP32H2) if (this->sleep_duration_.has_value()) esp_sleep_enable_timer_wakeup(*this->sleep_duration_); if (this->wakeup_pin_ != nullptr) { @@ -84,6 +87,15 @@ void DeepSleepComponent::deep_sleep_() { esp_sleep_pd_config(ESP_PD_DOMAIN_RTC_PERIPH, ESP_PD_OPTION_ON); } #endif + +#if defined(USE_ESP32_VARIANT_ESP32H2) + if (this->sleep_duration_.has_value()) + esp_sleep_enable_timer_wakeup(*this->sleep_duration_); + if (this->ext1_wakeup_.has_value()) { + esp_sleep_enable_ext1_wakeup(this->ext1_wakeup_->mask, this->ext1_wakeup_->wakeup_mode); + } +#endif + #if defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) if (this->sleep_duration_.has_value()) esp_sleep_enable_timer_wakeup(*this->sleep_duration_); diff --git a/esphome/components/demo/__init__.py b/esphome/components/demo/__init__.py index 349bd8e4cb..96ffb58b82 100644 --- a/esphome/components/demo/__init__.py +++ b/esphome/components/demo/__init__.py @@ -17,7 +17,6 @@ from esphome.const import ( CONF_DEVICE_CLASS, CONF_FORCE_UPDATE, CONF_ICON, - CONF_ID, CONF_INVERTED, CONF_MAX_VALUE, CONF_MIN_VALUE, @@ -153,9 +152,10 @@ CONFIG_SCHEMA = cv.Schema( }, ], ): [ - climate.CLIMATE_SCHEMA.extend(cv.COMPONENT_SCHEMA).extend( + climate.climate_schema(DemoClimate) + .extend(cv.COMPONENT_SCHEMA) + .extend( { - cv.GenerateID(): cv.declare_id(DemoClimate), cv.Required(CONF_TYPE): cv.enum(CLIMATE_TYPES, int=True), } ) @@ -183,9 +183,10 @@ CONFIG_SCHEMA = cv.Schema( }, ], ): [ - cover.COVER_SCHEMA.extend(cv.COMPONENT_SCHEMA).extend( + cover.cover_schema(DemoCover) + .extend(cv.COMPONENT_SCHEMA) + .extend( { - cv.GenerateID(): cv.declare_id(DemoCover), cv.Required(CONF_TYPE): cv.enum(COVER_TYPES, int=True), } ) @@ -211,9 +212,10 @@ CONFIG_SCHEMA = cv.Schema( }, ], ): [ - fan.FAN_SCHEMA.extend(cv.COMPONENT_SCHEMA).extend( + fan.fan_schema(DemoFan) + .extend(cv.COMPONENT_SCHEMA) + .extend( { - cv.GenerateID(CONF_OUTPUT_ID): cv.declare_id(DemoFan), cv.Required(CONF_TYPE): cv.enum(FAN_TYPES, int=True), } ) @@ -251,7 +253,9 @@ CONFIG_SCHEMA = cv.Schema( }, ], ): [ - light.RGB_LIGHT_SCHEMA.extend(cv.COMPONENT_SCHEMA).extend( + light.light_schema(DemoLight, light.LightType.RGB) + .extend(cv.COMPONENT_SCHEMA) + .extend( { cv.GenerateID(CONF_OUTPUT_ID): cv.declare_id(DemoLight), cv.Required(CONF_TYPE): cv.enum(LIGHT_TYPES, int=True), @@ -377,39 +381,33 @@ async def to_code(config): await cg.register_component(var, conf) for conf in config[CONF_CLIMATES]: - var = cg.new_Pvariable(conf[CONF_ID]) + var = await climate.new_climate(conf) await cg.register_component(var, conf) - await climate.register_climate(var, conf) cg.add(var.set_type(conf[CONF_TYPE])) for conf in config[CONF_COVERS]: - var = cg.new_Pvariable(conf[CONF_ID]) + var = await cover.new_cover(conf) await cg.register_component(var, conf) - await cover.register_cover(var, conf) cg.add(var.set_type(conf[CONF_TYPE])) for conf in config[CONF_FANS]: - var = cg.new_Pvariable(conf[CONF_OUTPUT_ID]) + var = await fan.new_fan(conf) await cg.register_component(var, conf) - await fan.register_fan(var, conf) cg.add(var.set_type(conf[CONF_TYPE])) for conf in config[CONF_LIGHTS]: - var = cg.new_Pvariable(conf[CONF_OUTPUT_ID]) + var = await light.new_light(conf) await cg.register_component(var, conf) - await light.register_light(var, conf) cg.add(var.set_type(conf[CONF_TYPE])) for conf in config[CONF_NUMBERS]: - var = cg.new_Pvariable(conf[CONF_ID]) - await cg.register_component(var, conf) - await number.register_number( - var, + var = await number.new_number( conf, min_value=conf[CONF_MIN_VALUE], max_value=conf[CONF_MAX_VALUE], step=conf[CONF_STEP], ) + await cg.register_component(var, conf) cg.add(var.set_type(conf[CONF_TYPE])) for conf in config[CONF_SENSORS]: diff --git a/esphome/components/dfrobot_sen0395/switch/__init__.py b/esphome/components/dfrobot_sen0395/switch/__init__.py index f854d08398..8e492080de 100644 --- a/esphome/components/dfrobot_sen0395/switch/__init__.py +++ b/esphome/components/dfrobot_sen0395/switch/__init__.py @@ -2,6 +2,7 @@ import esphome.codegen as cg from esphome.components import switch import esphome.config_validation as cv from esphome.const import CONF_TYPE, ENTITY_CATEGORY_CONFIG +from esphome.cpp_generator import MockObjClass from .. import CONF_DFROBOT_SEN0395_ID, DfrobotSen0395Component @@ -26,32 +27,30 @@ Sen0395StartAfterBootSwitch = dfrobot_sen0395_ns.class_( "Sen0395StartAfterBootSwitch", DfrobotSen0395Switch ) -_SWITCH_SCHEMA = ( - switch.switch_schema( - entity_category=ENTITY_CATEGORY_CONFIG, + +def _switch_schema(class_: MockObjClass) -> cv.Schema: + return ( + switch.switch_schema( + class_, + entity_category=ENTITY_CATEGORY_CONFIG, + ) + .extend( + { + cv.GenerateID(CONF_DFROBOT_SEN0395_ID): cv.use_id( + DfrobotSen0395Component + ), + } + ) + .extend(cv.COMPONENT_SCHEMA) ) - .extend( - { - cv.GenerateID(CONF_DFROBOT_SEN0395_ID): cv.use_id(DfrobotSen0395Component), - } - ) - .extend(cv.COMPONENT_SCHEMA) -) + CONFIG_SCHEMA = cv.typed_schema( { - "sensor_active": _SWITCH_SCHEMA.extend( - {cv.GenerateID(): cv.declare_id(Sen0395PowerSwitch)} - ), - "turn_on_led": _SWITCH_SCHEMA.extend( - {cv.GenerateID(): cv.declare_id(Sen0395LedSwitch)} - ), - "presence_via_uart": _SWITCH_SCHEMA.extend( - {cv.GenerateID(): cv.declare_id(Sen0395UartPresenceSwitch)} - ), - "start_after_boot": _SWITCH_SCHEMA.extend( - {cv.GenerateID(): cv.declare_id(Sen0395StartAfterBootSwitch)} - ), + "sensor_active": _switch_schema(Sen0395PowerSwitch), + "turn_on_led": _switch_schema(Sen0395LedSwitch), + "presence_via_uart": _switch_schema(Sen0395UartPresenceSwitch), + "start_after_boot": _switch_schema(Sen0395StartAfterBootSwitch), } ) diff --git a/esphome/components/endstop/cover.py b/esphome/components/endstop/cover.py index 286c876ff6..c16680b6af 100644 --- a/esphome/components/endstop/cover.py +++ b/esphome/components/endstop/cover.py @@ -6,7 +6,6 @@ from esphome.const import ( CONF_CLOSE_ACTION, CONF_CLOSE_DURATION, CONF_CLOSE_ENDSTOP, - CONF_ID, CONF_MAX_DURATION, CONF_OPEN_ACTION, CONF_OPEN_DURATION, @@ -17,25 +16,27 @@ from esphome.const import ( endstop_ns = cg.esphome_ns.namespace("endstop") EndstopCover = endstop_ns.class_("EndstopCover", cover.Cover, cg.Component) -CONFIG_SCHEMA = cover.COVER_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(EndstopCover), - cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), - cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, - cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_CLOSE_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), - cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_MAX_DURATION): cv.positive_time_period_milliseconds, - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + cover.cover_schema(EndstopCover) + .extend( + { + cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), + cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, + cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_CLOSE_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), + cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_MAX_DURATION): cv.positive_time_period_milliseconds, + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) await automation.build_automation( var.get_stop_trigger(), [], config[CONF_STOP_ACTION] diff --git a/esphome/components/esp32/__init__.py b/esphome/components/esp32/__init__.py index 912a8bf94b..12d0f9fcd5 100644 --- a/esphome/components/esp32/__init__.py +++ b/esphome/components/esp32/__init__.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +import itertools import logging import os from pathlib import Path @@ -37,6 +38,7 @@ from esphome.const import ( __version__, ) from esphome.core import CORE, HexInt, TimePeriod +from esphome.cpp_generator import RawExpression import esphome.final_validate as fv from esphome.helpers import copy_file_if_changed, mkdir_p, write_file_if_changed @@ -54,6 +56,12 @@ from .const import ( # noqa KEY_SUBMODULES, KEY_VARIANT, VARIANT_ESP32, + VARIANT_ESP32C2, + VARIANT_ESP32C3, + VARIANT_ESP32C6, + VARIANT_ESP32H2, + VARIANT_ESP32S2, + VARIANT_ESP32S3, VARIANT_FRIENDLY, VARIANTS, ) @@ -70,7 +78,43 @@ CONF_RELEASE = "release" CONF_ENABLE_IDF_EXPERIMENTAL_FEATURES = "enable_idf_experimental_features" +def get_cpu_frequencies(*frequencies): + return [str(x) + "MHZ" for x in frequencies] + + +CPU_FREQUENCIES = { + VARIANT_ESP32: get_cpu_frequencies(80, 160, 240), + VARIANT_ESP32S2: get_cpu_frequencies(80, 160, 240), + VARIANT_ESP32S3: get_cpu_frequencies(80, 160, 240), + VARIANT_ESP32C2: get_cpu_frequencies(80, 120), + VARIANT_ESP32C3: get_cpu_frequencies(80, 160), + VARIANT_ESP32C6: get_cpu_frequencies(80, 120, 160), + VARIANT_ESP32H2: get_cpu_frequencies(16, 32, 48, 64, 96), +} + +# Make sure not missed here if a new variant added. +assert all(v in CPU_FREQUENCIES for v in VARIANTS) + +FULL_CPU_FREQUENCIES = set(itertools.chain.from_iterable(CPU_FREQUENCIES.values())) + + def set_core_data(config): + cpu_frequency = config.get(CONF_CPU_FREQUENCY, None) + variant = config[CONF_VARIANT] + # if not specified in config, set to 160MHz if supported, the fastest otherwise + if cpu_frequency is None: + choices = CPU_FREQUENCIES[variant] + if "160MHZ" in choices: + cpu_frequency = "160MHZ" + else: + cpu_frequency = choices[-1] + config[CONF_CPU_FREQUENCY] = cpu_frequency + elif cpu_frequency not in CPU_FREQUENCIES[variant]: + raise cv.Invalid( + f"Invalid CPU frequency '{cpu_frequency}' for {config[CONF_VARIANT]}", + path=[CONF_CPU_FREQUENCY], + ) + CORE.data[KEY_ESP32] = {} CORE.data[KEY_CORE][KEY_TARGET_PLATFORM] = PLATFORM_ESP32 conf = config[CONF_FRAMEWORK] @@ -83,6 +127,7 @@ def set_core_data(config): CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] = cv.Version.parse( config[CONF_FRAMEWORK][CONF_VERSION] ) + CORE.data[KEY_ESP32][KEY_BOARD] = config[CONF_BOARD] CORE.data[KEY_ESP32][KEY_VARIANT] = config[CONF_VARIANT] CORE.data[KEY_ESP32][KEY_EXTRA_BUILD_FILES] = {} @@ -251,7 +296,7 @@ ARDUINO_PLATFORM_VERSION = cv.Version(5, 4, 0) # The default/recommended esp-idf framework version # - https://github.com/espressif/esp-idf/releases # - https://api.registry.platformio.org/v3/packages/platformio/tool/framework-espidf -RECOMMENDED_ESP_IDF_FRAMEWORK_VERSION = cv.Version(5, 1, 5) +RECOMMENDED_ESP_IDF_FRAMEWORK_VERSION = cv.Version(5, 1, 6) # The platformio/espressif32 version to use for esp-idf frameworks # - https://github.com/platformio/platform-espressif32/releases # - https://api.registry.platformio.org/v3/packages/platformio/platform/espressif32 @@ -274,12 +319,15 @@ SUPPORTED_PLATFORMIO_ESP_IDF_5X = [ # pioarduino versions that don't require a release number # List based on https://github.com/pioarduino/esp-idf/releases SUPPORTED_PIOARDUINO_ESP_IDF_5X = [ + cv.Version(5, 5, 0), cv.Version(5, 4, 1), cv.Version(5, 4, 0), + cv.Version(5, 3, 3), cv.Version(5, 3, 2), cv.Version(5, 3, 1), cv.Version(5, 3, 0), cv.Version(5, 1, 5), + cv.Version(5, 1, 6), ] @@ -321,8 +369,8 @@ def _arduino_check_versions(value): def _esp_idf_check_versions(value): value = value.copy() lookups = { - "dev": (cv.Version(5, 1, 5), "https://github.com/espressif/esp-idf.git"), - "latest": (cv.Version(5, 1, 5), None), + "dev": (cv.Version(5, 1, 6), "https://github.com/espressif/esp-idf.git"), + "latest": (cv.Version(5, 1, 6), None), "recommended": (RECOMMENDED_ESP_IDF_FRAMEWORK_VERSION, None), } @@ -550,11 +598,15 @@ FLASH_SIZES = [ ] CONF_FLASH_SIZE = "flash_size" +CONF_CPU_FREQUENCY = "cpu_frequency" CONF_PARTITIONS = "partitions" CONFIG_SCHEMA = cv.All( cv.Schema( { cv.Required(CONF_BOARD): cv.string_strict, + cv.Optional(CONF_CPU_FREQUENCY): cv.one_of( + *FULL_CPU_FREQUENCIES, upper=True + ), cv.Optional(CONF_FLASH_SIZE, default="4MB"): cv.one_of( *FLASH_SIZES, upper=True ), @@ -595,6 +647,7 @@ async def to_code(config): os.path.join(os.path.dirname(__file__), "post_build.py.script"), ) + freq = config[CONF_CPU_FREQUENCY][:-3] if conf[CONF_TYPE] == FRAMEWORK_ESP_IDF: cg.add_platformio_option("framework", "espidf") cg.add_build_flag("-DUSE_ESP_IDF") @@ -628,6 +681,9 @@ async def to_code(config): add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU0", False) add_idf_sdkconfig_option("CONFIG_ESP_TASK_WDT_CHECK_IDLE_TASK_CPU1", False) + # Set default CPU frequency + add_idf_sdkconfig_option(f"CONFIG_ESP_DEFAULT_CPU_FREQ_MHZ_{freq}", True) + cg.add_platformio_option("board_build.partitions", "partitions.csv") if CONF_PARTITIONS in config: add_extra_build_file( @@ -693,6 +749,7 @@ async def to_code(config): f"VERSION_CODE({framework_ver.major}, {framework_ver.minor}, {framework_ver.patch})" ), ) + cg.add(RawExpression(f"setCpuFrequencyMhz({freq})")) APP_PARTITION_SIZES = { diff --git a/esphome/components/esp32/core.cpp b/esphome/components/esp32/core.cpp index ff8e663ec1..c90d68d00e 100644 --- a/esphome/components/esp32/core.cpp +++ b/esphome/components/esp32/core.cpp @@ -13,11 +13,13 @@ #include #ifdef USE_ARDUINO -#include -#endif +#include +#else +#include void setup(); void loop(); +#endif namespace esphome { @@ -59,9 +61,13 @@ uint32_t arch_get_cpu_cycle_count() { return esp_cpu_get_cycle_count(); } uint32_t arch_get_cpu_cycle_count() { return cpu_hal_get_cycle_count(); } #endif uint32_t arch_get_cpu_freq_hz() { - rtc_cpu_freq_config_t config; - rtc_clk_cpu_freq_get_config(&config); - return config.freq_mhz * 1000000U; + uint32_t freq = 0; +#ifdef USE_ESP_IDF + esp_clk_tree_src_get_freq_hz(SOC_MOD_CLK_CPU, ESP_CLK_TREE_SRC_FREQ_PRECISION_CACHED, &freq); +#elif defined(USE_ARDUINO) + freq = ESP.getCpuFreqMHz() * 1000000; +#endif + return freq; } #ifdef USE_ESP_IDF diff --git a/esphome/components/esp32/gpio.cpp b/esphome/components/esp32/gpio.cpp index 7896597d3e..b554b6d09c 100644 --- a/esphome/components/esp32/gpio.cpp +++ b/esphome/components/esp32/gpio.cpp @@ -2,42 +2,66 @@ #include "gpio.h" #include "esphome/core/log.h" +#include "driver/gpio.h" +#include "driver/rtc_io.h" +#include "hal/gpio_hal.h" +#include "soc/soc_caps.h" +#include "soc/gpio_periph.h" #include +#if (SOC_RTCIO_PIN_COUNT > 0) +#include "hal/rtc_io_hal.h" +#endif + +#ifndef SOC_GPIO_SUPPORT_RTC_INDEPENDENT +#define SOC_GPIO_SUPPORT_RTC_INDEPENDENT 0 // NOLINT +#endif + namespace esphome { namespace esp32 { static const char *const TAG = "esp32"; +static const gpio_hal_context_t GPIO_HAL = {.dev = GPIO_HAL_GET_HW(GPIO_PORT_0)}; + bool ESP32InternalGPIOPin::isr_service_installed = false; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) -static gpio_mode_t IRAM_ATTR flags_to_mode(gpio::Flags flags) { +static gpio_mode_t flags_to_mode(gpio::Flags flags) { flags = (gpio::Flags)(flags & ~(gpio::FLAG_PULLUP | gpio::FLAG_PULLDOWN)); - if (flags == gpio::FLAG_INPUT) { + if (flags == gpio::FLAG_INPUT) return GPIO_MODE_INPUT; - } else if (flags == gpio::FLAG_OUTPUT) { + if (flags == gpio::FLAG_OUTPUT) return GPIO_MODE_OUTPUT; - } else if (flags == (gpio::FLAG_OUTPUT | gpio::FLAG_OPEN_DRAIN)) { + if (flags == (gpio::FLAG_OUTPUT | gpio::FLAG_OPEN_DRAIN)) return GPIO_MODE_OUTPUT_OD; - } else if (flags == (gpio::FLAG_INPUT | gpio::FLAG_OUTPUT | gpio::FLAG_OPEN_DRAIN)) { + if (flags == (gpio::FLAG_INPUT | gpio::FLAG_OUTPUT | gpio::FLAG_OPEN_DRAIN)) return GPIO_MODE_INPUT_OUTPUT_OD; - } else if (flags == (gpio::FLAG_INPUT | gpio::FLAG_OUTPUT)) { + if (flags == (gpio::FLAG_INPUT | gpio::FLAG_OUTPUT)) return GPIO_MODE_INPUT_OUTPUT; - } else { - // unsupported or gpio::FLAG_NONE - return GPIO_MODE_DISABLE; - } + // unsupported or gpio::FLAG_NONE + return GPIO_MODE_DISABLE; } struct ISRPinArg { gpio_num_t pin; + gpio::Flags flags; bool inverted; +#if defined(USE_ESP32_VARIANT_ESP32) + bool use_rtc; + int rtc_pin; +#endif }; ISRInternalGPIOPin ESP32InternalGPIOPin::to_isr() const { auto *arg = new ISRPinArg{}; // NOLINT(cppcoreguidelines-owning-memory) - arg->pin = pin_; + arg->pin = this->pin_; + arg->flags = gpio::FLAG_NONE; arg->inverted = inverted_; +#if defined(USE_ESP32_VARIANT_ESP32) + arg->use_rtc = rtc_gpio_is_valid_gpio(this->pin_); + if (arg->use_rtc) + arg->rtc_pin = rtc_io_number_get(this->pin_); +#endif return ISRInternalGPIOPin((void *) arg); } @@ -90,6 +114,7 @@ void ESP32InternalGPIOPin::setup() { if (flags_ & gpio::FLAG_OUTPUT) { gpio_set_drive_capability(pin_, drive_strength_); } + ESP_LOGD(TAG, "rtc: %d", SOC_GPIO_SUPPORT_RTC_INDEPENDENT); } void ESP32InternalGPIOPin::pin_mode(gpio::Flags flags) { @@ -115,28 +140,65 @@ void ESP32InternalGPIOPin::detach_interrupt() const { gpio_intr_disable(pin_); } using namespace esp32; bool IRAM_ATTR ISRInternalGPIOPin::digital_read() { - auto *arg = reinterpret_cast(arg_); - return bool(gpio_get_level(arg->pin)) != arg->inverted; + auto *arg = reinterpret_cast(this->arg_); + return bool(gpio_hal_get_level(&GPIO_HAL, arg->pin)) != arg->inverted; } + void IRAM_ATTR ISRInternalGPIOPin::digital_write(bool value) { - auto *arg = reinterpret_cast(arg_); - gpio_set_level(arg->pin, value != arg->inverted ? 1 : 0); + auto *arg = reinterpret_cast(this->arg_); + gpio_hal_set_level(&GPIO_HAL, arg->pin, value != arg->inverted); } + void IRAM_ATTR ISRInternalGPIOPin::clear_interrupt() { // not supported } + void IRAM_ATTR ISRInternalGPIOPin::pin_mode(gpio::Flags flags) { auto *arg = reinterpret_cast(arg_); - gpio_set_direction(arg->pin, flags_to_mode(flags)); - gpio_pull_mode_t pull_mode = GPIO_FLOATING; - if ((flags & gpio::FLAG_PULLUP) && (flags & gpio::FLAG_PULLDOWN)) { - pull_mode = GPIO_PULLUP_PULLDOWN; - } else if (flags & gpio::FLAG_PULLUP) { - pull_mode = GPIO_PULLUP_ONLY; - } else if (flags & gpio::FLAG_PULLDOWN) { - pull_mode = GPIO_PULLDOWN_ONLY; + gpio::Flags diff = (gpio::Flags)(flags ^ arg->flags); + if (diff & gpio::FLAG_OUTPUT) { + if (flags & gpio::FLAG_OUTPUT) { + gpio_hal_output_enable(&GPIO_HAL, arg->pin); + if (flags & gpio::FLAG_OPEN_DRAIN) + gpio_hal_od_enable(&GPIO_HAL, arg->pin); + } else { + gpio_hal_output_disable(&GPIO_HAL, arg->pin); + } } - gpio_set_pull_mode(arg->pin, pull_mode); + if (diff & gpio::FLAG_INPUT) { + if (flags & gpio::FLAG_INPUT) { + gpio_hal_input_enable(&GPIO_HAL, arg->pin); +#if defined(USE_ESP32_VARIANT_ESP32) + if (arg->use_rtc) { + if (flags & gpio::FLAG_PULLUP) { + rtcio_hal_pullup_enable(arg->rtc_pin); + } else { + rtcio_hal_pullup_disable(arg->rtc_pin); + } + if (flags & gpio::FLAG_PULLDOWN) { + rtcio_hal_pulldown_enable(arg->rtc_pin); + } else { + rtcio_hal_pulldown_disable(arg->rtc_pin); + } + } else +#endif + { + if (flags & gpio::FLAG_PULLUP) { + gpio_hal_pullup_en(&GPIO_HAL, arg->pin); + } else { + gpio_hal_pullup_dis(&GPIO_HAL, arg->pin); + } + if (flags & gpio::FLAG_PULLDOWN) { + gpio_hal_pulldown_en(&GPIO_HAL, arg->pin); + } else { + gpio_hal_pulldown_dis(&GPIO_HAL, arg->pin); + } + } + } else { + gpio_hal_input_disable(&GPIO_HAL, arg->pin); + } + } + arg->flags = flags; } } // namespace esphome diff --git a/esphome/components/esp32/gpio.py b/esphome/components/esp32/gpio.py index df01769a66..2bb10ce6ec 100644 --- a/esphome/components/esp32/gpio.py +++ b/esphome/components/esp32/gpio.py @@ -1,6 +1,6 @@ from dataclasses import dataclass import logging -from typing import Any +from typing import Any, Callable from esphome import pins import esphome.codegen as cg @@ -64,8 +64,7 @@ def _lookup_pin(value): def _translate_pin(value): if isinstance(value, dict) or value is None: raise cv.Invalid( - "This variable only supports pin numbers, not full pin schemas " - "(with inverted and mode)." + "This variable only supports pin numbers, not full pin schemas (with inverted and mode)." ) if isinstance(value, int) and not isinstance(value, bool): return value @@ -82,30 +81,22 @@ def _translate_pin(value): @dataclass class ESP32ValidationFunctions: - pin_validation: Any - usage_validation: Any + pin_validation: Callable[[Any], Any] + usage_validation: Callable[[Any], Any] _esp32_validations = { VARIANT_ESP32: ESP32ValidationFunctions( pin_validation=esp32_validate_gpio_pin, usage_validation=esp32_validate_supports ), - VARIANT_ESP32S2: ESP32ValidationFunctions( - pin_validation=esp32_s2_validate_gpio_pin, - usage_validation=esp32_s2_validate_supports, + VARIANT_ESP32C2: ESP32ValidationFunctions( + pin_validation=esp32_c2_validate_gpio_pin, + usage_validation=esp32_c2_validate_supports, ), VARIANT_ESP32C3: ESP32ValidationFunctions( pin_validation=esp32_c3_validate_gpio_pin, usage_validation=esp32_c3_validate_supports, ), - VARIANT_ESP32S3: ESP32ValidationFunctions( - pin_validation=esp32_s3_validate_gpio_pin, - usage_validation=esp32_s3_validate_supports, - ), - VARIANT_ESP32C2: ESP32ValidationFunctions( - pin_validation=esp32_c2_validate_gpio_pin, - usage_validation=esp32_c2_validate_supports, - ), VARIANT_ESP32C6: ESP32ValidationFunctions( pin_validation=esp32_c6_validate_gpio_pin, usage_validation=esp32_c6_validate_supports, @@ -114,6 +105,14 @@ _esp32_validations = { pin_validation=esp32_h2_validate_gpio_pin, usage_validation=esp32_h2_validate_supports, ), + VARIANT_ESP32S2: ESP32ValidationFunctions( + pin_validation=esp32_s2_validate_gpio_pin, + usage_validation=esp32_s2_validate_supports, + ), + VARIANT_ESP32S3: ESP32ValidationFunctions( + pin_validation=esp32_s3_validate_gpio_pin, + usage_validation=esp32_s3_validate_supports, + ), } diff --git a/esphome/components/esp32/gpio_esp32.py b/esphome/components/esp32/gpio_esp32.py index e4d3b6aaf3..973d2dc0ef 100644 --- a/esphome/components/esp32/gpio_esp32.py +++ b/esphome/components/esp32/gpio_esp32.py @@ -31,8 +31,7 @@ def esp32_validate_gpio_pin(value): ) if 9 <= value <= 10: _LOGGER.warning( - "Pin %s (9-10) might already be used by the " - "flash interface in QUAD IO flash mode.", + "Pin %s (9-10) might already be used by the flash interface in QUAD IO flash mode.", value, ) if value in (24, 28, 29, 30, 31): diff --git a/esphome/components/esp32/gpio_esp32_c2.py b/esphome/components/esp32/gpio_esp32_c2.py index abdcb1b655..32a24050ca 100644 --- a/esphome/components/esp32/gpio_esp32_c2.py +++ b/esphome/components/esp32/gpio_esp32_c2.py @@ -22,7 +22,7 @@ def esp32_c2_validate_supports(value): is_input = mode[CONF_INPUT] if num < 0 or num > 20: - raise cv.Invalid(f"Invalid pin number: {value} (must be 0-20)") + raise cv.Invalid(f"Invalid pin number: {num} (must be 0-20)") if is_input: # All ESP32 pins support input mode diff --git a/esphome/components/esp32/gpio_esp32_c3.py b/esphome/components/esp32/gpio_esp32_c3.py index 5b9ec0ebd9..c1427cc02a 100644 --- a/esphome/components/esp32/gpio_esp32_c3.py +++ b/esphome/components/esp32/gpio_esp32_c3.py @@ -35,7 +35,7 @@ def esp32_c3_validate_supports(value): is_input = mode[CONF_INPUT] if num < 0 or num > 21: - raise cv.Invalid(f"Invalid pin number: {value} (must be 0-21)") + raise cv.Invalid(f"Invalid pin number: {num} (must be 0-21)") if is_input: # All ESP32 pins support input mode diff --git a/esphome/components/esp32/gpio_esp32_c6.py b/esphome/components/esp32/gpio_esp32_c6.py index bc735f85c4..d466adb994 100644 --- a/esphome/components/esp32/gpio_esp32_c6.py +++ b/esphome/components/esp32/gpio_esp32_c6.py @@ -36,7 +36,7 @@ def esp32_c6_validate_supports(value): is_input = mode[CONF_INPUT] if num < 0 or num > 23: - raise cv.Invalid(f"Invalid pin number: {value} (must be 0-23)") + raise cv.Invalid(f"Invalid pin number: {num} (must be 0-23)") if is_input: # All ESP32 pins support input mode pass diff --git a/esphome/components/esp32/gpio_esp32_h2.py b/esphome/components/esp32/gpio_esp32_h2.py index 7413bf4db5..7c3a658b17 100644 --- a/esphome/components/esp32/gpio_esp32_h2.py +++ b/esphome/components/esp32/gpio_esp32_h2.py @@ -45,7 +45,7 @@ def esp32_h2_validate_supports(value): is_input = mode[CONF_INPUT] if num < 0 or num > 27: - raise cv.Invalid(f"Invalid pin number: {value} (must be 0-27)") + raise cv.Invalid(f"Invalid pin number: {num} (must be 0-27)") if is_input: # All ESP32 pins support input mode pass diff --git a/esphome/components/esp32_ble/ble.cpp b/esphome/components/esp32_ble/ble.cpp index ab2647b738..fc1303673f 100644 --- a/esphome/components/esp32_ble/ble.cpp +++ b/esphome/components/esp32_ble/ble.cpp @@ -110,6 +110,7 @@ void ESP32BLE::advertising_init_() { this->advertising_->set_scan_response(true); this->advertising_->set_min_preferred_interval(0x06); + this->advertising_->set_appearance(this->appearance_); } bool ESP32BLE::ble_setup_() { diff --git a/esphome/components/esp32_ble/ble.h b/esphome/components/esp32_ble/ble.h index ed7575f128..13ec3b6dd9 100644 --- a/esphome/components/esp32_ble/ble.h +++ b/esphome/components/esp32_ble/ble.h @@ -95,6 +95,7 @@ class ESP32BLE : public Component { void advertising_start(); void advertising_set_service_data(const std::vector &data); void advertising_set_manufacturer_data(const std::vector &data); + void advertising_set_appearance(uint16_t appearance) { this->appearance_ = appearance; } void advertising_add_service_uuid(ESPBTUUID uuid); void advertising_remove_service_uuid(ESPBTUUID uuid); void advertising_register_raw_advertisement_callback(std::function &&callback); @@ -128,11 +129,12 @@ class ESP32BLE : public Component { BLEComponentState state_{BLE_COMPONENT_STATE_OFF}; Queue ble_events_; - BLEAdvertising *advertising_; + BLEAdvertising *advertising_{}; esp_ble_io_cap_t io_cap_{ESP_IO_CAP_NONE}; - uint32_t advertising_cycle_time_; - bool enable_on_boot_; + uint32_t advertising_cycle_time_{}; + bool enable_on_boot_{}; optional name_; + uint16_t appearance_{0}; }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/components/esp32_ble/ble_advertising.h b/esphome/components/esp32_ble/ble_advertising.h index 946e414c1d..0b2142115d 100644 --- a/esphome/components/esp32_ble/ble_advertising.h +++ b/esphome/components/esp32_ble/ble_advertising.h @@ -32,6 +32,7 @@ class BLEAdvertising { void set_scan_response(bool scan_response) { this->scan_response_ = scan_response; } void set_min_preferred_interval(uint16_t interval) { this->advertising_data_.min_interval = interval; } void set_manufacturer_data(const std::vector &data); + void set_appearance(uint16_t appearance) { this->advertising_data_.appearance = appearance; } void set_service_data(const std::vector &data); void register_raw_advertisement_callback(std::function &&callback); diff --git a/esphome/components/esp32_ble_server/__init__.py b/esphome/components/esp32_ble_server/__init__.py index ab8e27ec43..0fcb5c9822 100644 --- a/esphome/components/esp32_ble_server/__init__.py +++ b/esphome/components/esp32_ble_server/__init__.py @@ -32,6 +32,7 @@ DEPENDENCIES = ["esp32"] DOMAIN = "esp32_ble_server" CONF_ADVERTISE = "advertise" +CONF_APPEARANCE = "appearance" CONF_BROADCAST = "broadcast" CONF_CHARACTERISTICS = "characteristics" CONF_DESCRIPTION = "description" @@ -421,6 +422,7 @@ CONFIG_SCHEMA = cv.Schema( cv.GenerateID(): cv.declare_id(BLEServer), cv.GenerateID(esp32_ble.CONF_BLE_ID): cv.use_id(esp32_ble.ESP32BLE), cv.Optional(CONF_MANUFACTURER): value_schema("string", templatable=False), + cv.Optional(CONF_APPEARANCE, default=0): cv.uint16_t, cv.Optional(CONF_MODEL): value_schema("string", templatable=False), cv.Optional(CONF_FIRMWARE_VERSION): value_schema("string", templatable=False), cv.Optional(CONF_MANUFACTURER_DATA): cv.Schema([cv.uint8_t]), @@ -531,6 +533,7 @@ async def to_code(config): cg.add(parent.register_gatts_event_handler(var)) cg.add(parent.register_ble_status_event_handler(var)) cg.add(var.set_parent(parent)) + cg.add(parent.advertising_set_appearance(config[CONF_APPEARANCE])) if CONF_MANUFACTURER_DATA in config: cg.add(var.set_manufacturer_data(config[CONF_MANUFACTURER_DATA])) for service_config in config[CONF_SERVICES]: diff --git a/esphome/components/esp32_ble_tracker/__init__.py b/esphome/components/esp32_ble_tracker/__init__.py index 68be2cbbe9..a4425b9680 100644 --- a/esphome/components/esp32_ble_tracker/__init__.py +++ b/esphome/components/esp32_ble_tracker/__init__.py @@ -17,6 +17,7 @@ from esphome.components.esp32_ble import ( import esphome.config_validation as cv from esphome.const import ( CONF_ACTIVE, + CONF_CONTINUOUS, CONF_DURATION, CONF_ID, CONF_INTERVAL, @@ -42,8 +43,8 @@ CONF_MAX_CONNECTIONS = "max_connections" CONF_ESP32_BLE_ID = "esp32_ble_id" CONF_SCAN_PARAMETERS = "scan_parameters" CONF_WINDOW = "window" -CONF_CONTINUOUS = "continuous" CONF_ON_SCAN_END = "on_scan_end" +CONF_SOFTWARE_COEXISTENCE = "software_coexistence" DEFAULT_MAX_CONNECTIONS = 3 IDF_MAX_CONNECTIONS = 9 @@ -203,6 +204,7 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_ON_SCAN_END): automation.validate_automation( {cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(BLEEndOfScanTrigger)} ), + cv.OnlyWith(CONF_SOFTWARE_COEXISTENCE, "wifi", default=True): bool, } ).extend(cv.COMPONENT_SCHEMA), ) @@ -310,6 +312,8 @@ async def to_code(config): if CORE.using_esp_idf: add_idf_sdkconfig_option("CONFIG_BT_ENABLED", True) + if config.get(CONF_SOFTWARE_COEXISTENCE): + add_idf_sdkconfig_option("CONFIG_SW_COEXIST_ENABLE", True) # https://github.com/espressif/esp-idf/issues/4101 # https://github.com/espressif/esp-idf/issues/2503 # Match arduino CONFIG_BTU_TASK_STACK_SIZE @@ -331,6 +335,8 @@ async def to_code(config): cg.add_define("USE_OTA_STATE_CALLBACK") # To be notified when an OTA update starts cg.add_define("USE_ESP32_BLE_CLIENT") + if config.get(CONF_SOFTWARE_COEXISTENCE): + cg.add_define("USE_ESP32_BLE_SOFTWARE_COEXISTENCE") ESP32_BLE_START_SCAN_ACTION_SCHEMA = cv.Schema( diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp index 760aac628a..be45b177ff 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.cpp @@ -21,6 +21,10 @@ #include "esphome/components/ota/ota_backend.h" #endif +#ifdef USE_ESP32_BLE_SOFTWARE_COEXISTENCE +#include +#endif + #ifdef USE_ARDUINO #include #endif @@ -57,7 +61,6 @@ void ESP32BLETracker::setup() { global_esp32_ble_tracker = this; this->scan_result_lock_ = xSemaphoreCreateMutex(); - this->scan_end_lock_ = xSemaphoreCreateMutex(); #ifdef USE_OTA ota::get_global_ota_callback()->add_on_state_callback( @@ -117,119 +120,119 @@ void ESP32BLETracker::loop() { } bool promote_to_connecting = discovered && !searching && !connecting; - if (!this->scanner_idle_) { - if (this->scan_result_index_ && // if it looks like we have a scan result we will take the lock - xSemaphoreTake(this->scan_result_lock_, 5L / portTICK_PERIOD_MS)) { - uint32_t index = this->scan_result_index_; - if (index >= ESP32BLETracker::SCAN_RESULT_BUFFER_SIZE) { - ESP_LOGW(TAG, "Too many BLE events to process. Some devices may not show up."); - } + if (this->scanner_state_ == ScannerState::RUNNING && + this->scan_result_index_ && // if it looks like we have a scan result we will take the lock + xSemaphoreTake(this->scan_result_lock_, 5L / portTICK_PERIOD_MS)) { + uint32_t index = this->scan_result_index_; + if (index >= ESP32BLETracker::SCAN_RESULT_BUFFER_SIZE) { + ESP_LOGW(TAG, "Too many BLE events to process. Some devices may not show up."); + } - if (this->raw_advertisements_) { + if (this->raw_advertisements_) { + for (auto *listener : this->listeners_) { + listener->parse_devices(this->scan_result_buffer_, this->scan_result_index_); + } + for (auto *client : this->clients_) { + client->parse_devices(this->scan_result_buffer_, this->scan_result_index_); + } + } + + if (this->parse_advertisements_) { + for (size_t i = 0; i < index; i++) { + ESPBTDevice device; + device.parse_scan_rst(this->scan_result_buffer_[i]); + + bool found = false; for (auto *listener : this->listeners_) { - listener->parse_devices(this->scan_result_buffer_, this->scan_result_index_); + if (listener->parse_device(device)) + found = true; } + for (auto *client : this->clients_) { - client->parse_devices(this->scan_result_buffer_, this->scan_result_index_); - } - } - - if (this->parse_advertisements_) { - for (size_t i = 0; i < index; i++) { - ESPBTDevice device; - device.parse_scan_rst(this->scan_result_buffer_[i]); - - bool found = false; - for (auto *listener : this->listeners_) { - if (listener->parse_device(device)) - found = true; - } - - for (auto *client : this->clients_) { - if (client->parse_device(device)) { - found = true; - if (!connecting && client->state() == ClientState::DISCOVERED) { - promote_to_connecting = true; - } + if (client->parse_device(device)) { + found = true; + if (!connecting && client->state() == ClientState::DISCOVERED) { + promote_to_connecting = true; } } + } - if (!found && !this->scan_continuous_) { - this->print_bt_device_info(device); - } + if (!found && !this->scan_continuous_) { + this->print_bt_device_info(device); } } - this->scan_result_index_ = 0; - xSemaphoreGive(this->scan_result_lock_); } - - /* - - Avoid starting the scanner if: - - we are already scanning - - we are connecting to a device - - we are disconnecting from a device - - Otherwise the scanner could fail to ever start again - and our only way to recover is to reboot. - - https://github.com/espressif/esp-idf/issues/6688 - - */ - if (!connecting && xSemaphoreTake(this->scan_end_lock_, 0L)) { - if (this->scan_continuous_) { - if (!disconnecting && !promote_to_connecting && !this->scan_start_failed_ && !this->scan_set_param_failed_) { - this->start_scan_(false); - } else { - // We didn't start the scan, so we need to release the lock - xSemaphoreGive(this->scan_end_lock_); - } - } else if (!this->scanner_idle_) { - this->end_of_scan_(); - return; - } + this->scan_result_index_ = 0; + xSemaphoreGive(this->scan_result_lock_); + } + if (this->scanner_state_ == ScannerState::STOPPED) { + this->end_of_scan_(); // Change state to IDLE + } + if (this->scanner_state_ == ScannerState::FAILED || + (this->scan_set_param_failed_ && this->scanner_state_ == ScannerState::RUNNING)) { + this->stop_scan_(); + if (this->scan_start_fail_count_ == std::numeric_limits::max()) { + ESP_LOGE(TAG, "ESP-IDF BLE scan could not restart after %d attempts, rebooting to restore BLE stack...", + std::numeric_limits::max()); + App.reboot(); } - - if (this->scan_start_failed_ || this->scan_set_param_failed_) { - if (this->scan_start_fail_count_ == std::numeric_limits::max()) { - ESP_LOGE(TAG, "ESP-IDF BLE scan could not restart after %d attempts, rebooting to restore BLE stack...", - std::numeric_limits::max()); - App.reboot(); - } - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - xSemaphoreGive(this->scan_end_lock_); - } else { - ESP_LOGD(TAG, "Stopping scan after failure..."); - this->stop_scan_(); - } - if (this->scan_start_failed_) { - ESP_LOGE(TAG, "Scan start failed: %d", this->scan_start_failed_); - this->scan_start_failed_ = ESP_BT_STATUS_SUCCESS; - } - if (this->scan_set_param_failed_) { - ESP_LOGE(TAG, "Scan set param failed: %d", this->scan_set_param_failed_); - this->scan_set_param_failed_ = ESP_BT_STATUS_SUCCESS; - } + if (this->scan_start_failed_) { + ESP_LOGE(TAG, "Scan start failed: %d", this->scan_start_failed_); + this->scan_start_failed_ = ESP_BT_STATUS_SUCCESS; + } + if (this->scan_set_param_failed_) { + ESP_LOGE(TAG, "Scan set param failed: %d", this->scan_set_param_failed_); + this->scan_set_param_failed_ = ESP_BT_STATUS_SUCCESS; } } + /* + Avoid starting the scanner if: + - we are already scanning + - we are connecting to a device + - we are disconnecting from a device + + Otherwise the scanner could fail to ever start again + and our only way to recover is to reboot. + + https://github.com/espressif/esp-idf/issues/6688 + + */ + if (this->scanner_state_ == ScannerState::IDLE && !connecting && !disconnecting && !promote_to_connecting) { +#ifdef USE_ESP32_BLE_SOFTWARE_COEXISTENCE + if (this->coex_prefer_ble_) { + this->coex_prefer_ble_ = false; + ESP_LOGD(TAG, "Setting coexistence preference to balanced."); + esp_coex_preference_set(ESP_COEX_PREFER_BALANCE); // Reset to default + } +#endif + if (this->scan_continuous_) { + this->start_scan_(false); // first = false + } + } // If there is a discovered client and no connecting // clients and no clients using the scanner to search for // devices, then stop scanning and promote the discovered // client to ready to connect. - if (promote_to_connecting) { + if (promote_to_connecting && + (this->scanner_state_ == ScannerState::RUNNING || this->scanner_state_ == ScannerState::IDLE)) { for (auto *client : this->clients_) { if (client->state() == ClientState::DISCOVERED) { - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - // Scanner is not running since we got the - // lock, so we can promote the client. - xSemaphoreGive(this->scan_end_lock_); + if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGD(TAG, "Stopping scan to make connection..."); + this->stop_scan_(); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGD(TAG, "Promoting client to connect..."); // We only want to promote one client at a time. // once the scanner is fully stopped. +#ifdef USE_ESP32_BLE_SOFTWARE_COEXISTENCE + ESP_LOGD(TAG, "Setting coexistence to Bluetooth to make connection."); + if (!this->coex_prefer_ble_) { + this->coex_prefer_ble_ = true; + esp_coex_preference_set(ESP_COEX_PREFER_BT); // Prioritize Bluetooth + } +#endif client->set_state(ClientState::READY_TO_CONNECT); - } else { - ESP_LOGD(TAG, "Pausing scan to make connection..."); - this->stop_scan_(); } break; } @@ -237,13 +240,7 @@ void ESP32BLETracker::loop() { } } -void ESP32BLETracker::start_scan() { - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - this->start_scan_(true); - } else { - ESP_LOGW(TAG, "Scan requested when a scan is already in progress. Ignoring."); - } -} +void ESP32BLETracker::start_scan() { this->start_scan_(true); } void ESP32BLETracker::stop_scan() { ESP_LOGD(TAG, "Stopping scan."); @@ -251,16 +248,23 @@ void ESP32BLETracker::stop_scan() { this->stop_scan_(); } -void ESP32BLETracker::ble_before_disabled_event_handler() { - this->stop_scan_(); - xSemaphoreGive(this->scan_end_lock_); -} +void ESP32BLETracker::ble_before_disabled_event_handler() { this->stop_scan_(); } void ESP32BLETracker::stop_scan_() { - this->cancel_timeout("scan"); - if (this->scanner_idle_) { + if (this->scanner_state_ != ScannerState::RUNNING && this->scanner_state_ != ScannerState::FAILED) { + if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan is already stopped while trying to stop."); + } else if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Scan is starting while trying to stop."); + } else if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Scan is already stopping while trying to stop."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan is already stopped while trying to stop."); + } return; } + this->cancel_timeout("scan"); + this->set_scanner_state_(ScannerState::STOPPING); esp_err_t err = esp_ble_gap_stop_scanning(); if (err != ESP_OK) { ESP_LOGE(TAG, "esp_ble_gap_stop_scanning failed: %d", err); @@ -273,13 +277,22 @@ void ESP32BLETracker::start_scan_(bool first) { ESP_LOGW(TAG, "Cannot start scan while ESP32BLE is disabled."); return; } - // The lock must be held when calling this function. - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - ESP_LOGE(TAG, "start_scan called without holding scan_end_lock_"); + if (this->scanner_state_ != ScannerState::IDLE) { + if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Cannot start scan while already starting."); + } else if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGE(TAG, "Cannot start scan while already running."); + } else if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Cannot start scan while already stopping."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Cannot start scan while already failed."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Cannot start scan while already stopped."); + } return; } - - ESP_LOGD(TAG, "Starting scan..."); + this->set_scanner_state_(ScannerState::STARTING); + ESP_LOGD(TAG, "Starting scan, set scanner state to STARTING."); if (!first) { for (auto *listener : this->listeners_) listener->on_scan_end(); @@ -307,24 +320,21 @@ void ESP32BLETracker::start_scan_(bool first) { ESP_LOGE(TAG, "esp_ble_gap_start_scanning failed: %d", err); return; } - this->scanner_idle_ = false; } void ESP32BLETracker::end_of_scan_() { // The lock must be held when calling this function. - if (xSemaphoreTake(this->scan_end_lock_, 0L)) { - ESP_LOGE(TAG, "end_of_scan_ called without holding the scan_end_lock_"); + if (this->scanner_state_ != ScannerState::STOPPED) { + ESP_LOGE(TAG, "end_of_scan_ called while scanner is not stopped."); return; } - - ESP_LOGD(TAG, "End of scan."); - this->scanner_idle_ = true; + ESP_LOGD(TAG, "End of scan, set scanner state to IDLE."); this->already_discovered_.clear(); - xSemaphoreGive(this->scan_end_lock_); this->cancel_timeout("scan"); for (auto *listener : this->listeners_) listener->on_scan_end(); + this->set_scanner_state_(ScannerState::IDLE); } void ESP32BLETracker::register_client(ESPBTClient *client) { @@ -392,19 +402,46 @@ void ESP32BLETracker::gap_scan_set_param_complete_(const esp_ble_gap_cb_param_t: void ESP32BLETracker::gap_scan_start_complete_(const esp_ble_gap_cb_param_t::ble_scan_start_cmpl_evt_param ¶m) { ESP_LOGV(TAG, "gap_scan_start_complete - status %d", param.status); this->scan_start_failed_ = param.status; + if (this->scanner_state_ != ScannerState::STARTING) { + if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGE(TAG, "Scan was already running when start complete."); + } else if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Scan was stopping when start complete."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Scan was in failed state when start complete."); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan was idle when start complete."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan was stopped when start complete."); + } + } if (param.status == ESP_BT_STATUS_SUCCESS) { this->scan_start_fail_count_ = 0; + this->set_scanner_state_(ScannerState::RUNNING); } else { + this->set_scanner_state_(ScannerState::FAILED); if (this->scan_start_fail_count_ != std::numeric_limits::max()) { this->scan_start_fail_count_++; } - xSemaphoreGive(this->scan_end_lock_); } } void ESP32BLETracker::gap_scan_stop_complete_(const esp_ble_gap_cb_param_t::ble_scan_stop_cmpl_evt_param ¶m) { ESP_LOGV(TAG, "gap_scan_stop_complete - status %d", param.status); - xSemaphoreGive(this->scan_end_lock_); + if (this->scanner_state_ != ScannerState::STOPPING) { + if (this->scanner_state_ == ScannerState::RUNNING) { + ESP_LOGE(TAG, "Scan was not running when stop complete."); + } else if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Scan was not started when stop complete."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Scan was in failed state when stop complete."); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan was idle when stop complete."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan was stopped when stop complete."); + } + } + this->set_scanner_state_(ScannerState::STOPPED); } void ESP32BLETracker::gap_scan_result_(const esp_ble_gap_cb_param_t::ble_scan_result_evt_param ¶m) { @@ -417,7 +454,21 @@ void ESP32BLETracker::gap_scan_result_(const esp_ble_gap_cb_param_t::ble_scan_re xSemaphoreGive(this->scan_result_lock_); } } else if (param.search_evt == ESP_GAP_SEARCH_INQ_CMPL_EVT) { - xSemaphoreGive(this->scan_end_lock_); + // Scan finished on its own + if (this->scanner_state_ != ScannerState::RUNNING) { + if (this->scanner_state_ == ScannerState::STOPPING) { + ESP_LOGE(TAG, "Scan was not running when scan completed."); + } else if (this->scanner_state_ == ScannerState::STARTING) { + ESP_LOGE(TAG, "Scan was not started when scan completed."); + } else if (this->scanner_state_ == ScannerState::FAILED) { + ESP_LOGE(TAG, "Scan was in failed state when scan completed."); + } else if (this->scanner_state_ == ScannerState::IDLE) { + ESP_LOGE(TAG, "Scan was idle when scan completed."); + } else if (this->scanner_state_ == ScannerState::STOPPED) { + ESP_LOGE(TAG, "Scan was stopped when scan completed."); + } + } + this->set_scanner_state_(ScannerState::STOPPED); } } @@ -428,6 +479,11 @@ void ESP32BLETracker::gattc_event_handler(esp_gattc_cb_event_t event, esp_gatt_i } } +void ESP32BLETracker::set_scanner_state_(ScannerState state) { + this->scanner_state_ = state; + this->scanner_state_callbacks_.call(state); +} + ESPBLEiBeacon::ESPBLEiBeacon(const uint8_t *data) { memcpy(&this->beacon_data_, data, sizeof(beacon_data_)); } optional ESPBLEiBeacon::from_manufacturer_data(const ServiceData &data) { if (!data.uuid.contains(0x4C, 0x00)) @@ -680,8 +736,26 @@ void ESP32BLETracker::dump_config() { ESP_LOGCONFIG(TAG, " Scan Window: %.1f ms", this->scan_window_ * 0.625f); ESP_LOGCONFIG(TAG, " Scan Type: %s", this->scan_active_ ? "ACTIVE" : "PASSIVE"); ESP_LOGCONFIG(TAG, " Continuous Scanning: %s", YESNO(this->scan_continuous_)); - ESP_LOGCONFIG(TAG, " Scanner Idle: %s", YESNO(this->scanner_idle_)); - ESP_LOGCONFIG(TAG, " Scan End: %s", YESNO(xSemaphoreGetMutexHolder(this->scan_end_lock_) == nullptr)); + switch (this->scanner_state_) { + case ScannerState::IDLE: + ESP_LOGCONFIG(TAG, " Scanner State: IDLE"); + break; + case ScannerState::STARTING: + ESP_LOGCONFIG(TAG, " Scanner State: STARTING"); + break; + case ScannerState::RUNNING: + ESP_LOGCONFIG(TAG, " Scanner State: RUNNING"); + break; + case ScannerState::STOPPING: + ESP_LOGCONFIG(TAG, " Scanner State: STOPPING"); + break; + case ScannerState::STOPPED: + ESP_LOGCONFIG(TAG, " Scanner State: STOPPED"); + break; + case ScannerState::FAILED: + ESP_LOGCONFIG(TAG, " Scanner State: FAILED"); + break; + } ESP_LOGCONFIG(TAG, " Connecting: %d, discovered: %d, searching: %d, disconnecting: %d", connecting_, discovered_, searching_, disconnecting_); if (this->scan_start_fail_count_) { diff --git a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h index 8b712a01ea..2e45d9602c 100644 --- a/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h +++ b/esphome/components/esp32_ble_tracker/esp32_ble_tracker.h @@ -154,6 +154,21 @@ enum class ClientState { ESTABLISHED, }; +enum class ScannerState { + // Scanner is idle, init state, set from the main loop when processing STOPPED + IDLE, + // Scanner is starting, set from the main loop only + STARTING, + // Scanner is running, set from the ESP callback only + RUNNING, + // Scanner failed to start, set from the ESP callback only + FAILED, + // Scanner is stopping, set from the main loop only + STOPPING, + // Scanner is stopped, set from the ESP callback only + STOPPED, +}; + enum class ConnectionType { // The default connection type, we hold all the services in ram // for the duration of the connection. @@ -203,6 +218,7 @@ class ESP32BLETracker : public Component, void set_scan_interval(uint32_t scan_interval) { scan_interval_ = scan_interval; } void set_scan_window(uint32_t scan_window) { scan_window_ = scan_window; } void set_scan_active(bool scan_active) { scan_active_ = scan_active; } + bool get_scan_active() const { return scan_active_; } void set_scan_continuous(bool scan_continuous) { scan_continuous_ = scan_continuous; } /// Setup the FreeRTOS task and the Bluetooth stack. @@ -226,6 +242,11 @@ class ESP32BLETracker : public Component, void gap_event_handler(esp_gap_ble_cb_event_t event, esp_ble_gap_cb_param_t *param) override; void ble_before_disabled_event_handler() override; + void add_scanner_state_callback(std::function &&callback) { + this->scanner_state_callbacks_.add(std::move(callback)); + } + ScannerState get_scanner_state() const { return this->scanner_state_; } + protected: void stop_scan_(); /// Start a single scan by setting up the parameters and doing some esp-idf calls. @@ -240,6 +261,8 @@ class ESP32BLETracker : public Component, void gap_scan_start_complete_(const esp_ble_gap_cb_param_t::ble_scan_start_cmpl_evt_param ¶m); /// Called when a `ESP_GAP_BLE_SCAN_STOP_COMPLETE_EVT` event is received. void gap_scan_stop_complete_(const esp_ble_gap_cb_param_t::ble_scan_stop_cmpl_evt_param ¶m); + /// Called to set the scanner state. Will also call callbacks to let listeners know when state is changed. + void set_scanner_state_(ScannerState state); int app_id_{0}; @@ -257,12 +280,12 @@ class ESP32BLETracker : public Component, uint8_t scan_start_fail_count_{0}; bool scan_continuous_; bool scan_active_; - bool scanner_idle_{true}; + ScannerState scanner_state_{ScannerState::IDLE}; + CallbackManager scanner_state_callbacks_; bool ble_was_disabled_{true}; bool raw_advertisements_{false}; bool parse_advertisements_{false}; SemaphoreHandle_t scan_result_lock_; - SemaphoreHandle_t scan_end_lock_; size_t scan_result_index_{0}; #ifdef USE_PSRAM const static u_int8_t SCAN_RESULT_BUFFER_SIZE = 32; @@ -276,6 +299,9 @@ class ESP32BLETracker : public Component, int discovered_{0}; int searching_{0}; int disconnecting_{0}; +#ifdef USE_ESP32_BLE_SOFTWARE_COEXISTENCE + bool coex_prefer_ble_{false}; +#endif }; // NOLINTNEXTLINE diff --git a/esphome/components/esp32_can/esp32_can.cpp b/esphome/components/esp32_can/esp32_can.cpp index a40f493075..b5e72497ce 100644 --- a/esphome/components/esp32_can/esp32_can.cpp +++ b/esphome/components/esp32_can/esp32_can.cpp @@ -17,7 +17,7 @@ static const char *const TAG = "esp32_can"; static bool get_bitrate(canbus::CanSpeed bitrate, twai_timing_config_t *t_config) { switch (bitrate) { #if defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32C3) || \ - defined(USE_ESP32_VARIANT_ESP32C6) || defined(USE_ESP32_VARIANT_ESP32H6) + defined(USE_ESP32_VARIANT_ESP32C6) || defined(USE_ESP32_VARIANT_ESP32H2) case canbus::CAN_1KBPS: *t_config = (twai_timing_config_t) TWAI_TIMING_CONFIG_1KBITS(); return true; diff --git a/esphome/components/esp32_rmt_led_strip/led_strip.cpp b/esphome/components/esp32_rmt_led_strip/led_strip.cpp index 4e8c862c23..355f60ef05 100644 --- a/esphome/components/esp32_rmt_led_strip/led_strip.cpp +++ b/esphome/components/esp32_rmt_led_strip/led_strip.cpp @@ -58,7 +58,7 @@ void ESP32RMTLEDStripLightOutput::setup() { channel.flags.io_loop_back = 0; channel.flags.io_od_mode = 0; channel.flags.invert_out = 0; - channel.flags.with_dma = 0; + channel.flags.with_dma = this->use_dma_; channel.intr_priority = 0; if (rmt_new_tx_channel(&channel, &this->channel_) != ESP_OK) { ESP_LOGE(TAG, "Channel creation failed"); diff --git a/esphome/components/esp32_rmt_led_strip/led_strip.h b/esphome/components/esp32_rmt_led_strip/led_strip.h index fe49b9a2f3..f0cec9b291 100644 --- a/esphome/components/esp32_rmt_led_strip/led_strip.h +++ b/esphome/components/esp32_rmt_led_strip/led_strip.h @@ -51,6 +51,7 @@ class ESP32RMTLEDStripLightOutput : public light::AddressableLight { void set_num_leds(uint16_t num_leds) { this->num_leds_ = num_leds; } void set_is_rgbw(bool is_rgbw) { this->is_rgbw_ = is_rgbw; } void set_is_wrgb(bool is_wrgb) { this->is_wrgb_ = is_wrgb; } + void set_use_dma(bool use_dma) { this->use_dma_ = use_dma; } void set_use_psram(bool use_psram) { this->use_psram_ = use_psram; } /// Set a maximum refresh rate in µs as some lights do not like being updated too often. @@ -85,7 +86,7 @@ class ESP32RMTLEDStripLightOutput : public light::AddressableLight { rmt_encoder_handle_t encoder_{nullptr}; rmt_symbol_word_t *rmt_buf_{nullptr}; rmt_symbol_word_t bit0_, bit1_, reset_; - uint32_t rmt_symbols_; + uint32_t rmt_symbols_{48}; #else rmt_item32_t *rmt_buf_{nullptr}; rmt_item32_t bit0_, bit1_, reset_; @@ -94,11 +95,12 @@ class ESP32RMTLEDStripLightOutput : public light::AddressableLight { uint8_t pin_; uint16_t num_leds_; - bool is_rgbw_; - bool is_wrgb_; - bool use_psram_; + bool is_rgbw_{false}; + bool is_wrgb_{false}; + bool use_dma_{false}; + bool use_psram_{false}; - RGBOrder rgb_order_; + RGBOrder rgb_order_{ORDER_RGB}; uint32_t last_refresh_{0}; optional max_refresh_rate_{}; diff --git a/esphome/components/esp32_rmt_led_strip/light.py b/esphome/components/esp32_rmt_led_strip/light.py index e2c9f7e64a..ae92d99b12 100644 --- a/esphome/components/esp32_rmt_led_strip/light.py +++ b/esphome/components/esp32_rmt_led_strip/light.py @@ -3,7 +3,7 @@ import logging from esphome import pins import esphome.codegen as cg -from esphome.components import esp32_rmt, light +from esphome.components import esp32, esp32_rmt, light import esphome.config_validation as cv from esphome.const import ( CONF_CHIPSET, @@ -15,6 +15,7 @@ from esphome.const import ( CONF_RGB_ORDER, CONF_RMT_CHANNEL, CONF_RMT_SYMBOLS, + CONF_USE_DMA, ) from esphome.core import CORE @@ -138,6 +139,11 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_CHIPSET): cv.one_of(*CHIPSETS, upper=True), cv.Optional(CONF_IS_RGBW, default=False): cv.boolean, cv.Optional(CONF_IS_WRGB, default=False): cv.boolean, + cv.Optional(CONF_USE_DMA): cv.All( + esp32.only_on_variant(supported=[esp32.const.VARIANT_ESP32S3]), + cv.only_with_esp_idf, + cv.boolean, + ), cv.Optional(CONF_USE_PSRAM, default=True): cv.boolean, cv.Inclusive( CONF_BIT0_HIGH, @@ -211,6 +217,8 @@ async def to_code(config): if esp32_rmt.use_new_rmt_driver(): cg.add(var.set_rmt_symbols(config[CONF_RMT_SYMBOLS])) + if CONF_USE_DMA in config: + cg.add(var.set_use_dma(config[CONF_USE_DMA])) else: rmt_channel_t = cg.global_ns.enum("rmt_channel_t") cg.add( diff --git a/esphome/components/esp8266/gpio.cpp b/esphome/components/esp8266/gpio.cpp index a24f217756..9f23e8e67e 100644 --- a/esphome/components/esp8266/gpio.cpp +++ b/esphome/components/esp8266/gpio.cpp @@ -8,7 +8,7 @@ namespace esp8266 { static const char *const TAG = "esp8266"; -static int IRAM_ATTR flags_to_mode(gpio::Flags flags, uint8_t pin) { +static int flags_to_mode(gpio::Flags flags, uint8_t pin) { if (flags == gpio::FLAG_INPUT) { // NOLINT(bugprone-branch-clone) return INPUT; } else if (flags == gpio::FLAG_OUTPUT) { @@ -34,12 +34,36 @@ static int IRAM_ATTR flags_to_mode(gpio::Flags flags, uint8_t pin) { struct ISRPinArg { uint8_t pin; bool inverted; + volatile uint32_t *in_reg; + volatile uint32_t *out_set_reg; + volatile uint32_t *out_clr_reg; + volatile uint32_t *mode_set_reg; + volatile uint32_t *mode_clr_reg; + volatile uint32_t *func_reg; + uint32_t mask; }; ISRInternalGPIOPin ESP8266GPIOPin::to_isr() const { auto *arg = new ISRPinArg{}; // NOLINT(cppcoreguidelines-owning-memory) - arg->pin = pin_; - arg->inverted = inverted_; + arg->pin = this->pin_; + arg->inverted = this->inverted_; + if (this->pin_ < 16) { + arg->in_reg = &GPI; + arg->out_set_reg = &GPOS; + arg->out_clr_reg = &GPOC; + arg->mode_set_reg = &GPES; + arg->mode_clr_reg = &GPEC; + arg->func_reg = &GPF(this->pin_); + arg->mask = 1 << this->pin_; + } else { + arg->in_reg = &GP16I; + arg->out_set_reg = &GP16O; + arg->out_clr_reg = nullptr; + arg->mode_set_reg = &GP16E; + arg->mode_clr_reg = nullptr; + arg->func_reg = &GPF16; + arg->mask = 1; + } return ISRInternalGPIOPin((void *) arg); } @@ -88,20 +112,57 @@ void ESP8266GPIOPin::detach_interrupt() const { detachInterrupt(pin_); } using namespace esp8266; bool IRAM_ATTR ISRInternalGPIOPin::digital_read() { - auto *arg = reinterpret_cast(arg_); - return bool(digitalRead(arg->pin)) != arg->inverted; // NOLINT + auto *arg = reinterpret_cast(this->arg_); + return bool(*arg->in_reg & arg->mask) != arg->inverted; } + void IRAM_ATTR ISRInternalGPIOPin::digital_write(bool value) { auto *arg = reinterpret_cast(arg_); - digitalWrite(arg->pin, value != arg->inverted ? 1 : 0); // NOLINT + if (arg->pin < 16) { + if (value != arg->inverted) { + *arg->out_set_reg = arg->mask; + } else { + *arg->out_clr_reg = arg->mask; + } + } else { + if (value != arg->inverted) { + *arg->out_set_reg |= 1; + } else { + *arg->out_set_reg &= ~1; + } + } } + void IRAM_ATTR ISRInternalGPIOPin::clear_interrupt() { auto *arg = reinterpret_cast(arg_); GPIO_REG_WRITE(GPIO_STATUS_W1TC_ADDRESS, 1UL << arg->pin); } + void IRAM_ATTR ISRInternalGPIOPin::pin_mode(gpio::Flags flags) { - auto *arg = reinterpret_cast(arg_); - pinMode(arg->pin, flags_to_mode(flags, arg->pin)); // NOLINT + auto *arg = reinterpret_cast(this->arg_); + if (arg->pin < 16) { + if (flags & gpio::FLAG_OUTPUT) { + *arg->mode_set_reg = arg->mask; + } else { + *arg->mode_clr_reg = arg->mask; + } + if (flags & gpio::FLAG_PULLUP) { + *arg->func_reg |= 1 << GPFPU; + } else { + *arg->func_reg &= ~(1 << GPFPU); + } + } else { + if (flags & gpio::FLAG_OUTPUT) { + *arg->mode_set_reg |= 1; + } else { + *arg->mode_set_reg &= ~1; + } + if (flags & gpio::FLAG_PULLDOWN) { + *arg->func_reg |= 1 << GP16FPD; + } else { + *arg->func_reg &= ~(1 << GP16FPD); + } + } } } // namespace esphome diff --git a/esphome/components/event/__init__.py b/esphome/components/event/__init__.py index a7732dfcaf..0e5fb43690 100644 --- a/esphome/components/event/__init__.py +++ b/esphome/components/event/__init__.py @@ -41,7 +41,7 @@ EventTrigger = event_ns.class_("EventTrigger", automation.Trigger.template()) validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_") -EVENT_SCHEMA = ( +_EVENT_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMPONENT_SCHEMA) .extend( @@ -58,19 +58,17 @@ EVENT_SCHEMA = ( ) ) -_UNDEF = object() - def event_schema( - class_: MockObjClass = _UNDEF, + class_: MockObjClass = cv.UNDEFINED, *, - icon: str = _UNDEF, - entity_category: str = _UNDEF, - device_class: str = _UNDEF, + icon: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, ) -> cv.Schema: schema = {} - if class_ is not _UNDEF: + if class_ is not cv.UNDEFINED: schema[cv.GenerateID()] = cv.declare_id(class_) for key, default, validator in [ @@ -78,10 +76,15 @@ def event_schema( (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), (CONF_DEVICE_CLASS, device_class, validate_device_class), ]: - if default is not _UNDEF: + if default is not cv.UNDEFINED: schema[cv.Optional(key, default=default)] = validator - return EVENT_SCHEMA.extend(schema) + return _EVENT_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +EVENT_SCHEMA = event_schema() +EVENT_SCHEMA.add_extra(cv.deprecated_schema_constant("event")) async def setup_event_core_(var, config, *, event_types: list[str]): diff --git a/esphome/components/factory_reset/switch/__init__.py b/esphome/components/factory_reset/switch/__init__.py index 17f4587e5d..a384a57f80 100644 --- a/esphome/components/factory_reset/switch/__init__.py +++ b/esphome/components/factory_reset/switch/__init__.py @@ -1,14 +1,7 @@ import esphome.codegen as cg from esphome.components import switch import esphome.config_validation as cv -from esphome.const import ( - CONF_ENTITY_CATEGORY, - CONF_ICON, - CONF_ID, - CONF_INVERTED, - ENTITY_CATEGORY_CONFIG, - ICON_RESTART_ALERT, -) +from esphome.const import ENTITY_CATEGORY_CONFIG, ICON_RESTART_ALERT from .. import factory_reset_ns @@ -16,21 +9,14 @@ FactoryResetSwitch = factory_reset_ns.class_( "FactoryResetSwitch", switch.Switch, cg.Component ) -CONFIG_SCHEMA = switch.SWITCH_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(FactoryResetSwitch), - cv.Optional(CONF_INVERTED): cv.invalid( - "Factory Reset switches do not support inverted mode!" - ), - cv.Optional(CONF_ICON, default=ICON_RESTART_ALERT): cv.icon, - cv.Optional( - CONF_ENTITY_CATEGORY, default=ENTITY_CATEGORY_CONFIG - ): cv.entity_category, - } +CONFIG_SCHEMA = switch.switch_schema( + FactoryResetSwitch, + block_inverted=True, + icon=ICON_RESTART_ALERT, + entity_category=ENTITY_CATEGORY_CONFIG, ).extend(cv.COMPONENT_SCHEMA) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await switch.new_switch(config) await cg.register_component(var, config) - await switch.register_switch(var, config) diff --git a/esphome/components/fan/__init__.py b/esphome/components/fan/__init__.py index 4e0e52cd65..960809ff70 100644 --- a/esphome/components/fan/__init__.py +++ b/esphome/components/fan/__init__.py @@ -5,6 +5,10 @@ from esphome.components import mqtt, web_server import esphome.config_validation as cv from esphome.const import ( CONF_DIRECTION, + CONF_DIRECTION_COMMAND_TOPIC, + CONF_DIRECTION_STATE_TOPIC, + CONF_ENTITY_CATEGORY, + CONF_ICON, CONF_ID, CONF_MQTT_ID, CONF_OFF_SPEED_CYCLE, @@ -80,16 +84,21 @@ FanPresetSetTrigger = fan_ns.class_( FanIsOnCondition = fan_ns.class_("FanIsOnCondition", automation.Condition.template()) FanIsOffCondition = fan_ns.class_("FanIsOffCondition", automation.Condition.template()) -FAN_SCHEMA = ( +_FAN_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( { - cv.GenerateID(): cv.declare_id(Fan), cv.Optional(CONF_RESTORE_MODE, default="ALWAYS_OFF"): cv.enum( RESTORE_MODES, upper=True, space="_" ), cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTFanComponent), + cv.Optional(CONF_DIRECTION_STATE_TOPIC): cv.All( + cv.requires_component("mqtt"), cv.publish_topic + ), + cv.Optional(CONF_DIRECTION_COMMAND_TOPIC): cv.All( + cv.requires_component("mqtt"), cv.subscribe_topic + ), cv.Optional(CONF_OSCILLATION_STATE_TOPIC): cv.All( cv.requires_component("mqtt"), cv.publish_topic ), @@ -151,6 +160,37 @@ FAN_SCHEMA = ( ) ) + +def fan_schema( + class_: cg.Pvariable, + *, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, + default_restore_mode: str = cv.UNDEFINED, +) -> cv.Schema: + schema = { + cv.GenerateID(): cv.declare_id(class_), + } + + for key, default, validator in [ + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ( + CONF_RESTORE_MODE, + default_restore_mode, + cv.enum(RESTORE_MODES, upper=True, space="_"), + ), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _FAN_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +FAN_SCHEMA = fan_schema(Fan) +FAN_SCHEMA.add_extra(cv.deprecated_schema_constant("fan")) + _PRESET_MODES_SCHEMA = cv.All( cv.ensure_list(cv.string_strict), cv.Length(min=1), @@ -193,6 +233,14 @@ async def setup_fan_core_(var, config): mqtt_ = cg.new_Pvariable(mqtt_id, var) await mqtt.register_mqtt_component(mqtt_, config) + if ( + direction_state_topic := config.get(CONF_DIRECTION_STATE_TOPIC) + ) is not None: + cg.add(mqtt_.set_custom_direction_state_topic(direction_state_topic)) + if ( + direction_command_topic := config.get(CONF_DIRECTION_COMMAND_TOPIC) + ) is not None: + cg.add(mqtt_.set_custom_direction_command_topic(direction_command_topic)) if ( oscillation_state_topic := config.get(CONF_OSCILLATION_STATE_TOPIC) ) is not None: @@ -251,10 +299,9 @@ async def register_fan(var, config): await setup_fan_core_(var, config) -async def create_fan_state(config): - var = cg.new_Pvariable(config[CONF_ID]) +async def new_fan(config, *args): + var = cg.new_Pvariable(config[CONF_ID], *args) await register_fan(var, config) - await cg.register_component(var, config) return var diff --git a/esphome/components/fastled_base/__init__.py b/esphome/components/fastled_base/__init__.py index 1e70e14f10..11e8423258 100644 --- a/esphome/components/fastled_base/__init__.py +++ b/esphome/components/fastled_base/__init__.py @@ -40,9 +40,6 @@ async def new_fastled_light(config): if CONF_MAX_REFRESH_RATE in config: cg.add(var.set_max_refresh_rate(config[CONF_MAX_REFRESH_RATE])) + cg.add_library("fastled/FastLED", "3.9.16") await light.register_light(var, config) - # https://github.com/FastLED/FastLED/blob/master/library.json - # 3.3.3 has an issue on ESP32 with RMT and fastled_clockless: - # https://github.com/esphome/issues/issues/1375 - cg.add_library("fastled/FastLED", "3.3.2") return var diff --git a/esphome/components/fastled_base/fastled_light.cpp b/esphome/components/fastled_base/fastled_light.cpp index 486364d0c0..3ecdee61b1 100644 --- a/esphome/components/fastled_base/fastled_light.cpp +++ b/esphome/components/fastled_base/fastled_light.cpp @@ -34,7 +34,7 @@ void FastLEDLightOutput::write_state(light::LightState *state) { this->mark_shown_(); ESP_LOGVV(TAG, "Writing RGB values to bus..."); - this->controller_->showLeds(); + this->controller_->showLeds(this->state_parent_->current_values.get_brightness() * 255); } } // namespace fastled_base diff --git a/esphome/components/feedback/cover.py b/esphome/components/feedback/cover.py index b90374f6e8..856818280f 100644 --- a/esphome/components/feedback/cover.py +++ b/esphome/components/feedback/cover.py @@ -7,7 +7,6 @@ from esphome.const import ( CONF_CLOSE_ACTION, CONF_CLOSE_DURATION, CONF_CLOSE_ENDSTOP, - CONF_ID, CONF_MAX_DURATION, CONF_OPEN_ACTION, CONF_OPEN_DURATION, @@ -50,36 +49,43 @@ def validate_infer_endstop(config): return config -CONFIG_FEEDBACK_COVER_BASE_SCHEMA = cover.COVER_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(FeedbackCover), - cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_OPEN_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), - cv.Optional(CONF_OPEN_SENSOR): cv.use_id(binary_sensor.BinarySensor), - cv.Optional(CONF_OPEN_OBSTACLE_SENSOR): cv.use_id(binary_sensor.BinarySensor), - cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_CLOSE_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), - cv.Optional(CONF_CLOSE_SENSOR): cv.use_id(binary_sensor.BinarySensor), - cv.Optional(CONF_CLOSE_OBSTACLE_SENSOR): cv.use_id(binary_sensor.BinarySensor), - cv.Optional(CONF_MAX_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_HAS_BUILT_IN_ENDSTOP, default=False): cv.boolean, - cv.Optional(CONF_ASSUMED_STATE): cv.boolean, - cv.Optional( - CONF_UPDATE_INTERVAL, "1000ms" - ): cv.positive_time_period_milliseconds, - cv.Optional(CONF_INFER_ENDSTOP_FROM_MOVEMENT, False): cv.boolean, - cv.Optional( - CONF_DIRECTION_CHANGE_WAIT_TIME - ): cv.positive_time_period_milliseconds, - cv.Optional( - CONF_ACCELERATION_WAIT_TIME, "0s" - ): cv.positive_time_period_milliseconds, - cv.Optional(CONF_OBSTACLE_ROLLBACK, default="10%"): cv.percentage, - }, -).extend(cv.COMPONENT_SCHEMA) +CONFIG_FEEDBACK_COVER_BASE_SCHEMA = ( + cover.cover_schema(FeedbackCover) + .extend( + { + cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_OPEN_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), + cv.Optional(CONF_OPEN_SENSOR): cv.use_id(binary_sensor.BinarySensor), + cv.Optional(CONF_OPEN_OBSTACLE_SENSOR): cv.use_id( + binary_sensor.BinarySensor + ), + cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_CLOSE_ENDSTOP): cv.use_id(binary_sensor.BinarySensor), + cv.Optional(CONF_CLOSE_SENSOR): cv.use_id(binary_sensor.BinarySensor), + cv.Optional(CONF_CLOSE_OBSTACLE_SENSOR): cv.use_id( + binary_sensor.BinarySensor + ), + cv.Optional(CONF_MAX_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_HAS_BUILT_IN_ENDSTOP, default=False): cv.boolean, + cv.Optional(CONF_ASSUMED_STATE): cv.boolean, + cv.Optional( + CONF_UPDATE_INTERVAL, "1000ms" + ): cv.positive_time_period_milliseconds, + cv.Optional(CONF_INFER_ENDSTOP_FROM_MOVEMENT, False): cv.boolean, + cv.Optional( + CONF_DIRECTION_CHANGE_WAIT_TIME + ): cv.positive_time_period_milliseconds, + cv.Optional( + CONF_ACCELERATION_WAIT_TIME, "0s" + ): cv.positive_time_period_milliseconds, + cv.Optional(CONF_OBSTACLE_ROLLBACK, default="10%"): cv.percentage, + }, + ) + .extend(cv.COMPONENT_SCHEMA) +) CONFIG_SCHEMA = cv.All( @@ -90,9 +96,8 @@ CONFIG_SCHEMA = cv.All( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) # STOP await automation.build_automation( diff --git a/esphome/components/gpio/one_wire/gpio_one_wire.cpp b/esphome/components/gpio/one_wire/gpio_one_wire.cpp index 36eaf2160a..8a56595efb 100644 --- a/esphome/components/gpio/one_wire/gpio_one_wire.cpp +++ b/esphome/components/gpio/one_wire/gpio_one_wire.cpp @@ -10,8 +10,10 @@ static const char *const TAG = "gpio.one_wire"; void GPIOOneWireBus::setup() { ESP_LOGCONFIG(TAG, "Setting up 1-wire bus..."); this->t_pin_->setup(); - // clear bus with 480µs high, otherwise initial reset in search might fail this->t_pin_->pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); + // clear bus with 480µs high, otherwise initial reset in search might fail + this->pin_.digital_write(true); + this->pin_.pin_mode(gpio::FLAG_OUTPUT); delayMicroseconds(480); this->search(); } @@ -22,40 +24,49 @@ void GPIOOneWireBus::dump_config() { this->dump_devices_(TAG); } -bool HOT IRAM_ATTR GPIOOneWireBus::reset() { +int HOT IRAM_ATTR GPIOOneWireBus::reset_int() { + InterruptLock lock; // See reset here: // https://www.maximintegrated.com/en/design/technical-documents/app-notes/1/126.html // Wait for communication to clear (delay G) - pin_.pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); + this->pin_.pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); uint8_t retries = 125; do { if (--retries == 0) - return false; + return -1; delayMicroseconds(2); - } while (!pin_.digital_read()); + } while (!this->pin_.digital_read()); - bool r; + bool r = false; // Send 480µs LOW TX reset pulse (drive bus low, delay H) - pin_.pin_mode(gpio::FLAG_OUTPUT); - pin_.digital_write(false); + this->pin_.digital_write(false); + this->pin_.pin_mode(gpio::FLAG_OUTPUT); delayMicroseconds(480); // Release the bus, delay I - pin_.pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); - delayMicroseconds(70); + this->pin_.pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); + uint32_t start = micros(); + delayMicroseconds(30); + + while (micros() - start < 300) { + // sample bus, 0=device(s) present, 1=no device present + r = !this->pin_.digital_read(); + if (r) + break; + delayMicroseconds(1); + } - // sample bus, 0=device(s) present, 1=no device present - r = !pin_.digital_read(); // delay J - delayMicroseconds(410); - return r; + delayMicroseconds(start + 480 - micros()); + this->pin_.digital_write(true); + this->pin_.pin_mode(gpio::FLAG_OUTPUT); + return r ? 1 : 0; } void HOT IRAM_ATTR GPIOOneWireBus::write_bit_(bool bit) { // drive bus low - pin_.pin_mode(gpio::FLAG_OUTPUT); - pin_.digital_write(false); + this->pin_.digital_write(false); // from datasheet: // write 0 low time: t_low0: min=60µs, max=120µs @@ -64,72 +75,62 @@ void HOT IRAM_ATTR GPIOOneWireBus::write_bit_(bool bit) { // recovery time: t_rec: min=1µs // ds18b20 appears to read the bus after roughly 14µs uint32_t delay0 = bit ? 6 : 60; - uint32_t delay1 = bit ? 59 : 5; + uint32_t delay1 = bit ? 64 : 10; // delay A/C delayMicroseconds(delay0); // release bus - pin_.digital_write(true); + this->pin_.digital_write(true); // delay B/D delayMicroseconds(delay1); } bool HOT IRAM_ATTR GPIOOneWireBus::read_bit_() { // drive bus low - pin_.pin_mode(gpio::FLAG_OUTPUT); - pin_.digital_write(false); + this->pin_.digital_write(false); - // note: for reading we'll need very accurate timing, as the - // timing for the digital_read() is tight; according to the datasheet, - // we should read at the end of 16µs starting from the bus low - // typically, the ds18b20 pulls the line high after 11µs for a logical 1 - // and 29µs for a logical 0 - - uint32_t start = micros(); - // datasheet says >1µs - delayMicroseconds(2); + // datasheet says >= 1µs + delayMicroseconds(5); // release bus, delay E - pin_.pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); - - // measure from start value directly, to get best accurate timing no matter - // how long pin_mode/delayMicroseconds took - uint32_t now = micros(); - if (now - start < 12) - delayMicroseconds(12 - (now - start)); + this->pin_.pin_mode(gpio::FLAG_INPUT | gpio::FLAG_PULLUP); + delayMicroseconds(8); // sample bus to read bit from peer - bool r = pin_.digital_read(); + bool r = this->pin_.digital_read(); - // read slot is at least 60µs; get as close to 60µs to spend less time with interrupts locked - now = micros(); - if (now - start < 60) - delayMicroseconds(60 - (now - start)); + // read slot is at least 60µs + delayMicroseconds(50); + this->pin_.digital_write(true); + this->pin_.pin_mode(gpio::FLAG_OUTPUT); return r; } void IRAM_ATTR GPIOOneWireBus::write8(uint8_t val) { + InterruptLock lock; for (uint8_t i = 0; i < 8; i++) { this->write_bit_(bool((1u << i) & val)); } } void IRAM_ATTR GPIOOneWireBus::write64(uint64_t val) { + InterruptLock lock; for (uint8_t i = 0; i < 64; i++) { this->write_bit_(bool((1ULL << i) & val)); } } uint8_t IRAM_ATTR GPIOOneWireBus::read8() { + InterruptLock lock; uint8_t ret = 0; - for (uint8_t i = 0; i < 8; i++) { + for (uint8_t i = 0; i < 8; i++) ret |= (uint8_t(this->read_bit_()) << i); - } return ret; } uint64_t IRAM_ATTR GPIOOneWireBus::read64() { + InterruptLock lock; uint64_t ret = 0; for (uint8_t i = 0; i < 8; i++) { ret |= (uint64_t(this->read_bit_()) << i); @@ -144,6 +145,7 @@ void GPIOOneWireBus::reset_search() { } uint64_t IRAM_ATTR GPIOOneWireBus::search_int() { + InterruptLock lock; if (this->last_device_flag_) return 0u; diff --git a/esphome/components/gpio/one_wire/gpio_one_wire.h b/esphome/components/gpio/one_wire/gpio_one_wire.h index fe949baec3..8874703971 100644 --- a/esphome/components/gpio/one_wire/gpio_one_wire.h +++ b/esphome/components/gpio/one_wire/gpio_one_wire.h @@ -18,7 +18,6 @@ class GPIOOneWireBus : public one_wire::OneWireBus, public Component { this->pin_ = pin->to_isr(); } - bool reset() override; void write8(uint8_t val) override; void write64(uint64_t val) override; uint8_t read8() override; @@ -31,10 +30,12 @@ class GPIOOneWireBus : public one_wire::OneWireBus, public Component { bool last_device_flag_{false}; uint64_t address_; + int reset_int() override; void reset_search() override; uint64_t search_int() override; void write_bit_(bool bit); bool read_bit_(); + bool read_bit_(uint32_t *t); }; } // namespace gpio diff --git a/esphome/components/gpio_expander/cached_gpio.h b/esphome/components/gpio_expander/cached_gpio.h index 784c5f0f4a..78c675cdb2 100644 --- a/esphome/components/gpio_expander/cached_gpio.h +++ b/esphome/components/gpio_expander/cached_gpio.h @@ -8,30 +8,45 @@ namespace esphome { namespace gpio_expander { /// @brief A class to cache the read state of a GPIO expander. +/// This class caches reads between GPIO Pins which are on the same bank. +/// This means that for reading whole Port (ex. 8 pins) component needs only one +/// I2C/SPI read per main loop call. It assumes, that one bit in byte identifies one GPIO pin +/// Template parameters: +/// T - Type which represents internal register. Could be uint8_t or uint16_t. Adjust to +/// match size of your internal GPIO bank register. +/// N - Number of pins template class CachedGpioExpander { public: bool digital_read(T pin) { - if (!this->read_cache_invalidated_[pin]) { - this->read_cache_invalidated_[pin] = true; - return this->digital_read_cache(pin); + uint8_t bank = pin / (sizeof(T) * BITS_PER_BYTE); + if (this->read_cache_invalidated_[bank]) { + this->read_cache_invalidated_[bank] = false; + if (!this->digital_read_hw(pin)) + return false; } - return this->digital_read_hw(pin); + return this->digital_read_cache(pin); } void digital_write(T pin, bool value) { this->digital_write_hw(pin, value); } protected: + /// @brief Call component low level function to read GPIO state from device virtual bool digital_read_hw(T pin) = 0; + /// @brief Call component read function from internal cache. virtual bool digital_read_cache(T pin) = 0; + /// @brief Call component low level function to write GPIO state to device virtual void digital_write_hw(T pin, bool value) = 0; + const uint8_t cache_byte_size_ = N / (sizeof(T) * BITS_PER_BYTE); + /// @brief Invalidate cache. This function should be called in component loop(). void reset_pin_cache_() { - for (T i = 0; i < N; i++) { - this->read_cache_invalidated_[i] = false; + for (T i = 0; i < this->cache_byte_size_; i++) { + this->read_cache_invalidated_[i] = true; } } - std::array read_cache_invalidated_{}; + static const uint8_t BITS_PER_BYTE = 8; + std::array read_cache_invalidated_{}; }; } // namespace gpio_expander diff --git a/esphome/components/gps/__init__.py b/esphome/components/gps/__init__.py index 51288ccc30..88e6f0fd9b 100644 --- a/esphome/components/gps/__init__.py +++ b/esphome/components/gps/__init__.py @@ -25,6 +25,7 @@ GPS = gps_ns.class_("GPS", cg.Component, uart.UARTDevice) GPSListener = gps_ns.class_("GPSListener") CONF_GPS_ID = "gps_id" +CONF_HDOP = "hdop" MULTI_CONF = True CONFIG_SCHEMA = cv.All( cv.Schema( @@ -40,7 +41,7 @@ CONFIG_SCHEMA = cv.All( ), cv.Optional(CONF_SPEED): sensor.sensor_schema( unit_of_measurement=UNIT_KILOMETER_PER_HOUR, - accuracy_decimals=6, + accuracy_decimals=3, ), cv.Optional(CONF_COURSE): sensor.sensor_schema( unit_of_measurement=UNIT_DEGREES, @@ -48,12 +49,16 @@ CONFIG_SCHEMA = cv.All( ), cv.Optional(CONF_ALTITUDE): sensor.sensor_schema( unit_of_measurement=UNIT_METER, - accuracy_decimals=1, + accuracy_decimals=2, ), cv.Optional(CONF_SATELLITES): sensor.sensor_schema( accuracy_decimals=0, state_class=STATE_CLASS_MEASUREMENT, ), + cv.Optional(CONF_HDOP): sensor.sensor_schema( + accuracy_decimals=3, + state_class=STATE_CLASS_MEASUREMENT, + ), } ) .extend(cv.polling_component_schema("20s")) @@ -92,5 +97,9 @@ async def to_code(config): sens = await sensor.new_sensor(config[CONF_SATELLITES]) cg.add(var.set_satellites_sensor(sens)) + if hdop_config := config.get(CONF_HDOP): + sens = await sensor.new_sensor(hdop_config) + cg.add(var.set_hdop_sensor(sens)) + # https://platformio.org/lib/show/1655/TinyGPSPlus cg.add_library("mikalhart/TinyGPSPlus", "1.0.2") diff --git a/esphome/components/gps/gps.cpp b/esphome/components/gps/gps.cpp index 8c924d629c..e54afdb07e 100644 --- a/esphome/components/gps/gps.cpp +++ b/esphome/components/gps/gps.cpp @@ -28,6 +28,9 @@ void GPS::update() { if (this->satellites_sensor_ != nullptr) this->satellites_sensor_->publish_state(this->satellites_); + + if (this->hdop_sensor_ != nullptr) + this->hdop_sensor_->publish_state(this->hdop_); } void GPS::loop() { @@ -44,23 +47,23 @@ void GPS::loop() { if (tiny_gps_.speed.isUpdated()) { this->speed_ = tiny_gps_.speed.kmph(); - ESP_LOGD(TAG, "Speed:"); - ESP_LOGD(TAG, " %f km/h", this->speed_); + ESP_LOGD(TAG, "Speed: %.3f km/h", this->speed_); } if (tiny_gps_.course.isUpdated()) { this->course_ = tiny_gps_.course.deg(); - ESP_LOGD(TAG, "Course:"); - ESP_LOGD(TAG, " %f °", this->course_); + ESP_LOGD(TAG, "Course: %.2f °", this->course_); } if (tiny_gps_.altitude.isUpdated()) { this->altitude_ = tiny_gps_.altitude.meters(); - ESP_LOGD(TAG, "Altitude:"); - ESP_LOGD(TAG, " %f m", this->altitude_); + ESP_LOGD(TAG, "Altitude: %.2f m", this->altitude_); } if (tiny_gps_.satellites.isUpdated()) { this->satellites_ = tiny_gps_.satellites.value(); - ESP_LOGD(TAG, "Satellites:"); - ESP_LOGD(TAG, " %d", this->satellites_); + ESP_LOGD(TAG, "Satellites: %d", this->satellites_); + } + if (tiny_gps_.hdop.isUpdated()) { + this->hdop_ = tiny_gps_.hdop.hdop(); + ESP_LOGD(TAG, "HDOP: %.3f", this->hdop_); } for (auto *listener : this->listeners_) diff --git a/esphome/components/gps/gps.h b/esphome/components/gps/gps.h index 0626fb0b0e..a400820738 100644 --- a/esphome/components/gps/gps.h +++ b/esphome/components/gps/gps.h @@ -33,6 +33,7 @@ class GPS : public PollingComponent, public uart::UARTDevice { void set_course_sensor(sensor::Sensor *course_sensor) { course_sensor_ = course_sensor; } void set_altitude_sensor(sensor::Sensor *altitude_sensor) { altitude_sensor_ = altitude_sensor; } void set_satellites_sensor(sensor::Sensor *satellites_sensor) { satellites_sensor_ = satellites_sensor; } + void set_hdop_sensor(sensor::Sensor *hdop_sensor) { hdop_sensor_ = hdop_sensor; } void register_listener(GPSListener *listener) { listener->parent_ = this; @@ -46,12 +47,13 @@ class GPS : public PollingComponent, public uart::UARTDevice { TinyGPSPlus &get_tiny_gps() { return this->tiny_gps_; } protected: - float latitude_ = -1; - float longitude_ = -1; - float speed_ = -1; - float course_ = -1; - float altitude_ = -1; - int satellites_ = -1; + float latitude_ = NAN; + float longitude_ = NAN; + float speed_ = NAN; + float course_ = NAN; + float altitude_ = NAN; + int satellites_ = 0; + double hdop_ = NAN; sensor::Sensor *latitude_sensor_{nullptr}; sensor::Sensor *longitude_sensor_{nullptr}; @@ -59,6 +61,7 @@ class GPS : public PollingComponent, public uart::UARTDevice { sensor::Sensor *course_sensor_{nullptr}; sensor::Sensor *altitude_sensor_{nullptr}; sensor::Sensor *satellites_sensor_{nullptr}; + sensor::Sensor *hdop_sensor_{nullptr}; bool has_time_{false}; TinyGPSPlus tiny_gps_; diff --git a/esphome/components/graph/__init__.py b/esphome/components/graph/__init__.py index 254294619e..6e8ba44bec 100644 --- a/esphome/components/graph/__init__.py +++ b/esphome/components/graph/__init__.py @@ -5,6 +5,7 @@ import esphome.config_validation as cv from esphome.const import ( CONF_BORDER, CONF_COLOR, + CONF_CONTINUOUS, CONF_DIRECTION, CONF_DURATION, CONF_HEIGHT, @@ -61,8 +62,6 @@ VALUE_POSITION_TYPE = { "BELOW": ValuePositionType.VALUE_POSITION_TYPE_BELOW, } -CONF_CONTINUOUS = "continuous" - GRAPH_TRACE_SCHEMA = cv.Schema( { cv.GenerateID(): cv.declare_id(GraphTrace), diff --git a/esphome/components/gree/climate.py b/esphome/components/gree/climate.py index 75436f2cf5..389c9fb3c7 100644 --- a/esphome/components/gree/climate.py +++ b/esphome/components/gree/climate.py @@ -18,6 +18,7 @@ MODELS = { "yac": Model.GREE_YAC, "yac1fb9": Model.GREE_YAC1FB9, "yx1ff": Model.GREE_YX1FF, + "yag": Model.GREE_YAG, } CONFIG_SCHEMA = climate_ir.CLIMATE_IR_WITH_RECEIVER_SCHEMA.extend( diff --git a/esphome/components/gree/gree.cpp b/esphome/components/gree/gree.cpp index 6d179a947b..e0cacb4f1e 100644 --- a/esphome/components/gree/gree.cpp +++ b/esphome/components/gree/gree.cpp @@ -22,13 +22,21 @@ void GreeClimate::transmit_state() { remote_state[0] = this->fan_speed_() | this->operation_mode_(); remote_state[1] = this->temperature_(); - if (this->model_ == GREE_YAN || this->model_ == GREE_YX1FF) { + if (this->model_ == GREE_YAN || this->model_ == GREE_YX1FF || this->model_ == GREE_YAG) { remote_state[2] = 0x60; remote_state[3] = 0x50; remote_state[4] = this->vertical_swing_(); } - if (this->model_ == GREE_YAC) { + if (this->model_ == GREE_YAG) { + remote_state[5] = 0x40; + + if (this->vertical_swing_() == GREE_VDIR_SWING || this->horizontal_swing_() == GREE_HDIR_SWING) { + remote_state[0] |= (1 << 6); + } + } + + if (this->model_ == GREE_YAC || this->model_ == GREE_YAG) { remote_state[4] |= (this->horizontal_swing_() << 4); } @@ -57,6 +65,12 @@ void GreeClimate::transmit_state() { // Calculate the checksum if (this->model_ == GREE_YAN || this->model_ == GREE_YX1FF) { remote_state[7] = ((remote_state[0] << 4) + (remote_state[1] << 4) + 0xC0); + } else if (this->model_ == GREE_YAG) { + remote_state[7] = + ((((remote_state[0] & 0x0F) + (remote_state[1] & 0x0F) + (remote_state[2] & 0x0F) + (remote_state[3] & 0x0F) + + ((remote_state[4] & 0xF0) >> 4) + ((remote_state[5] & 0xF0) >> 4) + ((remote_state[6] & 0xF0) >> 4) + 0x0A) & + 0x0F) + << 4); } else { remote_state[7] = ((((remote_state[0] & 0x0F) + (remote_state[1] & 0x0F) + (remote_state[2] & 0x0F) + (remote_state[3] & 0x0F) + diff --git a/esphome/components/gree/gree.h b/esphome/components/gree/gree.h index 6762b41eb0..f91d78cabd 100644 --- a/esphome/components/gree/gree.h +++ b/esphome/components/gree/gree.h @@ -58,7 +58,7 @@ const uint8_t GREE_VDIR_MIDDLE = 0x04; const uint8_t GREE_VDIR_MDOWN = 0x05; const uint8_t GREE_VDIR_DOWN = 0x06; -// Only available on YAC +// Only available on YAC/YAG // Horizontal air directions. Note that these cannot be set on all heat pumps const uint8_t GREE_HDIR_AUTO = 0x00; const uint8_t GREE_HDIR_MANUAL = 0x00; @@ -78,7 +78,7 @@ const uint8_t GREE_PRESET_SLEEP = 0x01; const uint8_t GREE_PRESET_SLEEP_BIT = 0x80; // Model codes -enum Model { GREE_GENERIC, GREE_YAN, GREE_YAA, GREE_YAC, GREE_YAC1FB9, GREE_YX1FF }; +enum Model { GREE_GENERIC, GREE_YAN, GREE_YAA, GREE_YAC, GREE_YAC1FB9, GREE_YX1FF, GREE_YAG }; class GreeClimate : public climate_ir::ClimateIR { public: diff --git a/esphome/components/he60r/cover.py b/esphome/components/he60r/cover.py index a483d2a571..a3a1b19f5a 100644 --- a/esphome/components/he60r/cover.py +++ b/esphome/components/he60r/cover.py @@ -1,17 +1,17 @@ import esphome.codegen as cg from esphome.components import cover, uart import esphome.config_validation as cv -from esphome.const import CONF_CLOSE_DURATION, CONF_ID, CONF_OPEN_DURATION +from esphome.const import CONF_CLOSE_DURATION, CONF_OPEN_DURATION he60r_ns = cg.esphome_ns.namespace("he60r") HE60rCover = he60r_ns.class_("HE60rCover", cover.Cover, cg.Component) CONFIG_SCHEMA = ( - cover.COVER_SCHEMA.extend(uart.UART_DEVICE_SCHEMA) + cover.cover_schema(HE60rCover) + .extend(uart.UART_DEVICE_SCHEMA) .extend(cv.COMPONENT_SCHEMA) .extend( { - cv.GenerateID(): cv.declare_id(HE60rCover), cv.Optional( CONF_OPEN_DURATION, default="15s" ): cv.positive_time_period_milliseconds, @@ -34,9 +34,8 @@ FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) await uart.register_uart_device(var, config) cg.add(var.set_close_duration(config[CONF_CLOSE_DURATION])) diff --git a/esphome/components/hm3301/hm3301.h b/esphome/components/hm3301/hm3301.h index bccdd1d35b..6779b4e195 100644 --- a/esphome/components/hm3301/hm3301.h +++ b/esphome/components/hm3301/hm3301.h @@ -8,7 +8,7 @@ namespace esphome { namespace hm3301 { -static const uint8_t SELECT_COMM_CMD = 0X88; +static const uint8_t SELECT_COMM_CMD = 0x88; class HM3301Component : public PollingComponent, public i2c::I2CDevice { public: diff --git a/esphome/components/http_request/__init__.py b/esphome/components/http_request/__init__.py index 78064fb4b4..9aa0c42fa2 100644 --- a/esphome/components/http_request/__init__.py +++ b/esphome/components/http_request/__init__.py @@ -10,9 +10,11 @@ from esphome.const import ( CONF_TIMEOUT, CONF_TRIGGER_ID, CONF_URL, + PLATFORM_HOST, __version__, ) from esphome.core import CORE, Lambda +from esphome.helpers import IS_MACOS DEPENDENCIES = ["network"] AUTO_LOAD = ["json", "watchdog"] @@ -21,6 +23,7 @@ http_request_ns = cg.esphome_ns.namespace("http_request") HttpRequestComponent = http_request_ns.class_("HttpRequestComponent", cg.Component) HttpRequestArduino = http_request_ns.class_("HttpRequestArduino", HttpRequestComponent) HttpRequestIDF = http_request_ns.class_("HttpRequestIDF", HttpRequestComponent) +HttpRequestHost = http_request_ns.class_("HttpRequestHost", HttpRequestComponent) HttpContainer = http_request_ns.class_("HttpContainer") @@ -43,10 +46,13 @@ CONF_REDIRECT_LIMIT = "redirect_limit" CONF_WATCHDOG_TIMEOUT = "watchdog_timeout" CONF_BUFFER_SIZE_RX = "buffer_size_rx" CONF_BUFFER_SIZE_TX = "buffer_size_tx" +CONF_CA_CERTIFICATE_PATH = "ca_certificate_path" CONF_MAX_RESPONSE_BUFFER_SIZE = "max_response_buffer_size" CONF_ON_RESPONSE = "on_response" CONF_HEADERS = "headers" +CONF_REQUEST_HEADERS = "request_headers" +CONF_COLLECT_HEADERS = "collect_headers" CONF_BODY = "body" CONF_JSON = "json" CONF_CAPTURE_RESPONSE = "capture_response" @@ -85,6 +91,8 @@ def validate_ssl_verification(config): def _declare_request_class(value): + if CORE.is_host: + return cv.declare_id(HttpRequestHost)(value) if CORE.using_esp_idf: return cv.declare_id(HttpRequestIDF)(value) if CORE.is_esp8266 or CORE.is_esp32 or CORE.is_rp2040: @@ -119,6 +127,10 @@ CONFIG_SCHEMA = cv.All( cv.SplitDefault(CONF_BUFFER_SIZE_TX, esp32_idf=512): cv.All( cv.uint16_t, cv.only_with_esp_idf ), + cv.Optional(CONF_CA_CERTIFICATE_PATH): cv.All( + cv.file_, + cv.only_on(PLATFORM_HOST), + ), } ).extend(cv.COMPONENT_SCHEMA), cv.require_framework_version( @@ -126,6 +138,7 @@ CONFIG_SCHEMA = cv.All( esp32_arduino=cv.Version(0, 0, 0), esp_idf=cv.Version(0, 0, 0), rp2040_arduino=cv.Version(0, 0, 0), + host=cv.Version(0, 0, 0), ), validate_ssl_verification, ) @@ -168,6 +181,21 @@ async def to_code(config): cg.add_library("ESP8266HTTPClient", None) if CORE.is_rp2040 and CORE.using_arduino: cg.add_library("HTTPClient", None) + if CORE.is_host: + if IS_MACOS: + cg.add_build_flag("-I/opt/homebrew/opt/openssl/include") + cg.add_build_flag("-L/opt/homebrew/opt/openssl/lib") + cg.add_build_flag("-lssl") + cg.add_build_flag("-lcrypto") + cg.add_build_flag("-Wl,-framework,CoreFoundation") + cg.add_build_flag("-Wl,-framework,Security") + cg.add_define("CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN") + cg.add_define("CPPHTTPLIB_OPENSSL_SUPPORT") + elif path := config.get(CONF_CA_CERTIFICATE_PATH): + cg.add_define("CPPHTTPLIB_OPENSSL_SUPPORT") + cg.add(var.set_ca_path(path)) + cg.add_build_flag("-lssl") + cg.add_build_flag("-lcrypto") await cg.register_component(var, config) @@ -176,9 +204,13 @@ HTTP_REQUEST_ACTION_SCHEMA = cv.Schema( { cv.GenerateID(): cv.use_id(HttpRequestComponent), cv.Required(CONF_URL): cv.templatable(validate_url), - cv.Optional(CONF_HEADERS): cv.All( + cv.Optional(CONF_HEADERS): cv.invalid( + "The 'headers' options has been renamed to 'request_headers'" + ), + cv.Optional(CONF_REQUEST_HEADERS): cv.All( cv.Schema({cv.string: cv.templatable(cv.string)}) ), + cv.Optional(CONF_COLLECT_HEADERS): cv.ensure_list(cv.string), cv.Optional(CONF_VERIFY_SSL): cv.invalid( f"{CONF_VERIFY_SSL} has moved to the base component configuration." ), @@ -263,11 +295,12 @@ async def http_request_action_to_code(config, action_id, template_arg, args): for key in json_: template_ = await cg.templatable(json_[key], args, cg.std_string) cg.add(var.add_json(key, template_)) - for key in config.get(CONF_HEADERS, []): - template_ = await cg.templatable( - config[CONF_HEADERS][key], args, cg.const_char_ptr - ) - cg.add(var.add_header(key, template_)) + for key, value in config.get(CONF_REQUEST_HEADERS, {}).items(): + template_ = await cg.templatable(value, args, cg.const_char_ptr) + cg.add(var.add_request_header(key, template_)) + + for value in config.get(CONF_COLLECT_HEADERS, []): + cg.add(var.add_collect_header(value)) for conf in config.get(CONF_ON_RESPONSE, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) diff --git a/esphome/components/http_request/http_request.cpp b/esphome/components/http_request/http_request.cpp index be8bef006e..ca9fd2c2dc 100644 --- a/esphome/components/http_request/http_request.cpp +++ b/esphome/components/http_request/http_request.cpp @@ -20,5 +20,25 @@ void HttpRequestComponent::dump_config() { } } +std::string HttpContainer::get_response_header(const std::string &header_name) { + auto response_headers = this->get_response_headers(); + auto header_name_lower_case = str_lower_case(header_name); + if (response_headers.count(header_name_lower_case) == 0) { + ESP_LOGW(TAG, "No header with name %s found", header_name_lower_case.c_str()); + return ""; + } else { + auto values = response_headers[header_name_lower_case]; + if (values.empty()) { + ESP_LOGE(TAG, "header with name %s returned an empty list, this shouldn't happen", + header_name_lower_case.c_str()); + return ""; + } else { + auto header_value = values.front(); + ESP_LOGD(TAG, "Header with name %s found with value %s", header_name_lower_case.c_str(), header_value.c_str()); + return header_value; + } + } +} + } // namespace http_request } // namespace esphome diff --git a/esphome/components/http_request/http_request.h b/esphome/components/http_request/http_request.h index e98fd1a475..a67b04eadc 100644 --- a/esphome/components/http_request/http_request.h +++ b/esphome/components/http_request/http_request.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -95,9 +96,19 @@ class HttpContainer : public Parented { size_t get_bytes_read() const { return this->bytes_read_; } + /** + * @brief Get response headers. + * + * @return The key is the lower case response header name, the value is the header value. + */ + std::map> get_response_headers() { return this->response_headers_; } + + std::string get_response_header(const std::string &header_name); + protected: size_t bytes_read_{0}; bool secure_{false}; + std::map> response_headers_{}; }; class HttpRequestResponseTrigger : public Trigger, std::string &> { @@ -119,21 +130,46 @@ class HttpRequestComponent : public Component { void set_follow_redirects(bool follow_redirects) { this->follow_redirects_ = follow_redirects; } void set_redirect_limit(uint16_t limit) { this->redirect_limit_ = limit; } - std::shared_ptr get(std::string url) { return this->start(std::move(url), "GET", "", {}); } - std::shared_ptr get(std::string url, std::list
headers) { - return this->start(std::move(url), "GET", "", std::move(headers)); + std::shared_ptr get(const std::string &url) { return this->start(url, "GET", "", {}); } + std::shared_ptr get(const std::string &url, const std::list
&request_headers) { + return this->start(url, "GET", "", request_headers); } - std::shared_ptr post(std::string url, std::string body) { - return this->start(std::move(url), "POST", std::move(body), {}); + std::shared_ptr get(const std::string &url, const std::list
&request_headers, + const std::set &collect_headers) { + return this->start(url, "GET", "", request_headers, collect_headers); } - std::shared_ptr post(std::string url, std::string body, std::list
headers) { - return this->start(std::move(url), "POST", std::move(body), std::move(headers)); + std::shared_ptr post(const std::string &url, const std::string &body) { + return this->start(url, "POST", body, {}); + } + std::shared_ptr post(const std::string &url, const std::string &body, + const std::list
&request_headers) { + return this->start(url, "POST", body, request_headers); + } + std::shared_ptr post(const std::string &url, const std::string &body, + const std::list
&request_headers, + const std::set &collect_headers) { + return this->start(url, "POST", body, request_headers, collect_headers); } - virtual std::shared_ptr start(std::string url, std::string method, std::string body, - std::list
headers) = 0; + std::shared_ptr start(const std::string &url, const std::string &method, const std::string &body, + const std::list
&request_headers) { + return this->start(url, method, body, request_headers, {}); + } + + std::shared_ptr start(const std::string &url, const std::string &method, const std::string &body, + const std::list
&request_headers, + const std::set &collect_headers) { + std::set lower_case_collect_headers; + for (const std::string &collect_header : collect_headers) { + lower_case_collect_headers.insert(str_lower_case(collect_header)); + } + return this->perform(url, method, body, request_headers, lower_case_collect_headers); + } protected: + virtual std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) = 0; const char *useragent_{nullptr}; bool follow_redirects_{}; uint16_t redirect_limit_{}; @@ -149,7 +185,11 @@ template class HttpRequestSendAction : public Action { TEMPLATABLE_VALUE(std::string, body) TEMPLATABLE_VALUE(bool, capture_response) - void add_header(const char *key, TemplatableValue value) { this->headers_.insert({key, value}); } + void add_request_header(const char *key, TemplatableValue value) { + this->request_headers_.insert({key, value}); + } + + void add_collect_header(const char *value) { this->collect_headers_.insert(value); } void add_json(const char *key, TemplatableValue value) { this->json_.insert({key, value}); } @@ -176,16 +216,17 @@ template class HttpRequestSendAction : public Action { auto f = std::bind(&HttpRequestSendAction::encode_json_func_, this, x..., std::placeholders::_1); body = json::build_json(f); } - std::list
headers; - for (const auto &item : this->headers_) { + std::list
request_headers; + for (const auto &item : this->request_headers_) { auto val = item.second; Header header; header.name = item.first; header.value = val.value(x...); - headers.push_back(header); + request_headers.push_back(header); } - auto container = this->parent_->start(this->url_.value(x...), this->method_.value(x...), body, headers); + auto container = this->parent_->start(this->url_.value(x...), this->method_.value(x...), body, request_headers, + this->collect_headers_); if (container == nullptr) { for (auto *trigger : this->error_triggers_) @@ -238,7 +279,8 @@ template class HttpRequestSendAction : public Action { } void encode_json_func_(Ts... x, JsonObject root) { this->json_func_(x..., root); } HttpRequestComponent *parent_; - std::map> headers_{}; + std::map> request_headers_{}; + std::set collect_headers_{"content-type", "content-length"}; std::map> json_{}; std::function json_func_{nullptr}; std::vector response_triggers_{}; diff --git a/esphome/components/http_request/http_request_arduino.cpp b/esphome/components/http_request/http_request_arduino.cpp index b0067e7839..b4378cdce6 100644 --- a/esphome/components/http_request/http_request_arduino.cpp +++ b/esphome/components/http_request/http_request_arduino.cpp @@ -14,8 +14,9 @@ namespace http_request { static const char *const TAG = "http_request.arduino"; -std::shared_ptr HttpRequestArduino::start(std::string url, std::string method, std::string body, - std::list
headers) { +std::shared_ptr HttpRequestArduino::perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) { if (!network::is_connected()) { this->status_momentary_error("failed", 1000); ESP_LOGW(TAG, "HTTP Request failed; Not connected to network"); @@ -95,14 +96,17 @@ std::shared_ptr HttpRequestArduino::start(std::string url, std::s if (this->useragent_ != nullptr) { container->client_.setUserAgent(this->useragent_); } - for (const auto &header : headers) { + for (const auto &header : request_headers) { container->client_.addHeader(header.name.c_str(), header.value.c_str(), false, true); } // returned needed headers must be collected before the requests - static const char *header_keys[] = {"Content-Length", "Content-Type"}; - static const size_t HEADER_COUNT = sizeof(header_keys) / sizeof(header_keys[0]); - container->client_.collectHeaders(header_keys, HEADER_COUNT); + const char *header_keys[collect_headers.size()]; + int index = 0; + for (auto const &header_name : collect_headers) { + header_keys[index++] = header_name.c_str(); + } + container->client_.collectHeaders(header_keys, index); App.feed_wdt(); container->status_code = container->client_.sendRequest(method.c_str(), body.c_str()); @@ -121,6 +125,18 @@ std::shared_ptr HttpRequestArduino::start(std::string url, std::s // Still return the container, so it can be used to get the status code and error message } + container->response_headers_ = {}; + auto header_count = container->client_.headers(); + for (int i = 0; i < header_count; i++) { + const std::string header_name = str_lower_case(container->client_.headerName(i).c_str()); + if (collect_headers.count(header_name) > 0) { + std::string header_value = container->client_.header(i).c_str(); + ESP_LOGD(TAG, "Received response header, name: %s, value: %s", header_name.c_str(), header_value.c_str()); + container->response_headers_[header_name].push_back(header_value); + break; + } + } + int content_length = container->client_.getSize(); ESP_LOGD(TAG, "Content-Length: %d", content_length); container->content_length = (size_t) content_length; diff --git a/esphome/components/http_request/http_request_arduino.h b/esphome/components/http_request/http_request_arduino.h index dfdf4a35e2..ac9ddffbb0 100644 --- a/esphome/components/http_request/http_request_arduino.h +++ b/esphome/components/http_request/http_request_arduino.h @@ -29,9 +29,10 @@ class HttpContainerArduino : public HttpContainer { }; class HttpRequestArduino : public HttpRequestComponent { - public: - std::shared_ptr start(std::string url, std::string method, std::string body, - std::list
headers) override; + protected: + std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) override; }; } // namespace http_request diff --git a/esphome/components/http_request/http_request_host.cpp b/esphome/components/http_request/http_request_host.cpp new file mode 100644 index 0000000000..192032c1ac --- /dev/null +++ b/esphome/components/http_request/http_request_host.cpp @@ -0,0 +1,141 @@ +#include "http_request_host.h" + +#ifdef USE_HOST + +#include +#include "esphome/components/network/util.h" +#include "esphome/components/watchdog/watchdog.h" + +#include "esphome/core/application.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace http_request { + +static const char *const TAG = "http_request.host"; + +std::shared_ptr HttpRequestHost::perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set response_headers) { + if (!network::is_connected()) { + this->status_momentary_error("failed", 1000); + ESP_LOGW(TAG, "HTTP Request failed; Not connected to network"); + return nullptr; + } + + std::regex url_regex(R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)", std::regex::extended); + std::smatch url_match_result; + + if (!std::regex_match(url, url_match_result, url_regex) || url_match_result.length() < 7) { + ESP_LOGE(TAG, "HTTP Request failed; Malformed URL: %s", url.c_str()); + return nullptr; + } + auto host = url_match_result[4].str(); + auto scheme_host = url_match_result[1].str() + url_match_result[3].str(); + auto path = url_match_result[5].str() + url_match_result[6].str(); + if (path.empty()) + path = "/"; + + std::shared_ptr container = std::make_shared(); + container->set_parent(this); + + const uint32_t start = millis(); + + watchdog::WatchdogManager wdm(this->get_watchdog_timeout()); + + httplib::Headers h_headers; + h_headers.emplace("Host", host.c_str()); + h_headers.emplace("User-Agent", this->useragent_); + for (const auto &[name, value] : request_headers) { + h_headers.emplace(name, value); + } + httplib::Client client(scheme_host.c_str()); + if (!client.is_valid()) { + ESP_LOGE(TAG, "HTTP Request failed; Invalid URL: %s", url.c_str()); + return nullptr; + } + client.set_follow_location(this->follow_redirects_); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (this->ca_path_ != nullptr) + client.set_ca_cert_path(this->ca_path_); +#endif + + httplib::Result result; + if (method == "GET") { + result = client.Get(path, h_headers, [&](const char *data, size_t data_length) { + ESP_LOGV(TAG, "Got data length: %zu", data_length); + container->response_body_.insert(container->response_body_.end(), (const uint8_t *) data, + (const uint8_t *) data + data_length); + return true; + }); + } else if (method == "HEAD") { + result = client.Head(path, h_headers); + } else if (method == "PUT") { + result = client.Put(path, h_headers, body, ""); + if (result) { + auto data = std::vector(result->body.begin(), result->body.end()); + container->response_body_.insert(container->response_body_.end(), data.begin(), data.end()); + } + } else if (method == "PATCH") { + result = client.Patch(path, h_headers, body, ""); + if (result) { + auto data = std::vector(result->body.begin(), result->body.end()); + container->response_body_.insert(container->response_body_.end(), data.begin(), data.end()); + } + } else if (method == "POST") { + result = client.Post(path, h_headers, body, ""); + if (result) { + auto data = std::vector(result->body.begin(), result->body.end()); + container->response_body_.insert(container->response_body_.end(), data.begin(), data.end()); + } + } else { + ESP_LOGW(TAG, "HTTP Request failed - unsupported method %s; URL: %s", method.c_str(), url.c_str()); + container->end(); + return nullptr; + } + App.feed_wdt(); + if (!result) { + ESP_LOGW(TAG, "HTTP Request failed; URL: %s, error code: %u", url.c_str(), (unsigned) result.error()); + container->end(); + this->status_momentary_error("failed", 1000); + return nullptr; + } + App.feed_wdt(); + auto response = *result; + container->status_code = response.status; + if (!is_success(response.status)) { + ESP_LOGE(TAG, "HTTP Request failed; URL: %s; Code: %d", url.c_str(), response.status); + this->status_momentary_error("failed", 1000); + // Still return the container, so it can be used to get the status code and error message + } + + container->content_length = container->response_body_.size(); + for (auto header : response.headers) { + ESP_LOGD(TAG, "Header: %s: %s", header.first.c_str(), header.second.c_str()); + auto lower_name = str_lower_case(header.first); + if (response_headers.find(lower_name) != response_headers.end()) { + container->response_headers_[lower_name].emplace_back(header.second); + } + } + container->duration_ms = millis() - start; + return container; +} + +int HttpContainerHost::read(uint8_t *buf, size_t max_len) { + auto bytes_remaining = this->response_body_.size() - this->bytes_read_; + auto read_len = std::min(max_len, bytes_remaining); + memcpy(buf, this->response_body_.data() + this->bytes_read_, read_len); + this->bytes_read_ += read_len; + return read_len; +} + +void HttpContainerHost::end() { + watchdog::WatchdogManager wdm(this->parent_->get_watchdog_timeout()); + this->response_body_ = std::vector(); + this->bytes_read_ = 0; +} + +} // namespace http_request +} // namespace esphome + +#endif // USE_HOST diff --git a/esphome/components/http_request/http_request_host.h b/esphome/components/http_request/http_request_host.h new file mode 100644 index 0000000000..49fd3b43fe --- /dev/null +++ b/esphome/components/http_request/http_request_host.h @@ -0,0 +1,37 @@ +#pragma once + +#include "http_request.h" + +#ifdef USE_HOST + +#define CPPHTTPLIB_NO_EXCEPTIONS +#include "httplib.h" +namespace esphome { +namespace http_request { + +class HttpRequestHost; +class HttpContainerHost : public HttpContainer { + public: + int read(uint8_t *buf, size_t max_len) override; + void end() override; + + protected: + friend class HttpRequestHost; + std::vector response_body_{}; +}; + +class HttpRequestHost : public HttpRequestComponent { + public: + std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set response_headers) override; + void set_ca_path(const char *ca_path) { this->ca_path_ = ca_path; } + + protected: + const char *ca_path_{}; +}; + +} // namespace http_request +} // namespace esphome + +#endif // USE_HOST diff --git a/esphome/components/http_request/http_request_idf.cpp b/esphome/components/http_request/http_request_idf.cpp index 78c37403f5..0923062822 100644 --- a/esphome/components/http_request/http_request_idf.cpp +++ b/esphome/components/http_request/http_request_idf.cpp @@ -19,14 +19,41 @@ namespace http_request { static const char *const TAG = "http_request.idf"; +struct UserData { + const std::set &collect_headers; + std::map> response_headers; +}; + void HttpRequestIDF::dump_config() { HttpRequestComponent::dump_config(); ESP_LOGCONFIG(TAG, " Buffer Size RX: %u", this->buffer_size_rx_); ESP_LOGCONFIG(TAG, " Buffer Size TX: %u", this->buffer_size_tx_); } -std::shared_ptr HttpRequestIDF::start(std::string url, std::string method, std::string body, - std::list
headers) { +esp_err_t HttpRequestIDF::http_event_handler(esp_http_client_event_t *evt) { + UserData *user_data = (UserData *) evt->user_data; + + switch (evt->event_id) { + case HTTP_EVENT_ON_HEADER: { + const std::string header_name = str_lower_case(evt->header_key); + if (user_data->collect_headers.count(header_name)) { + const std::string header_value = evt->header_value; + ESP_LOGD(TAG, "Received response header, name: %s, value: %s", header_name.c_str(), header_value.c_str()); + user_data->response_headers[header_name].push_back(header_value); + break; + } + break; + } + default: { + break; + } + } + return ESP_OK; +} + +std::shared_ptr HttpRequestIDF::perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) { if (!network::is_connected()) { this->status_momentary_error("failed", 1000); ESP_LOGE(TAG, "HTTP Request failed; Not connected to network"); @@ -76,6 +103,10 @@ std::shared_ptr HttpRequestIDF::start(std::string url, std::strin const uint32_t start = millis(); watchdog::WatchdogManager wdm(this->get_watchdog_timeout()); + config.event_handler = http_event_handler; + auto user_data = UserData{collect_headers, {}}; + config.user_data = static_cast(&user_data); + esp_http_client_handle_t client = esp_http_client_init(&config); std::shared_ptr container = std::make_shared(client); @@ -83,7 +114,7 @@ std::shared_ptr HttpRequestIDF::start(std::string url, std::strin container->set_secure(secure); - for (const auto &header : headers) { + for (const auto &header : request_headers) { esp_http_client_set_header(client, header.name.c_str(), header.value.c_str()); } @@ -124,6 +155,7 @@ std::shared_ptr HttpRequestIDF::start(std::string url, std::strin container->feed_wdt(); container->status_code = esp_http_client_get_status_code(client); container->feed_wdt(); + container->set_response_headers(user_data.response_headers); if (is_success(container->status_code)) { container->duration_ms = millis() - start; return container; diff --git a/esphome/components/http_request/http_request_idf.h b/esphome/components/http_request/http_request_idf.h index 2ed50698b9..5c5b784853 100644 --- a/esphome/components/http_request/http_request_idf.h +++ b/esphome/components/http_request/http_request_idf.h @@ -21,6 +21,10 @@ class HttpContainerIDF : public HttpContainer { /// @brief Feeds the watchdog timer if the executing task has one attached void feed_wdt(); + void set_response_headers(std::map> &response_headers) { + this->response_headers_ = std::move(response_headers); + } + protected: esp_http_client_handle_t client_; }; @@ -29,16 +33,19 @@ class HttpRequestIDF : public HttpRequestComponent { public: void dump_config() override; - std::shared_ptr start(std::string url, std::string method, std::string body, - std::list
headers) override; - void set_buffer_size_rx(uint16_t buffer_size_rx) { this->buffer_size_rx_ = buffer_size_rx; } void set_buffer_size_tx(uint16_t buffer_size_tx) { this->buffer_size_tx_ = buffer_size_tx; } protected: + std::shared_ptr perform(std::string url, std::string method, std::string body, + std::list
request_headers, + std::set collect_headers) override; // if zero ESP-IDF will use DEFAULT_HTTP_BUF_SIZE uint16_t buffer_size_rx_{}; uint16_t buffer_size_tx_{}; + + /// @brief Monitors the http client events to gather response headers + static esp_err_t http_event_handler(esp_http_client_event_t *evt); }; } // namespace http_request diff --git a/esphome/components/http_request/httplib.h b/esphome/components/http_request/httplib.h new file mode 100644 index 0000000000..a2f4436ec7 --- /dev/null +++ b/esphome/components/http_request/httplib.h @@ -0,0 +1,9691 @@ +#pragma once + +/** + * NOTE: This is a copy of httplib.h from https://github.com/yhirose/cpp-httplib + * + * It has been modified only to add ifdefs for USE_HOST. While it contains many functions unused in ESPHome, + * it was considered preferable to use it with as few changes as possible, to facilitate future updates. + */ + +#include "esphome/core/defines.h" + +// +// httplib.h +// +// Copyright (c) 2024 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifdef USE_HOST +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.18.2" + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN32 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 10000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() - 1 : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = long; +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m) &S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m) &S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#include +#if TARGET_OS_OSX +#include +#include +#endif // TARGET_OS_OSX +#endif // _WIN32 + +#include +#include +#include +#include + +#if defined(_WIN32) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, + 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 97, + 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, + 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, + 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, + 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, 227, 228, 229, + 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, + 252, 253, 254, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, + }; + return table[(unsigned char) (char) c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { return to_lower(ca) == to_lower(cb); }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { return equal(a, b); } +}; + +struct hash { + size_t operator()(const std::string &key) const { return hash_core(key.data(), key.size(), 0); } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +} // namespace case_ignore + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { + this->exit_function(); + } + } + + void release() { this->execute_on_destruction = false; } + + private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = + std::unordered_multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { + public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + + private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = std::function; + +using ContentProviderWithoutLength = std::function; + +using ContentProviderResourceReleaser = std::function; + +struct MultipartFormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using MultipartFormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = + std::function; + +using ContentReceiver = std::function; + +using MultipartContentHeader = std::function; + +class ContentReader { + public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(std::move(reader)), multipart_reader_(std::move(multipart_reader)) {} + + bool operator()(MultipartContentHeader header, ContentReceiver receiver) const { + return multipart_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { return reader_(std::move(receiver)); } + + Reader reader_; + MultipartReader multipart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Params params; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + std::unordered_map path_params; + + // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + bool has_file(const std::string &key) const; + MultipartFormData get_file_value(const std::string &key) const; + std::vector get_file_values(const std::string &key) const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider(size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider(const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider(const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_file_content(const std::string &path, const std::string &content_type); + void set_file_content(const std::string &path); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; +}; + +class Stream { + public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { + public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { + public: + explicit ThreadPool(size_t n, size_t mqr = 0) : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + + private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait(lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { + break; + } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { + public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched agains the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { + public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + + private: + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { + public: + RegexMatcher(const std::string &pattern) : regex_(pattern) {} + + bool match(Request &request) const override; + + private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { + public: + using Handler = std::function; + + using ExceptionHandler = std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = std::function; + + using HandlerWithContentReader = + std::function; + + using Expect100ContinueHandler = std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + template Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core(std::forward(handler), + std::is_convertible{}); + } + + Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server &set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + void decommission(); + + std::function new_task_queue; + + protected: + bool process_request(Stream &strm, const std::string &remote_addr, int remote_port, const std::string &local_addr, + int local_port, bool close_connection, bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + + private: + using Handlers = std::vector, Handler>>; + using HandlersForContentReader = + std::vector, HandlerWithContentReader>>; + + static std::unique_ptr make_matcher(const std::string &pattern); + + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + + socket_t create_server_socket(const std::string &host, int port, int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, const Handlers &handlers) const; + bool dispatch_request_for_content_reader(Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool read_content_with_content_receiver(Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic is_decommisioned{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { + public: + Result() = default; + Result(std::unique_ptr &&res, Error err, Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)) {} + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, const char *def = "", size_t id = 0) const; + uint64_t get_request_header_value_u64(const std::string &key, uint64_t def = 0, size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + + private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +}; + +class ClientImpl { + public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template void set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); +#endif + + void set_logger(Logger logger); + + protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool url_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; +#endif + + Logger logger_; + + private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_peer_could_be_closed(SSL *ssl) const; +#endif + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); + std::unique_ptr send_with_content_provider(Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider(const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool process_socket(const Socket &socket, std::function callback); + virtual bool is_ssl() const; +}; + +class Client { + public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + Client &operator=(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, + Progress progress); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type, + Progress progress); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type, Progress progress); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template void set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + + private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { + public: + SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store = nullptr); + + SSLServer(const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + + void update_certs(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store = nullptr); + + private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient final : public ClientImpl { + public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path, const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + + private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool process_socket(const Socket &socket, std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy(Socket &sock, Response &res, bool &success, Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast(duration - std::chrono::seconds(sec)).count(); + callback(static_cast(sec), static_cast(usec)); +} + +inline uint64_t get_header_value_u64(const Headers &headers, const std::string &key, uint64_t def, size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +} // namespace detail + +inline uint64_t Request::get_header_value_u64(const std::string &key, uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline uint64_t Response::get_header_value_u64(const std::string &key, uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + +inline void default_socket_options(socket_t sock) { + int opt = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, reinterpret_cast(&opt), sizeof(opt)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)); +#endif +#endif +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: + return "Continue"; + case StatusCode::SwitchingProtocol_101: + return "Switching Protocol"; + case StatusCode::Processing_102: + return "Processing"; + case StatusCode::EarlyHints_103: + return "Early Hints"; + case StatusCode::OK_200: + return "OK"; + case StatusCode::Created_201: + return "Created"; + case StatusCode::Accepted_202: + return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: + return "No Content"; + case StatusCode::ResetContent_205: + return "Reset Content"; + case StatusCode::PartialContent_206: + return "Partial Content"; + case StatusCode::MultiStatus_207: + return "Multi-Status"; + case StatusCode::AlreadyReported_208: + return "Already Reported"; + case StatusCode::IMUsed_226: + return "IM Used"; + case StatusCode::MultipleChoices_300: + return "Multiple Choices"; + case StatusCode::MovedPermanently_301: + return "Moved Permanently"; + case StatusCode::Found_302: + return "Found"; + case StatusCode::SeeOther_303: + return "See Other"; + case StatusCode::NotModified_304: + return "Not Modified"; + case StatusCode::UseProxy_305: + return "Use Proxy"; + case StatusCode::unused_306: + return "unused"; + case StatusCode::TemporaryRedirect_307: + return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: + return "Permanent Redirect"; + case StatusCode::BadRequest_400: + return "Bad Request"; + case StatusCode::Unauthorized_401: + return "Unauthorized"; + case StatusCode::PaymentRequired_402: + return "Payment Required"; + case StatusCode::Forbidden_403: + return "Forbidden"; + case StatusCode::NotFound_404: + return "Not Found"; + case StatusCode::MethodNotAllowed_405: + return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: + return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: + return "Request Timeout"; + case StatusCode::Conflict_409: + return "Conflict"; + case StatusCode::Gone_410: + return "Gone"; + case StatusCode::LengthRequired_411: + return "Length Required"; + case StatusCode::PreconditionFailed_412: + return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: + return "Payload Too Large"; + case StatusCode::UriTooLong_414: + return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: + return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: + return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: + return "Expectation Failed"; + case StatusCode::ImATeapot_418: + return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: + return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: + return "Unprocessable Content"; + case StatusCode::Locked_423: + return "Locked"; + case StatusCode::FailedDependency_424: + return "Failed Dependency"; + case StatusCode::TooEarly_425: + return "Too Early"; + case StatusCode::UpgradeRequired_426: + return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: + return "Precondition Required"; + case StatusCode::TooManyRequests_429: + return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: + return "Not Implemented"; + case StatusCode::BadGateway_502: + return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: + return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: + return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: + return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: + return "Insufficient Storage"; + case StatusCode::LoopDetected_508: + return "Loop Detected"; + case StatusCode::NotExtended_510: + return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: + return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + static std::string BearerHeaderPrefix = "Bearer "; + return req.get_header_value("Authorization").substr(BearerHeaderPrefix.length()); + } + return ""; +} + +template +inline Server &Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server &Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server &Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: + return "Success (no error)"; + case Error::Connection: + return "Could not establish connection"; + case Error::BindIPAddress: + return "Failed to bind IP address"; + case Error::Read: + return "Failed to read connection"; + case Error::Write: + return "Failed to write connection"; + case Error::ExceedRedirectCount: + return "Maximum redirect count exceeded"; + case Error::Canceled: + return "Connection handling canceled"; + case Error::SSLConnection: + return "SSL connection failed"; + case Error::SSLLoadingCerts: + return "SSL certificate loading failed"; + case Error::SSLServerVerification: + return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: + return "Compression failed"; + case Error::ConnectionTimeout: + return "Connection timed out"; + case Error::ProxyConnection: + return "Proxy connection failed"; + case Error::Unknown: + return "Unknown"; + default: + break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline uint64_t Result::get_request_header_value_u64(const std::string &key, uint64_t def, size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + +template +inline void ClientImpl::set_connection_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_connection_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void Client::set_connection_timeout(const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +#if defined(_WIN32) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { + ws.clear(); + } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + + private: +#if defined(_WIN32) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + +std::string encode_query_param(const std::string &value); + +std::string decode_url(const std::string &s, bool convert_plus_to_space); + +void read_file(const std::string &path, std::string &out); + +std::string trim_copy(const std::string &s); + +void divide(const char *data, std::size_t size, char d, + std::function fn); + +void divide(const std::string &str, char d, + std::function fn); + +void split(const char *b, const char *e, char d, std::function fn); + +void split(const char *b, const char *e, char d, size_t m, std::function fn); + +bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, std::function callback); + +socket_t create_client_socket(const std::string &host, const std::string &ip, int port, int address_family, + bool tcp_nodelay, bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, + const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, const char *def, size_t id); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { + public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + const std::string &get_buffer() const; + + private: + std::string buffer; + size_t position = 0; +}; + +class compressor { + public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, Callback callback) = 0; +}; + +class decompressor { + public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, Callback callback) = 0; +}; + +class nocompressor final : public compressor { + public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { + public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, Callback callback) override; + + private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { + public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, Callback callback) override; + + private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { + public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, Callback callback) override; + + private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { + public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, Callback callback) override; + + private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { + public: + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + + private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +class mmap { + public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + + private: +#if defined(_WIN32) + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; +#else + int fd_ = -1; +#endif + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; +}; + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) { + if (i >= s.size()) { + return false; + } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { + return false; + } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { + return false; + } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { return ret_ >= 0 && S_ISREG(st_.st_mode); } +inline bool FileStat::is_dir() const { return ret_ >= 0 && S_ISDIR(st_.st_mode); } + +inline std::string encode_query_param(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || c == '.' || c == '!' || c == '~' || c == '*' || + c == '\'' || c == '(' || c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_url(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': + result += "%20"; + break; + case '+': + result += "%2B"; + break; + case '\r': + result += "%0D"; + break; + case '\n': + result += "%0A"; + break; + case '\'': + result += "%27"; + break; + case ',': + result += "%2C"; + break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': + result += "%3B"; + break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { + result.append(buff, len); + } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { + return m[1].str(); + } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void divide(const char *data, std::size_t size, char d, + std::function fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void divide(const std::string &str, char d, + std::function fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + if (byte == '\n') { + break; + } +#else + if (prev_byte == '\r' && byte == '\n') { + break; + } + prev_byte = byte; +#endif + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) { open(path); } + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { + return false; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL); +#else + hFile_ = + ::CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); +#endif + + if (hFile_ == INVALID_HANDLE_VALUE) { + return false; + } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { + return false; + } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } + size_ = static_cast(size.QuadPart); + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + hMapping_ = ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); +#else + hMapping_ = ::CreateFileMappingW(hFile_, NULL, PAGE_READONLY, 0, 0, NULL); +#endif + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } + + if (hMapping_ == NULL) { + close(); + return false; + } + +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + addr_ = ::MapViewOfFile(hMapping_, FILE_MAP_READ, 0, 0, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { + return false; + } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); + + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { + close(); + is_open_empty_file = true; + return false; + } +#endif + + return true; +} + +inline bool mmap::is_open() const { return is_open_empty_file ? true : addr_ != nullptr; } + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { return is_open_empty_file ? "" : static_cast(addr_); } + +inline void mmap::close() { +#if defined(_WIN32) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } + + is_open_empty_file = false; +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return -1; + } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); }); +#endif +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return -1; + } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); }); +#endif +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { + return Error::ConnectionTimeout; + } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return Error::Connection; + } +#endif + + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); }); + + if (ret == 0) { + return Error::ConnectionTimeout; + } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { + public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { + public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; +}; +#endif + +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, time_t keep_alive_timeout_sec) { + using namespace std::chrono; + + const auto interval_usec = CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { + return true; + } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + + while (true) { + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); + if (val < 0) { + break; // Ssocket error + } else if (val == 0) { + if (steady_clock::now() - start > timeout) { + break; // Timeout + } + } else { + return true; // Ready for read + } + } + + return false; +} + +template +inline bool process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { + break; + } + count--; + } + return ret; +} + +template +inline bool process_server_socket(const std::atomic &svr_sock, socket_t sock, size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, T callback) { + return process_server_socket_core(svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, + write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, int address_family, int socket_flags, + bool tcp_nodelay, bool ipv6_v6only, SocketOptions socket_options, + BindOrConnect bind_or_connect) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_IP; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { + node = host.c_str(); + } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#ifndef _WIN32 + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { + return INVALID_SOCKET; + } + +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, hints.ai_protocol); +#else + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast(sizeof(addr) - sizeof(addr.sun_path) + addrlen); + +#ifndef SOCK_CLOEXEC + fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif + + if (socket_options) { + socket_options(sock); + } + + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo(node, service.c_str(), &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + +#ifdef SOCK_CLOEXEC + auto sock = socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + +#endif + if (sock == INVALID_SOCKET) { + continue; + } + +#if !defined _WIN32 && !defined SOCK_CLOEXEC + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { + auto opt = 1; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (rp->ai_family == AF_INET6) { + auto opt = ipv6_v6only ? 1 : 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (socket_options) { + socket_options(sock); + } + + // bind or connect + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { + return sock; + } + + close_socket(sock); + + if (quit) { + break; + } + } + + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host.c_str(), "0", &hints, &result)) { + return false; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + return ret; +} + +#if !defined _WIN32 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + return addr_candidate; +} +#endif + +inline socket_t create_client_socket(const std::string &host, const std::string &ip, int port, int address_family, + bool tcp_nodelay, bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { + ip_from_if = intf; + } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, connection_timeout_usec); + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { + quit = true; + } + return false; + } + } + + set_nonblocking(sock2, false); + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec * 1000 + read_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec * 1000 + write_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + error = Error::Success; + return true; + }); + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { + error = Error::Connection; + } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, ipstr.data(), + static_cast(ipstr.size()), nullptr, 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), &addr_len)) { +#ifndef _WIN32 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) // __APPLE__ + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, unsigned int h) { + return (l == 0) ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & h * 33) ^ static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { return str2tag_core(s.data(), s.size(), 0); } + +namespace udl { + +inline constexpr unsigned int operator""_t(const char *s, size_t l) { return str2tag_core(s, l, 0); } + +} // namespace udl + +inline std::string find_content_type(const std::string &path, const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { + return it->second; + } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: + return default_content_type; + + case "css"_t: + return "text/css"; + case "csv"_t: + return "text/csv"; + case "htm"_t: + case "html"_t: + return "text/html"; + case "js"_t: + case "mjs"_t: + return "text/javascript"; + case "txt"_t: + return "text/plain"; + case "vtt"_t: + return "text/vtt"; + + case "apng"_t: + return "image/apng"; + case "avif"_t: + return "image/avif"; + case "bmp"_t: + return "image/bmp"; + case "gif"_t: + return "image/gif"; + case "png"_t: + return "image/png"; + case "svg"_t: + return "image/svg+xml"; + case "webp"_t: + return "image/webp"; + case "ico"_t: + return "image/x-icon"; + case "tif"_t: + return "image/tiff"; + case "tiff"_t: + return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: + return "image/jpeg"; + + case "mp4"_t: + return "video/mp4"; + case "mpeg"_t: + return "video/mpeg"; + case "webm"_t: + return "video/webm"; + + case "mp3"_t: + return "audio/mp3"; + case "mpga"_t: + return "audio/mpeg"; + case "weba"_t: + return "audio/webm"; + case "wav"_t: + return "audio/wave"; + + case "otf"_t: + return "font/otf"; + case "ttf"_t: + return "font/ttf"; + case "woff"_t: + return "font/woff"; + case "woff2"_t: + return "font/woff2"; + + case "7z"_t: + return "application/x-7z-compressed"; + case "atom"_t: + return "application/atom+xml"; + case "pdf"_t: + return "application/pdf"; + case "json"_t: + return "application/json"; + case "rss"_t: + return "application/rss+xml"; + case "tar"_t: + return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: + return "application/xhtml+xml"; + case "xslt"_t: + return "application/xslt+xml"; + case "xml"_t: + return "application/xml"; + case "gz"_t: + return "application/gzip"; + case "zip"_t: + return "application/zip"; + case "wasm"_t: + return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: + return true; + + case "text/event-stream"_t: + return false; + + default: + return !content_type.rfind("text/", 0); + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { + return EncodingType::None; + } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void) (s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { + return EncodingType::Brotli; + } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { + return EncodingType::Gzip; + } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, bool /*last*/, Callback callback) { + if (!data_length) { + return true; + } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = (std::numeric_limits::max)(); + + strm_.avail_in = static_cast((std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { + return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = (std::numeric_limits::max)(); + + strm_.avail_in = static_cast((std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm_); + return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { + return false; + } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); } + +inline brotli_compressor::~brotli_compressor() { BrotliEncoderDestroyInstance(state_); } + +inline bool brotli_compressor::compress(const char *data, size_t data_length, bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { + break; + } + } else { + if (!available_in) { + break; + } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { + BrotliDecoderDestroyInstance(decoder_s); + } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, size_t data_length, Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream(decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return false; + } + + if (!callback(buff.data(), buff.size() - avail_out)) { + return false; + } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { return headers.find(key) != headers.end(); } + +inline const char *get_header_value(const Headers &headers, const std::string &key, const char *def, size_t id) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second.c_str(); + } + return def; +} + +template inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + if (p == end) { + return false; + } + + auto key_end = p; + + if (*p++ != ':') { + return false; + } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { + return false; + } + + auto key = std::string(beg, key_end); + auto val = case_ignore::equal(key, "Location") ? std::string(p, end) : decode_url(std::string(p, end), false); + + // NOTE: From RFC 9110: + // Field values containing CR, LF, or NUL characters are + // invalid and dangerous, due to the varying ways that + // implementations might parse and interpret those + // characters; a recipient of CR, LF, or NUL within a field + // value MUST either reject the message or replace each of + // those characters with SP before further processing or + // forwarding of that message. + static const std::string CR_LF_NUL("\r\n\0", 3); + if (val.find_first_of(CR_LF_NUL) != std::string::npos) { + return false; + } + + fn(key, val); + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { + return false; + } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { + break; + } + } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + // Blank line indicates end of headers. + if (line_reader.size() == 1) { + break; + } + line_terminator_len = 1; +#else + continue; // Skip invalid line. +#endif + } + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, std::string &val) { headers.emplace(key, val); })) { + return false; + } + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return false; + } + + if (!out(buf, static_cast(n), r, len)) { + return false; + } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { + return false; + } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return; + } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n <= 0) { + return true; + } + + if (!out(buf, static_cast(n), r, 0)) { + return false; + } + r += static_cast(n); + } + + return true; +} + +template inline bool read_content_chunked(Stream &strm, T &x, ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { + return false; + } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { + return false; + } + if (chunk_len == ULONG_MAX) { + return false; + } + + if (chunk_len == 0) { + break; + } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { + return false; + } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { + return false; + } + + if (!line_reader.getline()) { + return false; + } + } + + assert(chunk_len == 0); + + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-htpplib now allows + // chuncked transfer coding data without the final CRLF. + if (!line_reader.getline()) { + return true; + } + + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](const std::string &key, const std::string &val) { x.headers.emplace(key, val); }); + + if (!line_reader.getline()) { + return false; + } + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return case_ignore::equal(get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, ContentReceiverWithProgress receiver, bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { return receiver(buf2, n2, off, len); }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, Progress progress, + ContentReceiverWithProgress receiver, bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, x, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 : StatusCode::BadRequest_400; + } + return ret; + }); +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); + if (len < 0) { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { + return false; + } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, size_t offset, size_t length, + T is_shutting_down, Error &error) { + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (strm.is_writable() && write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, error); +} + +template +inline bool write_content_without_length(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!strm.is_writable() || !write_data(strm, d, l)) { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool write_content_chunked(Stream &strm, const ContentProvider &content_provider, const T &is_shutting_down, + U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { + return; + } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + static const std::string done_marker("0\r\n"); + if (!write_data(strm, done_marker.data(), done_marker.size())) { + ok = false; + } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + static const std::string crlf("\r\n"); + if (!write_data(strm, crlf.data(), crlf.size())) { + ok = false; + } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { done_with_trailer(&trailer); }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, const ContentProvider &content_provider, const T &is_shutting_down, + U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { + res.location = location; + } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { + query += "&"; + } + query += it->first; + query += "="; + query += encode_query_param(it->second); + } + return query; +} + +inline void parse_query_text(const char *data, std::size_t size, Params ¶ms) { + std::set cache; + split(data, data + size, '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { + return; + } + cache.insert(std::move(kv)); + + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { parse_query_text(s.data(), s.size(), params); } + +inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { + return false; + } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { + return; + } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { + return; + } + + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; + } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); + }); + return all_valid_ranges && !ranges.empty(); + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { + return false; +} +#endif + +class MultipartFormDataParser { + public: + MultipartFormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const ContentReceiver &content_callback, + const MultipartContentHeader &header_callback) { + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + buf_erase(buf_find(dash_boundary_crlf_)); + if (dash_boundary_crlf_.size() > buf_size()) { + return true; + } + if (!buf_start_with(dash_boundary_crlf_)) { + return false; + } + buf_erase(dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](const std::string &, const std::string &) {})) { + is_valid_ = false; + return false; + } + + static const std::string header_content_type = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = trim_copy(header.substr(header_content_type.size())); + } else { + static const std::regex re_content_disposition(R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { + file_.filename = it->second; + } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 enconnding... + static const std::regex re_rfc5987_encoding(R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_url(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { + return true; + } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { + return true; + } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { + return true; + } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { + return true; + } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + + private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + bool start_with_case_ignore(const std::string &a, const std::string &b) const { + if (a.size() < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, const std::string &b) const { + if (epos - spos < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { + return false; + } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { return start_with(buf_, buf_spos_, buf_epos_, s); } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { + return buf_size(); + } + if (buf_[pos] == c) { + break; + } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { + return buf_size(); + } + + if (start_with(buf_, pos, buf_epos_, s)) { + return pos - buf_spos_; + } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { + buf_.resize(remaining_size + n); + } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string random_string(size_t length) { + static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + static std::random_device seed_gen; + + // Request 128 bits of entropy for initialization + static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + + static std::mt19937 engine(seed_sequence); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string serialize_multipart_formdata_item_begin(const T &item, const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string serialize_multipart_formdata(const MultipartFormDataItems &items, const std::string &boundary, + bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { + body += serialize_multipart_formdata_finish(boundary); + } + + return body; +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t contant_len = static_cast(res.content_length_ ? res.content_length_ : res.body.size()); + + ssize_t prev_first_pos = -1; + ssize_t prev_last_pos = -1; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { + return true; + } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = contant_len; + } + + if (first_pos == -1) { + first_pos = contant_len - last_pos; + last_pos = contant_len - 1; + } + + if (last_pos == -1) { + last_pos = contant_len - 1; + } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && last_pos <= contant_len - 1)) { + return true; + } + + // Ranges must be in ascending order + if (first_pos <= prev_first_pos) { + return true; + } + + // Request must not have more than two overlapping ranges + if (first_pos <= prev_last_pos) { + overwrapping_count++; + if (overwrapping_count > 2) { + return true; + } + } + + prev_first_pos = (std::max)(prev_first_pos, first_pos); + prev_last_pos = (std::max)(prev_last_pos, last_pos); + } + } + + return false; +} + +inline std::pair get_range_offset_and_length(Range r, size_t content_length) { + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && r.second < static_cast(content_length)); + (void) (content_length); + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field(const std::pair &offset_and_length, + size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, const std::string &boundary, const std::string &content_type, + size_t content_length, SToken stoken, CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, const std::string &boundary, + const std::string &content_type, size_t content_length, std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, const std::string &boundary, + const std::string &content_type, size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, + const std::string &content_type, size_t content_length, + const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "PRI" || + req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { + return true; + } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr(EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { return message_digest(s, EVP_md5()); } + +inline std::string SHA_256(const std::string &s) { return message_digest(s, EVP_sha256()); } + +inline std::string SHA_512(const std::string &s) { return message_digest(s, EVP_sha512()); } +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY) NULL, L"ROOT"); + if (!hStore) { + return false; + } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != nullptr) { + auto encoded_cert = static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX +template using CFObjectPtr = std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { + CFRelease(obj); + } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { + return false; + } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast(CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { + continue; + } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast(CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // TARGET_OS_OSX +#endif // _WIN32 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN32 +class WSInit { + public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) + is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) + WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, size_t cnonce_count, const std::string &cnonce, + const std::string &username, const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { + algo = auth.at("algorithm"); + } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = + "Digest username=\"" + username + "\", realm=\"" + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + cnonce + "\", response=\"") + + response + "\"" + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), static_cast(m.length(1))); + auto val = m.length(2) > 0 ? s.substr(static_cast(m.position(2)), static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { + public: + explicit ContentProviderAdapter(ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { return content_provider_(offset, sink); } + + private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { + return std::string(); + } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { + addrs.push_back(ip); + } + } +} + +inline std::string append_query_params(const std::string &path, const Params ¶ms) { + std::string path_with_query = path; + const static std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { + field += ", "; + } + if (r.first != -1) { + field += std::to_string(r.first); + } + field += '-'; + if (r.second != -1) { + field += std::to_string(r.second); + } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { return detail::has_header(headers, key); } + +inline std::string Request::get_header_value(const std::string &key, const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const std::string &key) const { return params.find(key) != params.end(); } + +inline std::string Request::get_param_value(const std::string &key, size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second; + } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +inline bool Request::has_file(const std::string &key) const { return files.find(key) != files.end(); } + +inline MultipartFormData Request::get_file_value(const std::string &key) const { + auto it = files.find(key); + if (it != files.end()) { + return it->second; + } + return MultipartFormData(); +} + +inline std::vector Request::get_file_values(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { return headers.find(key) != headers.end(); } + +inline std::string Response::get_header_value(const std::string &key, const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider(size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { + content_provider_ = std::move(provider); + } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider(const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider(const std::string &content_type, + ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +inline void Response::set_file_content(const std::string &path, const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { file_content_path_ = path; } + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, const char *def, size_t id) const { + return detail::get_header_value(request_headers_, key, def, id); +} + +inline size_t Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { return write(ptr, strlen(ptr)); } + +inline ssize_t Stream::write(const std::string &s) { return write(s.data(), s.size()); } + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec) + : sock_(sock), + read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } + +inline bool SocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!is_readable()) { + return -1; + } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { + return -1; + } + +#if defined(_WIN32) && !defined(_WIN64) + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + static constexpr char marker[] = "/:"; + + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find(marker, last_param_end == 0 ? last_param_end : last_param_end - 1); + if (marker_pos == std::string::npos) { + break; + } + + static_fragments_.push_back(pattern.substr(last_param_end, marker_pos - last_param_end + 1)); + + const auto param_name_start = marker_pos + 2; + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { + sep_pos = pattern.length(); + } + + auto param_name = pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = + "Encountered path parameter '" + param_name + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { + continue; + } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { + sep_pos = request.path.length(); + } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace(param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everything up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() : new_task_queue([] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers) { + detail::FileStat stat(dir); + if (stat.is_dir()) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server &Server::set_file_extension_and_mimetype_mapping(const std::string &ext, const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, std::true_type) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler_core(Handler handler, std::false_type) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server &Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer(std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { + is_decommisioned = true; + } + return ret >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { + is_decommisioned = true; + } + return ret; +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const std::string &host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running_ && !is_decommisioned) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + is_decommisioned = false; +} + +inline void Server::decommission() { is_decommisioned = true; } + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { + return false; + } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: + req.method = std::string(b, e); + break; + case 1: + req.target = std::string(b, e); + break; + case 2: + req.version = std::string(b, e); + break; + default: + break; + } + count++; + }); + + if (count != 3) { + return false; + } + } + + static const std::set methods{"GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { + return false; + } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { + return false; + } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_url(std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, bool close_connection, const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { + apply_ranges(req, res, content_type, boundary); + } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } + + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } + + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { + post_routing_handler_(req, res); + } + + // Response line and headers + { + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { + return false; + } + if (!header_writer_(bstrm, res.headers)) { + return false; + } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { + logger_(req, res); + } + + return ret; +} + +inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, + const std::string &boundary, const std::string &content_type) { + auto is_shutting_down = [this]() { return this->svr_sock_ == INVALID_SOCKET; }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length(req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, offset_and_length.first, offset_and_length.second, + is_shutting_down); + } else { + return detail::write_multipart_ranges_data(strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto file_count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + if (file_count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { + return false; + } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = (std::min)(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, multipart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, multipart_header); + }; + } else { + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res, bool head) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { + path += "index.html"; + } + + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + return false; + } + + res.set_content_provider( + mm->size(), detail::find_content_type(path, file_extension_and_mimetype_map_, default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t Server::create_server_socket(const std::string &host, int port, int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket(host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, ipv6_v6only_, + std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { + return false; + } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { + if (is_decommisioned) { + return -1; + } + + if (!is_valid()) { + return -1; + } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { + return -1; + } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + if (is_decommisioned) { + return false; + } + + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN32 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN32 + } +#endif + +#if defined _WIN32 + // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else + socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec_ * 1000 + read_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec_); + tv.tv_usec = static_cast(read_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec_ * 1000 + write_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec_); + tv.tv_usec = static_cast(write_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + if (!task_queue->enqueue([this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + is_decommisioned = !ret; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + auto is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, std::move(header), std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader(req, res, std::move(reader), delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { + return false; + } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res); + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length(req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field(offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } + } + } else { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field(offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader(Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool Server::process_request(Stream &strm, const std::string &remote_addr, int remote_port, + const std::string &local_addr, int local_port, bool close_connection, + bool &connection_closed, const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { + return false; + } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + +#ifdef _WIN32 + // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). +#else +#ifndef CPPHTTPLIB_USE_POLL + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::InternalServerError_500; + return write_response(strm, close_connection, req, res); + } +#endif +#endif + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + req.local_addr = local_addr; + req.local_port = local_port; + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { + setup_request(req); + } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + detail::write_response_line(strm, status); + strm.write("\r\n"); + break; + default: + connection_closed = true; + return write_response(strm, true, req, res); + } + } + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': + val += "\\r"; + break; + case '\n': + val += "\\n"; + break; + default: + val += s[i]; + break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 : StatusCode::PartialContent_206; + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type(path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider(mm->size(), content_type, [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { + res.status = StatusCode::NotFound_404; + } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, + connection_closed, nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path) + : host_(detail::escape_abstract_namespace_unix_domain(host)), + port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), + client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + url_encode_ = rhs.url_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket(proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { + ip = it->second; + } + + return detail::create_client_socket(host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, socket_options_, + connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { + return false; + } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { + return; + } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { + return; + } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { + return false; + } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { + return false; + } // CRLF + if (!line_reader.getline()) { + return false; + } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { + char buf[1]; + return !SSL_peek(ssl, buf, 1) && SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} +#endif + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { + is_alive = false; + } + } +#endif + + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, res, success, error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { + return false; + } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, [&](Stream &strm) { return handle_request(strm, req, res, close_connection, error); }); + + if (!ret) { + if (error == Error::Success) { + error = Error::Unknown; + } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { + res = new_res; + } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { + return false; + } + + const static std::regex re( + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { + return false; + } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { + next_host = m[3].str(); + } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { + next_scheme = scheme; + } + if (next_host.empty()) { + next_host = host_; + } + if (next_path.empty()) { + next_path = "/"; + } + + auto path = detail::decode_url(next_path, true) + next_query; + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host, next_port); + cli.copy_settings(*this); + if (ca_cert_store_) { + cli.set_ca_cert_store(ca_cert_store_); + } + return detail::redirect(cli, req, res, path, location, error); +#else + return false; +#endif + } else { + ClientImpl cli(next_host, next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, path, location, error); + } + } +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, const Request &req, Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, is_shutting_down, *compressor, error); + } else { + return detail::write_content(strm, req.content_provider_, 0, req.content_length_, is_shutting_down, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { + req.set_header("Accept", "*/*"); + } + + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { + accept_encoding += ", "; + } + accept_encoding += "gzip, deflate"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header(basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert( + make_basic_authentication_header(proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header(bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header(proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path_with_query = req.params.empty() ? req.path : append_query_params(req.path, req.params); + + const auto &path = url_encode_ ? detail::encode_url(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, const std::string &content_type, Error &error) { + if (!content_type.empty()) { + req.set_header("Content-Type", content_type); + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + req.set_header("Content-Encoding", "gzip"); + } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = + compressor.compress(data, data_len, last, [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter(std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider(const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Progress progress) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + req.progress = progress; + + auto error = Error::Success; + + auto res = send_with_content_provider(req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + + return Result{std::move(res), error, std::move(req.headers)}; +} + +inline std::string ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { + return "[" + host + "]"; + } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast([&](const char *buf, size_t n, uint64_t off, uint64_t len) { + if (redirect) { + return true; + } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { + error = Error::Canceled; + } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + assert(res.body.size() + n <= res.body.max_size()); + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress || redirect) { + return true; + } + auto ret = req.progress(current, total); + if (!ret) { + error = Error::Canceled; + } + return ret; + }; + + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(len); + } + } + + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { + error = Error::Read; + } + return false; + } + } + + // Log + if (logger_) { + logger_(req, res); + } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin(provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool ClientImpl::process_socket(const Socket &socket, std::function callback) { + return detail::process_client_socket(socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path) { return Get(path, Headers(), Progress()); } + +inline Result ClientImpl::Get(const std::string &path, Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, ContentReceiver content_receiver, Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, Headers(), std::move(response_handler), std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return Get(path, Headers(), std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = [content_receiver](const char *data, size_t data_length, uint64_t /*offset*/, + uint64_t /*total_length*/) { return content_receiver(data, data_length); }; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, const Headers &headers, + Progress progress) { + if (params.empty()) { + return Get(path, headers); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { return Head(path, Headers()); } + +inline Result ClientImpl::Head(const std::string &path, const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { return Post(path, std::string(), std::string()); } + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Post(path, Headers(), body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body, content_length, nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, const std::string &content_type) { + return Post(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Post(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return send_with_content_provider("POST", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { return Post(path, Headers(), params); } + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return Post(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Post(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms, + Progress progress) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded", progress); +} + +inline Result ClientImpl::Post(const std::string &path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path) { return Put(path, std::string(), std::string()); } + +inline Result ClientImpl::Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Put(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body, content_length, nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, const std::string &content_type) { + return Put(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Put(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return Put(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Put(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, content_length, std::move(content_provider), nullptr, + content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { return Put(path, Headers(), params); } + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms, + Progress progress) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded", progress); +} + +inline Result ClientImpl::Put(const std::string &path, const MultipartFormDataItems &items) { + return Put(path, Headers(), items); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), content_type, + nullptr); +} +inline Result ClientImpl::Patch(const std::string &path) { return Patch(path, std::string(), std::string()); } + +inline Result ClientImpl::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Patch(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type) { + return Patch(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, Progress progress) { + return send_with_content_provider("PATCH", path, headers, body, content_length, nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const std::string &body, const std::string &content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return Patch(path, headers, body, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return send_with_content_provider("PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, content_type, + progress); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return Patch(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Patch(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, content_length, std::move(content_provider), + nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path) { + return Delete(path, Headers(), std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers) { + return Delete(path, headers, std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return Delete(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type) { + return Delete(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, Progress progress) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + req.progress = progress; + + if (!content_type.empty()) { + req.set_header("Content-Type", content_type); + } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, const std::string &body, const std::string &content_type) { + return Delete(path, Headers(), body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, progress); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return Delete(path, headers, body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, progress); +} + +inline Result ClientImpl::Options(const std::string &path) { return Options(path, Headers()); } + +inline Result ClientImpl::Options(const std::string &path, const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { bearer_token_auth_token_ = token; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; } + +inline void ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { default_headers_ = std::move(headers); } + +inline void ClientImpl::set_header_writer(std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { address_family_ = family; } + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { interface_ = intf; } + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); + if (!mem) { + return nullptr; + } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { + return nullptr; + } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { + continue; + } + + if (itmp->x509) { + X509_STORE_add_cert(cts, itmp->x509); + } + if (itmp->crl) { + X509_STORE_add_crl(cts, itmp->crl); + } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { server_hostname_verification_ = enabled; } + +inline void ClientImpl::set_server_certificate_verifier(std::function verifier) { + server_certificate_verifier_ = verifier; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { logger_ = std::move(logger); } + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { +#ifdef _WIN32 + SSL_shutdown(ssl); +#else + timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); + + auto ret = SSL_shutdown(ssl); + while (ret == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + ret = SSL_shutdown(ssl); + } +#endif + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, U ssl_connect_or_accept, time_t timeout_sec, + time_t timeout_usec) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: + break; + } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl(const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core(svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +class SSLInit { + public: + SSLInit() { OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); } +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec) + : sock_(sock), + ssl_(ssl), + read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_READ || (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { + auto handle_size = static_cast(std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_WRITE || (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (is_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata(ctx_, reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1 || + SSL_CTX_check_private_key(ctx_) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, client_ca_cert_dir_path); + + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { + SSL_CTX_free(ctx_); + } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking(sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, + connection_closed, [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path, const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { host_components_.emplace_back(b, e); }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast(const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { host_components_.emplace_back(b, e); }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast(const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { + SSL_CTX_free(ctx_); + } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { return verify_result_; } + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket(socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + proxy_res = Response(); + if (!detail::process_client_socket(socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, + proxy_digest_auth_password_, true)); + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, ca_cert_dir_path_.c_str())) { + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN32 + loaded = detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // TARGET_OS_OSX +#endif // _WIN32 + if (!loaded) { + SSL_CTX_set_default_verify_paths(ctx_); + } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking(socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + if (server_certificate_verifier_) { + if (!server_certificate_verifier_(ssl2)) { + error = Error::SSLServerVerification; + return false; + } + } else { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + error = Error::SSLServerHostnameVerification; + return false; + } + } + } + } + + return true; + }, + [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); +#endif + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool SSLClient::process_socket(const Socket &socket, std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl(socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || verify_host_with_common_name(server_cert); +} + +inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 {}; + struct in_addr addr {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { + ret = true; + } + } + + GENERAL_NAMES_free(const_cast(reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { pattern_components.emplace_back(b, e); }); + + if (host_components_.size() != pattern_components.size()) { + return false; + } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && !p.compare(0, p.size() - 1, h)); + if (!partial_match) { + return false; + } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re(R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { + host = m[3].str(); + } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); + } + } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. + cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); + } +} // namespace detail + +inline Client::Client(const std::string &host, int port) : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { return cli_ != nullptr && cli_->is_valid(); } + +inline Result Client::Get(const std::string &path) { return cli_->Get(path); } +inline Result Client::Get(const std::string &path, const Headers &headers) { return cli_->Get(path, headers); } +inline Result Client::Get(const std::string &path, Progress progress) { return cli_->Get(path, std::move(progress)); } +inline Result Client::Get(const std::string &path, const Headers &headers, Progress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, ContentReceiver content_receiver) { + return cli_->Get(path, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, headers, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver) { + return cli_->Get(path, std::move(response_handler), std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(response_handler), std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), std::move(content_receiver), + std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { return cli_->Head(path, headers); } + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { return cli_->Post(path, headers); } +inline Result Client::Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Post(path, body, content_type); +} +inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Post(path, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} +inline Result Client::Post(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, content_length, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return cli_->Post(path, headers, content_length, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return cli_->Post(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } +inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress) { + return cli_->Post(path, headers, params, progress); +} +inline Result Client::Post(const std::string &path, const MultipartFormDataItems &items) { + return cli_->Post(path, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + return cli_->Post(path, headers, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Post(path, headers, items, provider_items); +} +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Put(path, body, content_type); +} +inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Put(path, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} +inline Result Client::Put(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, content_length, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return cli_->Put(path, headers, content_length, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return cli_->Put(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } +inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms, Progress progress) { + return cli_->Put(path, headers, params, progress); +} +inline Result Client::Put(const std::string &path, const MultipartFormDataItems &items) { + return cli_->Put(path, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + return cli_->Put(path, headers, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Put(path, headers, items, boundary); +} +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Put(path, headers, items, provider_items); +} +inline Result Client::Patch(const std::string &path) { return cli_->Patch(path); } +inline Result Client::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Patch(path, body, content_type); +} +inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} +inline Result Client::Patch(const std::string &path, size_t content_length, ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, content_length, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const std::string &content_type) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, const std::string &content_type) { + return cli_->Patch(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Delete(const std::string &path) { return cli_->Delete(path); } +inline Result Client::Delete(const std::string &path, const Headers &headers) { return cli_->Delete(path, headers); } +inline Result Client::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Delete(path, body, content_type); +} +inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_type, progress); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} +inline Result Client::Options(const std::string &path) { return cli_->Options(path); } +inline Result Client::Options(const std::string &path, const Headers &headers) { return cli_->Options(path, headers); } + +inline bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { cli_->set_default_headers(std::move(headers)); } + +inline void Client::set_header_writer(std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { cli_->set_address_family(family); } + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { cli_->set_connection_timeout(sec, usec); } + +inline void Client::set_read_timeout(time_t sec, time_t usec) { cli_->set_read_timeout(sec, usec); } + +inline void Client::set_write_timeout(time_t sec, time_t usec) { cli_->set_write_timeout(sec, usec); } + +inline void Client::set_basic_auth(const std::string &username, const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { cli_->set_bearer_token_auth(token); } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { cli_->set_follow_location(on); } + +inline void Client::set_url_encode(bool on) { cli_->set_url_encode(on); } + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { cli_->set_interface(intf); } + +inline void Client::set_proxy(const std::string &host, int port) { cli_->set_proxy(host, port); } +inline void Client::set_proxy_basic_auth(const std::string &username, const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { cli_->set_proxy_bearer_token_auth(token); } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier(std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} +#endif + +inline void Client::set_logger(Logger logger) { cli_->set_logger(std::move(logger)); } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { + return static_cast(*cli_).ssl_context(); + } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) +#undef poll +#endif + +#endif // CPPHTTPLIB_HTTPLIB_H + +#endif diff --git a/esphome/components/http_request/update/__init__.py b/esphome/components/http_request/update/__init__.py index f1b282d891..abb4b2a430 100644 --- a/esphome/components/http_request/update/__init__.py +++ b/esphome/components/http_request/update/__init__.py @@ -16,14 +16,17 @@ HttpRequestUpdate = http_request_ns.class_( CONF_OTA_ID = "ota_id" -CONFIG_SCHEMA = update.UPDATE_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(HttpRequestUpdate), - cv.GenerateID(CONF_OTA_ID): cv.use_id(OtaHttpRequestComponent), - cv.GenerateID(CONF_HTTP_REQUEST_ID): cv.use_id(HttpRequestComponent), - cv.Required(CONF_SOURCE): cv.url, - } -).extend(cv.polling_component_schema("6h")) +CONFIG_SCHEMA = ( + update.update_schema(HttpRequestUpdate) + .extend( + { + cv.GenerateID(CONF_OTA_ID): cv.use_id(OtaHttpRequestComponent), + cv.GenerateID(CONF_HTTP_REQUEST_ID): cv.use_id(HttpRequestComponent), + cv.Required(CONF_SOURCE): cv.url, + } + ) + .extend(cv.polling_component_schema("6h")) +) async def to_code(config): diff --git a/esphome/components/i2c/i2c.h b/esphome/components/i2c/i2c.h index 8d8e139c61..15f786245b 100644 --- a/esphome/components/i2c/i2c.h +++ b/esphome/components/i2c/i2c.h @@ -139,6 +139,10 @@ class I2CDevice { /// @param address of the device void set_i2c_address(uint8_t address) { address_ = address; } + /// @brief Returns the I2C address of the object. + /// @return the I2C address + uint8_t get_i2c_address() const { return this->address_; } + /// @brief we store the pointer to the I2CBus to use /// @param bus pointer to the I2CBus object void set_i2c_bus(I2CBus *bus) { bus_ = bus; } diff --git a/esphome/components/i2c/i2c_bus_esp_idf.cpp b/esphome/components/i2c/i2c_bus_esp_idf.cpp index c5d6dd8b2a..c14300f725 100644 --- a/esphome/components/i2c/i2c_bus_esp_idf.cpp +++ b/esphome/components/i2c/i2c_bus_esp_idf.cpp @@ -67,7 +67,7 @@ void IDFI2CBus::setup() { ESP_LOGV(TAG, "i2c_timeout set to %" PRIu32 " ticks (%" PRIu32 " us)", timeout_ * 80, timeout_); } } - err = i2c_driver_install(port_, I2C_MODE_MASTER, 0, 0, ESP_INTR_FLAG_IRAM); + err = i2c_driver_install(port_, I2C_MODE_MASTER, 0, 0, 0); if (err != ESP_OK) { ESP_LOGW(TAG, "i2c_driver_install failed: %s", esp_err_to_name(err)); this->mark_failed(); diff --git a/esphome/components/i2s_audio/__init__.py b/esphome/components/i2s_audio/__init__.py index fa515a585f..0d413adb8a 100644 --- a/esphome/components/i2s_audio/__init__.py +++ b/esphome/components/i2s_audio/__init__.py @@ -8,7 +8,15 @@ from esphome.components.esp32.const import ( VARIANT_ESP32S3, ) import esphome.config_validation as cv -from esphome.const import CONF_BITS_PER_SAMPLE, CONF_CHANNEL, CONF_ID, CONF_SAMPLE_RATE +from esphome.const import ( + CONF_BITS_PER_SAMPLE, + CONF_CHANNEL, + CONF_ID, + CONF_SAMPLE_RATE, + KEY_CORE, + KEY_FRAMEWORK_VERSION, +) +from esphome.core import CORE from esphome.cpp_generator import MockObjClass import esphome.final_validate as fv @@ -31,10 +39,14 @@ CONF_SECONDARY = "secondary" CONF_USE_APLL = "use_apll" CONF_BITS_PER_CHANNEL = "bits_per_channel" +CONF_MCLK_MULTIPLE = "mclk_multiple" CONF_MONO = "mono" CONF_LEFT = "left" CONF_RIGHT = "right" CONF_STEREO = "stereo" +CONF_BOTH = "both" + +CONF_USE_LEGACY = "use_legacy" i2s_audio_ns = cg.esphome_ns.namespace("i2s_audio") I2SAudioComponent = i2s_audio_ns.class_("I2SAudioComponent", cg.Component) @@ -50,6 +62,12 @@ I2S_MODE_OPTIONS = { CONF_SECONDARY: i2s_mode_t.I2S_MODE_SLAVE, # NOLINT } +i2s_role_t = cg.global_ns.enum("i2s_role_t") +I2S_ROLE_OPTIONS = { + CONF_PRIMARY: i2s_role_t.I2S_ROLE_MASTER, # NOLINT + CONF_SECONDARY: i2s_role_t.I2S_ROLE_SLAVE, # NOLINT +} + # https://github.com/espressif/esp-idf/blob/master/components/soc/{variant}/include/soc/soc_caps.h I2S_PORTS = { VARIANT_ESP32: 2, @@ -60,10 +78,23 @@ I2S_PORTS = { i2s_channel_fmt_t = cg.global_ns.enum("i2s_channel_fmt_t") I2S_CHANNELS = { - CONF_MONO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ALL_LEFT, - CONF_LEFT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_LEFT, - CONF_RIGHT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_RIGHT, - CONF_STEREO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_RIGHT_LEFT, + CONF_MONO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ALL_LEFT, # left data to both channels + CONF_LEFT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_LEFT, # mono data + CONF_RIGHT: i2s_channel_fmt_t.I2S_CHANNEL_FMT_ONLY_RIGHT, # mono data + CONF_STEREO: i2s_channel_fmt_t.I2S_CHANNEL_FMT_RIGHT_LEFT, # stereo data to both channels +} + +i2s_slot_mode_t = cg.global_ns.enum("i2s_slot_mode_t") +I2S_SLOT_MODE = { + CONF_MONO: i2s_slot_mode_t.I2S_SLOT_MODE_MONO, + CONF_STEREO: i2s_slot_mode_t.I2S_SLOT_MODE_STEREO, +} + +i2s_std_slot_mask_t = cg.global_ns.enum("i2s_std_slot_mask_t") +I2S_STD_SLOT_MASK = { + CONF_LEFT: i2s_std_slot_mask_t.I2S_STD_SLOT_LEFT, + CONF_RIGHT: i2s_std_slot_mask_t.I2S_STD_SLOT_RIGHT, + CONF_BOTH: i2s_std_slot_mask_t.I2S_STD_SLOT_BOTH, } i2s_bits_per_sample_t = cg.global_ns.enum("i2s_bits_per_sample_t") @@ -83,9 +114,37 @@ I2S_BITS_PER_CHANNEL = { 32: i2s_bits_per_chan_t.I2S_BITS_PER_CHAN_32BIT, } +i2s_slot_bit_width_t = cg.global_ns.enum("i2s_slot_bit_width_t") +I2S_SLOT_BIT_WIDTH = { + "default": i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_AUTO, + 8: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_8BIT, + 16: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_16BIT, + 24: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_24BIT, + 32: i2s_slot_bit_width_t.I2S_SLOT_BIT_WIDTH_32BIT, +} + +i2s_mclk_multiple_t = cg.global_ns.enum("i2s_mclk_multiple_t") +I2S_MCLK_MULTIPLE = { + 128: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_128, + 256: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_256, + 384: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_384, + 512: i2s_mclk_multiple_t.I2S_MCLK_MULTIPLE_512, +} + _validate_bits = cv.float_with_unit("bits", "bit") +def validate_mclk_divisible_by_3(config): + if config[CONF_BITS_PER_SAMPLE] == 24 and config[CONF_MCLK_MULTIPLE] % 3 != 0: + raise cv.Invalid( + f"{CONF_MCLK_MULTIPLE} must be divisible by 3 when bits per sample is 24" + ) + return config + + +_use_legacy_driver = None + + def i2s_audio_component_schema( class_: MockObjClass, *, @@ -97,43 +156,83 @@ def i2s_audio_component_schema( { cv.GenerateID(): cv.declare_id(class_), cv.GenerateID(CONF_I2S_AUDIO_ID): cv.use_id(I2SAudioComponent), - cv.Optional(CONF_CHANNEL, default=default_channel): cv.enum(I2S_CHANNELS), + cv.Optional(CONF_CHANNEL, default=default_channel): cv.one_of( + *I2S_CHANNELS + ), cv.Optional(CONF_SAMPLE_RATE, default=default_sample_rate): cv.int_range( min=1 ), cv.Optional(CONF_BITS_PER_SAMPLE, default=default_bits_per_sample): cv.All( - _validate_bits, cv.enum(I2S_BITS_PER_SAMPLE) + _validate_bits, cv.one_of(*I2S_BITS_PER_SAMPLE) ), - cv.Optional(CONF_I2S_MODE, default=CONF_PRIMARY): cv.enum( - I2S_MODE_OPTIONS, lower=True + cv.Optional(CONF_I2S_MODE, default=CONF_PRIMARY): cv.one_of( + *I2S_MODE_OPTIONS, lower=True ), cv.Optional(CONF_USE_APLL, default=False): cv.boolean, cv.Optional(CONF_BITS_PER_CHANNEL, default="default"): cv.All( cv.Any(cv.float_with_unit("bits", "bit"), "default"), - cv.enum(I2S_BITS_PER_CHANNEL), + cv.one_of(*I2S_BITS_PER_CHANNEL), ), + cv.Optional(CONF_MCLK_MULTIPLE, default=256): cv.one_of(*I2S_MCLK_MULTIPLE), } ) async def register_i2s_audio_component(var, config): await cg.register_parented(var, config[CONF_I2S_AUDIO_ID]) - - cg.add(var.set_i2s_mode(config[CONF_I2S_MODE])) - cg.add(var.set_channel(config[CONF_CHANNEL])) + if use_legacy(): + cg.add(var.set_i2s_mode(I2S_MODE_OPTIONS[config[CONF_I2S_MODE]])) + cg.add(var.set_channel(I2S_CHANNELS[config[CONF_CHANNEL]])) + cg.add( + var.set_bits_per_sample(I2S_BITS_PER_SAMPLE[config[CONF_BITS_PER_SAMPLE]]) + ) + cg.add( + var.set_bits_per_channel( + I2S_BITS_PER_CHANNEL[config[CONF_BITS_PER_CHANNEL]] + ) + ) + else: + cg.add(var.set_i2s_role(I2S_ROLE_OPTIONS[config[CONF_I2S_MODE]])) + slot_mode = config[CONF_CHANNEL] + if slot_mode != CONF_STEREO: + slot_mode = CONF_MONO + slot_mask = config[CONF_CHANNEL] + if slot_mask not in [CONF_LEFT, CONF_RIGHT]: + slot_mask = CONF_BOTH + cg.add(var.set_slot_mode(I2S_SLOT_MODE[slot_mode])) + cg.add(var.set_std_slot_mask(I2S_STD_SLOT_MASK[slot_mask])) + cg.add(var.set_slot_bit_width(I2S_SLOT_BIT_WIDTH[config[CONF_BITS_PER_SAMPLE]])) cg.add(var.set_sample_rate(config[CONF_SAMPLE_RATE])) - cg.add(var.set_bits_per_sample(config[CONF_BITS_PER_SAMPLE])) - cg.add(var.set_bits_per_channel(config[CONF_BITS_PER_CHANNEL])) cg.add(var.set_use_apll(config[CONF_USE_APLL])) + cg.add(var.set_mclk_multiple(I2S_MCLK_MULTIPLE[config[CONF_MCLK_MULTIPLE]])) -CONFIG_SCHEMA = cv.Schema( - { - cv.GenerateID(): cv.declare_id(I2SAudioComponent), - cv.Required(CONF_I2S_LRCLK_PIN): pins.internal_gpio_output_pin_number, - cv.Optional(CONF_I2S_BCLK_PIN): pins.internal_gpio_output_pin_number, - cv.Optional(CONF_I2S_MCLK_PIN): pins.internal_gpio_output_pin_number, - } +def validate_use_legacy(value): + global _use_legacy_driver # noqa: PLW0603 + if CONF_USE_LEGACY in value: + if (_use_legacy_driver is not None) and ( + _use_legacy_driver != value[CONF_USE_LEGACY] + ): + raise cv.Invalid( + f"All i2s_audio components must set {CONF_USE_LEGACY} to the same value." + ) + if (not value[CONF_USE_LEGACY]) and (CORE.using_arduino): + raise cv.Invalid("Arduino supports only the legacy i2s driver.") + _use_legacy_driver = value[CONF_USE_LEGACY] + return value + + +CONFIG_SCHEMA = cv.All( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(I2SAudioComponent), + cv.Required(CONF_I2S_LRCLK_PIN): pins.internal_gpio_output_pin_number, + cv.Optional(CONF_I2S_BCLK_PIN): pins.internal_gpio_output_pin_number, + cv.Optional(CONF_I2S_MCLK_PIN): pins.internal_gpio_output_pin_number, + cv.Optional(CONF_USE_LEGACY): cv.boolean, + }, + ), + validate_use_legacy, ) @@ -148,12 +247,22 @@ def _final_validate(_): ) +def use_legacy(): + framework_version = CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] + if CORE.using_esp_idf and framework_version >= cv.Version(5, 0, 0): + if not _use_legacy_driver: + return False + return True + + FINAL_VALIDATE_SCHEMA = _final_validate async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) + if use_legacy(): + cg.add_define("USE_I2S_LEGACY") cg.add(var.set_lrclk_pin(config[CONF_I2S_LRCLK_PIN])) if CONF_I2S_BCLK_PIN in config: diff --git a/esphome/components/i2s_audio/i2s_audio.h b/esphome/components/i2s_audio/i2s_audio.h index 7e2798c33d..e839bcd891 100644 --- a/esphome/components/i2s_audio/i2s_audio.h +++ b/esphome/components/i2s_audio/i2s_audio.h @@ -2,9 +2,14 @@ #ifdef USE_ESP32 -#include #include "esphome/core/component.h" #include "esphome/core/helpers.h" +#include "esphome/core/defines.h" +#ifdef USE_I2S_LEGACY +#include +#else +#include +#endif namespace esphome { namespace i2s_audio { @@ -13,20 +18,36 @@ class I2SAudioComponent; class I2SAudioBase : public Parented { public: +#ifdef USE_I2S_LEGACY void set_i2s_mode(i2s_mode_t mode) { this->i2s_mode_ = mode; } void set_channel(i2s_channel_fmt_t channel) { this->channel_ = channel; } - void set_sample_rate(uint32_t sample_rate) { this->sample_rate_ = sample_rate; } void set_bits_per_sample(i2s_bits_per_sample_t bits_per_sample) { this->bits_per_sample_ = bits_per_sample; } void set_bits_per_channel(i2s_bits_per_chan_t bits_per_channel) { this->bits_per_channel_ = bits_per_channel; } +#else + void set_i2s_role(i2s_role_t role) { this->i2s_role_ = role; } + void set_slot_mode(i2s_slot_mode_t slot_mode) { this->slot_mode_ = slot_mode; } + void set_std_slot_mask(i2s_std_slot_mask_t std_slot_mask) { this->std_slot_mask_ = std_slot_mask; } + void set_slot_bit_width(i2s_slot_bit_width_t slot_bit_width) { this->slot_bit_width_ = slot_bit_width; } +#endif + void set_sample_rate(uint32_t sample_rate) { this->sample_rate_ = sample_rate; } void set_use_apll(uint32_t use_apll) { this->use_apll_ = use_apll; } + void set_mclk_multiple(i2s_mclk_multiple_t mclk_multiple) { this->mclk_multiple_ = mclk_multiple; } protected: +#ifdef USE_I2S_LEGACY i2s_mode_t i2s_mode_{}; i2s_channel_fmt_t channel_; - uint32_t sample_rate_; i2s_bits_per_sample_t bits_per_sample_; i2s_bits_per_chan_t bits_per_channel_; +#else + i2s_role_t i2s_role_{}; + i2s_slot_mode_t slot_mode_; + i2s_std_slot_mask_t std_slot_mask_; + i2s_slot_bit_width_t slot_bit_width_; +#endif + uint32_t sample_rate_; bool use_apll_; + i2s_mclk_multiple_t mclk_multiple_; }; class I2SAudioIn : public I2SAudioBase {}; @@ -37,6 +58,7 @@ class I2SAudioComponent : public Component { public: void setup() override; +#ifdef USE_I2S_LEGACY i2s_pin_config_t get_pin_config() const { return { .mck_io_num = this->mclk_pin_, @@ -46,6 +68,20 @@ class I2SAudioComponent : public Component { .data_in_num = I2S_PIN_NO_CHANGE, }; } +#else + i2s_std_gpio_config_t get_pin_config() const { + return {.mclk = (gpio_num_t) this->mclk_pin_, + .bclk = (gpio_num_t) this->bclk_pin_, + .ws = (gpio_num_t) this->lrclk_pin_, + .dout = I2S_GPIO_UNUSED, // add local ports + .din = I2S_GPIO_UNUSED, + .invert_flags = { + .mclk_inv = false, + .bclk_inv = false, + .ws_inv = false, + }}; + } +#endif void set_mclk_pin(int pin) { this->mclk_pin_ = pin; } void set_bclk_pin(int pin) { this->bclk_pin_ = pin; } @@ -62,9 +98,13 @@ class I2SAudioComponent : public Component { I2SAudioIn *audio_in_{nullptr}; I2SAudioOut *audio_out_{nullptr}; - +#ifdef USE_I2S_LEGACY int mclk_pin_{I2S_PIN_NO_CHANGE}; int bclk_pin_{I2S_PIN_NO_CHANGE}; +#else + int mclk_pin_{I2S_GPIO_UNUSED}; + int bclk_pin_{I2S_GPIO_UNUSED}; +#endif int lrclk_pin_; i2s_port_t port_{}; }; diff --git a/esphome/components/i2s_audio/media_player/__init__.py b/esphome/components/i2s_audio/media_player/__init__.py index 2882729b1e..bed25b011f 100644 --- a/esphome/components/i2s_audio/media_player/__init__.py +++ b/esphome/components/i2s_audio/media_player/__init__.py @@ -14,6 +14,7 @@ from .. import ( I2SAudioComponent, I2SAudioOut, i2s_audio_ns, + use_legacy, ) CODEOWNERS = ["@jesserockz"] @@ -87,6 +88,14 @@ CONFIG_SCHEMA = cv.All( ) +def _final_validate(_): + if not use_legacy(): + raise cv.Invalid("I2S media player is only compatible with legacy i2s driver.") + + +FINAL_VALIDATE_SCHEMA = _final_validate + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) diff --git a/esphome/components/i2s_audio/microphone/__init__.py b/esphome/components/i2s_audio/microphone/__init__.py index 161046e962..7bbb94f6e3 100644 --- a/esphome/components/i2s_audio/microphone/__init__.py +++ b/esphome/components/i2s_audio/microphone/__init__.py @@ -1,17 +1,28 @@ from esphome import pins import esphome.codegen as cg -from esphome.components import esp32, microphone +from esphome.components import audio, esp32, microphone from esphome.components.adc import ESP32_VARIANT_ADC1_PIN_TO_CHANNEL, validate_adc_pin import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_NUMBER +from esphome.const import ( + CONF_BITS_PER_SAMPLE, + CONF_CHANNEL, + CONF_ID, + CONF_NUM_CHANNELS, + CONF_NUMBER, + CONF_SAMPLE_RATE, +) from .. import ( CONF_I2S_DIN_PIN, + CONF_LEFT, + CONF_MONO, CONF_RIGHT, I2SAudioIn, i2s_audio_component_schema, i2s_audio_ns, register_i2s_audio_component, + use_legacy, + validate_mclk_divisible_by_3, ) CODEOWNERS = ["@jesserockz"] @@ -19,6 +30,7 @@ DEPENDENCIES = ["i2s_audio"] CONF_ADC_PIN = "adc_pin" CONF_ADC_TYPE = "adc_type" +CONF_CORRECT_DC_OFFSET = "correct_dc_offset" CONF_PDM = "pdm" I2SAudioMicrophone = i2s_audio_ns.class_( @@ -29,7 +41,7 @@ INTERNAL_ADC_VARIANTS = [esp32.const.VARIANT_ESP32] PDM_VARIANTS = [esp32.const.VARIANT_ESP32, esp32.const.VARIANT_ESP32S3] -def validate_esp32_variant(config): +def _validate_esp32_variant(config): variant = esp32.get_esp32_variant() if config[CONF_ADC_TYPE] == "external": if config[CONF_PDM]: @@ -43,16 +55,47 @@ def validate_esp32_variant(config): raise NotImplementedError +def _validate_channel(config): + if config[CONF_CHANNEL] == CONF_MONO: + raise cv.Invalid(f"I2S microphone does not support {CONF_MONO}.") + return config + + +def _set_num_channels_from_config(config): + if config[CONF_CHANNEL] in (CONF_LEFT, CONF_RIGHT): + config[CONF_NUM_CHANNELS] = 1 + else: + config[CONF_NUM_CHANNELS] = 2 + + return config + + +def _set_stream_limits(config): + audio.set_stream_limits( + min_bits_per_sample=config.get(CONF_BITS_PER_SAMPLE), + max_bits_per_sample=config.get(CONF_BITS_PER_SAMPLE), + min_channels=config.get(CONF_NUM_CHANNELS), + max_channels=config.get(CONF_NUM_CHANNELS), + min_sample_rate=config.get(CONF_SAMPLE_RATE), + max_sample_rate=config.get(CONF_SAMPLE_RATE), + )(config) + + return config + + BASE_SCHEMA = microphone.MICROPHONE_SCHEMA.extend( i2s_audio_component_schema( I2SAudioMicrophone, default_sample_rate=16000, default_channel=CONF_RIGHT, default_bits_per_sample="32bit", + ).extend( + { + cv.Optional(CONF_CORRECT_DC_OFFSET, default=False): cv.boolean, + } ) ).extend(cv.COMPONENT_SCHEMA) - CONFIG_SCHEMA = cv.All( cv.typed_schema( { @@ -70,10 +113,23 @@ CONFIG_SCHEMA = cv.All( }, key=CONF_ADC_TYPE, ), - validate_esp32_variant, + _validate_esp32_variant, + _validate_channel, + _set_num_channels_from_config, + _set_stream_limits, + validate_mclk_divisible_by_3, ) +def _final_validate(config): + if not use_legacy(): + if config[CONF_ADC_TYPE] == "internal": + raise cv.Invalid("Internal ADC is only compatible with legacy i2s driver.") + + +FINAL_VALIDATE_SCHEMA = _final_validate + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) @@ -88,3 +144,5 @@ async def to_code(config): else: cg.add(var.set_din_pin(config[CONF_I2S_DIN_PIN])) cg.add(var.set_pdm(config[CONF_PDM])) + + cg.add(var.set_correct_dc_offset(config[CONF_CORRECT_DC_OFFSET])) diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp index 4dbc9dcdac..2ff1daa197 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.cpp @@ -2,20 +2,46 @@ #ifdef USE_ESP32 +#ifdef USE_I2S_LEGACY #include +#else +#include +#include +#endif #include "esphome/core/hal.h" #include "esphome/core/log.h" +#include "esphome/components/audio/audio.h" + namespace esphome { namespace i2s_audio { -static const size_t BUFFER_SIZE = 512; +static const UBaseType_t MAX_LISTENERS = 16; + +static const uint32_t READ_DURATION_MS = 16; + +static const size_t TASK_STACK_SIZE = 4096; +static const ssize_t TASK_PRIORITY = 23; + +// Use an exponential moving average to correct a DC offset with weight factor 1/1000 +static const int32_t DC_OFFSET_MOVING_AVERAGE_COEFFICIENT_DENOMINATOR = 1000; static const char *const TAG = "i2s_audio.microphone"; +enum MicrophoneEventGroupBits : uint32_t { + COMMAND_STOP = (1 << 0), // stops the microphone task + TASK_STARTING = (1 << 10), + TASK_RUNNING = (1 << 11), + TASK_STOPPING = (1 << 12), + TASK_STOPPED = (1 << 13), + + ALL_BITS = 0x00FFFFFF, // All valid FreeRTOS event group bits +}; + void I2SAudioMicrophone::setup() { ESP_LOGCONFIG(TAG, "Setting up I2S Audio Microphone..."); +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC if (this->adc_) { if (this->parent_->get_port() != I2S_NUM_0) { @@ -24,6 +50,7 @@ void I2SAudioMicrophone::setup() { return; } } else +#endif #endif { if (this->pdm_) { @@ -34,19 +61,75 @@ void I2SAudioMicrophone::setup() { } } } + + this->active_listeners_semaphore_ = xSemaphoreCreateCounting(MAX_LISTENERS, MAX_LISTENERS); + if (this->active_listeners_semaphore_ == nullptr) { + ESP_LOGE(TAG, "Failed to create semaphore"); + this->mark_failed(); + return; + } + + this->event_group_ = xEventGroupCreate(); + if (this->event_group_ == nullptr) { + ESP_LOGE(TAG, "Failed to create event group"); + this->mark_failed(); + return; + } + + this->configure_stream_settings_(); +} + +void I2SAudioMicrophone::configure_stream_settings_() { + uint8_t channel_count = 1; +#ifdef USE_I2S_LEGACY + uint8_t bits_per_sample = this->bits_per_sample_; + + if (this->channel_ == I2S_CHANNEL_FMT_RIGHT_LEFT) { + channel_count = 2; + } +#else + uint8_t bits_per_sample = 16; + if (this->slot_bit_width_ != I2S_SLOT_BIT_WIDTH_AUTO) { + bits_per_sample = this->slot_bit_width_; + } + + if (this->slot_mode_ == I2S_SLOT_MODE_STEREO) { + channel_count = 2; + } +#endif + +#ifdef USE_ESP32_VARIANT_ESP32 + // ESP32 reads audio aligned to a multiple of 2 bytes. For example, if configured for 24 bits per sample, then it will + // produce 32 bits per sample, where the actual data is in the most significant bits. Other ESP32 variants produce 24 + // bits per sample in this situation. + if (bits_per_sample < 16) { + bits_per_sample = 16; + } else if ((bits_per_sample > 16) && (bits_per_sample <= 32)) { + bits_per_sample = 32; + } +#endif + + if (this->pdm_) { + bits_per_sample = 16; // PDM mics are always 16 bits per sample + } + + this->audio_stream_info_ = audio::AudioStreamInfo(bits_per_sample, channel_count, this->sample_rate_); } void I2SAudioMicrophone::start() { if (this->is_failed()) return; - if (this->state_ == microphone::STATE_RUNNING) - return; // Already running - this->state_ = microphone::STATE_STARTING; + + xSemaphoreTake(this->active_listeners_semaphore_, 0); } -void I2SAudioMicrophone::start_() { + +bool I2SAudioMicrophone::start_driver_() { if (!this->parent_->try_lock()) { - return; // Waiting for another i2s to return lock + return false; // Waiting for another i2s to return lock } + esp_err_t err; + +#ifdef USE_I2S_LEGACY i2s_driver_config_t config = { .mode = (i2s_mode_t) (this->i2s_mode_ | I2S_MODE_RX), .sample_rate = this->sample_rate_, @@ -55,16 +138,14 @@ void I2SAudioMicrophone::start_() { .communication_format = I2S_COMM_FORMAT_STAND_I2S, .intr_alloc_flags = ESP_INTR_FLAG_LEVEL1, .dma_buf_count = 4, - .dma_buf_len = 256, + .dma_buf_len = 240, // Must be divisible by 3 to support 24 bits per sample on old driver and newer variants .use_apll = this->use_apll_, .tx_desc_auto_clear = false, .fixed_mclk = 0, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, .bits_per_chan = this->bits_per_channel_, }; - esp_err_t err; - #if SOC_I2S_SUPPORTS_ADC if (this->adc_) { config.mode = (i2s_mode_t) (config.mode | I2S_MODE_ADC_BUILT_IN); @@ -72,20 +153,20 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error installing I2S driver: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } err = i2s_set_adc_mode(ADC_UNIT_1, this->adc_channel_); if (err != ESP_OK) { ESP_LOGW(TAG, "Error setting ADC mode: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } err = i2s_adc_enable(this->parent_->get_port()); if (err != ESP_OK) { ESP_LOGW(TAG, "Error enabling ADC: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } } else @@ -98,7 +179,7 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error installing I2S driver: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } i2s_pin_config_t pin_config = this->parent_->get_pin_config(); @@ -108,26 +189,122 @@ void I2SAudioMicrophone::start_() { if (err != ESP_OK) { ESP_LOGW(TAG, "Error setting I2S pin: %s", esp_err_to_name(err)); this->status_set_error(); - return; + return false; } } - this->state_ = microphone::STATE_RUNNING; - this->high_freq_.start(); +#else + i2s_chan_config_t chan_cfg = { + .id = this->parent_->get_port(), + .role = this->i2s_role_, + .dma_desc_num = 4, + .dma_frame_num = 256, + .auto_clear = false, + }; + /* Allocate a new RX channel and get the handle of this channel */ + err = i2s_new_channel(&chan_cfg, NULL, &this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error creating new I2S channel: %s", esp_err_to_name(err)); + this->status_set_error(); + return false; + } + + i2s_clock_src_t clk_src = I2S_CLK_SRC_DEFAULT; +#ifdef I2S_CLK_SRC_APLL + if (this->use_apll_) { + clk_src = I2S_CLK_SRC_APLL; + } +#endif + i2s_std_gpio_config_t pin_config = this->parent_->get_pin_config(); +#if SOC_I2S_SUPPORTS_PDM_RX + if (this->pdm_) { + i2s_pdm_rx_clk_config_t clk_cfg = { + .sample_rate_hz = this->sample_rate_, + .clk_src = clk_src, + .mclk_multiple = this->mclk_multiple_, + .dn_sample_mode = I2S_PDM_DSR_8S, + }; + + i2s_pdm_rx_slot_config_t slot_cfg = I2S_PDM_RX_SLOT_DEFAULT_CONFIG(I2S_DATA_BIT_WIDTH_16BIT, this->slot_mode_); + switch (this->std_slot_mask_) { + case I2S_STD_SLOT_LEFT: + slot_cfg.slot_mask = I2S_PDM_SLOT_LEFT; + break; + case I2S_STD_SLOT_RIGHT: + slot_cfg.slot_mask = I2S_PDM_SLOT_RIGHT; + break; + case I2S_STD_SLOT_BOTH: + slot_cfg.slot_mask = I2S_PDM_SLOT_BOTH; + break; + } + + /* Init the channel into PDM RX mode */ + i2s_pdm_rx_config_t pdm_rx_cfg = { + .clk_cfg = clk_cfg, + .slot_cfg = slot_cfg, + .gpio_cfg = + { + .clk = pin_config.ws, + .din = this->din_pin_, + .invert_flags = + { + .clk_inv = pin_config.invert_flags.ws_inv, + }, + }, + }; + err = i2s_channel_init_pdm_rx_mode(this->rx_handle_, &pdm_rx_cfg); + } else +#endif + { + i2s_std_clk_config_t clk_cfg = { + .sample_rate_hz = this->sample_rate_, + .clk_src = clk_src, + .mclk_multiple = this->mclk_multiple_, + }; + i2s_std_slot_config_t std_slot_cfg = + I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) this->slot_bit_width_, this->slot_mode_); + std_slot_cfg.slot_bit_width = this->slot_bit_width_; + std_slot_cfg.slot_mask = this->std_slot_mask_; + + pin_config.din = this->din_pin_; + + i2s_std_config_t std_cfg = { + .clk_cfg = clk_cfg, + .slot_cfg = std_slot_cfg, + .gpio_cfg = pin_config, + }; + /* Initialize the channel */ + err = i2s_channel_init_std_mode(this->rx_handle_, &std_cfg); + } + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error initializing I2S channel: %s", esp_err_to_name(err)); + this->status_set_error(); + return false; + } + + /* Before reading data, start the RX channel first */ + i2s_channel_enable(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error enabling I2S Microphone: %s", esp_err_to_name(err)); + this->status_set_error(); + return false; + } +#endif + this->status_clear_error(); + this->configure_stream_settings_(); // redetermine the settings in case some settings were changed after compilation + return true; } void I2SAudioMicrophone::stop() { if (this->state_ == microphone::STATE_STOPPED || this->is_failed()) return; - if (this->state_ == microphone::STATE_STARTING) { - this->state_ = microphone::STATE_STOPPED; - return; - } - this->state_ = microphone::STATE_STOPPING; + + xSemaphoreGive(this->active_listeners_semaphore_); } -void I2SAudioMicrophone::stop_() { +void I2SAudioMicrophone::stop_driver_() { esp_err_t err; +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC if (this->adc_) { err = i2s_adc_disable(this->parent_->get_port()); @@ -150,68 +327,181 @@ void I2SAudioMicrophone::stop_() { this->status_set_error(); return; } +#else + /* Have to stop the channel before deleting it */ + err = i2s_channel_disable(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error stopping I2S microphone: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } + /* If the handle is not needed any more, delete it to release the channel resources */ + err = i2s_del_channel(this->rx_handle_); + if (err != ESP_OK) { + ESP_LOGW(TAG, "Error deleting I2S channel: %s", esp_err_to_name(err)); + this->status_set_error(); + return; + } +#endif this->parent_->unlock(); - this->state_ = microphone::STATE_STOPPED; - this->high_freq_.stop(); this->status_clear_error(); } -size_t I2SAudioMicrophone::read(int16_t *buf, size_t len) { +void I2SAudioMicrophone::mic_task(void *params) { + I2SAudioMicrophone *this_microphone = (I2SAudioMicrophone *) params; + + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_STARTING); + + uint8_t start_counter = 0; + bool started = this_microphone->start_driver_(); + while (!started && start_counter < 10) { + // Attempt to load the driver again in 100 ms. Doesn't slow down main loop since its in a task. + vTaskDelay(pdMS_TO_TICKS(100)); + ++start_counter; + started = this_microphone->start_driver_(); + } + + if (started) { + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_RUNNING); + const size_t bytes_to_read = this_microphone->audio_stream_info_.ms_to_bytes(READ_DURATION_MS); + std::vector samples; + samples.reserve(bytes_to_read); + + while (!(xEventGroupGetBits(this_microphone->event_group_) & COMMAND_STOP)) { + if (this_microphone->data_callbacks_.size() > 0) { + samples.resize(bytes_to_read); + size_t bytes_read = this_microphone->read_(samples.data(), bytes_to_read, 2 * pdMS_TO_TICKS(READ_DURATION_MS)); + samples.resize(bytes_read); + if (this_microphone->correct_dc_offset_) { + this_microphone->fix_dc_offset_(samples); + } + this_microphone->data_callbacks_.call(samples); + } else { + vTaskDelay(pdMS_TO_TICKS(READ_DURATION_MS)); + } + } + } + + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_STOPPING); + this_microphone->stop_driver_(); + + xEventGroupSetBits(this_microphone->event_group_, MicrophoneEventGroupBits::TASK_STOPPED); + while (true) { + // Continuously delay until the loop method deletes the task + vTaskDelay(pdMS_TO_TICKS(10)); + } +} + +void I2SAudioMicrophone::fix_dc_offset_(std::vector &data) { + const size_t bytes_per_sample = this->audio_stream_info_.samples_to_bytes(1); + const uint32_t total_samples = this->audio_stream_info_.bytes_to_samples(data.size()); + + if (total_samples == 0) { + return; + } + + int64_t offset_accumulator = 0; + for (uint32_t sample_index = 0; sample_index < total_samples; ++sample_index) { + const uint32_t byte_index = sample_index * bytes_per_sample; + int32_t sample = audio::unpack_audio_sample_to_q31(&data[byte_index], bytes_per_sample); + offset_accumulator += sample; + sample -= this->dc_offset_; + audio::pack_q31_as_audio_sample(sample, &data[byte_index], bytes_per_sample); + } + + const int32_t new_offset = offset_accumulator / total_samples; + this->dc_offset_ = new_offset / DC_OFFSET_MOVING_AVERAGE_COEFFICIENT_DENOMINATOR + + (DC_OFFSET_MOVING_AVERAGE_COEFFICIENT_DENOMINATOR - 1) * this->dc_offset_ / + DC_OFFSET_MOVING_AVERAGE_COEFFICIENT_DENOMINATOR; +} + +size_t I2SAudioMicrophone::read_(uint8_t *buf, size_t len, TickType_t ticks_to_wait) { size_t bytes_read = 0; - esp_err_t err = i2s_read(this->parent_->get_port(), buf, len, &bytes_read, (100 / portTICK_PERIOD_MS)); - if (err != ESP_OK) { +#ifdef USE_I2S_LEGACY + esp_err_t err = i2s_read(this->parent_->get_port(), buf, len, &bytes_read, ticks_to_wait); +#else + // i2s_channel_read expects the timeout value in ms, not ticks + esp_err_t err = i2s_channel_read(this->rx_handle_, buf, len, &bytes_read, pdTICKS_TO_MS(ticks_to_wait)); +#endif + if ((err != ESP_OK) && ((err != ESP_ERR_TIMEOUT) || (ticks_to_wait != 0))) { + // Ignore ESP_ERR_TIMEOUT if ticks_to_wait = 0, as it will read the data on the next call ESP_LOGW(TAG, "Error reading from I2S microphone: %s", esp_err_to_name(err)); this->status_set_warning(); return 0; } - if (bytes_read == 0) { + if ((bytes_read == 0) && (ticks_to_wait > 0)) { this->status_set_warning(); return 0; } this->status_clear_warning(); - // ESP-IDF I2S implementation right-extends 8-bit data to 16 bits, - // and 24-bit data to 32 bits. - switch (this->bits_per_sample_) { - case I2S_BITS_PER_SAMPLE_8BIT: - case I2S_BITS_PER_SAMPLE_16BIT: - return bytes_read; - case I2S_BITS_PER_SAMPLE_24BIT: - case I2S_BITS_PER_SAMPLE_32BIT: { - size_t samples_read = bytes_read / sizeof(int32_t); - for (size_t i = 0; i < samples_read; i++) { - int32_t temp = reinterpret_cast(buf)[i] >> 14; - buf[i] = clamp(temp, INT16_MIN, INT16_MAX); - } - return samples_read * sizeof(int16_t); +#if defined(USE_ESP32_VARIANT_ESP32) and not defined(USE_I2S_LEGACY) + // For ESP32 8/16 bit standard mono mode samples need to be switched. + if (this->slot_mode_ == I2S_SLOT_MODE_MONO && this->slot_bit_width_ <= 16 && !this->pdm_) { + size_t samples_read = bytes_read / sizeof(int16_t); + for (int i = 0; i < samples_read; i += 2) { + int16_t tmp = buf[i]; + buf[i] = buf[i + 1]; + buf[i + 1] = tmp; } - default: - ESP_LOGE(TAG, "Unsupported bits per sample: %d", this->bits_per_sample_); - return 0; } -} - -void I2SAudioMicrophone::read_() { - std::vector samples; - samples.resize(BUFFER_SIZE); - size_t bytes_read = this->read(samples.data(), BUFFER_SIZE / sizeof(int16_t)); - samples.resize(bytes_read / sizeof(int16_t)); - this->data_callbacks_.call(samples); +#endif + return bytes_read; } void I2SAudioMicrophone::loop() { + uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); + + if (event_group_bits & MicrophoneEventGroupBits::TASK_STARTING) { + ESP_LOGD(TAG, "Task has started, attempting to setup I2S audio driver"); + xEventGroupClearBits(this->event_group_, MicrophoneEventGroupBits::TASK_STARTING); + } + + if (event_group_bits & MicrophoneEventGroupBits::TASK_RUNNING) { + ESP_LOGD(TAG, "Task is running and reading data"); + + xEventGroupClearBits(this->event_group_, MicrophoneEventGroupBits::TASK_RUNNING); + this->state_ = microphone::STATE_RUNNING; + } + + if (event_group_bits & MicrophoneEventGroupBits::TASK_STOPPING) { + ESP_LOGD(TAG, "Task is stopping, attempting to unload the I2S audio driver"); + xEventGroupClearBits(this->event_group_, MicrophoneEventGroupBits::TASK_STOPPING); + } + + if ((event_group_bits & MicrophoneEventGroupBits::TASK_STOPPED)) { + ESP_LOGD(TAG, "Task is finished, freeing resources"); + vTaskDelete(this->task_handle_); + this->task_handle_ = nullptr; + xEventGroupClearBits(this->event_group_, ALL_BITS); + this->state_ = microphone::STATE_STOPPED; + } + + if ((uxSemaphoreGetCount(this->active_listeners_semaphore_) < MAX_LISTENERS) && + (this->state_ == microphone::STATE_STOPPED)) { + this->state_ = microphone::STATE_STARTING; + } + if ((uxSemaphoreGetCount(this->active_listeners_semaphore_) == MAX_LISTENERS) && + (this->state_ == microphone::STATE_RUNNING)) { + this->state_ = microphone::STATE_STOPPING; + } + switch (this->state_) { - case microphone::STATE_STOPPED: - break; case microphone::STATE_STARTING: - this->start_(); - break; - case microphone::STATE_RUNNING: - if (this->data_callbacks_.size() > 0) { - this->read_(); + if ((this->task_handle_ == nullptr) && !this->status_has_error()) { + xTaskCreate(I2SAudioMicrophone::mic_task, "mic_task", TASK_STACK_SIZE, (void *) this, TASK_PRIORITY, + &this->task_handle_); + + if (this->task_handle_ == nullptr) { + this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); + } } break; + case microphone::STATE_RUNNING: + break; case microphone::STATE_STOPPING: - this->stop_(); + xEventGroupSetBits(this->event_group_, MicrophoneEventGroupBits::COMMAND_STOP); + break; + case microphone::STATE_STOPPED: break; } } diff --git a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h index ea3f357624..39249e879b 100644 --- a/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h +++ b/esphome/components/i2s_audio/microphone/i2s_audio_microphone.h @@ -7,6 +7,11 @@ #include "esphome/components/microphone/microphone.h" #include "esphome/core/component.h" +#include +#include +#include +#include + namespace esphome { namespace i2s_audio { @@ -18,31 +23,60 @@ class I2SAudioMicrophone : public I2SAudioIn, public microphone::Microphone, pub void loop() override; + void set_correct_dc_offset(bool correct_dc_offset) { this->correct_dc_offset_ = correct_dc_offset; } + +#ifdef USE_I2S_LEGACY void set_din_pin(int8_t pin) { this->din_pin_ = pin; } +#else + void set_din_pin(int8_t pin) { this->din_pin_ = (gpio_num_t) pin; } +#endif + void set_pdm(bool pdm) { this->pdm_ = pdm; } - size_t read(int16_t *buf, size_t len) override; - +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_ADC void set_adc_channel(adc1_channel_t channel) { this->adc_channel_ = channel; this->adc_ = true; } +#endif #endif protected: - void start_(); - void stop_(); - void read_(); + bool start_driver_(); + void stop_driver_(); + /// @brief Attempts to correct a microphone DC offset; e.g., a microphones silent level is offset from 0. Applies a + /// correction offset that is updated using an exponential moving average for all samples away from 0. + /// @param data + void fix_dc_offset_(std::vector &data); + + size_t read_(uint8_t *buf, size_t len, TickType_t ticks_to_wait); + + /// @brief Sets the Microphone ``audio_stream_info_`` member variable to the configured I2S settings. + void configure_stream_settings_(); + + static void mic_task(void *params); + + SemaphoreHandle_t active_listeners_semaphore_{nullptr}; + EventGroupHandle_t event_group_{nullptr}; + + TaskHandle_t task_handle_{nullptr}; + +#ifdef USE_I2S_LEGACY int8_t din_pin_{I2S_PIN_NO_CHANGE}; #if SOC_I2S_SUPPORTS_ADC adc1_channel_t adc_channel_{ADC1_CHANNEL_MAX}; bool adc_{false}; +#endif +#else + gpio_num_t din_pin_{I2S_GPIO_UNUSED}; + i2s_chan_handle_t rx_handle_; #endif bool pdm_{false}; - HighFrequencyLoopRequester high_freq_; + bool correct_dc_offset_; + int32_t dc_offset_{0}; }; } // namespace i2s_audio diff --git a/esphome/components/i2s_audio/speaker/__init__.py b/esphome/components/i2s_audio/speaker/__init__.py index aa3b50d336..bb9f24bf0b 100644 --- a/esphome/components/i2s_audio/speaker/__init__.py +++ b/esphome/components/i2s_audio/speaker/__init__.py @@ -26,6 +26,8 @@ from .. import ( i2s_audio_component_schema, i2s_audio_ns, register_i2s_audio_component, + use_legacy, + validate_mclk_divisible_by_3, ) AUTO_LOAD = ["audio"] @@ -60,7 +62,7 @@ I2C_COMM_FMT_OPTIONS = { "pcm_long": i2s_comm_format_t.I2S_COMM_FORMAT_PCM_LONG, } -NO_INTERNAL_DAC_VARIANTS = [esp32.const.VARIANT_ESP32S2] +INTERNAL_DAC_VARIANTS = [esp32.const.VARIANT_ESP32] def _set_num_channels_from_config(config): @@ -101,7 +103,7 @@ def _validate_esp32_variant(config): if config[CONF_DAC_TYPE] != "internal": return config variant = esp32.get_esp32_variant() - if variant in NO_INTERNAL_DAC_VARIANTS: + if variant not in INTERNAL_DAC_VARIANTS: raise cv.Invalid(f"{variant} does not have an internal DAC") return config @@ -143,8 +145,8 @@ CONFIG_SCHEMA = cv.All( cv.Required( CONF_I2S_DOUT_PIN ): pins.internal_gpio_output_pin_number, - cv.Optional(CONF_I2S_COMM_FMT, default="stand_i2s"): cv.enum( - I2C_COMM_FMT_OPTIONS, lower=True + cv.Optional(CONF_I2S_COMM_FMT, default="stand_i2s"): cv.one_of( + *I2C_COMM_FMT_OPTIONS, lower=True ), } ), @@ -154,9 +156,23 @@ CONFIG_SCHEMA = cv.All( _validate_esp32_variant, _set_num_channels_from_config, _set_stream_limits, + validate_mclk_divisible_by_3, ) +def _final_validate(config): + if not use_legacy(): + if config[CONF_DAC_TYPE] == "internal": + raise cv.Invalid("Internal DAC is only compatible with legacy i2s driver.") + if config[CONF_I2S_COMM_FMT] == "stand_max": + raise cv.Invalid( + "I2S standard max format only implemented with legacy i2s driver." + ) + + +FINAL_VALIDATE_SCHEMA = _final_validate + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) @@ -167,7 +183,17 @@ async def to_code(config): cg.add(var.set_internal_dac_mode(config[CONF_CHANNEL])) else: cg.add(var.set_dout_pin(config[CONF_I2S_DOUT_PIN])) - cg.add(var.set_i2s_comm_fmt(config[CONF_I2S_COMM_FMT])) + if use_legacy(): + cg.add( + var.set_i2s_comm_fmt(I2C_COMM_FMT_OPTIONS[config[CONF_I2S_COMM_FMT]]) + ) + else: + fmt = "std" # equals stand_i2s, stand_pcm_long, i2s_msb, pcm_long + if config[CONF_I2S_COMM_FMT] in ["stand_msb", "i2s_lsb"]: + fmt = "msb" + elif config[CONF_I2S_COMM_FMT] in ["stand_pcm_short", "pcm_short", "pcm"]: + fmt = "pcm" + cg.add(var.set_i2s_comm_fmt(fmt)) if config[CONF_TIMEOUT] != CONF_NEVER: cg.add(var.set_timeout(config[CONF_TIMEOUT])) cg.add(var.set_buffer_duration(config[CONF_BUFFER_DURATION])) diff --git a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp index da25914c87..d85409f1a8 100644 --- a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp +++ b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.cpp @@ -2,7 +2,11 @@ #ifdef USE_ESP32 +#ifdef USE_I2S_LEGACY #include +#else +#include +#endif #include "esphome/components/audio/audio.h" @@ -10,6 +14,8 @@ #include "esphome/core/hal.h" #include "esphome/core/log.h" +#include "esp_timer.h" + namespace esphome { namespace i2s_audio { @@ -294,13 +300,21 @@ void I2SAudioSpeaker::speaker_task(void *params) { // Audio stream info changed, stop the speaker task so it will restart with the proper settings. break; } - +#ifdef USE_I2S_LEGACY i2s_event_t i2s_event; while (xQueueReceive(this_speaker->i2s_event_queue_, &i2s_event, 0)) { if (i2s_event.type == I2S_EVENT_TX_Q_OVF) { tx_dma_underflow = true; } } +#else + bool overflow; + while (xQueueReceive(this_speaker->i2s_event_queue_, &overflow, 0)) { + if (overflow) { + tx_dma_underflow = true; + } + } +#endif if (this_speaker->pause_state_) { // Pause state is accessed atomically, so thread safe @@ -319,6 +333,18 @@ void I2SAudioSpeaker::speaker_task(void *params) { bytes_read / sizeof(int16_t), this_speaker->q15_volume_factor_); } +#ifdef USE_ESP32_VARIANT_ESP32 + // For ESP32 8/16 bit mono mode samples need to be switched. + if (audio_stream_info.get_channels() == 1 && audio_stream_info.get_bits_per_sample() <= 16) { + size_t len = bytes_read / sizeof(int16_t); + int16_t *tmp_buf = (int16_t *) this_speaker->data_buffer_; + for (int i = 0; i < len; i += 2) { + int16_t tmp = tmp_buf[i]; + tmp_buf[i] = tmp_buf[i + 1]; + tmp_buf[i + 1] = tmp; + } + } +#endif // Write the audio data to a single DMA buffer at a time to reduce latency for the audio duration played // callback. const uint32_t batches = (bytes_read + single_dma_buffer_input_size - 1) / single_dma_buffer_input_size; @@ -327,6 +353,7 @@ void I2SAudioSpeaker::speaker_task(void *params) { size_t bytes_written = 0; size_t bytes_to_write = std::min(single_dma_buffer_input_size, bytes_read); +#ifdef USE_I2S_LEGACY if (audio_stream_info.get_bits_per_sample() == (uint8_t) this_speaker->bits_per_sample_) { i2s_write(this_speaker->parent_->get_port(), this_speaker->data_buffer_ + i * single_dma_buffer_input_size, bytes_to_write, &bytes_written, pdMS_TO_TICKS(DMA_BUFFER_DURATION_MS * 5)); @@ -336,26 +363,20 @@ void I2SAudioSpeaker::speaker_task(void *params) { audio_stream_info.get_bits_per_sample(), this_speaker->bits_per_sample_, &bytes_written, pdMS_TO_TICKS(DMA_BUFFER_DURATION_MS * 5)); } +#else + i2s_channel_write(this_speaker->tx_handle_, this_speaker->data_buffer_ + i * single_dma_buffer_input_size, + bytes_to_write, &bytes_written, pdMS_TO_TICKS(DMA_BUFFER_DURATION_MS * 5)); +#endif - uint32_t write_timestamp = micros(); + int64_t now = esp_timer_get_time(); if (bytes_written != bytes_to_write) { xEventGroupSetBits(this_speaker->event_group_, SpeakerEventGroupBits::ERR_ESP_INVALID_SIZE); } - bytes_read -= bytes_written; - this_speaker->accumulated_frames_written_ += audio_stream_info.bytes_to_frames(bytes_written); - const uint32_t new_playback_ms = - audio_stream_info.frames_to_milliseconds_with_remainder(&this_speaker->accumulated_frames_written_); - const uint32_t remainder_us = - audio_stream_info.frames_to_microseconds(this_speaker->accumulated_frames_written_); - - uint32_t pending_frames = - audio_stream_info.bytes_to_frames(bytes_read + this_speaker->audio_ring_buffer_->available()); - const uint32_t pending_ms = audio_stream_info.frames_to_milliseconds_with_remainder(&pending_frames); - - this_speaker->audio_output_callback_(new_playback_ms, remainder_us, pending_ms, write_timestamp); + this_speaker->audio_output_callback_(audio_stream_info.bytes_to_frames(bytes_written), + now + dma_buffers_duration_ms * 1000); tx_dma_underflow = false; last_data_received_time = millis(); @@ -369,8 +390,12 @@ void I2SAudioSpeaker::speaker_task(void *params) { } xEventGroupSetBits(this_speaker->event_group_, SpeakerEventGroupBits::STATE_STOPPING); - +#ifdef USE_I2S_LEGACY i2s_driver_uninstall(this_speaker->parent_->get_port()); +#else + i2s_channel_disable(this_speaker->tx_handle_); + i2s_del_channel(this_speaker->tx_handle_); +#endif this_speaker->parent_->unlock(); } @@ -462,12 +487,21 @@ esp_err_t I2SAudioSpeaker::allocate_buffers_(size_t data_buffer_size, size_t rin } esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_stream_info) { +#ifdef USE_I2S_LEGACY if ((this->i2s_mode_ & I2S_MODE_SLAVE) && (this->sample_rate_ != audio_stream_info.get_sample_rate())) { // NOLINT +#else + if ((this->i2s_role_ & I2S_ROLE_SLAVE) && (this->sample_rate_ != audio_stream_info.get_sample_rate())) { // NOLINT +#endif // Can't reconfigure I2S bus, so the sample rate must match the configured value return ESP_ERR_NOT_SUPPORTED; } +#ifdef USE_I2S_LEGACY if ((i2s_bits_per_sample_t) audio_stream_info.get_bits_per_sample() > this->bits_per_sample_) { +#else + if (this->slot_bit_width_ != I2S_SLOT_BIT_WIDTH_AUTO && + (i2s_slot_bit_width_t) audio_stream_info.get_bits_per_sample() > this->slot_bit_width_) { +#endif // Currently can't handle the case when the incoming audio has more bits per sample than the configured value return ESP_ERR_NOT_SUPPORTED; } @@ -476,6 +510,9 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea return ESP_ERR_INVALID_STATE; } + uint32_t dma_buffer_length = audio_stream_info.ms_to_frames(DMA_BUFFER_DURATION_MS); + +#ifdef USE_I2S_LEGACY i2s_channel_fmt_t channel = this->channel_; if (audio_stream_info.get_channels() == 1) { @@ -488,8 +525,6 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea channel = I2S_CHANNEL_FMT_RIGHT_LEFT; } - int dma_buffer_length = audio_stream_info.ms_to_frames(DMA_BUFFER_DURATION_MS); - i2s_driver_config_t config = { .mode = (i2s_mode_t) (this->i2s_mode_ | I2S_MODE_TX), .sample_rate = audio_stream_info.get_sample_rate(), @@ -498,11 +533,11 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea .communication_format = this->i2s_comm_fmt_, .intr_alloc_flags = ESP_INTR_FLAG_LEVEL1, .dma_buf_count = DMA_BUFFERS_COUNT, - .dma_buf_len = dma_buffer_length, + .dma_buf_len = (int) dma_buffer_length, .use_apll = this->use_apll_, .tx_desc_auto_clear = true, .fixed_mclk = I2S_PIN_NO_CHANGE, - .mclk_multiple = I2S_MCLK_MULTIPLE_256, + .mclk_multiple = this->mclk_multiple_, .bits_per_chan = this->bits_per_channel_, #if SOC_I2S_SUPPORTS_TDM .chan_mask = (i2s_channel_t) (I2S_TDM_ACTIVE_CH0 | I2S_TDM_ACTIVE_CH1), @@ -545,6 +580,98 @@ esp_err_t I2SAudioSpeaker::start_i2s_driver_(audio::AudioStreamInfo &audio_strea i2s_driver_uninstall(this->parent_->get_port()); this->parent_->unlock(); } +#else + i2s_chan_config_t chan_cfg = { + .id = this->parent_->get_port(), + .role = this->i2s_role_, + .dma_desc_num = DMA_BUFFERS_COUNT, + .dma_frame_num = dma_buffer_length, + .auto_clear = true, + }; + /* Allocate a new TX channel and get the handle of this channel */ + esp_err_t err = i2s_new_channel(&chan_cfg, &this->tx_handle_, NULL); + if (err != ESP_OK) { + this->parent_->unlock(); + return err; + } + + i2s_clock_src_t clk_src = I2S_CLK_SRC_DEFAULT; +#ifdef I2S_CLK_SRC_APLL + if (this->use_apll_) { + clk_src = I2S_CLK_SRC_APLL; + } +#endif + i2s_std_gpio_config_t pin_config = this->parent_->get_pin_config(); + + i2s_std_clk_config_t clk_cfg = { + .sample_rate_hz = audio_stream_info.get_sample_rate(), + .clk_src = clk_src, + .mclk_multiple = this->mclk_multiple_, + }; + + i2s_slot_mode_t slot_mode = this->slot_mode_; + i2s_std_slot_mask_t slot_mask = this->std_slot_mask_; + if (audio_stream_info.get_channels() == 1) { + slot_mode = I2S_SLOT_MODE_MONO; + } else if (audio_stream_info.get_channels() == 2) { + slot_mode = I2S_SLOT_MODE_STEREO; + slot_mask = I2S_STD_SLOT_BOTH; + } + + i2s_std_slot_config_t std_slot_cfg; + if (this->i2s_comm_fmt_ == "std") { + std_slot_cfg = + I2S_STD_PHILIPS_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) audio_stream_info.get_bits_per_sample(), slot_mode); + } else if (this->i2s_comm_fmt_ == "pcm") { + std_slot_cfg = + I2S_STD_PCM_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) audio_stream_info.get_bits_per_sample(), slot_mode); + } else { + std_slot_cfg = + I2S_STD_MSB_SLOT_DEFAULT_CONFIG((i2s_data_bit_width_t) audio_stream_info.get_bits_per_sample(), slot_mode); + } +#ifdef USE_ESP32_VARIANT_ESP32 + // There seems to be a bug on the ESP32 (non-variant) platform where setting the slot bit width higher then the bits + // per sample causes the audio to play too fast. Setting the ws_width to the configured slot bit width seems to + // make it play at the correct speed while sending more bits per slot. + if (this->slot_bit_width_ != I2S_SLOT_BIT_WIDTH_AUTO) { + std_slot_cfg.ws_width = static_cast(this->slot_bit_width_); + } +#else + std_slot_cfg.slot_bit_width = this->slot_bit_width_; +#endif + std_slot_cfg.slot_mask = slot_mask; + + pin_config.dout = this->dout_pin_; + + i2s_std_config_t std_cfg = { + .clk_cfg = clk_cfg, + .slot_cfg = std_slot_cfg, + .gpio_cfg = pin_config, + }; + /* Initialize the channel */ + err = i2s_channel_init_std_mode(this->tx_handle_, &std_cfg); + + if (err != ESP_OK) { + i2s_del_channel(this->tx_handle_); + this->parent_->unlock(); + return err; + } + if (this->i2s_event_queue_ == nullptr) { + this->i2s_event_queue_ = xQueueCreate(1, sizeof(bool)); + } + const i2s_event_callbacks_t callbacks = { + .on_send_q_ovf = i2s_overflow_cb, + }; + + i2s_channel_register_event_callback(this->tx_handle_, &callbacks, this); + + /* Before reading data, start the TX channel first */ + i2s_channel_enable(this->tx_handle_); + if (err != ESP_OK) { + i2s_del_channel(this->tx_handle_); + this->parent_->unlock(); + } +#endif return err; } @@ -564,6 +691,15 @@ void I2SAudioSpeaker::delete_task_(size_t buffer_size) { vTaskDelete(nullptr); } +#ifndef USE_I2S_LEGACY +bool IRAM_ATTR I2SAudioSpeaker::i2s_overflow_cb(i2s_chan_handle_t handle, i2s_event_data_t *event, void *user_ctx) { + I2SAudioSpeaker *this_speaker = (I2SAudioSpeaker *) user_ctx; + bool overflow = true; + xQueueOverwrite(this_speaker->i2s_event_queue_, &overflow); + return false; +} +#endif + } // namespace i2s_audio } // namespace esphome diff --git a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h index 7b14a57aac..b5e4b94bc4 100644 --- a/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h +++ b/esphome/components/i2s_audio/speaker/i2s_audio_speaker.h @@ -4,8 +4,6 @@ #include "../i2s_audio.h" -#include - #include #include #include @@ -30,11 +28,16 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp void set_buffer_duration(uint32_t buffer_duration_ms) { this->buffer_duration_ms_ = buffer_duration_ms; } void set_timeout(uint32_t ms) { this->timeout_ = ms; } - void set_dout_pin(uint8_t pin) { this->dout_pin_ = pin; } +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_DAC void set_internal_dac_mode(i2s_dac_mode_t mode) { this->internal_dac_mode_ = mode; } #endif + void set_dout_pin(uint8_t pin) { this->dout_pin_ = pin; } void set_i2s_comm_fmt(i2s_comm_format_t mode) { this->i2s_comm_fmt_ = mode; } +#else + void set_dout_pin(uint8_t pin) { this->dout_pin_ = (gpio_num_t) pin; } + void set_i2s_comm_fmt(std::string mode) { this->i2s_comm_fmt_ = std::move(mode); } +#endif void start() override; void stop() override; @@ -86,6 +89,10 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp /// @return True if an ERR_ESP bit is set and false if err == ESP_OK bool send_esp_err_to_event_group_(esp_err_t err); +#ifndef USE_I2S_LEGACY + static bool i2s_overflow_cb(i2s_chan_handle_t handle, i2s_event_data_t *event, void *user_ctx); +#endif + /// @brief Allocates the data buffer and ring buffer /// @param data_buffer_size Number of bytes to allocate for the data buffer. /// @param ring_buffer_size Number of bytes to allocate for the ring buffer. @@ -121,7 +128,6 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp uint32_t buffer_duration_ms_; optional timeout_; - uint8_t dout_pin_; bool task_created_{false}; bool pause_state_{false}; @@ -130,10 +136,17 @@ class I2SAudioSpeaker : public I2SAudioOut, public speaker::Speaker, public Comp size_t bytes_written_{0}; +#ifdef USE_I2S_LEGACY #if SOC_I2S_SUPPORTS_DAC i2s_dac_mode_t internal_dac_mode_{I2S_DAC_CHANNEL_DISABLE}; #endif + uint8_t dout_pin_; i2s_comm_format_t i2s_comm_fmt_; +#else + gpio_num_t dout_pin_; + std::string i2s_comm_fmt_; + i2s_chan_handle_t tx_handle_; +#endif uint32_t accumulated_frames_written_{0}; }; diff --git a/esphome/components/ili9xxx/ili9xxx_init.h b/esphome/components/ili9xxx/ili9xxx_init.h index f05b884be6..7b176ed57a 100644 --- a/esphome/components/ili9xxx/ili9xxx_init.h +++ b/esphome/components/ili9xxx/ili9xxx_init.h @@ -388,7 +388,7 @@ static const uint8_t PROGMEM INITCMD_GC9D01N[] = { 0x8D, 1, 0xFF, 0x8E, 1, 0xFF, 0x8F, 1, 0xFF, - 0X3A, 1, 0x05, // COLMOD: Pixel Format Set (3Ah) MCU interface, 16 bits / pixel + 0x3A, 1, 0x05, // COLMOD: Pixel Format Set (3Ah) MCU interface, 16 bits / pixel 0xEC, 1, 0x01, // Inversion (ECh) DINV=1+2H1V column for Dual Gate (BFh=0) // According to datasheet Inversion (ECh) value 0x01 isn't valid, but Lilygo uses it everywhere 0x74, 7, 0x02, 0x0E, 0x00, 0x00, 0x00, 0x00, 0x00, diff --git a/esphome/components/image/__init__.py b/esphome/components/image/__init__.py index 20b041a321..5d593ac3d4 100644 --- a/esphome/components/image/__init__.py +++ b/esphome/components/image/__init__.py @@ -286,11 +286,22 @@ CONF_TRANSPARENCY = "transparency" IMAGE_DOWNLOAD_TIMEOUT = 30 # seconds SOURCE_LOCAL = "local" -SOURCE_MDI = "mdi" SOURCE_WEB = "web" +SOURCE_MDI = "mdi" +SOURCE_MDIL = "mdil" +SOURCE_MEMORY = "memory" + +MDI_SOURCES = { + SOURCE_MDI: "https://raw.githubusercontent.com/Templarian/MaterialDesign/master/svg/", + SOURCE_MDIL: "https://raw.githubusercontent.com/Pictogrammers/MaterialDesignLight/refs/heads/master/svg/", + SOURCE_MEMORY: "https://raw.githubusercontent.com/Pictogrammers/Memory/refs/heads/main/src/svg/", +} + Image_ = image_ns.class_("Image") +INSTANCE_TYPE = Image_ + def compute_local_image_path(value) -> Path: url = value[CONF_URL] if isinstance(value, dict) else value @@ -311,12 +322,12 @@ def download_file(url, path): return str(path) -def download_mdi(value): +def download_gh_svg(value, source): mdi_id = value[CONF_ICON] if isinstance(value, dict) else value - base_dir = external_files.compute_local_file_dir(DOMAIN) / "mdi" + base_dir = external_files.compute_local_file_dir(DOMAIN) / source path = base_dir / f"{mdi_id}.svg" - url = f"https://raw.githubusercontent.com/Templarian/MaterialDesign/master/svg/{mdi_id}.svg" + url = MDI_SOURCES[source] + mdi_id + ".svg" return download_file(url, path) @@ -351,12 +362,12 @@ def validate_cairosvg_installed(): def validate_file_shorthand(value): value = cv.string_strict(value) - if value.startswith("mdi:"): - match = re.search(r"mdi:([a-zA-Z0-9\-]+)", value) + parts = value.strip().split(":") + if len(parts) == 2 and parts[0] in MDI_SOURCES: + match = re.match(r"[a-zA-Z0-9\-]+", parts[1]) if match is None: - raise cv.Invalid("Could not parse mdi icon name.") - icon = match.group(1) - return download_mdi(icon) + raise cv.Invalid(f"Could not parse mdi icon name from '{value}'.") + return download_gh_svg(parts[1], parts[0]) if value.startswith("http://") or value.startswith("https://"): return download_image(value) @@ -372,12 +383,20 @@ LOCAL_SCHEMA = cv.All( local_path, ) -MDI_SCHEMA = cv.All( - { - cv.Required(CONF_ICON): cv.string, - }, - download_mdi, -) + +def mdi_schema(source): + def validate_mdi(value): + return download_gh_svg(value, source) + + return cv.All( + cv.Schema( + { + cv.Required(CONF_ICON): cv.string, + } + ), + validate_mdi, + ) + WEB_SCHEMA = cv.All( { @@ -386,12 +405,13 @@ WEB_SCHEMA = cv.All( download_image, ) + TYPED_FILE_SCHEMA = cv.typed_schema( { SOURCE_LOCAL: LOCAL_SCHEMA, - SOURCE_MDI: MDI_SCHEMA, SOURCE_WEB: WEB_SCHEMA, - }, + } + | {source: mdi_schema(source) for source in MDI_SOURCES}, key=CONF_SOURCE, ) diff --git a/esphome/components/image/image.cpp b/esphome/components/image/image.cpp index f05f4af711..82e46e3460 100644 --- a/esphome/components/image/image.cpp +++ b/esphome/components/image/image.cpp @@ -6,10 +6,27 @@ namespace esphome { namespace image { void Image::draw(int x, int y, display::Display *display, Color color_on, Color color_off) { + int img_x0 = 0; + int img_y0 = 0; + int w = width_; + int h = height_; + + auto clipping = display->get_clipping(); + if (clipping.is_set()) { + if (clipping.x > x) + img_x0 += clipping.x - x; + if (clipping.y > y) + img_y0 += clipping.y - y; + if (w > clipping.x2() - x) + w = clipping.x2() - x; + if (h > clipping.y2() - y) + h = clipping.y2() - y; + } + switch (type_) { case IMAGE_TYPE_BINARY: { - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { if (this->get_binary_pixel_(img_x, img_y)) { display->draw_pixel_at(x + img_x, y + img_y, color_on); } else if (!this->transparency_) { @@ -20,8 +37,8 @@ void Image::draw(int x, int y, display::Display *display, Color color_on, Color break; } case IMAGE_TYPE_GRAYSCALE: - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { const uint32_t pos = (img_x + img_y * this->width_); const uint8_t gray = progmem_read_byte(this->data_start_ + pos); Color color = Color(gray, gray, gray, 0xFF); @@ -47,8 +64,8 @@ void Image::draw(int x, int y, display::Display *display, Color color_on, Color } break; case IMAGE_TYPE_RGB565: - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { auto color = this->get_rgb565_pixel_(img_x, img_y); if (color.w >= 0x80) { display->draw_pixel_at(x + img_x, y + img_y, color); @@ -57,8 +74,8 @@ void Image::draw(int x, int y, display::Display *display, Color color_on, Color } break; case IMAGE_TYPE_RGB: - for (int img_x = 0; img_x < width_; img_x++) { - for (int img_y = 0; img_y < height_; img_y++) { + for (int img_x = img_x0; img_x < w; img_x++) { + for (int img_y = img_y0; img_y < h; img_y++) { auto color = this->get_rgb_pixel_(img_x, img_y); if (color.w >= 0x80) { display->draw_pixel_at(x + img_x, y + img_y, color); diff --git a/esphome/components/internal_temperature/internal_temperature.cpp b/esphome/components/internal_temperature/internal_temperature.cpp index afa5583e59..d3ff7433b6 100644 --- a/esphome/components/internal_temperature/internal_temperature.cpp +++ b/esphome/components/internal_temperature/internal_temperature.cpp @@ -9,7 +9,7 @@ uint8_t temprature_sens_read(); } #elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) || \ defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32H2) || \ - defined(USE_ESP32_VARIANT_ESP32C2) + defined(USE_ESP32_VARIANT_ESP32C2) || defined(USE_ESP32_VARIANT_ESP32P4) #if ESP_IDF_VERSION < ESP_IDF_VERSION_VAL(5, 0, 0) #include "driver/temp_sensor.h" #else @@ -33,7 +33,8 @@ static const char *const TAG = "internal_temperature"; #ifdef USE_ESP32 #if (ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0)) && \ (defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32H2) || defined(USE_ESP32_VARIANT_ESP32C2)) + defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32H2) || defined(USE_ESP32_VARIANT_ESP32C2) || \ + defined(USE_ESP32_VARIANT_ESP32P4)) static temperature_sensor_handle_t tsensNew = NULL; #endif // ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0) && USE_ESP32_VARIANT #endif // USE_ESP32 @@ -49,7 +50,7 @@ void InternalTemperatureSensor::update() { success = (raw != 128); #elif defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) || \ defined(USE_ESP32_VARIANT_ESP32S2) || defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32H2) || \ - defined(USE_ESP32_VARIANT_ESP32C2) + defined(USE_ESP32_VARIANT_ESP32C2) || defined(USE_ESP32_VARIANT_ESP32P4) #if ESP_IDF_VERSION < ESP_IDF_VERSION_VAL(5, 0, 0) temp_sensor_config_t tsens = TSENS_CONFIG_DEFAULT(); temp_sensor_set_config(tsens); @@ -100,7 +101,8 @@ void InternalTemperatureSensor::setup() { #ifdef USE_ESP32 #if (ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(5, 0, 0)) && \ (defined(USE_ESP32_VARIANT_ESP32C3) || defined(USE_ESP32_VARIANT_ESP32C6) || defined(USE_ESP32_VARIANT_ESP32S2) || \ - defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32H2) || defined(USE_ESP32_VARIANT_ESP32C2)) + defined(USE_ESP32_VARIANT_ESP32S3) || defined(USE_ESP32_VARIANT_ESP32H2) || defined(USE_ESP32_VARIANT_ESP32C2) || \ + defined(USE_ESP32_VARIANT_ESP32P4)) ESP_LOGCONFIG(TAG, "Setting up temperature sensor..."); temperature_sensor_config_t tsens_config = TEMPERATURE_SENSOR_CONFIG_DEFAULT(-10, 80); diff --git a/esphome/components/key_collector/__init__.py b/esphome/components/key_collector/__init__.py index 5750812f5c..17af40da1a 100644 --- a/esphome/components/key_collector/__init__.py +++ b/esphome/components/key_collector/__init__.py @@ -3,6 +3,7 @@ import esphome.codegen as cg from esphome.components import key_provider import esphome.config_validation as cv from esphome.const import ( + CONF_ENABLE_ON_BOOT, CONF_ID, CONF_MAX_LENGTH, CONF_MIN_LENGTH, @@ -28,6 +29,8 @@ CONF_ON_RESULT = "on_result" key_collector_ns = cg.esphome_ns.namespace("key_collector") KeyCollector = key_collector_ns.class_("KeyCollector", cg.Component) +EnableAction = key_collector_ns.class_("EnableAction", automation.Action) +DisableAction = key_collector_ns.class_("DisableAction", automation.Action) CONFIG_SCHEMA = cv.All( cv.COMPONENT_SCHEMA.extend( @@ -46,6 +49,7 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_ON_RESULT): automation.validate_automation(single=True), cv.Optional(CONF_ON_TIMEOUT): automation.validate_automation(single=True), cv.Optional(CONF_TIMEOUT): cv.positive_time_period_milliseconds, + cv.Optional(CONF_ENABLE_ON_BOOT, default=True): cv.boolean, } ), cv.has_at_least_one_key(CONF_END_KEYS, CONF_MAX_LENGTH), @@ -94,3 +98,34 @@ async def to_code(config): ) if CONF_TIMEOUT in config: cg.add(var.set_timeout(config[CONF_TIMEOUT])) + cg.add(var.set_enabled(config[CONF_ENABLE_ON_BOOT])) + + +@automation.register_action( + "key_collector.enable", + EnableAction, + automation.maybe_simple_id( + { + cv.GenerateID(): cv.use_id(KeyCollector), + } + ), +) +async def enable_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var + + +@automation.register_action( + "key_collector.disable", + DisableAction, + automation.maybe_simple_id( + { + cv.GenerateID(): cv.use_id(KeyCollector), + } + ), +) +async def disable_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var diff --git a/esphome/components/key_collector/key_collector.cpp b/esphome/components/key_collector/key_collector.cpp index bf2333d97d..ffb4b47fa2 100644 --- a/esphome/components/key_collector/key_collector.cpp +++ b/esphome/components/key_collector/key_collector.cpp @@ -45,6 +45,12 @@ void KeyCollector::set_provider(key_provider::KeyProvider *provider) { provider->add_on_key_callback([this](uint8_t key) { this->key_pressed_(key); }); } +void KeyCollector::set_enabled(bool enabled) { + this->enabled_ = enabled; + if (!enabled) + this->clear(false); +} + void KeyCollector::clear(bool progress_update) { this->result_.clear(); this->start_key_ = 0; @@ -55,6 +61,8 @@ void KeyCollector::clear(bool progress_update) { void KeyCollector::send_key(uint8_t key) { this->key_pressed_(key); } void KeyCollector::key_pressed_(uint8_t key) { + if (!this->enabled_) + return; this->last_key_time_ = millis(); if (!this->start_keys_.empty() && !this->start_key_) { if (this->start_keys_.find(key) != std::string::npos) { diff --git a/esphome/components/key_collector/key_collector.h b/esphome/components/key_collector/key_collector.h index 7ef53929ef..6e585ddd8e 100644 --- a/esphome/components/key_collector/key_collector.h +++ b/esphome/components/key_collector/key_collector.h @@ -25,6 +25,7 @@ class KeyCollector : public Component { Trigger *get_result_trigger() const { return this->result_trigger_; }; Trigger *get_timeout_trigger() const { return this->timeout_trigger_; }; void set_timeout(int timeout) { this->timeout_ = timeout; }; + void set_enabled(bool enabled); void clear(bool progress_update = true); void send_key(uint8_t key); @@ -47,6 +48,15 @@ class KeyCollector : public Component { Trigger *timeout_trigger_; uint32_t last_key_time_; uint32_t timeout_{0}; + bool enabled_; +}; + +template class EnableAction : public Action, public Parented { + void play(Ts... x) override { this->parent_->set_enabled(true); } +}; + +template class DisableAction : public Action, public Parented { + void play(Ts... x) override { this->parent_->set_enabled(false); } }; } // namespace key_collector diff --git a/esphome/components/ld2410/ld2410.h b/esphome/components/ld2410/ld2410.h index 8084d4c33e..1bbaa8987a 100644 --- a/esphome/components/ld2410/ld2410.h +++ b/esphome/components/ld2410/ld2410.h @@ -129,7 +129,7 @@ enum PeriodicDataStructure : uint8_t { LIGHT_SENSOR = 37, OUT_PIN_SENSOR = 38, }; -enum PeriodicDataValue : uint8_t { HEAD = 0XAA, END = 0x55, CHECK = 0x00 }; +enum PeriodicDataValue : uint8_t { HEAD = 0xAA, END = 0x55, CHECK = 0x00 }; enum AckDataStructure : uint8_t { COMMAND = 6, COMMAND_STATUS = 7 }; diff --git a/esphome/components/ld2450/ld2450.h b/esphome/components/ld2450/ld2450.h index 32e4bc02e4..e0927e5d7d 100644 --- a/esphome/components/ld2450/ld2450.h +++ b/esphome/components/ld2450/ld2450.h @@ -105,7 +105,7 @@ enum PeriodicDataStructure : uint8_t { TARGET_RESOLUTION = 10, }; -enum PeriodicDataValue : uint8_t { HEAD = 0XAA, END = 0x55, CHECK = 0x00 }; +enum PeriodicDataValue : uint8_t { HEAD = 0xAA, END = 0x55, CHECK = 0x00 }; enum AckDataStructure : uint8_t { COMMAND = 6, COMMAND_STATUS = 7 }; diff --git a/esphome/components/light/__init__.py b/esphome/components/light/__init__.py index feac385b66..237ab45f38 100644 --- a/esphome/components/light/__init__.py +++ b/esphome/components/light/__init__.py @@ -1,3 +1,5 @@ +import enum + import esphome.automation as auto import esphome.codegen as cg from esphome.components import mqtt, power_supply, web_server @@ -13,15 +15,18 @@ from esphome.const import ( CONF_COLOR_TEMPERATURE, CONF_DEFAULT_TRANSITION_LENGTH, CONF_EFFECTS, + CONF_ENTITY_CATEGORY, CONF_FLASH_TRANSITION_LENGTH, CONF_GAMMA_CORRECT, CONF_GREEN, + CONF_ICON, CONF_ID, CONF_INITIAL_STATE, CONF_MQTT_ID, CONF_ON_STATE, CONF_ON_TURN_OFF, CONF_ON_TURN_ON, + CONF_OUTPUT_ID, CONF_POWER_SUPPLY, CONF_RED, CONF_RESTORE_MODE, @@ -33,6 +38,7 @@ from esphome.const import ( CONF_WHITE, ) from esphome.core import coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity from .automation import LIGHT_STATE_SCHEMA @@ -141,6 +147,51 @@ ADDRESSABLE_LIGHT_SCHEMA = RGB_LIGHT_SCHEMA.extend( ) +class LightType(enum.IntEnum): + """Light type enum.""" + + BINARY = 0 + BRIGHTNESS_ONLY = 1 + RGB = 2 + ADDRESSABLE = 3 + + +def light_schema( + class_: MockObjClass, + type_: LightType, + *, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, + default_restore_mode: str = cv.UNDEFINED, +) -> cv.Schema: + schema = { + cv.GenerateID(CONF_OUTPUT_ID): cv.declare_id(class_), + } + + for key, default, validator in [ + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ( + CONF_RESTORE_MODE, + default_restore_mode, + cv.enum(RESTORE_MODES, upper=True, space="_"), + ), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + if type_ == LightType.BINARY: + return BINARY_LIGHT_SCHEMA.extend(schema) + if type_ == LightType.BRIGHTNESS_ONLY: + return BRIGHTNESS_ONLY_LIGHT_SCHEMA.extend(schema) + if type_ == LightType.RGB: + return RGB_LIGHT_SCHEMA.extend(schema) + if type_ == LightType.ADDRESSABLE: + return ADDRESSABLE_LIGHT_SCHEMA.extend(schema) + + raise ValueError(f"Invalid light type: {type_}") + + def validate_color_temperature_channels(value): if ( CONF_COLD_WHITE_COLOR_TEMPERATURE in value @@ -223,6 +274,12 @@ async def register_light(output_var, config): await setup_light_core_(light_var, output_var, config) +async def new_light(config, *args): + output_var = cg.new_Pvariable(config[CONF_OUTPUT_ID], *args) + await register_light(output_var, config) + return output_var + + @coroutine_with_priority(100.0) async def to_code(config): cg.add_define("USE_LIGHT") diff --git a/esphome/components/lock/__init__.py b/esphome/components/lock/__init__.py index 6925861b52..a96290dca6 100644 --- a/esphome/components/lock/__init__.py +++ b/esphome/components/lock/__init__.py @@ -4,6 +4,8 @@ import esphome.codegen as cg from esphome.components import mqtt, web_server import esphome.config_validation as cv from esphome.const import ( + CONF_ENTITY_CATEGORY, + CONF_ICON, CONF_ID, CONF_MQTT_ID, CONF_ON_LOCK, @@ -12,6 +14,7 @@ from esphome.const import ( CONF_WEB_SERVER, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity CODEOWNERS = ["@esphome/core"] @@ -31,7 +34,19 @@ LockCondition = lock_ns.class_("LockCondition", Condition) LockLockTrigger = lock_ns.class_("LockLockTrigger", automation.Trigger.template()) LockUnlockTrigger = lock_ns.class_("LockUnlockTrigger", automation.Trigger.template()) -LOCK_SCHEMA = ( +LockState = lock_ns.enum("LockState") + +LOCK_STATES = { + "LOCKED": LockState.LOCK_STATE_LOCKED, + "UNLOCKED": LockState.LOCK_STATE_UNLOCKED, + "JAMMED": LockState.LOCK_STATE_JAMMED, + "LOCKING": LockState.LOCK_STATE_LOCKING, + "UNLOCKING": LockState.LOCK_STATE_UNLOCKING, +} + +validate_lock_state = cv.enum(LOCK_STATES, upper=True) + +_LOCK_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( @@ -52,7 +67,33 @@ LOCK_SCHEMA = ( ) -async def setup_lock_core_(var, config): +def lock_schema( + class_: MockObjClass = cv.UNDEFINED, + *, + icon: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, +) -> cv.Schema: + schema = {} + + if class_ is not cv.UNDEFINED: + schema[cv.GenerateID()] = cv.declare_id(class_) + + for key, default, validator in [ + (CONF_ICON, icon, cv.icon), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _LOCK_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +LOCK_SCHEMA = lock_schema() +LOCK_SCHEMA.add_extra(cv.deprecated_schema_constant("lock")) + + +async def _setup_lock_core(var, config): await setup_entity(var, config) for conf in config.get(CONF_ON_LOCK, []): @@ -74,12 +115,18 @@ async def register_lock(var, config): if not CORE.has_id(config[CONF_ID]): var = cg.Pvariable(config[CONF_ID], var) cg.add(cg.App.register_lock(var)) - await setup_lock_core_(var, config) + await _setup_lock_core(var, config) + + +async def new_lock(config, *args): + var = cg.new_Pvariable(config[CONF_ID], *args) + await register_lock(var, config) + return var LOCK_ACTION_SCHEMA = maybe_simple_id( { - cv.Required(CONF_ID): cv.use_id(Lock), + cv.GenerateID(CONF_ID): cv.use_id(Lock), } ) diff --git a/esphome/components/lock/automation.h b/esphome/components/lock/automation.h index 74cfbe2ef6..8cb3b64ffe 100644 --- a/esphome/components/lock/automation.h +++ b/esphome/components/lock/automation.h @@ -1,8 +1,8 @@ #pragma once -#include "esphome/core/component.h" -#include "esphome/core/automation.h" #include "esphome/components/lock/lock.h" +#include "esphome/core/automation.h" +#include "esphome/core/component.h" namespace esphome { namespace lock { @@ -72,16 +72,5 @@ class LockUnlockTrigger : public Trigger<> { } }; -template class LockPublishAction : public Action { - public: - LockPublishAction(Lock *a_lock) : lock_(a_lock) {} - TEMPLATABLE_VALUE(LockState, state) - - void play(Ts... x) override { this->lock_->publish_state(this->state_.value(x...)); } - - protected: - Lock *lock_; -}; - } // namespace lock } // namespace esphome diff --git a/esphome/components/logger/__init__.py b/esphome/components/logger/__init__.py index 113f306327..01e75a424d 100644 --- a/esphome/components/logger/__init__.py +++ b/esphome/components/logger/__init__.py @@ -79,6 +79,7 @@ DEFAULT = "DEFAULT" CONF_INITIAL_LEVEL = "initial_level" CONF_LOGGER_ID = "logger_id" +CONF_TASK_LOG_BUFFER_SIZE = "task_log_buffer_size" UART_SELECTION_ESP32 = { VARIANT_ESP32: [UART0, UART1, UART2], @@ -180,6 +181,20 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_BAUD_RATE, default=115200): cv.positive_int, cv.Optional(CONF_TX_BUFFER_SIZE, default=512): cv.validate_bytes, cv.Optional(CONF_DEASSERT_RTS_DTR, default=False): cv.boolean, + cv.SplitDefault( + CONF_TASK_LOG_BUFFER_SIZE, + esp32=768, # Default: 768 bytes (~5-6 messages with 70-byte text plus thread names) + ): cv.All( + cv.only_on_esp32, + cv.validate_bytes, + cv.Any( + cv.int_(0), # Disabled + cv.int_range( + min=640, # Min: ~4-5 messages with 70-byte text plus thread names + max=32768, # Max: Depends on message sizes, typically ~300 messages with default size + ), + ), + ), cv.SplitDefault( CONF_HARDWARE_UART, esp8266=UART0, @@ -238,6 +253,12 @@ async def to_code(config): baud_rate, config[CONF_TX_BUFFER_SIZE], ) + if CORE.is_esp32: + task_log_buffer_size = config[CONF_TASK_LOG_BUFFER_SIZE] + if task_log_buffer_size > 0: + cg.add_define("USE_ESPHOME_TASK_LOG_BUFFER") + cg.add(log.init_log_buffer(task_log_buffer_size)) + cg.add(log.set_log_level(initial_level)) if CONF_HARDWARE_UART in config: cg.add( diff --git a/esphome/components/logger/logger.cpp b/esphome/components/logger/logger.cpp index 57f0ba9f9a..812a7cc16d 100644 --- a/esphome/components/logger/logger.cpp +++ b/esphome/components/logger/logger.cpp @@ -1,5 +1,8 @@ #include "logger.h" #include +#ifdef USE_ESPHOME_TASK_LOG_BUFFER +#include // For unique_ptr +#endif #include "esphome/core/hal.h" #include "esphome/core/log.h" @@ -10,127 +13,121 @@ namespace logger { static const char *const TAG = "logger"; -static const char *const LOG_LEVEL_COLORS[] = { - "", // NONE - ESPHOME_LOG_BOLD(ESPHOME_LOG_COLOR_RED), // ERROR - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_YELLOW), // WARNING - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_GREEN), // INFO - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_MAGENTA), // CONFIG - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_CYAN), // DEBUG - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_GRAY), // VERBOSE - ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_WHITE), // VERY_VERBOSE -}; -static const char *const LOG_LEVEL_LETTERS[] = { - "", // NONE - "E", // ERROR - "W", // WARNING - "I", // INFO - "C", // CONFIG - "D", // DEBUG - "V", // VERBOSE - "VV", // VERY_VERBOSE -}; +#ifdef USE_ESP32 +// Implementation for ESP32 (multi-core with atomic support) +// Main thread: synchronous logging with direct buffer access +// Other threads: console output with stack buffer, callbacks via async buffer +void HOT Logger::log_vprintf_(int level, const char *tag, int line, const char *format, va_list args) { // NOLINT + if (level > this->level_for(tag) || recursion_guard_.load(std::memory_order_relaxed)) + return; + recursion_guard_.store(true, std::memory_order_relaxed); -void Logger::write_header_(int level, const char *tag, int line) { - if (level < 0) - level = 0; - if (level > 7) - level = 7; - - const char *color = LOG_LEVEL_COLORS[level]; - const char *letter = LOG_LEVEL_LETTERS[level]; -#if defined(USE_ESP32) || defined(USE_LIBRETINY) TaskHandle_t current_task = xTaskGetCurrentTaskHandle(); -#else - void *current_task = nullptr; -#endif - if (current_task == main_task_) { - this->printf_to_buffer_("%s[%s][%s:%03u]: ", color, letter, tag, line); - } else { - const char *thread_name = ""; // NOLINT(clang-analyzer-deadcode.DeadStores) -#if defined(USE_ESP32) - thread_name = pcTaskGetName(current_task); -#elif defined(USE_LIBRETINY) - thread_name = pcTaskGetTaskName(current_task); -#endif - this->printf_to_buffer_("%s[%s][%s:%03u]%s[%s]%s: ", color, letter, tag, line, - ESPHOME_LOG_BOLD(ESPHOME_LOG_COLOR_RED), thread_name, color); - } -} + // For main task: call log_message_to_buffer_and_send_ which does console and callback logging + if (current_task == main_task_) { + this->log_message_to_buffer_and_send_(level, tag, line, format, args); + recursion_guard_.store(false, std::memory_order_release); + return; + } + + // For non-main tasks: use stack-allocated buffer only for console output + if (this->baud_rate_ > 0) { // If logging is enabled, write to console + // Maximum size for console log messages (includes null terminator) + static const size_t MAX_CONSOLE_LOG_MSG_SIZE = 144; + char console_buffer[MAX_CONSOLE_LOG_MSG_SIZE]; // MUST be stack allocated for thread safety + int buffer_at = 0; // Initialize buffer position + this->format_log_to_buffer_with_terminator_(level, tag, line, format, args, console_buffer, &buffer_at, + MAX_CONSOLE_LOG_MSG_SIZE); + this->write_msg_(console_buffer); + } + +#ifdef USE_ESPHOME_TASK_LOG_BUFFER + // For non-main tasks, queue the message for callbacks - but only if we have any callbacks registered + if (this->log_callback_.size() > 0) { + // This will be processed in the main loop + this->log_buffer_->send_message_thread_safe(static_cast(level), tag, static_cast(line), + current_task, format, args); + } +#endif // USE_ESPHOME_TASK_LOG_BUFFER + + recursion_guard_.store(false, std::memory_order_release); +} +#endif // USE_ESP32 + +#ifndef USE_ESP32 +// Implementation for platforms that do not support atomic operations +// or have to consider logging in other tasks void HOT Logger::log_vprintf_(int level, const char *tag, int line, const char *format, va_list args) { // NOLINT if (level > this->level_for(tag) || recursion_guard_) return; recursion_guard_ = true; - this->reset_buffer_(); - this->write_header_(level, tag, line); - this->vprintf_to_buffer_(format, args); - this->write_footer_(); - this->log_message_(level, tag); + + // Format and send to both console and callbacks + this->log_message_to_buffer_and_send_(level, tag, line, format, args); + recursion_guard_ = false; } +#endif // !USE_ESP32 + #ifdef USE_STORE_LOG_STR_IN_FLASH +// Implementation for ESP8266 with flash string support. +// Note: USE_STORE_LOG_STR_IN_FLASH is only defined for ESP8266. void Logger::log_vprintf_(int level, const char *tag, int line, const __FlashStringHelper *format, va_list args) { // NOLINT if (level > this->level_for(tag) || recursion_guard_) return; recursion_guard_ = true; - this->reset_buffer_(); - // copy format string + this->tx_buffer_at_ = 0; + + // Copy format string from progmem auto *format_pgm_p = reinterpret_cast(format); - size_t len = 0; char ch = '.'; - while (!this->is_buffer_full_() && ch != '\0') { + while (this->tx_buffer_at_ < this->tx_buffer_size_ && ch != '\0') { this->tx_buffer_[this->tx_buffer_at_++] = ch = (char) progmem_read_byte(format_pgm_p++); } - // Buffer full form copying format - if (this->is_buffer_full_()) + + // Buffer full from copying format + if (this->tx_buffer_at_ >= this->tx_buffer_size_) { + recursion_guard_ = false; // Make sure to reset the recursion guard before returning return; + } - // length of format string, includes null terminator - uint32_t offset = this->tx_buffer_at_; + // Save the offset before calling format_log_to_buffer_with_terminator_ + // since it will increment tx_buffer_at_ to the end of the formatted string + uint32_t msg_start = this->tx_buffer_at_; + this->format_log_to_buffer_with_terminator_(level, tag, line, this->tx_buffer_, args, this->tx_buffer_, + &this->tx_buffer_at_, this->tx_buffer_size_); + + // Write to console and send callback starting at the msg_start + if (this->baud_rate_ > 0) { + this->write_msg_(this->tx_buffer_ + msg_start); + } + this->call_log_callbacks_(level, tag, this->tx_buffer_ + msg_start); - // now apply vsnprintf - this->write_header_(level, tag, line); - this->vprintf_to_buffer_(this->tx_buffer_, args); - this->write_footer_(); - this->log_message_(level, tag, offset); recursion_guard_ = false; } -#endif +#endif // USE_STORE_LOG_STR_IN_FLASH -int HOT Logger::level_for(const char *tag) { - if (this->log_levels_.count(tag) != 0) - return this->log_levels_[tag]; +inline int Logger::level_for(const char *tag) { + auto it = this->log_levels_.find(tag); + if (it != this->log_levels_.end()) + return it->second; return this->current_level_; } -void HOT Logger::log_message_(int level, const char *tag, int offset) { - // remove trailing newline - if (this->tx_buffer_[this->tx_buffer_at_ - 1] == '\n') { - this->tx_buffer_at_--; - } - // make sure null terminator is present - this->set_null_terminator_(); - - const char *msg = this->tx_buffer_ + offset; - - if (this->baud_rate_ > 0) { - this->write_msg_(msg); - } - +void HOT Logger::call_log_callbacks_(int level, const char *tag, const char *msg) { #ifdef USE_ESP32 - // Suppress network-logging if memory constrained, but still log to serial - // ports. In some configurations (eg BLE enabled) there may be some transient + // Suppress network-logging if memory constrained + // In some configurations (eg BLE enabled) there may be some transient // memory exhaustion, and trying to log when OOM can lead to a crash. Skipping // here usually allows the stack to recover instead. // See issue #1234 for analysis. if (xPortGetFreeHeapSize() < 2048) return; #endif - this->log_callback_.call(level, tag, msg); } @@ -141,21 +138,50 @@ Logger::Logger(uint32_t baud_rate, size_t tx_buffer_size) : baud_rate_(baud_rate this->main_task_ = xTaskGetCurrentTaskHandle(); #endif } +#ifdef USE_ESPHOME_TASK_LOG_BUFFER +void Logger::init_log_buffer(size_t total_buffer_size) { + this->log_buffer_ = esphome::make_unique(total_buffer_size); +} +#endif -#ifdef USE_LOGGER_USB_CDC +#if defined(USE_LOGGER_USB_CDC) || defined(USE_ESP32) void Logger::loop() { -#ifdef USE_ARDUINO - if (this->uart_ != UART_SELECTION_USB_CDC) { - return; +#if defined(USE_LOGGER_USB_CDC) && defined(USE_ARDUINO) + if (this->uart_ == UART_SELECTION_USB_CDC) { + static bool opened = false; + if (opened == Serial) { + return; + } + if (false == opened) { + App.schedule_dump_config(); + } + opened = !opened; } - static bool opened = false; - if (opened == Serial) { - return; +#endif + +#ifdef USE_ESPHOME_TASK_LOG_BUFFER + // Process any buffered messages when available + if (this->log_buffer_->has_messages()) { + logger::TaskLogBuffer::LogMessage *message; + const char *text; + void *received_token; + + // Process messages from the buffer + while (this->log_buffer_->borrow_message_main_loop(&message, &text, &received_token)) { + this->tx_buffer_at_ = 0; + // Use the thread name that was stored when the message was created + // This avoids potential crashes if the task no longer exists + const char *thread_name = message->thread_name[0] != '\0' ? message->thread_name : nullptr; + this->write_header_to_buffer_(message->level, message->tag, message->line, thread_name, this->tx_buffer_, + &this->tx_buffer_at_, this->tx_buffer_size_); + this->write_body_to_buffer_(text, message->text_length, this->tx_buffer_, &this->tx_buffer_at_, + this->tx_buffer_size_); + this->write_footer_to_buffer_(this->tx_buffer_, &this->tx_buffer_at_, this->tx_buffer_size_); + this->tx_buffer_[this->tx_buffer_at_] = '\0'; + this->call_log_callbacks_(message->level, message->tag, this->tx_buffer_); + this->log_buffer_->release_message_main_loop(received_token); + } } - if (false == opened) { - App.schedule_dump_config(); - } - opened = !opened; #endif } #endif @@ -171,7 +197,7 @@ void Logger::add_on_log_callback(std::functionlog_callback_.add(std::move(callback)); } float Logger::get_setup_priority() const { return setup_priority::BUS + 500.0f; } -const char *const LOG_LEVELS[] = {"NONE", "ERROR", "WARN", "INFO", "CONFIG", "DEBUG", "VERBOSE", "VERY_VERBOSE"}; +static const char *const LOG_LEVELS[] = {"NONE", "ERROR", "WARN", "INFO", "CONFIG", "DEBUG", "VERBOSE", "VERY_VERBOSE"}; void Logger::dump_config() { ESP_LOGCONFIG(TAG, "Logger:"); @@ -181,12 +207,16 @@ void Logger::dump_config() { ESP_LOGCONFIG(TAG, " Log Baud Rate: %" PRIu32, this->baud_rate_); ESP_LOGCONFIG(TAG, " Hardware UART: %s", get_uart_selection_()); #endif +#ifdef USE_ESPHOME_TASK_LOG_BUFFER + if (this->log_buffer_) { + ESP_LOGCONFIG(TAG, " Task Log Buffer Size: %u", this->log_buffer_->size()); + } +#endif for (auto &it : this->log_levels_) { ESP_LOGCONFIG(TAG, " Level for '%s': %s", it.first.c_str(), LOG_LEVELS[it.second]); } } -void Logger::write_footer_() { this->write_to_buffer_(ESPHOME_LOG_RESET_COLOR, strlen(ESPHOME_LOG_RESET_COLOR)); } void Logger::set_log_level(int level) { if (level > ESPHOME_LOG_LEVEL) { diff --git a/esphome/components/logger/logger.h b/esphome/components/logger/logger.h index c4c873e020..8619cc0992 100644 --- a/esphome/components/logger/logger.h +++ b/esphome/components/logger/logger.h @@ -2,12 +2,19 @@ #include #include +#ifdef USE_ESP32 +#include +#endif #include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/defines.h" #include "esphome/core/helpers.h" #include "esphome/core/log.h" +#ifdef USE_ESPHOME_TASK_LOG_BUFFER +#include "task_log_buffer.h" +#endif + #ifdef USE_ARDUINO #if defined(USE_ESP8266) || defined(USE_ESP32) #include @@ -26,6 +33,29 @@ namespace esphome { namespace logger { +// Color and letter constants for log levels +static const char *const LOG_LEVEL_COLORS[] = { + "", // NONE + ESPHOME_LOG_BOLD(ESPHOME_LOG_COLOR_RED), // ERROR + ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_YELLOW), // WARNING + ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_GREEN), // INFO + ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_MAGENTA), // CONFIG + ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_CYAN), // DEBUG + ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_GRAY), // VERBOSE + ESPHOME_LOG_COLOR(ESPHOME_LOG_COLOR_WHITE), // VERY_VERBOSE +}; + +static const char *const LOG_LEVEL_LETTERS[] = { + "", // NONE + "E", // ERROR + "W", // WARNING + "I", // INFO + "C", // CONFIG + "D", // DEBUG + "V", // VERBOSE + "VV", // VERY_VERBOSE +}; + #if defined(USE_ESP32) || defined(USE_ESP8266) || defined(USE_RP2040) || defined(USE_LIBRETINY) /** Enum for logging UART selection * @@ -57,7 +87,10 @@ enum UARTSelection { class Logger : public Component { public: explicit Logger(uint32_t baud_rate, size_t tx_buffer_size); -#ifdef USE_LOGGER_USB_CDC +#ifdef USE_ESPHOME_TASK_LOG_BUFFER + void init_log_buffer(size_t total_buffer_size); +#endif +#if defined(USE_LOGGER_USB_CDC) || defined(USE_ESP32) void loop() override; #endif /// Manually set the baud rate for serial, set to 0 to disable. @@ -87,7 +120,7 @@ class Logger : public Component { void pre_setup(); void dump_config() override; - int level_for(const char *tag); + inline int level_for(const char *tag); /// Register a callback that will be called for every log message sent void add_on_log_callback(std::function &&callback); @@ -103,46 +136,66 @@ class Logger : public Component { #endif protected: - void write_header_(int level, const char *tag, int line); - void write_footer_(); - void log_message_(int level, const char *tag, int offset = 0); + void call_log_callbacks_(int level, const char *tag, const char *msg); void write_msg_(const char *msg); - inline bool is_buffer_full_() const { return this->tx_buffer_at_ >= this->tx_buffer_size_; } - inline int buffer_remaining_capacity_() const { return this->tx_buffer_size_ - this->tx_buffer_at_; } - inline void reset_buffer_() { this->tx_buffer_at_ = 0; } - inline void set_null_terminator_() { - // does not increment buffer_at - this->tx_buffer_[this->tx_buffer_at_] = '\0'; - } - inline void write_to_buffer_(char value) { - if (!this->is_buffer_full_()) - this->tx_buffer_[this->tx_buffer_at_++] = value; - } - inline void write_to_buffer_(const char *value, int length) { - for (int i = 0; i < length && !this->is_buffer_full_(); i++) { - this->tx_buffer_[this->tx_buffer_at_++] = value[i]; + // Format a log message with printf-style arguments and write it to a buffer with header, footer, and null terminator + // It's the caller's responsibility to initialize buffer_at (typically to 0) + inline void HOT format_log_to_buffer_with_terminator_(int level, const char *tag, int line, const char *format, + va_list args, char *buffer, int *buffer_at, int buffer_size) { +#if defined(USE_ESP32) || defined(USE_LIBRETINY) + this->write_header_to_buffer_(level, tag, line, this->get_thread_name_(), buffer, buffer_at, buffer_size); +#else + this->write_header_to_buffer_(level, tag, line, nullptr, buffer, buffer_at, buffer_size); +#endif + this->format_body_to_buffer_(buffer, buffer_at, buffer_size, format, args); + this->write_footer_to_buffer_(buffer, buffer_at, buffer_size); + + // Always ensure the buffer has a null terminator, even if we need to + // overwrite the last character of the actual content + if (*buffer_at >= buffer_size) { + buffer[buffer_size - 1] = '\0'; // Truncate and ensure null termination + } else { + buffer[*buffer_at] = '\0'; // Normal case, append null terminator } } - inline void vprintf_to_buffer_(const char *format, va_list args) { - if (this->is_buffer_full_()) - return; - int remaining = this->buffer_remaining_capacity_(); - int ret = vsnprintf(this->tx_buffer_ + this->tx_buffer_at_, remaining, format, args); - if (ret < 0) { - // Encoding error, do not increment buffer_at + + // Helper to format and send a log message to both console and callbacks + inline void HOT log_message_to_buffer_and_send_(int level, const char *tag, int line, const char *format, + va_list args) { + // Format to tx_buffer and prepare for output + this->tx_buffer_at_ = 0; // Initialize buffer position + this->format_log_to_buffer_with_terminator_(level, tag, line, format, args, this->tx_buffer_, &this->tx_buffer_at_, + this->tx_buffer_size_); + + if (this->baud_rate_ > 0) { + this->write_msg_(this->tx_buffer_); // If logging is enabled, write to console + } + this->call_log_callbacks_(level, tag, this->tx_buffer_); + } + + // Write the body of the log message to the buffer + inline void write_body_to_buffer_(const char *value, size_t length, char *buffer, int *buffer_at, int buffer_size) { + // Calculate available space + const int available = buffer_size - *buffer_at; + if (available <= 0) return; + + // Determine copy length (minimum of remaining capacity and string length) + const size_t copy_len = (length < static_cast(available)) ? length : available; + + // Copy the data + if (copy_len > 0) { + memcpy(buffer + *buffer_at, value, copy_len); + *buffer_at += copy_len; } - if (ret >= remaining) { - // output was too long, truncated - ret = remaining; - } - this->tx_buffer_at_ += ret; } - inline void printf_to_buffer_(const char *format, ...) { + + // Format string to explicit buffer with varargs + inline void printf_to_buffer_(const char *format, char *buffer, int *buffer_at, int buffer_size, ...) { va_list arg; - va_start(arg, format); - this->vprintf_to_buffer_(format, arg); + va_start(arg, buffer_size); + this->format_body_to_buffer_(buffer, buffer_at, buffer_size, format, arg); va_end(arg); } @@ -169,10 +222,82 @@ class Logger : public Component { std::map log_levels_{}; CallbackManager log_callback_{}; int current_level_{ESPHOME_LOG_LEVEL_VERY_VERBOSE}; - /// Prevents recursive log calls, if true a log message is already being processed. - bool recursion_guard_ = false; +#ifdef USE_ESP32 + std::atomic recursion_guard_{false}; +#ifdef USE_ESPHOME_TASK_LOG_BUFFER + std::unique_ptr log_buffer_; // Will be initialized with init_log_buffer +#endif +#else + bool recursion_guard_{false}; +#endif void *main_task_ = nullptr; CallbackManager level_callback_{}; + +#if defined(USE_ESP32) || defined(USE_LIBRETINY) + const char *HOT get_thread_name_() { + TaskHandle_t current_task = xTaskGetCurrentTaskHandle(); + if (current_task == main_task_) { + return nullptr; // Main task + } else { +#if defined(USE_ESP32) + return pcTaskGetName(current_task); +#elif defined(USE_LIBRETINY) + return pcTaskGetTaskName(current_task); +#endif + } + } +#endif + + inline void HOT write_header_to_buffer_(int level, const char *tag, int line, const char *thread_name, char *buffer, + int *buffer_at, int buffer_size) { + // Format header + if (level < 0) + level = 0; + if (level > 7) + level = 7; + + const char *color = esphome::logger::LOG_LEVEL_COLORS[level]; + const char *letter = esphome::logger::LOG_LEVEL_LETTERS[level]; + +#if defined(USE_ESP32) || defined(USE_LIBRETINY) + if (thread_name != nullptr) { + // Non-main task with thread name + this->printf_to_buffer_("%s[%s][%s:%03u]%s[%s]%s: ", buffer, buffer_at, buffer_size, color, letter, tag, line, + ESPHOME_LOG_BOLD(ESPHOME_LOG_COLOR_RED), thread_name, color); + return; + } +#endif + // Main task or non ESP32/LibreTiny platform + this->printf_to_buffer_("%s[%s][%s:%03u]: ", buffer, buffer_at, buffer_size, color, letter, tag, line); + } + + inline void HOT format_body_to_buffer_(char *buffer, int *buffer_at, int buffer_size, const char *format, + va_list args) { + // Get remaining capacity in the buffer + const int remaining = buffer_size - *buffer_at; + if (remaining <= 0) + return; + + const int ret = vsnprintf(buffer + *buffer_at, remaining, format, args); + + if (ret < 0) { + return; // Encoding error, do not increment buffer_at + } + + // Update buffer_at with the formatted length (handle truncation) + int formatted_len = (ret >= remaining) ? remaining : ret; + *buffer_at += formatted_len; + + // Remove all trailing newlines right after formatting + while (*buffer_at > 0 && buffer[*buffer_at - 1] == '\n') { + (*buffer_at)--; + } + } + + inline void HOT write_footer_to_buffer_(char *buffer, int *buffer_at, int buffer_size) { + static const int RESET_COLOR_LEN = strlen(ESPHOME_LOG_RESET_COLOR); + this->write_body_to_buffer_(ESPHOME_LOG_RESET_COLOR, RESET_COLOR_LEN, buffer, buffer_at, buffer_size); + } }; extern Logger *global_logger; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/components/logger/task_log_buffer.cpp b/esphome/components/logger/task_log_buffer.cpp new file mode 100644 index 0000000000..24d9284f1a --- /dev/null +++ b/esphome/components/logger/task_log_buffer.cpp @@ -0,0 +1,138 @@ + +#include "task_log_buffer.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" + +#ifdef USE_ESPHOME_TASK_LOG_BUFFER + +namespace esphome { +namespace logger { + +TaskLogBuffer::TaskLogBuffer(size_t total_buffer_size) { + // Store the buffer size + this->size_ = total_buffer_size; + // Allocate memory for the ring buffer using ESPHome's RAM allocator + RAMAllocator allocator; + this->storage_ = allocator.allocate(this->size_); + // Create a static ring buffer with RINGBUF_TYPE_NOSPLIT for message integrity + this->ring_buffer_ = xRingbufferCreateStatic(this->size_, RINGBUF_TYPE_NOSPLIT, this->storage_, &this->structure_); +} + +TaskLogBuffer::~TaskLogBuffer() { + if (this->ring_buffer_ != nullptr) { + // Delete the ring buffer + vRingbufferDelete(this->ring_buffer_); + this->ring_buffer_ = nullptr; + + // Free the allocated memory + RAMAllocator allocator; + allocator.deallocate(this->storage_, this->size_); + this->storage_ = nullptr; + } +} + +bool TaskLogBuffer::borrow_message_main_loop(LogMessage **message, const char **text, void **received_token) { + if (message == nullptr || text == nullptr || received_token == nullptr) { + return false; + } + + size_t item_size = 0; + void *received_item = xRingbufferReceive(ring_buffer_, &item_size, 0); + if (received_item == nullptr) { + return false; + } + + LogMessage *msg = static_cast(received_item); + *message = msg; + *text = msg->text_data(); + *received_token = received_item; + + return true; +} + +void TaskLogBuffer::release_message_main_loop(void *token) { + if (token == nullptr) { + return; + } + vRingbufferReturnItem(ring_buffer_, token); + // Update counter to mark all messages as processed + last_processed_counter_ = message_counter_.load(std::memory_order_relaxed); +} + +bool TaskLogBuffer::send_message_thread_safe(uint8_t level, const char *tag, uint16_t line, TaskHandle_t task_handle, + const char *format, va_list args) { + // First, calculate the exact length needed using a null buffer (no actual writing) + va_list args_copy; + va_copy(args_copy, args); + int ret = vsnprintf(nullptr, 0, format, args_copy); + va_end(args_copy); + + if (ret <= 0) { + return false; // Formatting error or empty message + } + + // Calculate actual text length (capped to maximum size) + static constexpr size_t MAX_TEXT_SIZE = 255; + size_t text_length = (static_cast(ret) > MAX_TEXT_SIZE) ? MAX_TEXT_SIZE : ret; + + // Calculate total size needed (header + text length + null terminator) + size_t total_size = sizeof(LogMessage) + text_length + 1; + + // Acquire memory directly from the ring buffer + void *acquired_memory = nullptr; + BaseType_t result = xRingbufferSendAcquire(ring_buffer_, &acquired_memory, total_size, 0); + + if (result != pdTRUE || acquired_memory == nullptr) { + return false; // Failed to acquire memory + } + + // Set up the message header in the acquired memory + LogMessage *msg = static_cast(acquired_memory); + msg->level = level; + msg->tag = tag; + msg->line = line; + + // Store the thread name now instead of waiting until main loop processing + // This avoids crashes if the task completes or is deleted between when this message + // is enqueued and when it's processed by the main loop + const char *thread_name = pcTaskGetName(task_handle); + if (thread_name != nullptr) { + strncpy(msg->thread_name, thread_name, sizeof(msg->thread_name) - 1); + msg->thread_name[sizeof(msg->thread_name) - 1] = '\0'; // Ensure null termination + } else { + msg->thread_name[0] = '\0'; // Empty string if no thread name + } + + // Format the message text directly into the acquired memory + // We add 1 to text_length to ensure space for null terminator during formatting + char *text_area = msg->text_data(); + ret = vsnprintf(text_area, text_length + 1, format, args); + + // Handle unexpected formatting error + if (ret <= 0) { + vRingbufferReturnItem(ring_buffer_, acquired_memory); + return false; + } + + // Remove trailing newlines + while (text_length > 0 && text_area[text_length - 1] == '\n') { + text_length--; + } + + msg->text_length = text_length; + // Complete the send operation with the acquired memory + result = xRingbufferSendComplete(ring_buffer_, acquired_memory); + + if (result != pdTRUE) { + return false; // Failed to complete the message send + } + + // Message sent successfully, increment the counter + message_counter_.fetch_add(1, std::memory_order_relaxed); + return true; +} + +} // namespace logger +} // namespace esphome + +#endif // USE_ESPHOME_TASK_LOG_BUFFER diff --git a/esphome/components/logger/task_log_buffer.h b/esphome/components/logger/task_log_buffer.h new file mode 100644 index 0000000000..1618a5a121 --- /dev/null +++ b/esphome/components/logger/task_log_buffer.h @@ -0,0 +1,69 @@ +#pragma once + +#include "esphome/core/defines.h" +#include "esphome/core/helpers.h" + +#ifdef USE_ESPHOME_TASK_LOG_BUFFER +#include +#include +#include +#include +#include +#include + +namespace esphome { +namespace logger { + +class TaskLogBuffer { + public: + // Structure for a log message header (text data follows immediately after) + struct LogMessage { + const char *tag; // We store the pointer, assuming tags are static + char thread_name[16]; // Store thread name directly (only used for non-main threads) + uint16_t text_length; // Length of the message text (up to ~64KB) + uint16_t line; // Source code line number + uint8_t level; // Log level (0-7) + + // Methods for accessing message contents + inline char *text_data() { return reinterpret_cast(this) + sizeof(LogMessage); } + + inline const char *text_data() const { return reinterpret_cast(this) + sizeof(LogMessage); } + }; + + // Constructor that takes a total buffer size + explicit TaskLogBuffer(size_t total_buffer_size); + ~TaskLogBuffer(); + + // NOT thread-safe - borrow a message from the ring buffer, only call from main loop + bool borrow_message_main_loop(LogMessage **message, const char **text, void **received_token); + + // NOT thread-safe - release a message buffer and update the counter, only call from main loop + void release_message_main_loop(void *token); + + // Thread-safe - send a message to the ring buffer from any thread + bool send_message_thread_safe(uint8_t level, const char *tag, uint16_t line, TaskHandle_t task_handle, + const char *format, va_list args); + + // Check if there are messages ready to be processed using an atomic counter for performance + inline bool HOT has_messages() const { + return message_counter_.load(std::memory_order_relaxed) != last_processed_counter_; + } + + // Get the total buffer size in bytes + inline size_t size() const { return size_; } + + private: + RingbufHandle_t ring_buffer_{nullptr}; // FreeRTOS ring buffer handle + StaticRingbuffer_t structure_; // Static structure for the ring buffer + uint8_t *storage_{nullptr}; // Pointer to allocated memory + size_t size_{0}; // Size of allocated memory + + // Atomic counter for message tracking (only differences matter) + std::atomic message_counter_{0}; // Incremented when messages are committed + mutable uint16_t last_processed_counter_{0}; // Tracks last processed message +}; + +} // namespace logger +} // namespace esphome + +#endif // USE_ESPHOME_TASK_LOG_BUFFER diff --git a/esphome/components/lvgl/__init__.py b/esphome/components/lvgl/__init__.py index 30fa58c380..f60d60d9a4 100644 --- a/esphome/components/lvgl/__init__.py +++ b/esphome/components/lvgl/__init__.py @@ -2,6 +2,7 @@ import logging from esphome.automation import build_automation, register_action, validate_automation import esphome.codegen as cg +from esphome.components.const import CONF_DRAW_ROUNDING from esphome.components.display import Display import esphome.config_validation as cv from esphome.const import ( @@ -17,14 +18,14 @@ from esphome.const import ( CONF_TRIGGER_ID, CONF_TYPE, ) -from esphome.core import CORE, ID +from esphome.core import CORE, ID, Lambda from esphome.cpp_generator import MockObj from esphome.final_validate import full_config from esphome.helpers import write_file_if_changed from . import defines as df, helpers, lv_validation as lvalid -from .automation import disp_update, focused_widgets, update_to_code -from .defines import CONF_DRAW_ROUNDING, add_define +from .automation import disp_update, focused_widgets, refreshed_widgets, update_to_code +from .defines import add_define from .encoders import ( ENCODERS_CONFIG, encoders_to_code, @@ -239,6 +240,13 @@ def final_validation(configs): "A non adjustable arc may not be focused", path, ) + for w in refreshed_widgets: + path = global_config.get_path_for_id(w) + widget_conf = global_config.get_config_for_path(path[:-1]) + if not any(isinstance(v, Lambda) for v in widget_conf.values()): + raise cv.Invalid( + f"Widget '{w}' does not have any templated properties to refresh", + ) async def to_code(configs): @@ -323,7 +331,7 @@ async def to_code(configs): displays, frac, config[df.CONF_FULL_REFRESH], - config[df.CONF_DRAW_ROUNDING], + config[CONF_DRAW_ROUNDING], config[df.CONF_RESUME_ON_INPUT], ) await cg.register_component(lv_component, config) @@ -413,7 +421,7 @@ LVGL_SCHEMA = cv.All( df.CONF_DEFAULT_FONT, default="montserrat_14" ): lvalid.lv_font, cv.Optional(df.CONF_FULL_REFRESH, default=False): cv.boolean, - cv.Optional(df.CONF_DRAW_ROUNDING, default=2): cv.positive_int, + cv.Optional(CONF_DRAW_ROUNDING, default=2): cv.positive_int, cv.Optional(CONF_BUFFER_SIZE, default="100%"): cv.percentage, cv.Optional(df.CONF_LOG_LEVEL, default="WARN"): cv.one_of( *df.LV_LOG_LEVELS, upper=True diff --git a/esphome/components/lvgl/automation.py b/esphome/components/lvgl/automation.py index 4a71872022..5fea9bfdb1 100644 --- a/esphome/components/lvgl/automation.py +++ b/esphome/components/lvgl/automation.py @@ -35,7 +35,13 @@ from .lvcode import ( lv_obj, lvgl_comp, ) -from .schemas import DISP_BG_SCHEMA, LIST_ACTION_SCHEMA, LVGL_SCHEMA, base_update_schema +from .schemas import ( + ALL_STYLES, + DISP_BG_SCHEMA, + LIST_ACTION_SCHEMA, + LVGL_SCHEMA, + base_update_schema, +) from .types import ( LV_STATE, LvglAction, @@ -57,6 +63,7 @@ from .widgets import ( # Record widgets that are used in a focused action here focused_widgets = set() +refreshed_widgets = set() async def action_to_code( @@ -361,3 +368,45 @@ async def obj_update_to_code(config, action_id, template_arg, args): return await action_to_code( widgets, do_update, action_id, template_arg, args, config ) + + +def validate_refresh_config(config): + for w in config: + refreshed_widgets.add(w[CONF_ID]) + return config + + +@automation.register_action( + "lvgl.widget.refresh", + ObjUpdateAction, + cv.All( + cv.ensure_list( + cv.maybe_simple_value( + { + cv.Required(CONF_ID): cv.use_id(lv_obj_t), + }, + key=CONF_ID, + ) + ), + validate_refresh_config, + ), +) +async def obj_refresh_to_code(config, action_id, template_arg, args): + widget = await get_widgets(config) + + async def do_refresh(widget: Widget): + # only update style properties that might have changed, i.e. are templated + config = {k: v for k, v in widget.config.items() if isinstance(v, Lambda)} + await set_obj_properties(widget, config) + # must pass all widget-specific options here, even if not templated, but only do so if at least one is + # templated. First filter out common style properties. + config = {k: v for k, v in widget.config.items() if k not in ALL_STYLES} + if any(isinstance(v, Lambda) for v in config.values()): + await widget.type.to_code(widget, config) + if ( + widget.type.w_type.value_property is not None + and widget.type.w_type.value_property in config + ): + lv.event_send(widget.obj, UPDATE_EVENT, nullptr) + + return await action_to_code(widget, do_refresh, action_id, template_arg, args) diff --git a/esphome/components/lvgl/defines.py b/esphome/components/lvgl/defines.py index 7dedb55418..7783fb2321 100644 --- a/esphome/components/lvgl/defines.py +++ b/esphome/components/lvgl/defines.py @@ -424,7 +424,6 @@ CONF_DEFAULT_FONT = "default_font" CONF_DEFAULT_GROUP = "default_group" CONF_DIR = "dir" CONF_DISPLAYS = "displays" -CONF_DRAW_ROUNDING = "draw_rounding" CONF_EDITING = "editing" CONF_ENCODERS = "encoders" CONF_END_ANGLE = "end_angle" diff --git a/esphome/components/lvgl/lvgl_esphome.cpp b/esphome/components/lvgl/lvgl_esphome.cpp index 2e5ba25851..4c30d14e15 100644 --- a/esphome/components/lvgl/lvgl_esphome.cpp +++ b/esphome/components/lvgl/lvgl_esphome.cpp @@ -434,7 +434,11 @@ void LvglComponent::setup() { auto height = display->get_height(); size_t buffer_pixels = width * height / this->buffer_frac_; auto buf_bytes = buffer_pixels * LV_COLOR_DEPTH / 8; - auto *buffer = lv_custom_mem_alloc(buf_bytes); // NOLINT + void *buffer = nullptr; + if (this->buffer_frac_ >= 4) + buffer = malloc(buf_bytes); // NOLINT + if (buffer == nullptr) + buffer = lv_custom_mem_alloc(buf_bytes); // NOLINT if (buffer == nullptr) { this->mark_failed(); this->status_set_error("Memory allocation failure"); diff --git a/esphome/components/lvgl/text/__init__.py b/esphome/components/lvgl/text/__init__.py index 89db139a6a..eb56cdb7a7 100644 --- a/esphome/components/lvgl/text/__init__.py +++ b/esphome/components/lvgl/text/__init__.py @@ -19,9 +19,8 @@ from ..widgets import get_widgets, wait_for_widgets LVGLText = lvgl_ns.class_("LVGLText", text.Text) -CONFIG_SCHEMA = text.TEXT_SCHEMA.extend( +CONFIG_SCHEMA = text.text_schema(LVGLText).extend( { - cv.GenerateID(): cv.declare_id(LVGLText), cv.Required(CONF_WIDGET): cv.use_id(LvText), } ) diff --git a/esphome/components/mapping/__init__.py b/esphome/components/mapping/__init__.py new file mode 100644 index 0000000000..79657084fa --- /dev/null +++ b/esphome/components/mapping/__init__.py @@ -0,0 +1,134 @@ +import difflib + +import esphome.codegen as cg +import esphome.config_validation as cv +from esphome.const import CONF_FROM, CONF_ID, CONF_TO +from esphome.core import CORE +from esphome.cpp_generator import MockObj, VariableDeclarationExpression, add_global +from esphome.loader import get_component + +CODEOWNERS = ["@clydebarrow"] +MULTI_CONF = True + +map_ = cg.std_ns.class_("map") + +CONF_ENTRIES = "entries" +CONF_CLASS = "class" + + +class IndexType: + """ + Represents a type of index in a map. + """ + + def __init__(self, validator, data_type, conversion): + self.validator = validator + self.data_type = data_type + self.conversion = conversion + + +INDEX_TYPES = { + "int": IndexType(cv.int_, cg.int_, int), + "string": IndexType(cv.string, cg.std_string, str), +} + + +def to_schema(value): + """ + Generate a schema for the 'to' field of a map. This can be either one of the index types or a class name. + :param value: + :return: + """ + return cv.Any( + cv.one_of(*INDEX_TYPES, lower=True), + cv.one_of(*CORE.id_classes.keys()), + )(value) + + +BASE_SCHEMA = cv.Schema( + { + cv.Required(CONF_ID): cv.declare_id(map_), + cv.Required(CONF_FROM): cv.one_of(*INDEX_TYPES, lower=True), + cv.Required(CONF_TO): cv.string, + }, + extra=cv.ALLOW_EXTRA, +) + + +def get_object_type(to_): + """ + Get the object type from a string. Possible formats: + xxx The name of a component which defines INSTANCE_TYPE + esphome::xxx::yyy A C++ class name defined in a component + xxx::yyy A C++ class name defined in a component + yyy A C++ class name defined in the core + """ + + if cls := CORE.id_classes.get(to_): + return cls + if cls := CORE.id_classes.get(to_.removeprefix("esphome::")): + return cls + # get_component will throw a wobbly if we don't check this first. + if "." in to_: + return None + if component := get_component(to_): + return component.instance_type + return None + + +def map_schema(config): + config = BASE_SCHEMA(config) + if CONF_ENTRIES not in config or not isinstance(config[CONF_ENTRIES], dict): + raise cv.Invalid("an entries list is required for a map") + entries = config[CONF_ENTRIES] + if len(entries) == 0: + raise cv.Invalid("Map must have at least one entry") + to_ = config[CONF_TO] + if to_ in INDEX_TYPES: + value_type = INDEX_TYPES[to_].validator + else: + value_type = get_object_type(to_) + if value_type is None: + matches = difflib.get_close_matches(to_, CORE.id_classes) + raise cv.Invalid( + f"No known mappable class name matches '{to_}'; did you mean one of {', '.join(matches)}?" + ) + value_type = cv.use_id(value_type) + config[CONF_ENTRIES] = {k: value_type(v) for k, v in entries.items()} + return config + + +CONFIG_SCHEMA = map_schema + + +async def to_code(config): + entries = config[CONF_ENTRIES] + from_ = config[CONF_FROM] + to_ = config[CONF_TO] + index_conversion = INDEX_TYPES[from_].conversion + index_type = INDEX_TYPES[from_].data_type + if to_ in INDEX_TYPES: + value_conversion = INDEX_TYPES[to_].conversion + value_type = INDEX_TYPES[to_].data_type + entries = { + index_conversion(key): value_conversion(value) + for key, value in entries.items() + } + else: + entries = { + index_conversion(key): await cg.get_variable(value) + for key, value in entries.items() + } + value_type = get_object_type(to_) + if list(entries.values())[0].op != ".": + value_type = value_type.operator("ptr") + varid = config[CONF_ID] + varid.type = map_.template(index_type, value_type) + var = MockObj(varid, ".") + decl = VariableDeclarationExpression(varid.type, "", varid) + add_global(decl) + CORE.register_variable(varid, var) + + for key, value in entries.items(): + cg.add(var.insert((key, value))) + return var diff --git a/esphome/components/max7219digit/max7219digit.cpp b/esphome/components/max7219digit/max7219digit.cpp index ec9970d1a0..13b75ca734 100644 --- a/esphome/components/max7219digit/max7219digit.cpp +++ b/esphome/components/max7219digit/max7219digit.cpp @@ -4,6 +4,8 @@ #include "esphome/core/hal.h" #include "max7219font.h" +#include + namespace esphome { namespace max7219digit { @@ -61,45 +63,42 @@ void MAX7219Component::dump_config() { } void MAX7219Component::loop() { - uint32_t now = millis(); - + const uint32_t now = millis(); + const uint32_t millis_since_last_scroll = now - this->last_scroll_; + const size_t first_line_size = this->max_displaybuffer_[0].size(); // check if the buffer has shrunk past the current position since last update - if ((this->max_displaybuffer_[0].size() >= this->old_buffer_size_ + 3) || - (this->max_displaybuffer_[0].size() <= this->old_buffer_size_ - 3)) { + if ((first_line_size >= this->old_buffer_size_ + 3) || (first_line_size <= this->old_buffer_size_ - 3)) { + ESP_LOGV(TAG, "Buffer size changed %d to %d", this->old_buffer_size_, first_line_size); this->stepsleft_ = 0; this->display(); - this->old_buffer_size_ = this->max_displaybuffer_[0].size(); + this->old_buffer_size_ = first_line_size; } - // Reset the counter back to 0 when full string has been displayed. - if (this->stepsleft_ > this->max_displaybuffer_[0].size()) - this->stepsleft_ = 0; - - // Return if there is no need to scroll or scroll is off - if (!this->scroll_ || (this->max_displaybuffer_[0].size() <= (size_t) get_width_internal())) { + if (!this->scroll_ || (first_line_size <= (size_t) get_width_internal())) { + ESP_LOGVV(TAG, "Return if there is no need to scroll or scroll is off."); this->display(); return; } - if ((this->stepsleft_ == 0) && (now - this->last_scroll_ < this->scroll_delay_)) { + if ((this->stepsleft_ == 0) && (millis_since_last_scroll < this->scroll_delay_)) { + ESP_LOGVV(TAG, "At first step. Waiting for scroll delay"); this->display(); return; } - // Dwell time at end of string in case of stop at end if (this->scroll_mode_ == ScrollMode::STOP) { - if (this->stepsleft_ >= this->max_displaybuffer_[0].size() - (size_t) get_width_internal() + 1) { - if (now - this->last_scroll_ >= this->scroll_dwell_) { - this->stepsleft_ = 0; - this->last_scroll_ = now; - this->display(); + if (this->stepsleft_ + get_width_internal() == first_line_size + 1) { + if (millis_since_last_scroll < this->scroll_dwell_) { + ESP_LOGVV(TAG, "Dwell time at end of string in case of stop at end. Step %d, since last scroll %d, dwell %d.", + this->stepsleft_, millis_since_last_scroll, this->scroll_dwell_); + return; } - return; + ESP_LOGV(TAG, "Dwell time passed. Continue scrolling."); } } - // Actual call to scroll left action - if (now - this->last_scroll_ >= this->scroll_speed_) { + if (millis_since_last_scroll >= this->scroll_speed_) { + ESP_LOGVV(TAG, "Call to scroll left action"); this->last_scroll_ = now; this->scroll_left(); this->display(); @@ -227,19 +226,20 @@ void MAX7219Component::scroll(bool on_off) { this->set_scroll(on_off); } void MAX7219Component::scroll_left() { for (int chip_line = 0; chip_line < this->num_chip_lines_; chip_line++) { + auto scroll = [&](std::vector &line, uint16_t steps) { + std::rotate(line.begin(), std::next(line.begin(), steps), line.end()); + }; if (this->update_) { this->max_displaybuffer_[chip_line].push_back(this->bckgrnd_); - for (uint16_t i = 0; i < this->stepsleft_; i++) { - this->max_displaybuffer_[chip_line].push_back(this->max_displaybuffer_[chip_line].front()); - this->max_displaybuffer_[chip_line].erase(this->max_displaybuffer_[chip_line].begin()); - } + scroll(this->max_displaybuffer_[chip_line], + (this->stepsleft_ + 1) % (this->max_displaybuffer_[chip_line].size())); } else { - this->max_displaybuffer_[chip_line].push_back(this->max_displaybuffer_[chip_line].front()); - this->max_displaybuffer_[chip_line].erase(this->max_displaybuffer_[chip_line].begin()); + scroll(this->max_displaybuffer_[chip_line], 1); } } this->update_ = false; this->stepsleft_++; + this->stepsleft_ %= this->max_displaybuffer_[0].size(); } void MAX7219Component::send_char(uint8_t chip, uint8_t data) { diff --git a/esphome/components/mdns/__init__.py b/esphome/components/mdns/__init__.py index e8902d5222..4b5e40dfea 100644 --- a/esphome/components/mdns/__init__.py +++ b/esphome/components/mdns/__init__.py @@ -35,8 +35,8 @@ SERVICE_SCHEMA = cv.Schema( { cv.Required(CONF_SERVICE): cv.string, cv.Required(CONF_PROTOCOL): cv.string, - cv.Optional(CONF_PORT, default=0): cv.Any(0, cv.port), - cv.Optional(CONF_TXT, default={}): {cv.string: cv.string}, + cv.Optional(CONF_PORT, default=0): cv.templatable(cv.Any(0, cv.port)), + cv.Optional(CONF_TXT, default={}): {cv.string: cv.templatable(cv.string)}, } ) @@ -102,12 +102,18 @@ async def to_code(config): for service in config[CONF_SERVICES]: txt = [ - mdns_txt_record(txt_key, txt_value) + cg.StructInitializer( + MDNSTXTRecord, + ("key", txt_key), + ("value", await cg.templatable(txt_value, [], cg.std_string)), + ) for txt_key, txt_value in service[CONF_TXT].items() ] - exp = mdns_service( - service[CONF_SERVICE], service[CONF_PROTOCOL], service[CONF_PORT], txt + service[CONF_SERVICE], + service[CONF_PROTOCOL], + await cg.templatable(service[CONF_PORT], [], cg.uint16), + txt, ) cg.add(var.add_extra_service(exp)) diff --git a/esphome/components/mdns/mdns_component.cpp b/esphome/components/mdns/mdns_component.cpp index 2fc09330cd..ffc668e218 100644 --- a/esphome/components/mdns/mdns_component.cpp +++ b/esphome/components/mdns/mdns_component.cpp @@ -1,9 +1,9 @@ #include "esphome/core/defines.h" #ifdef USE_MDNS -#include "mdns_component.h" -#include "esphome/core/version.h" #include "esphome/core/application.h" #include "esphome/core/log.h" +#include "esphome/core/version.h" +#include "mdns_component.h" #ifdef USE_API #include "esphome/components/api/api_server.h" @@ -62,7 +62,11 @@ void MDNSComponent::compile_records_() { #endif #ifdef USE_API_NOISE - service.txt_records.push_back({"api_encryption", "Noise_NNpsk0_25519_ChaChaPoly_SHA256"}); + if (api::global_api_server->get_noise_ctx()->has_psk()) { + service.txt_records.push_back({"api_encryption", "Noise_NNpsk0_25519_ChaChaPoly_SHA256"}); + } else { + service.txt_records.push_back({"api_encryption_supported", "Noise_NNpsk0_25519_ChaChaPoly_SHA256"}); + } #endif #ifdef ESPHOME_PROJECT_NAME @@ -117,9 +121,11 @@ void MDNSComponent::dump_config() { ESP_LOGCONFIG(TAG, " Hostname: %s", this->hostname_.c_str()); ESP_LOGV(TAG, " Services:"); for (const auto &service : this->services_) { - ESP_LOGV(TAG, " - %s, %s, %d", service.service_type.c_str(), service.proto.c_str(), service.port); + ESP_LOGV(TAG, " - %s, %s, %d", service.service_type.c_str(), service.proto.c_str(), + const_cast &>(service.port).value()); for (const auto &record : service.txt_records) { - ESP_LOGV(TAG, " TXT: %s = %s", record.key.c_str(), record.value.c_str()); + ESP_LOGV(TAG, " TXT: %s = %s", record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/mdns/mdns_component.h b/esphome/components/mdns/mdns_component.h index dfb5b72292..9eb2ba11d0 100644 --- a/esphome/components/mdns/mdns_component.h +++ b/esphome/components/mdns/mdns_component.h @@ -3,6 +3,7 @@ #ifdef USE_MDNS #include #include +#include "esphome/core/automation.h" #include "esphome/core/component.h" namespace esphome { @@ -10,7 +11,7 @@ namespace mdns { struct MDNSTXTRecord { std::string key; - std::string value; + TemplatableValue value; }; struct MDNSService { @@ -20,7 +21,7 @@ struct MDNSService { // second label indicating protocol _including_ underscore character prefix // as defined in RFC6763 Section 7, like "_tcp" or "_udp" std::string proto; - uint16_t port; + TemplatableValue port; std::vector txt_records; }; diff --git a/esphome/components/mdns/mdns_esp32.cpp b/esphome/components/mdns/mdns_esp32.cpp index 8006eb27f1..fed18d3630 100644 --- a/esphome/components/mdns/mdns_esp32.cpp +++ b/esphome/components/mdns/mdns_esp32.cpp @@ -31,11 +31,12 @@ void MDNSComponent::setup() { mdns_txt_item_t it{}; // dup strings to ensure the pointer is valid even after the record loop it.key = strdup(record.key.c_str()); - it.value = strdup(record.value.c_str()); + it.value = strdup(const_cast &>(record.value).value().c_str()); txt_records.push_back(it); } - err = mdns_service_add(nullptr, service.service_type.c_str(), service.proto.c_str(), service.port, - txt_records.data(), txt_records.size()); + uint16_t port = const_cast &>(service.port).value(); + err = mdns_service_add(nullptr, service.service_type.c_str(), service.proto.c_str(), port, txt_records.data(), + txt_records.size()); // free records for (const auto &it : txt_records) { diff --git a/esphome/components/mdns/mdns_esp8266.cpp b/esphome/components/mdns/mdns_esp8266.cpp index 7b6e7ec448..2c90d57021 100644 --- a/esphome/components/mdns/mdns_esp8266.cpp +++ b/esphome/components/mdns/mdns_esp8266.cpp @@ -29,9 +29,11 @@ void MDNSComponent::setup() { while (*service_type == '_') { service_type++; } - MDNS.addService(service_type, proto, service.port); + uint16_t port = const_cast &>(service.port).value(); + MDNS.addService(service_type, proto, port); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), record.value.c_str()); + MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/mdns/mdns_libretiny.cpp b/esphome/components/mdns/mdns_libretiny.cpp index c9a9a289dd..7a41ec9dce 100644 --- a/esphome/components/mdns/mdns_libretiny.cpp +++ b/esphome/components/mdns/mdns_libretiny.cpp @@ -29,9 +29,11 @@ void MDNSComponent::setup() { while (*service_type == '_') { service_type++; } - MDNS.addService(service_type, proto, service.port); + uint16_t port_ = const_cast &>(service.port).value(); + MDNS.addService(service_type, proto, port_); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), record.value.c_str()); + MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/mdns/mdns_rp2040.cpp b/esphome/components/mdns/mdns_rp2040.cpp index 89e668ee59..95894323f4 100644 --- a/esphome/components/mdns/mdns_rp2040.cpp +++ b/esphome/components/mdns/mdns_rp2040.cpp @@ -29,9 +29,11 @@ void MDNSComponent::setup() { while (*service_type == '_') { service_type++; } - MDNS.addService(service_type, proto, service.port); + uint16_t port = const_cast &>(service.port).value(); + MDNS.addService(service_type, proto, port); for (const auto &record : service.txt_records) { - MDNS.addServiceTxt(service_type, proto, record.key.c_str(), record.value.c_str()); + MDNS.addServiceTxt(service_type, proto, record.key.c_str(), + const_cast &>(record.value).value().c_str()); } } } diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index 0862406e46..0efe2ac288 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -12,6 +12,7 @@ import esphome.config_validation as cv from esphome.const import ( CONF_FILE, CONF_ID, + CONF_INTERNAL, CONF_MICROPHONE, CONF_MODEL, CONF_PASSWORD, @@ -40,6 +41,7 @@ CONF_ON_WAKE_WORD_DETECTED = "on_wake_word_detected" CONF_PROBABILITY_CUTOFF = "probability_cutoff" CONF_SLIDING_WINDOW_AVERAGE_SIZE = "sliding_window_average_size" CONF_SLIDING_WINDOW_SIZE = "sliding_window_size" +CONF_STOP_AFTER_DETECTION = "stop_after_detection" CONF_TENSOR_ARENA_SIZE = "tensor_arena_size" CONF_VAD = "vad" @@ -49,13 +51,20 @@ micro_wake_word_ns = cg.esphome_ns.namespace("micro_wake_word") MicroWakeWord = micro_wake_word_ns.class_("MicroWakeWord", cg.Component) +DisableModelAction = micro_wake_word_ns.class_("DisableModelAction", automation.Action) +EnableModelAction = micro_wake_word_ns.class_("EnableModelAction", automation.Action) StartAction = micro_wake_word_ns.class_("StartAction", automation.Action) StopAction = micro_wake_word_ns.class_("StopAction", automation.Action) +ModelIsEnabledCondition = micro_wake_word_ns.class_( + "ModelIsEnabledCondition", automation.Condition +) IsRunningCondition = micro_wake_word_ns.class_( "IsRunningCondition", automation.Condition ) +WakeWordModel = micro_wake_word_ns.class_("WakeWordModel") + def _validate_json_filename(value): value = cv.string(value) @@ -169,9 +178,10 @@ def _convert_manifest_v1_to_v2(v1_manifest): # Original Inception-based V1 manifest models require a minimum of 45672 bytes v2_manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE] = 45672 - # Original Inception-based V1 manifest models use a 20 ms feature step size v2_manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE] = 20 + # Original Inception-based V1 manifest models were trained only on TTS English samples + v2_manifest[KEY_TRAINED_LANGUAGES] = ["en"] return v2_manifest @@ -296,14 +306,16 @@ MODEL_SOURCE_SCHEMA = cv.Any( MODEL_SCHEMA = cv.Schema( { + cv.GenerateID(CONF_ID): cv.declare_id(WakeWordModel), cv.Optional(CONF_MODEL): MODEL_SOURCE_SCHEMA, cv.Optional(CONF_PROBABILITY_CUTOFF): cv.percentage, cv.Optional(CONF_SLIDING_WINDOW_SIZE): cv.positive_int, + cv.Optional(CONF_INTERNAL, default=False): cv.boolean, cv.GenerateID(CONF_RAW_DATA_ID): cv.declare_id(cg.uint8), } ) -# Provide a default VAD model that could be overridden +# Provides a default VAD model that could be overridden VAD_MODEL_SCHEMA = MODEL_SCHEMA.extend( cv.Schema( { @@ -328,7 +340,14 @@ CONFIG_SCHEMA = cv.All( cv.Schema( { cv.GenerateID(): cv.declare_id(MicroWakeWord), - cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone), + cv.Optional( + CONF_MICROPHONE, default={} + ): microphone.microphone_source_schema( + min_bits_per_sample=16, + max_bits_per_sample=16, + min_channels=1, + max_channels=1, + ), cv.Required(CONF_MODELS): cv.ensure_list( cv.maybe_simple_value(MODEL_SCHEMA, key=CONF_MODEL) ), @@ -336,6 +355,7 @@ CONFIG_SCHEMA = cv.All( single=True ), cv.Optional(CONF_VAD): _maybe_empty_vad_schema, + cv.Optional(CONF_STOP_AFTER_DETECTION, default=True): cv.boolean, cv.Optional(CONF_MODEL): cv.invalid( f"The {CONF_MODEL} parameter has moved to be a list element under the {CONF_MODELS} parameter." ), @@ -404,39 +424,42 @@ def _feature_step_size_validate(config): raise cv.Invalid("Cannot load models with different features step sizes.") -FINAL_VALIDATE_SCHEMA = _feature_step_size_validate +FINAL_VALIDATE_SCHEMA = cv.All( + cv.Schema( + { + cv.Required( + CONF_MICROPHONE + ): microphone.final_validate_microphone_source_schema( + "micro_wake_word", sample_rate=16000 + ), + }, + extra=cv.ALLOW_EXTRA, + ), + _feature_step_size_validate, +) async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) - mic = await cg.get_variable(config[CONF_MICROPHONE]) - cg.add(var.set_microphone(mic)) + mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) + cg.add(var.set_microphone_source(mic_source)) + + cg.add_define("USE_MICRO_WAKE_WORD") + cg.add_define("USE_OTA_STATE_CALLBACK") esp32.add_idf_component( name="esp-tflite-micro", repo="https://github.com/espressif/esp-tflite-micro", - ref="v1.3.1", - ) - # add esp-nn dependency for tflite-micro to work around https://github.com/espressif/esp-nn/issues/17 - # ...remove after switching to IDF 5.1.4+ - esp32.add_idf_component( - name="esp-nn", - repo="https://github.com/espressif/esp-nn", - ref="v1.1.0", + ref="v1.3.3.1", ) cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") cg.add_build_flag("-DESP_NN") - if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): - await automation.build_automation( - var.get_wake_word_detected_trigger(), - [(cg.std_string, "wake_word")], - on_wake_word_detection_config, - ) + cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") if vad_model := config.get(CONF_VAD): cg.add_define("USE_MICRO_WAKE_WORD_VAD") @@ -444,7 +467,7 @@ async def to_code(config): # Use the general model loading code for the VAD codegen config[CONF_MODELS].append(vad_model) - for model_parameters in config[CONF_MODELS]: + for i, model_parameters in enumerate(config[CONF_MODELS]): model_config = model_parameters.get(CONF_MODEL) data = [] manifest, data = _model_config_to_manifest_data(model_config) @@ -455,6 +478,8 @@ async def to_code(config): probability_cutoff = model_parameters.get( CONF_PROBABILITY_CUTOFF, manifest[KEY_MICRO][CONF_PROBABILITY_CUTOFF] ) + quantized_probability_cutoff = int(probability_cutoff * 255) + sliding_window_size = model_parameters.get( CONF_SLIDING_WINDOW_SIZE, manifest[KEY_MICRO][CONF_SLIDING_WINDOW_SIZE], @@ -464,24 +489,40 @@ async def to_code(config): cg.add( var.add_vad_model( prog_arr, - probability_cutoff, + quantized_probability_cutoff, sliding_window_size, manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], ) ) else: - cg.add( - var.add_wake_word_model( - prog_arr, - probability_cutoff, - sliding_window_size, - manifest[KEY_WAKE_WORD], - manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], - ) + # Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash + default_enabled = i == 0 + wake_word_model = cg.new_Pvariable( + model_parameters[CONF_ID], + str(model_parameters[CONF_ID]), + prog_arr, + quantized_probability_cutoff, + sliding_window_size, + manifest[KEY_WAKE_WORD], + manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE], + default_enabled, + model_parameters[CONF_INTERNAL], ) + for lang in manifest[KEY_TRAINED_LANGUAGES]: + cg.add(wake_word_model.add_trained_language(lang)) + + cg.add(var.add_wake_word_model(wake_word_model)) + cg.add(var.set_features_step_size(manifest[KEY_MICRO][CONF_FEATURE_STEP_SIZE])) - cg.add_library("kahrendt/ESPMicroSpeechFeatures", "1.1.0") + cg.add(var.set_stop_after_detection(config[CONF_STOP_AFTER_DETECTION])) + + if on_wake_word_detection_config := config.get(CONF_ON_WAKE_WORD_DETECTED): + await automation.build_automation( + var.get_wake_word_detected_trigger(), + [(cg.std_string, "wake_word")], + on_wake_word_detection_config, + ) MICRO_WAKE_WORD_ACTION_SCHEMA = cv.Schema({cv.GenerateID(): cv.use_id(MicroWakeWord)}) @@ -496,3 +537,30 @@ async def micro_wake_word_action_to_code(config, action_id, template_arg, args): var = cg.new_Pvariable(action_id, template_arg) await cg.register_parented(var, config[CONF_ID]) return var + + +MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA = automation.maybe_simple_id( + { + cv.Required(CONF_ID): cv.use_id(WakeWordModel), + } +) + + +@register_action( + "micro_wake_word.enable_model", + EnableModelAction, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +@register_action( + "micro_wake_word.disable_model", + DisableModelAction, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +@register_condition( + "micro_wake_word.model_is_enabled", + ModelIsEnabledCondition, + MICRO_WAKE_WORLD_MODEL_ACTION_SCHEMA, +) +async def model_action(config, action_id, template_arg, args): + parent = await cg.get_variable(config[CONF_ID]) + return cg.new_Pvariable(action_id, template_arg, parent) diff --git a/esphome/components/micro_wake_word/automation.h b/esphome/components/micro_wake_word/automation.h new file mode 100644 index 0000000000..f10a4ed347 --- /dev/null +++ b/esphome/components/micro_wake_word/automation.h @@ -0,0 +1,54 @@ +#pragma once + +#include "micro_wake_word.h" +#include "streaming_model.h" + +#ifdef USE_ESP_IDF +namespace esphome { +namespace micro_wake_word { + +template class StartAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->start(); } +}; + +template class StopAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->stop(); } +}; + +template class IsRunningCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->is_running(); } +}; + +template class EnableModelAction : public Action { + public: + explicit EnableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + void play(Ts... x) override { this->wake_word_model_->enable(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +template class DisableModelAction : public Action { + public: + explicit DisableModelAction(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + void play(Ts... x) override { this->wake_word_model_->disable(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +template class ModelIsEnabledCondition : public Condition { + public: + explicit ModelIsEnabledCondition(WakeWordModel *wake_word_model) : wake_word_model_(wake_word_model) {} + bool check(Ts... x) override { return this->wake_word_model_->is_enabled(); } + + protected: + WakeWordModel *wake_word_model_; +}; + +} // namespace micro_wake_word +} // namespace esphome +#endif diff --git a/esphome/components/micro_wake_word/micro_wake_word.cpp b/esphome/components/micro_wake_word/micro_wake_word.cpp index b58c7ec434..a44348fdc9 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.cpp +++ b/esphome/components/micro_wake_word/micro_wake_word.cpp @@ -1,5 +1,4 @@ #include "micro_wake_word.h" -#include "streaming_model.h" #ifdef USE_ESP_IDF @@ -7,41 +6,55 @@ #include "esphome/core/helpers.h" #include "esphome/core/log.h" -#include -#include +#include "esphome/components/audio/audio_transfer_buffer.h" -#include -#include -#include - -#include +#ifdef USE_OTA +#include "esphome/components/ota/ota_backend.h" +#endif namespace esphome { namespace micro_wake_word { static const char *const TAG = "micro_wake_word"; -static const size_t SAMPLE_RATE_HZ = 16000; // 16 kHz -static const size_t BUFFER_LENGTH = 64; // 0.064 seconds -static const size_t BUFFER_SIZE = SAMPLE_RATE_HZ / 1000 * BUFFER_LENGTH; -static const size_t INPUT_BUFFER_SIZE = 16 * SAMPLE_RATE_HZ / 1000; // 16ms * 16kHz / 1000ms +static const ssize_t DETECTION_QUEUE_LENGTH = 5; + +static const size_t DATA_TIMEOUT_MS = 50; + +static const uint32_t RING_BUFFER_DURATION_MS = 120; + +static const uint32_t INFERENCE_TASK_STACK_SIZE = 3072; +static const UBaseType_t INFERENCE_TASK_PRIORITY = 3; + +enum EventGroupBits : uint32_t { + COMMAND_STOP = (1 << 0), // Signals the inference task should stop + + TASK_STARTING = (1 << 3), + TASK_RUNNING = (1 << 4), + TASK_STOPPING = (1 << 5), + TASK_STOPPED = (1 << 6), + + ERROR_MEMORY = (1 << 9), + ERROR_INFERENCE = (1 << 10), + + WARNING_FULL_RING_BUFFER = (1 << 13), + + ERROR_BITS = ERROR_MEMORY | ERROR_INFERENCE, + ALL_BITS = 0xfffff, // 24 total bits available in an event group +}; float MicroWakeWord::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } static const LogString *micro_wake_word_state_to_string(State state) { switch (state) { - case State::IDLE: - return LOG_STR("IDLE"); - case State::START_MICROPHONE: - return LOG_STR("START_MICROPHONE"); - case State::STARTING_MICROPHONE: - return LOG_STR("STARTING_MICROPHONE"); + case State::STARTING: + return LOG_STR("STARTING"); case State::DETECTING_WAKE_WORD: return LOG_STR("DETECTING_WAKE_WORD"); - case State::STOP_MICROPHONE: - return LOG_STR("STOP_MICROPHONE"); - case State::STOPPING_MICROPHONE: - return LOG_STR("STOPPING_MICROPHONE"); + case State::STOPPING: + return LOG_STR("STOPPING"); + case State::STOPPED: + return LOG_STR("STOPPED"); default: return LOG_STR("UNKNOWN"); } @@ -51,7 +64,7 @@ void MicroWakeWord::dump_config() { ESP_LOGCONFIG(TAG, "microWakeWord:"); ESP_LOGCONFIG(TAG, " models:"); for (auto &model : this->wake_word_models_) { - model.log_model_config(); + model->log_model_config(); } #ifdef USE_MICRO_WAKE_WORD_VAD this->vad_model_->log_model_config(); @@ -61,88 +74,270 @@ void MicroWakeWord::dump_config() { void MicroWakeWord::setup() { ESP_LOGCONFIG(TAG, "Setting up microWakeWord..."); - if (!this->register_streaming_ops_(this->streaming_op_resolver_)) { + this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; + this->frontend_config_.window.step_size_ms = this->features_step_size_; + this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; + this->frontend_config_.filterbank.lower_band_limit = FILTERBANK_LOWER_BAND_LIMIT; + this->frontend_config_.filterbank.upper_band_limit = FILTERBANK_UPPER_BAND_LIMIT; + this->frontend_config_.noise_reduction.smoothing_bits = NOISE_REDUCTION_SMOOTHING_BITS; + this->frontend_config_.noise_reduction.even_smoothing = NOISE_REDUCTION_EVEN_SMOOTHING; + this->frontend_config_.noise_reduction.odd_smoothing = NOISE_REDUCTION_ODD_SMOOTHING; + this->frontend_config_.noise_reduction.min_signal_remaining = NOISE_REDUCTION_MIN_SIGNAL_REMAINING; + this->frontend_config_.pcan_gain_control.enable_pcan = PCAN_GAIN_CONTROL_ENABLE_PCAN; + this->frontend_config_.pcan_gain_control.strength = PCAN_GAIN_CONTROL_STRENGTH; + this->frontend_config_.pcan_gain_control.offset = PCAN_GAIN_CONTROL_OFFSET; + this->frontend_config_.pcan_gain_control.gain_bits = PCAN_GAIN_CONTROL_GAIN_BITS; + this->frontend_config_.log_scale.enable_log = LOG_SCALE_ENABLE_LOG; + this->frontend_config_.log_scale.scale_shift = LOG_SCALE_SCALE_SHIFT; + + this->event_group_ = xEventGroupCreate(); + if (this->event_group_ == nullptr) { + ESP_LOGE(TAG, "Failed to create event group"); this->mark_failed(); return; } + this->detection_queue_ = xQueueCreate(DETECTION_QUEUE_LENGTH, sizeof(DetectionEvent)); + if (this->detection_queue_ == nullptr) { + ESP_LOGE(TAG, "Failed to create detection event queue"); + this->mark_failed(); + return; + } + + this->microphone_source_->add_data_callback([this](const std::vector &data) { + if (this->state_ == State::STOPPED) { + return; + } + std::shared_ptr temp_ring_buffer = this->ring_buffer_.lock(); + if (this->ring_buffer_.use_count() > 1) { + size_t bytes_free = temp_ring_buffer->free(); + + if (bytes_free < data.size()) { + xEventGroupSetBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); + temp_ring_buffer->reset(); + } + temp_ring_buffer->write((void *) data.data(), data.size()); + } + }); + +#ifdef USE_OTA + ota::get_global_ota_callback()->add_on_state_callback( + [this](ota::OTAState state, float progress, uint8_t error, ota::OTAComponent *comp) { + if (state == ota::OTA_STARTED) { + this->suspend_task_(); + } else if (state == ota::OTA_ERROR) { + this->resume_task_(); + } + }); +#endif ESP_LOGCONFIG(TAG, "Micro Wake Word initialized"); - - this->frontend_config_.window.size_ms = FEATURE_DURATION_MS; - this->frontend_config_.window.step_size_ms = this->features_step_size_; - this->frontend_config_.filterbank.num_channels = PREPROCESSOR_FEATURE_SIZE; - this->frontend_config_.filterbank.lower_band_limit = 125.0; - this->frontend_config_.filterbank.upper_band_limit = 7500.0; - this->frontend_config_.noise_reduction.smoothing_bits = 10; - this->frontend_config_.noise_reduction.even_smoothing = 0.025; - this->frontend_config_.noise_reduction.odd_smoothing = 0.06; - this->frontend_config_.noise_reduction.min_signal_remaining = 0.05; - this->frontend_config_.pcan_gain_control.enable_pcan = 1; - this->frontend_config_.pcan_gain_control.strength = 0.95; - this->frontend_config_.pcan_gain_control.offset = 80.0; - this->frontend_config_.pcan_gain_control.gain_bits = 21; - this->frontend_config_.log_scale.enable_log = 1; - this->frontend_config_.log_scale.scale_shift = 6; } -void MicroWakeWord::add_wake_word_model(const uint8_t *model_start, float probability_cutoff, - size_t sliding_window_average_size, const std::string &wake_word, - size_t tensor_arena_size) { - this->wake_word_models_.emplace_back(model_start, probability_cutoff, sliding_window_average_size, wake_word, - tensor_arena_size); +void MicroWakeWord::inference_task(void *params) { + MicroWakeWord *this_mww = (MicroWakeWord *) params; + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STARTING); + + { // Ensures any C++ objects fall out of scope to deallocate before deleting the task + + const size_t new_bytes_to_process = + this_mww->microphone_source_->get_audio_stream_info().ms_to_bytes(this_mww->features_step_size_); + std::unique_ptr audio_buffer; + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]; + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + // Allocate audio transfer buffer + audio_buffer = audio::AudioSourceTransferBuffer::create(new_bytes_to_process); + + if (audio_buffer == nullptr) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); + } + } + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + // Allocate ring buffer + std::shared_ptr temp_ring_buffer = RingBuffer::create( + this_mww->microphone_source_->get_audio_stream_info().ms_to_bytes(RING_BUFFER_DURATION_MS)); + if (temp_ring_buffer.use_count() == 0) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_MEMORY); + } + audio_buffer->set_source(temp_ring_buffer); + this_mww->ring_buffer_ = temp_ring_buffer; + } + + if (!(xEventGroupGetBits(this_mww->event_group_) & ERROR_BITS)) { + this_mww->microphone_source_->start(); + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_RUNNING); + + while (!(xEventGroupGetBits(this_mww->event_group_) & COMMAND_STOP)) { + audio_buffer->transfer_data_from_source(pdMS_TO_TICKS(DATA_TIMEOUT_MS)); + + if (audio_buffer->available() < new_bytes_to_process) { + // Insufficient data to generate new spectrogram features, read more next iteration + continue; + } + + // Generate new spectrogram features + uint32_t processed_samples = this_mww->generate_features_( + (int16_t *) audio_buffer->get_buffer_start(), audio_buffer->available() / sizeof(int16_t), features_buffer); + audio_buffer->decrease_buffer_length(processed_samples * sizeof(int16_t)); + + // Run inference using the new spectorgram features + if (!this_mww->update_model_probabilities_(features_buffer)) { + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::ERROR_INFERENCE); + break; + } + + // Process each model's probabilities and possibly send a Detection Event to the queue + this_mww->process_probabilities_(); + } + } + } + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPING); + + this_mww->unload_models_(); + this_mww->microphone_source_->stop(); + FrontendFreeStateContents(&this_mww->frontend_state_); + + xEventGroupSetBits(this_mww->event_group_, EventGroupBits::TASK_STOPPED); + while (true) { + // Continuously delay until the main loop deletes the task + delay(10); + } } +std::vector MicroWakeWord::get_wake_words() { + std::vector external_wake_word_models; + for (auto *model : this->wake_word_models_) { + if (!model->get_internal_only()) { + external_wake_word_models.push_back(model); + } + } + return external_wake_word_models; +} + +void MicroWakeWord::add_wake_word_model(WakeWordModel *model) { this->wake_word_models_.push_back(model); } + #ifdef USE_MICRO_WAKE_WORD_VAD -void MicroWakeWord::add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, +void MicroWakeWord::add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size) { this->vad_model_ = make_unique(model_start, probability_cutoff, sliding_window_size, tensor_arena_size); } #endif +void MicroWakeWord::suspend_task_() { + if (this->inference_task_handle_ != nullptr) { + vTaskSuspend(this->inference_task_handle_); + } +} + +void MicroWakeWord::resume_task_() { + if (this->inference_task_handle_ != nullptr) { + vTaskResume(this->inference_task_handle_); + } +} + void MicroWakeWord::loop() { + uint32_t event_group_bits = xEventGroupGetBits(this->event_group_); + + if (event_group_bits & EventGroupBits::ERROR_MEMORY) { + xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_MEMORY); + ESP_LOGE(TAG, "Encountered an error allocating buffers"); + } + + if (event_group_bits & EventGroupBits::ERROR_INFERENCE) { + xEventGroupClearBits(this->event_group_, EventGroupBits::ERROR_INFERENCE); + ESP_LOGE(TAG, "Encountered an error while performing an inference"); + } + + if (event_group_bits & EventGroupBits::WARNING_FULL_RING_BUFFER) { + xEventGroupClearBits(this->event_group_, EventGroupBits::WARNING_FULL_RING_BUFFER); + ESP_LOGW(TAG, "Not enough free bytes in ring buffer to store incoming audio data. Resetting the ring buffer. Wake " + "word detection accuracy will temporarily be reduced."); + } + + if (event_group_bits & EventGroupBits::TASK_STARTING) { + ESP_LOGD(TAG, "Inference task has started, attempting to allocate memory for buffers"); + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STARTING); + } + + if (event_group_bits & EventGroupBits::TASK_RUNNING) { + ESP_LOGD(TAG, "Inference task is running"); + + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_RUNNING); + this->set_state_(State::DETECTING_WAKE_WORD); + } + + if (event_group_bits & EventGroupBits::TASK_STOPPING) { + ESP_LOGD(TAG, "Inference task is stopping, deallocating buffers"); + xEventGroupClearBits(this->event_group_, EventGroupBits::TASK_STOPPING); + } + + if ((event_group_bits & EventGroupBits::TASK_STOPPED)) { + ESP_LOGD(TAG, "Inference task is finished, freeing task resources"); + vTaskDelete(this->inference_task_handle_); + this->inference_task_handle_ = nullptr; + xEventGroupClearBits(this->event_group_, ALL_BITS); + xQueueReset(this->detection_queue_); + this->set_state_(State::STOPPED); + } + + if ((this->pending_start_) && (this->state_ == State::STOPPED)) { + this->set_state_(State::STARTING); + this->pending_start_ = false; + } + + if ((this->pending_stop_) && (this->state_ == State::DETECTING_WAKE_WORD)) { + this->set_state_(State::STOPPING); + this->pending_stop_ = false; + } + switch (this->state_) { - case State::IDLE: - break; - case State::START_MICROPHONE: - ESP_LOGD(TAG, "Starting Microphone"); - this->microphone_->start(); - this->set_state_(State::STARTING_MICROPHONE); - this->high_freq_.start(); - break; - case State::STARTING_MICROPHONE: - if (this->microphone_->is_running()) { - this->set_state_(State::DETECTING_WAKE_WORD); - } - break; - case State::DETECTING_WAKE_WORD: - while (!this->has_enough_samples_()) { - this->read_microphone_(); - } - this->update_model_probabilities_(); - if (this->detect_wake_words_()) { - ESP_LOGD(TAG, "Wake Word '%s' Detected", (this->detected_wake_word_).c_str()); - this->detected_ = true; - this->set_state_(State::STOP_MICROPHONE); - } - break; - case State::STOP_MICROPHONE: - ESP_LOGD(TAG, "Stopping Microphone"); - this->microphone_->stop(); - this->set_state_(State::STOPPING_MICROPHONE); - this->high_freq_.stop(); - this->unload_models_(); - this->deallocate_buffers_(); - break; - case State::STOPPING_MICROPHONE: - if (this->microphone_->is_stopped()) { - this->set_state_(State::IDLE); - if (this->detected_) { - this->wake_word_detected_trigger_->trigger(this->detected_wake_word_); - this->detected_ = false; - this->detected_wake_word_ = ""; + case State::STARTING: + if ((this->inference_task_handle_ == nullptr) && !this->status_has_error()) { + // Setup preprocesor feature generator. If done in the task, it would lock the task to its initial core, as it + // uses floating point operations. + if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, + this->microphone_source_->get_audio_stream_info().get_sample_rate())) { + this->status_momentary_error( + "Failed to allocate buffers for spectrogram feature processor, attempting again in 1 second", 1000); + return; + } + + xTaskCreate(MicroWakeWord::inference_task, "mww", INFERENCE_TASK_STACK_SIZE, (void *) this, + INFERENCE_TASK_PRIORITY, &this->inference_task_handle_); + + if (this->inference_task_handle_ == nullptr) { + FrontendFreeStateContents(&this->frontend_state_); // Deallocate frontend state + this->status_momentary_error("Task failed to start, attempting again in 1 second", 1000); } } break; + case State::DETECTING_WAKE_WORD: { + DetectionEvent detection_event; + while (xQueueReceive(this->detection_queue_, &detection_event, 0)) { + if (detection_event.blocked_by_vad) { + ESP_LOGD(TAG, "Wake word model predicts '%s', but VAD model doesn't.", detection_event.wake_word->c_str()); + } else { + constexpr float uint8_to_float_divisor = + 255.0f; // Converting a quantized uint8 probability to floating point + ESP_LOGD(TAG, "Detected '%s' with sliding average probability is %.2f and max probability is %.2f", + detection_event.wake_word->c_str(), (detection_event.average_probability / uint8_to_float_divisor), + (detection_event.max_probability / uint8_to_float_divisor)); + this->wake_word_detected_trigger_->trigger(*detection_event.wake_word); + if (this->stop_after_detection_) { + this->stop(); + } + } + } + break; + } + case State::STOPPING: + xEventGroupSetBits(this->event_group_, EventGroupBits::COMMAND_STOP); + break; + case State::STOPPED: + break; } } @@ -157,212 +352,40 @@ void MicroWakeWord::start() { return; } - if (!this->load_models_() || !this->allocate_buffers_()) { - ESP_LOGE(TAG, "Failed to load the wake word model(s) or allocate buffers"); - this->status_set_error(); - } else { - this->status_clear_error(); - } - - if (this->status_has_error()) { - ESP_LOGW(TAG, "Wake word component has an error. Please check logs"); + if (this->is_running()) { + ESP_LOGW(TAG, "Wake word detection is already running"); return; } - if (this->state_ != State::IDLE) { - ESP_LOGW(TAG, "Wake word is already running"); - return; - } + ESP_LOGD(TAG, "Starting wake word detection"); - this->reset_states_(); - this->set_state_(State::START_MICROPHONE); + this->pending_start_ = true; + this->pending_stop_ = false; } void MicroWakeWord::stop() { - if (this->state_ == State::IDLE) { - ESP_LOGW(TAG, "Wake word is already stopped"); + if (this->state_ == STOPPED) return; - } - if (this->state_ == State::STOPPING_MICROPHONE) { - ESP_LOGW(TAG, "Wake word is already stopping"); - return; - } - this->set_state_(State::STOP_MICROPHONE); + + ESP_LOGD(TAG, "Stopping wake word detection"); + + this->pending_start_ = false; + this->pending_stop_ = true; } void MicroWakeWord::set_state_(State state) { - ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), - LOG_STR_ARG(micro_wake_word_state_to_string(state))); - this->state_ = state; + if (this->state_ != state) { + ESP_LOGD(TAG, "State changed from %s to %s", LOG_STR_ARG(micro_wake_word_state_to_string(this->state_)), + LOG_STR_ARG(micro_wake_word_state_to_string(state))); + this->state_ = state; + } } -size_t MicroWakeWord::read_microphone_() { - size_t bytes_read = this->microphone_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (bytes_read == 0) { - return 0; - } - - size_t bytes_free = this->ring_buffer_->free(); - - if (bytes_free < bytes_read) { - ESP_LOGW(TAG, - "Not enough free bytes in ring buffer to store incoming audio data (free bytes=%d, incoming bytes=%d). " - "Resetting the ring buffer. Wake word detection accuracy will be reduced.", - bytes_free, bytes_read); - - this->ring_buffer_->reset(); - } - - return this->ring_buffer_->write((void *) this->input_buffer_, bytes_read); -} - -bool MicroWakeWord::allocate_buffers_() { - ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - - if (this->input_buffer_ == nullptr) { - this->input_buffer_ = audio_samples_allocator.allocate(INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (this->input_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate input buffer"); - return false; - } - } - - if (this->preprocessor_audio_buffer_ == nullptr) { - this->preprocessor_audio_buffer_ = audio_samples_allocator.allocate(this->new_samples_to_get_()); - if (this->preprocessor_audio_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate the audio preprocessor's buffer."); - return false; - } - } - - if (this->ring_buffer_ == nullptr) { - this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); - if (this->ring_buffer_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate ring buffer"); - return false; - } - } - - return true; -} - -void MicroWakeWord::deallocate_buffers_() { - ExternalRAMAllocator audio_samples_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - audio_samples_allocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - this->input_buffer_ = nullptr; - audio_samples_allocator.deallocate(this->preprocessor_audio_buffer_, this->new_samples_to_get_()); - this->preprocessor_audio_buffer_ = nullptr; -} - -bool MicroWakeWord::load_models_() { - // Setup preprocesor feature generator - if (!FrontendPopulateState(&this->frontend_config_, &this->frontend_state_, AUDIO_SAMPLE_FREQUENCY)) { - ESP_LOGD(TAG, "Failed to populate frontend state"); - FrontendFreeStateContents(&this->frontend_state_); - return false; - } - - // Setup streaming models - for (auto &model : this->wake_word_models_) { - if (!model.load_model(this->streaming_op_resolver_)) { - ESP_LOGE(TAG, "Failed to initialize a wake word model."); - return false; - } - } -#ifdef USE_MICRO_WAKE_WORD_VAD - if (!this->vad_model_->load_model(this->streaming_op_resolver_)) { - ESP_LOGE(TAG, "Failed to initialize VAD model."); - return false; - } -#endif - - return true; -} - -void MicroWakeWord::unload_models_() { - FrontendFreeStateContents(&this->frontend_state_); - - for (auto &model : this->wake_word_models_) { - model.unload_model(); - } -#ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->unload_model(); -#endif -} - -void MicroWakeWord::update_model_probabilities_() { - int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]; - - if (!this->generate_features_for_window_(audio_features)) { - return; - } - - // Increase the counter since the last positive detection - this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); - - for (auto &model : this->wake_word_models_) { - // Perform inference - model.perform_streaming_inference(audio_features); - } -#ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->perform_streaming_inference(audio_features); -#endif -} - -bool MicroWakeWord::detect_wake_words_() { - // Verify we have processed samples since the last positive detection - if (this->ignore_windows_ < 0) { - return false; - } - -#ifdef USE_MICRO_WAKE_WORD_VAD - bool vad_state = this->vad_model_->determine_detected(); -#endif - - for (auto &model : this->wake_word_models_) { - if (model.determine_detected()) { -#ifdef USE_MICRO_WAKE_WORD_VAD - if (vad_state) { -#endif - this->detected_wake_word_ = model.get_wake_word(); - return true; -#ifdef USE_MICRO_WAKE_WORD_VAD - } else { - ESP_LOGD(TAG, "Wake word model predicts %s, but VAD model doesn't.", model.get_wake_word().c_str()); - } -#endif - } - } - - return false; -} - -bool MicroWakeWord::has_enough_samples_() { - return this->ring_buffer_->available() >= - (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)) * sizeof(int16_t); -} - -bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]) { - // Ensure we have enough new audio samples in the ring buffer for a full window - if (!this->has_enough_samples_()) { - return false; - } - - size_t bytes_read = this->ring_buffer_->read((void *) (this->preprocessor_audio_buffer_), - this->new_samples_to_get_() * sizeof(int16_t), pdMS_TO_TICKS(200)); - - if (bytes_read == 0) { - ESP_LOGE(TAG, "Could not read data from Ring Buffer"); - } else if (bytes_read < this->new_samples_to_get_() * sizeof(int16_t)) { - ESP_LOGD(TAG, "Partial Read of Data by Model"); - ESP_LOGD(TAG, "Could only read %d bytes when required %d bytes ", bytes_read, - (int) (this->new_samples_to_get_() * sizeof(int16_t))); - return false; - } - - size_t num_samples_read; - struct FrontendOutput frontend_output = FrontendProcessSamples( - &this->frontend_state_, this->preprocessor_audio_buffer_, this->new_samples_to_get_(), &num_samples_read); +size_t MicroWakeWord::generate_features_(int16_t *audio_buffer, size_t samples_available, + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]) { + size_t processed_samples = 0; + struct FrontendOutput frontend_output = + FrontendProcessSamples(&this->frontend_state_, audio_buffer, samples_available, &processed_samples); for (size_t i = 0; i < frontend_output.size; ++i) { // These scaling values are set to match the TFLite audio frontend int8 output. @@ -372,8 +395,8 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F // for historical reasons, to match up with the output of other feature // generators. // The process is then further complicated when we quantize the model. This - // means we have to scale the 0.0 to 26.0 real values to the -128 to 127 - // signed integer numbers. + // means we have to scale the 0.0 to 26.0 real values to the -128 (INT8_MIN) + // to 127 (INT8_MAX) signed integer numbers. // All this means that to get matching values from our integer feature // output into the tensor input, we have to perform: // input = (((feature / 25.6) / 26.0) * 256) - 128 @@ -382,74 +405,63 @@ bool MicroWakeWord::generate_features_for_window_(int8_t features[PREPROCESSOR_F constexpr int32_t value_scale = 256; constexpr int32_t value_div = 666; // 666 = 25.6 * 26.0 after rounding int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; - value -= 128; - if (value < -128) { - value = -128; - } - if (value > 127) { - value = 127; - } - features[i] = value; + + value += INT8_MIN; // Adds a -128; i.e., subtracts 128 + features_buffer[i] = static_cast(clamp(value, INT8_MIN, INT8_MAX)); } - return true; + return processed_samples; } -void MicroWakeWord::reset_states_() { - ESP_LOGD(TAG, "Resetting buffers and probabilities"); - this->ring_buffer_->reset(); - this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; +void MicroWakeWord::process_probabilities_() { +#ifdef USE_MICRO_WAKE_WORD_VAD + DetectionEvent vad_state = this->vad_model_->determine_detected(); + + this->vad_state_ = vad_state.detected; // atomic write, so thread safe +#endif + for (auto &model : this->wake_word_models_) { - model.reset_probabilities(); + if (model->get_unprocessed_probability_status()) { + // Only detect wake words if there is a new probability since the last check + DetectionEvent wake_word_state = model->determine_detected(); + if (wake_word_state.detected) { +#ifdef USE_MICRO_WAKE_WORD_VAD + if (vad_state.detected) { +#endif + xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); + model->reset_probabilities(); +#ifdef USE_MICRO_WAKE_WORD_VAD + } else { + wake_word_state.blocked_by_vad = true; + xQueueSend(this->detection_queue_, &wake_word_state, portMAX_DELAY); + } +#endif + } + } + } +} + +void MicroWakeWord::unload_models_() { + for (auto &model : this->wake_word_models_) { + model->unload_model(); } #ifdef USE_MICRO_WAKE_WORD_VAD - this->vad_model_->reset_probabilities(); + this->vad_model_->unload_model(); #endif } -bool MicroWakeWord::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { - if (op_resolver.AddCallOnce() != kTfLiteOk) - return false; - if (op_resolver.AddVarHandle() != kTfLiteOk) - return false; - if (op_resolver.AddReshape() != kTfLiteOk) - return false; - if (op_resolver.AddReadVariable() != kTfLiteOk) - return false; - if (op_resolver.AddStridedSlice() != kTfLiteOk) - return false; - if (op_resolver.AddConcatenation() != kTfLiteOk) - return false; - if (op_resolver.AddAssignVariable() != kTfLiteOk) - return false; - if (op_resolver.AddConv2D() != kTfLiteOk) - return false; - if (op_resolver.AddMul() != kTfLiteOk) - return false; - if (op_resolver.AddAdd() != kTfLiteOk) - return false; - if (op_resolver.AddMean() != kTfLiteOk) - return false; - if (op_resolver.AddFullyConnected() != kTfLiteOk) - return false; - if (op_resolver.AddLogistic() != kTfLiteOk) - return false; - if (op_resolver.AddQuantize() != kTfLiteOk) - return false; - if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) - return false; - if (op_resolver.AddAveragePool2D() != kTfLiteOk) - return false; - if (op_resolver.AddMaxPool2D() != kTfLiteOk) - return false; - if (op_resolver.AddPad() != kTfLiteOk) - return false; - if (op_resolver.AddPack() != kTfLiteOk) - return false; - if (op_resolver.AddSplitV() != kTfLiteOk) - return false; +bool MicroWakeWord::update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]) { + bool success = true; - return true; + for (auto &model : this->wake_word_models_) { + // Perform inference + success = success & model->perform_streaming_inference(audio_features); + } +#ifdef USE_MICRO_WAKE_WORD_VAD + success = success & this->vad_model_->perform_streaming_inference(audio_features); +#endif + + return success; } } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/micro_wake_word.h b/esphome/components/micro_wake_word/micro_wake_word.h index 0c805b75fc..d46c40e48b 100644 --- a/esphome/components/micro_wake_word/micro_wake_word.h +++ b/esphome/components/micro_wake_word/micro_wake_word.h @@ -5,33 +5,27 @@ #include "preprocessor_settings.h" #include "streaming_model.h" +#include "esphome/components/microphone/microphone_source.h" + #include "esphome/core/automation.h" #include "esphome/core/component.h" #include "esphome/core/ring_buffer.h" -#include "esphome/components/microphone/microphone.h" +#include +#include #include -#include -#include -#include - namespace esphome { namespace micro_wake_word { enum State { - IDLE, - START_MICROPHONE, - STARTING_MICROPHONE, + STARTING, DETECTING_WAKE_WORD, - STOP_MICROPHONE, - STOPPING_MICROPHONE, + STOPPING, + STOPPED, }; -// The number of audio slices to process before accepting a positive detection -static const uint8_t MIN_SLICES_BEFORE_DETECTION = 74; - class MicroWakeWord : public Component { public: void setup() override; @@ -42,132 +36,91 @@ class MicroWakeWord : public Component { void start(); void stop(); - bool is_running() const { return this->state_ != State::IDLE; } + bool is_running() const { return this->state_ != State::STOPPED; } void set_features_step_size(uint8_t step_size) { this->features_step_size_ = step_size; } - void set_microphone(microphone::Microphone *microphone) { this->microphone_ = microphone; } + void set_microphone_source(microphone::MicrophoneSource *microphone_source) { + this->microphone_source_ = microphone_source; + } + + void set_stop_after_detection(bool stop_after_detection) { this->stop_after_detection_ = stop_after_detection; } Trigger *get_wake_word_detected_trigger() const { return this->wake_word_detected_trigger_; } - void add_wake_word_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size); + void add_wake_word_model(WakeWordModel *model); #ifdef USE_MICRO_WAKE_WORD_VAD - void add_vad_model(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, + void add_vad_model(const uint8_t *model_start, uint8_t probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); + + // Intended for the voice assistant component to fetch VAD status + bool get_vad_state() { return this->vad_state_; } #endif + // Intended for the voice assistant component to access which wake words are available + // Since these are pointers to the WakeWordModel objects, the voice assistant component can enable or disable them + std::vector get_wake_words(); + protected: - microphone::Microphone *microphone_{nullptr}; + microphone::MicrophoneSource *microphone_source_{nullptr}; Trigger *wake_word_detected_trigger_ = new Trigger(); - State state_{State::IDLE}; - HighFrequencyLoopRequester high_freq_; + State state_{State::STOPPED}; - std::unique_ptr ring_buffer_; - - std::vector wake_word_models_; + std::weak_ptr ring_buffer_; + std::vector wake_word_models_; #ifdef USE_MICRO_WAKE_WORD_VAD std::unique_ptr vad_model_; + bool vad_state_{false}; #endif - tflite::MicroMutableOpResolver<20> streaming_op_resolver_; + bool pending_start_{false}; + bool pending_stop_{false}; + + bool stop_after_detection_; + + uint8_t features_step_size_; // Audio frontend handles generating spectrogram features struct FrontendConfig frontend_config_; struct FrontendState frontend_state_; - // When the wake word detection first starts, we ignore this many audio - // feature slices before accepting a positive detection - int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; + // Handles managing the stop/state of the inference task + EventGroupHandle_t event_group_; - uint8_t features_step_size_; + // Used to send messages about the models' states to the main loop + QueueHandle_t detection_queue_; - // Stores audio read from the microphone before being added to the ring buffer. - int16_t *input_buffer_{nullptr}; - // Stores audio to be fed into the audio frontend for generating features. - int16_t *preprocessor_audio_buffer_{nullptr}; + static void inference_task(void *params); + TaskHandle_t inference_task_handle_{nullptr}; - bool detected_{false}; - std::string detected_wake_word_{""}; + /// @brief Suspends the inference task + void suspend_task_(); + /// @brief Resumes the inference task + void resume_task_(); void set_state_(State state); - /// @brief Tests if there are enough samples in the ring buffer to generate new features. - /// @return True if enough samples, false otherwise. - bool has_enough_samples_(); + /// @brief Generates spectrogram features from an input buffer of audio samples + /// @param audio_buffer (int16_t *) Buffer containing input audio samples + /// @param samples_available (size_t) Number of samples avaiable in the input buffer + /// @param features_buffer (int8_t *) Buffer to store generated features + /// @return (size_t) Number of samples processed from the input buffer + size_t generate_features_(int16_t *audio_buffer, size_t samples_available, + int8_t features_buffer[PREPROCESSOR_FEATURE_SIZE]); - /** Reads audio from microphone into the ring buffer - * - * Audio data (16000 kHz with int16 samples) is read into the input_buffer_. - * Verifies the ring buffer has enough space for all audio data. If not, it logs - * a warning and resets the ring buffer entirely. - * @return Number of bytes written to the ring buffer - */ - size_t read_microphone_(); + /// @brief Processes any new probabilities for each model. If any wake word is detected, it will send a DetectionEvent + /// to the detection_queue_. + void process_probabilities_(); - /// @brief Allocates memory for input_buffer_, preprocessor_audio_buffer_, and ring_buffer_ - /// @return True if successful, false otherwise - bool allocate_buffers_(); - - /// @brief Frees memory allocated for input_buffer_ and preprocessor_audio_buffer_ - void deallocate_buffers_(); - - /// @brief Loads streaming models and prepares the feature generation frontend - /// @return True if successful, false otherwise - bool load_models_(); - - /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. Frees memory used by the feature - /// generation frontend. + /// @brief Deletes each model's TFLite interpreters and frees tensor arena memory. void unload_models_(); - /** Performs inference with each configured model - * - * If enough audio samples are available, it will generate one slice of new features. - * It then loops through and performs inference with each of the loaded models. - */ - void update_model_probabilities_(); - - /** Checks every model's recent probabilities to determine if the wake word has been predicted - * - * Verifies the models have processed enough new samples for accurate predictions. - * Sets detected_wake_word_ to the wake word, if one is detected. - * @return True if a wake word is predicted, false otherwise - */ - bool detect_wake_words_(); - - /** Generates features for a window of audio samples - * - * Reads samples from the ring buffer and feeds them into the preprocessor frontend. - * Adapted from TFLite microspeech frontend. - * @param features int8_t array to store the audio features - * @return True if successful, false otherwise. - */ - bool generate_features_for_window_(int8_t features[PREPROCESSOR_FEATURE_SIZE]); - - /// @brief Resets the ring buffer, ignore_windows_, and sliding window probabilities - void reset_states_(); - - /// @brief Returns true if successfully registered the streaming model's TensorFlow operations - bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); - - inline uint16_t new_samples_to_get_() { return (this->features_step_size_ * (AUDIO_SAMPLE_FREQUENCY / 1000)); } -}; - -template class StartAction : public Action, public Parented { - public: - void play(Ts... x) override { this->parent_->start(); } -}; - -template class StopAction : public Action, public Parented { - public: - void play(Ts... x) override { this->parent_->stop(); } -}; - -template class IsRunningCondition : public Condition, public Parented { - public: - bool check(Ts... x) override { return this->parent_->is_running(); } + /// @brief Runs an inference with each model using the new spectrogram features + /// @param audio_features (int8_t *) Buffer containing new spectrogram features + /// @return True if successful, false if any errors were encountered + bool update_model_probabilities_(const int8_t audio_features[PREPROCESSOR_FEATURE_SIZE]); }; } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/preprocessor_settings.h b/esphome/components/micro_wake_word/preprocessor_settings.h index 03f4fb5230..3de21de92e 100644 --- a/esphome/components/micro_wake_word/preprocessor_settings.h +++ b/esphome/components/micro_wake_word/preprocessor_settings.h @@ -7,13 +7,30 @@ namespace esphome { namespace micro_wake_word { +// Settings for controlling the spectrogram feature generation by the preprocessor. +// These must match the settings used when training a particular model. +// All microWakeWord models have been trained with these specific paramters. + // The number of features the audio preprocessor generates per slice static const uint8_t PREPROCESSOR_FEATURE_SIZE = 40; // Duration of each slice used as input into the preprocessor static const uint8_t FEATURE_DURATION_MS = 30; -// Audio sample frequency in hertz -static const uint16_t AUDIO_SAMPLE_FREQUENCY = 16000; +static const float FILTERBANK_LOWER_BAND_LIMIT = 125.0; +static const float FILTERBANK_UPPER_BAND_LIMIT = 7500.0; + +static const uint8_t NOISE_REDUCTION_SMOOTHING_BITS = 10; +static const float NOISE_REDUCTION_EVEN_SMOOTHING = 0.025; +static const float NOISE_REDUCTION_ODD_SMOOTHING = 0.06; +static const float NOISE_REDUCTION_MIN_SIGNAL_REMAINING = 0.05; + +static const bool PCAN_GAIN_CONTROL_ENABLE_PCAN = true; +static const float PCAN_GAIN_CONTROL_STRENGTH = 0.95; +static const float PCAN_GAIN_CONTROL_OFFSET = 80.0; +static const uint8_t PCAN_GAIN_CONTROL_GAIN_BITS = 21; + +static const bool LOG_SCALE_ENABLE_LOG = true; +static const uint8_t LOG_SCALE_SCALE_SHIFT = 6; } // namespace micro_wake_word } // namespace esphome diff --git a/esphome/components/micro_wake_word/streaming_model.cpp b/esphome/components/micro_wake_word/streaming_model.cpp index d0d2e2df05..ce3d8c2e4c 100644 --- a/esphome/components/micro_wake_word/streaming_model.cpp +++ b/esphome/components/micro_wake_word/streaming_model.cpp @@ -1,8 +1,7 @@ -#ifdef USE_ESP_IDF - #include "streaming_model.h" -#include "esphome/core/hal.h" +#ifdef USE_ESP_IDF + #include "esphome/core/helpers.h" #include "esphome/core/log.h" @@ -13,18 +12,18 @@ namespace micro_wake_word { void WakeWordModel::log_model_config() { ESP_LOGCONFIG(TAG, " - Wake Word: %s", this->wake_word_.c_str()); - ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_); + ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_); } void VADModel::log_model_config() { ESP_LOGCONFIG(TAG, " - VAD Model"); - ESP_LOGCONFIG(TAG, " Probability cutoff: %.3f", this->probability_cutoff_); + ESP_LOGCONFIG(TAG, " Probability cutoff: %.2f", this->probability_cutoff_ / 255.0f); ESP_LOGCONFIG(TAG, " Sliding window size: %d", this->sliding_window_size_); } -bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) { - ExternalRAMAllocator arena_allocator(ExternalRAMAllocator::ALLOW_FAILURE); +bool StreamingModel::load_model_() { + RAMAllocator arena_allocator(RAMAllocator::ALLOW_FAILURE); if (this->tensor_arena_ == nullptr) { this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); @@ -51,8 +50,9 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) } if (this->interpreter_ == nullptr) { - this->interpreter_ = make_unique( - tflite::GetModel(this->model_start_), op_resolver, this->tensor_arena_, this->tensor_arena_size_, this->mrv_); + this->interpreter_ = + make_unique(tflite::GetModel(this->model_start_), this->streaming_op_resolver_, + this->tensor_arena_, this->tensor_arena_size_, this->mrv_); if (this->interpreter_->AllocateTensors() != kTfLiteOk) { ESP_LOGE(TAG, "Failed to allocate tensors for the streaming model"); return false; @@ -84,34 +84,55 @@ bool StreamingModel::load_model(tflite::MicroMutableOpResolver<20> &op_resolver) } } + this->loaded_ = true; + this->reset_probabilities(); return true; } void StreamingModel::unload_model() { this->interpreter_.reset(); - ExternalRAMAllocator arena_allocator(ExternalRAMAllocator::ALLOW_FAILURE); + RAMAllocator arena_allocator(RAMAllocator::ALLOW_FAILURE); - arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); - this->tensor_arena_ = nullptr; - arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); - this->var_arena_ = nullptr; + if (this->tensor_arena_ != nullptr) { + arena_allocator.deallocate(this->tensor_arena_, this->tensor_arena_size_); + this->tensor_arena_ = nullptr; + } + + if (this->var_arena_ != nullptr) { + arena_allocator.deallocate(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); + this->var_arena_ = nullptr; + } + + this->loaded_ = false; } bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]) { - if (this->interpreter_ != nullptr) { + if (this->enabled_ && !this->loaded_) { + // Model is enabled but isn't loaded + if (!this->load_model_()) { + return false; + } + } + + if (!this->enabled_ && this->loaded_) { + // Model is disabled but still loaded + this->unload_model(); + return true; + } + + if (this->loaded_) { TfLiteTensor *input = this->interpreter_->input(0); + uint8_t stride = this->interpreter_->input(0)->dims->data[1]; + this->current_stride_step_ = this->current_stride_step_ % stride; + std::memmove( (int8_t *) (tflite::GetTensorData(input)) + PREPROCESSOR_FEATURE_SIZE * this->current_stride_step_, features, PREPROCESSOR_FEATURE_SIZE); ++this->current_stride_step_; - uint8_t stride = this->interpreter_->input(0)->dims->data[1]; - if (this->current_stride_step_ >= stride) { - this->current_stride_step_ = 0; - TfLiteStatus invoke_status = this->interpreter_->Invoke(); if (invoke_status != kTfLiteOk) { ESP_LOGW(TAG, "Streaming interpreter invoke failed"); @@ -124,65 +145,161 @@ bool StreamingModel::perform_streaming_inference(const int8_t features[PREPROCES if (this->last_n_index_ == this->sliding_window_size_) this->last_n_index_ = 0; this->recent_streaming_probabilities_[this->last_n_index_] = output->data.uint8[0]; // probability; + this->unprocessed_probability_status_ = true; } - return true; + this->ignore_windows_ = std::min(this->ignore_windows_ + 1, 0); } - ESP_LOGE(TAG, "Streaming interpreter is not initialized."); - return false; + return true; } void StreamingModel::reset_probabilities() { for (auto &prob : this->recent_streaming_probabilities_) { prob = 0; } + this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION; } -WakeWordModel::WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size) { +WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, + size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, + bool default_enabled, bool internal_only) { + this->id_ = id; this->model_start_ = model_start; - this->probability_cutoff_ = probability_cutoff; + this->default_probability_cutoff_ = default_probability_cutoff; + this->probability_cutoff_ = default_probability_cutoff; this->sliding_window_size_ = sliding_window_average_size; this->recent_streaming_probabilities_.resize(sliding_window_average_size, 0); this->wake_word_ = wake_word; this->tensor_arena_size_ = tensor_arena_size; + this->register_streaming_ops_(this->streaming_op_resolver_); + this->current_stride_step_ = 0; + this->internal_only_ = internal_only; + + this->pref_ = global_preferences->make_preference(fnv1_hash(id)); + bool enabled; + if (this->pref_.load(&enabled)) { + // Use the enabled state loaded from flash + this->enabled_ = enabled; + } else { + // If no state saved, then use the default + this->enabled_ = default_enabled; + } }; -bool WakeWordModel::determine_detected() { +void WakeWordModel::enable() { + this->enabled_ = true; + if (!this->internal_only_) { + this->pref_.save(&this->enabled_); + } +} + +void WakeWordModel::disable() { + this->enabled_ = false; + if (!this->internal_only_) { + this->pref_.save(&this->enabled_); + } +} + +DetectionEvent WakeWordModel::determine_detected() { + DetectionEvent detection_event; + detection_event.wake_word = &this->wake_word_; + detection_event.max_probability = 0; + detection_event.average_probability = 0; + + if ((this->ignore_windows_ < 0) || !this->enabled_) { + detection_event.detected = false; + return detection_event; + } + uint32_t sum = 0; for (auto &prob : this->recent_streaming_probabilities_) { + detection_event.max_probability = std::max(detection_event.max_probability, prob); sum += prob; } - float sliding_window_average = static_cast(sum) / static_cast(255 * this->sliding_window_size_); + detection_event.average_probability = sum / this->sliding_window_size_; + detection_event.detected = sum > this->probability_cutoff_ * this->sliding_window_size_; - // Detect the wake word if the sliding window average is above the cutoff - if (sliding_window_average > this->probability_cutoff_) { - ESP_LOGD(TAG, "The '%s' model sliding average probability is %.3f and most recent probability is %.3f", - this->wake_word_.c_str(), sliding_window_average, - this->recent_streaming_probabilities_[this->last_n_index_] / (255.0)); - return true; - } - return false; + this->unprocessed_probability_status_ = false; + return detection_event; } -VADModel::VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, +VADModel::VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size) { this->model_start_ = model_start; - this->probability_cutoff_ = probability_cutoff; + this->default_probability_cutoff_ = default_probability_cutoff; + this->probability_cutoff_ = default_probability_cutoff; this->sliding_window_size_ = sliding_window_size; this->recent_streaming_probabilities_.resize(sliding_window_size, 0); this->tensor_arena_size_ = tensor_arena_size; -}; + this->register_streaming_ops_(this->streaming_op_resolver_); +} + +DetectionEvent VADModel::determine_detected() { + DetectionEvent detection_event; + detection_event.max_probability = 0; + detection_event.average_probability = 0; + + if (!this->enabled_) { + // We disabled the VAD model for some reason... so we shouldn't block wake words from being detected + detection_event.detected = true; + return detection_event; + } -bool VADModel::determine_detected() { uint32_t sum = 0; for (auto &prob : this->recent_streaming_probabilities_) { + detection_event.max_probability = std::max(detection_event.max_probability, prob); sum += prob; } - float sliding_window_average = static_cast(sum) / static_cast(255 * this->sliding_window_size_); + detection_event.average_probability = sum / this->sliding_window_size_; + detection_event.detected = sum > (this->probability_cutoff_ * this->sliding_window_size_); - return sliding_window_average > this->probability_cutoff_; + return detection_event; +} + +bool StreamingModel::register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver) { + if (op_resolver.AddCallOnce() != kTfLiteOk) + return false; + if (op_resolver.AddVarHandle() != kTfLiteOk) + return false; + if (op_resolver.AddReshape() != kTfLiteOk) + return false; + if (op_resolver.AddReadVariable() != kTfLiteOk) + return false; + if (op_resolver.AddStridedSlice() != kTfLiteOk) + return false; + if (op_resolver.AddConcatenation() != kTfLiteOk) + return false; + if (op_resolver.AddAssignVariable() != kTfLiteOk) + return false; + if (op_resolver.AddConv2D() != kTfLiteOk) + return false; + if (op_resolver.AddMul() != kTfLiteOk) + return false; + if (op_resolver.AddAdd() != kTfLiteOk) + return false; + if (op_resolver.AddMean() != kTfLiteOk) + return false; + if (op_resolver.AddFullyConnected() != kTfLiteOk) + return false; + if (op_resolver.AddLogistic() != kTfLiteOk) + return false; + if (op_resolver.AddQuantize() != kTfLiteOk) + return false; + if (op_resolver.AddDepthwiseConv2D() != kTfLiteOk) + return false; + if (op_resolver.AddAveragePool2D() != kTfLiteOk) + return false; + if (op_resolver.AddMaxPool2D() != kTfLiteOk) + return false; + if (op_resolver.AddPad() != kTfLiteOk) + return false; + if (op_resolver.AddPack() != kTfLiteOk) + return false; + if (op_resolver.AddSplitV() != kTfLiteOk) + return false; + + return true; } } // namespace micro_wake_word diff --git a/esphome/components/micro_wake_word/streaming_model.h b/esphome/components/micro_wake_word/streaming_model.h index 0d85579f35..b7b22b9700 100644 --- a/esphome/components/micro_wake_word/streaming_model.h +++ b/esphome/components/micro_wake_word/streaming_model.h @@ -4,6 +4,8 @@ #include "preprocessor_settings.h" +#include "esphome/core/preferences.h" + #include #include #include @@ -11,31 +13,71 @@ namespace esphome { namespace micro_wake_word { +static const uint8_t MIN_SLICES_BEFORE_DETECTION = 100; static const uint32_t STREAMING_MODEL_VARIABLE_ARENA_SIZE = 1024; +struct DetectionEvent { + std::string *wake_word; + bool detected; + bool partially_detection; // Set if the most recent probability exceed the threshold, but the sliding window average + // hasn't yet + uint8_t max_probability; + uint8_t average_probability; + bool blocked_by_vad = false; +}; + class StreamingModel { public: virtual void log_model_config() = 0; - virtual bool determine_detected() = 0; + virtual DetectionEvent determine_detected() = 0; + // Performs inference on the given features. + // - If the model is enabled but not loaded, it will load it + // - If the model is disabled but loaded, it will unload it + // Returns true if sucessful or false if there is an error bool perform_streaming_inference(const int8_t features[PREPROCESSOR_FEATURE_SIZE]); - /// @brief Sets all recent_streaming_probabilities to 0 + /// @brief Sets all recent_streaming_probabilities to 0 and resets the ignore window count void reset_probabilities(); - /// @brief Allocates tensor and variable arenas and sets up the model interpreter - /// @param op_resolver MicroMutableOpResolver object that must exist until the model is unloaded - /// @return True if successful, false otherwise - bool load_model(tflite::MicroMutableOpResolver<20> &op_resolver); - /// @brief Destroys the TFLite interpreter and frees the tensor and variable arenas' memory void unload_model(); - protected: - uint8_t current_stride_step_{0}; + /// @brief Enable the model. The next performing_streaming_inference call will load it. + virtual void enable() { this->enabled_ = true; } - float probability_cutoff_; + /// @brief Disable the model. The next performing_streaming_inference call will unload it. + virtual void disable() { this->enabled_ = false; } + + /// @brief Return true if the model is enabled. + bool is_enabled() const { return this->enabled_; } + + bool get_unprocessed_probability_status() const { return this->unprocessed_probability_status_; } + + // Quantized probability cutoffs mapping 0.0 - 1.0 to 0 - 255 + uint8_t get_default_probability_cutoff() const { return this->default_probability_cutoff_; } + uint8_t get_probability_cutoff() const { return this->probability_cutoff_; } + void set_probability_cutoff(uint8_t probability_cutoff) { this->probability_cutoff_ = probability_cutoff; } + + protected: + /// @brief Allocates tensor and variable arenas and sets up the model interpreter + /// @return True if successful, false otherwise + bool load_model_(); + /// @brief Returns true if successfully registered the streaming model's TensorFlow operations + bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); + + tflite::MicroMutableOpResolver<20> streaming_op_resolver_; + + bool loaded_{false}; + bool enabled_{true}; + bool unprocessed_probability_status_{false}; + uint8_t current_stride_step_{0}; + int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION}; + + uint8_t default_probability_cutoff_; + uint8_t probability_cutoff_; size_t sliding_window_size_; + size_t last_n_index_{0}; size_t tensor_arena_size_; std::vector recent_streaming_probabilities_; @@ -50,32 +92,62 @@ class StreamingModel { class WakeWordModel final : public StreamingModel { public: - WakeWordModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_average_size, - const std::string &wake_word, size_t tensor_arena_size); + /// @brief Constructs a wake word model object + /// @param id (std::string) identifier for this model + /// @param model_start (const uint8_t *) pointer to the start of the model's TFLite FlatBuffer + /// @param default_probability_cutoff (uint8_t) probability cutoff for acceping the wake word has been said + /// @param sliding_window_average_size (size_t) the length of the sliding window computing the mean rolling + /// probability + /// @param wake_word (std::string) Friendly name of the wake word + /// @param tensor_arena_size (size_t) Size in bytes for allocating the tensor arena + /// @param default_enabled (bool) If true, it will be enabled by default on first boot + /// @param internal_only (bool) If true, the model will not be exposed to HomeAssistant as an available model + WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t default_probability_cutoff, + size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size, + bool default_enabled, bool internal_only); void log_model_config() override; /// @brief Checks for the wake word by comparing the mean probability in the sliding window with the probability /// cutoff /// @return True if wake word is detected, false otherwise - bool determine_detected() override; + DetectionEvent determine_detected() override; + const std::string &get_id() const { return this->id_; } const std::string &get_wake_word() const { return this->wake_word_; } + void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); } + const std::vector &get_trained_languages() const { return this->trained_languages_; } + + /// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it. + void enable() override; + + /// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it. + void disable() override; + + bool get_internal_only() { return this->internal_only_; } + protected: + std::string id_; std::string wake_word_; + std::vector trained_languages_; + + bool internal_only_; + + ESPPreferenceObject pref_; }; class VADModel final : public StreamingModel { public: - VADModel(const uint8_t *model_start, float probability_cutoff, size_t sliding_window_size, size_t tensor_arena_size); + VADModel(const uint8_t *model_start, uint8_t default_probability_cutoff, size_t sliding_window_size, + size_t tensor_arena_size); void log_model_config() override; /// @brief Checks for voice activity by comparing the max probability in the sliding window with the probability /// cutoff /// @return True if voice activity is detected, false otherwise - bool determine_detected() override; + DetectionEvent determine_detected() override; }; } // namespace micro_wake_word diff --git a/esphome/components/microphone/__init__.py b/esphome/components/microphone/__init__.py index 4e5471b117..29bdcfa3f3 100644 --- a/esphome/components/microphone/__init__.py +++ b/esphome/components/microphone/__init__.py @@ -1,12 +1,21 @@ from esphome import automation from esphome.automation import maybe_simple_id import esphome.codegen as cg +from esphome.components import audio import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_TRIGGER_ID +from esphome.const import ( + CONF_BITS_PER_SAMPLE, + CONF_CHANNELS, + CONF_GAIN_FACTOR, + CONF_ID, + CONF_MICROPHONE, + CONF_TRIGGER_ID, +) from esphome.core import CORE from esphome.coroutine import coroutine_with_priority -CODEOWNERS = ["@jesserockz"] +AUTO_LOAD = ["audio"] +CODEOWNERS = ["@jesserockz", "@kahrendt"] IS_PLATFORM_COMPONENT = True @@ -15,6 +24,7 @@ CONF_ON_DATA = "on_data" microphone_ns = cg.esphome_ns.namespace("microphone") Microphone = microphone_ns.class_("Microphone") +MicrophoneSource = microphone_ns.class_("MicrophoneSource") CaptureAction = microphone_ns.class_( "CaptureAction", automation.Action, cg.Parented.template(Microphone) @@ -22,16 +32,23 @@ CaptureAction = microphone_ns.class_( StopCaptureAction = microphone_ns.class_( "StopCaptureAction", automation.Action, cg.Parented.template(Microphone) ) +MuteAction = microphone_ns.class_( + "MuteAction", automation.Action, cg.Parented.template(Microphone) +) +UnmuteAction = microphone_ns.class_( + "UnmuteAction", automation.Action, cg.Parented.template(Microphone) +) DataTrigger = microphone_ns.class_( "DataTrigger", - automation.Trigger.template(cg.std_vector.template(cg.int16).operator("ref")), + automation.Trigger.template(cg.std_vector.template(cg.uint8).operator("ref")), ) IsCapturingCondition = microphone_ns.class_( "IsCapturingCondition", automation.Condition ) +IsMutedCondition = microphone_ns.class_("IsMutedCondition", automation.Condition) async def setup_microphone_core_(var, config): @@ -39,7 +56,7 @@ async def setup_microphone_core_(var, config): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var) await automation.build_automation( trigger, - [(cg.std_vector.template(cg.int16).operator("ref").operator("const"), "x")], + [(cg.std_vector.template(cg.uint8).operator("ref").operator("const"), "x")], conf, ) @@ -50,7 +67,7 @@ async def register_microphone(var, config): await setup_microphone_core_(var, config) -MICROPHONE_SCHEMA = cv.Schema( +MICROPHONE_SCHEMA = cv.Schema.extend(audio.AUDIO_COMPONENT_SCHEMA).extend( { cv.Optional(CONF_ON_DATA): automation.validate_automation( { @@ -64,7 +81,110 @@ MICROPHONE_SCHEMA = cv.Schema( MICROPHONE_ACTION_SCHEMA = maybe_simple_id({cv.GenerateID(): cv.use_id(Microphone)}) -async def media_player_action(config, action_id, template_arg, args): +def microphone_source_schema( + min_bits_per_sample: int = 16, + max_bits_per_sample: int = 16, + min_channels: int = 1, + max_channels: int = 1, +): + """Schema for a microphone source + + Components requesting microphone data should use this schema instead of accessing a microphone directly. + + Args: + min_bits_per_sample (int, optional): Minimum number of bits per sample the requesting component supports. Defaults to 16. + max_bits_per_sample (int, optional): Maximum number of bits per sample the requesting component supports. Defaults to 16. + min_channels (int, optional): Minimum number of channels the requesting component supports. Defaults to 1. + max_channels (int, optional): Maximum number of channels the requesting component supports. Defaults to 1. + """ + + def _validate_unique_channels(config): + if len(config) != len(set(config)): + raise cv.Invalid("Channels must be unique") + return config + + return cv.All( + automation.maybe_conf( + CONF_MICROPHONE, + { + cv.GenerateID(CONF_ID): cv.declare_id(MicrophoneSource), + cv.GenerateID(CONF_MICROPHONE): cv.use_id(Microphone), + cv.Optional(CONF_BITS_PER_SAMPLE, default=16): cv.int_range( + min_bits_per_sample, max_bits_per_sample + ), + cv.Optional(CONF_CHANNELS, default="0"): cv.All( + cv.ensure_list(cv.int_range(0, 7)), + cv.Length(min=min_channels, max=max_channels), + _validate_unique_channels, + ), + cv.Optional(CONF_GAIN_FACTOR, default="1"): cv.int_range(1, 64), + }, + ), + ) + + +def final_validate_microphone_source_schema( + component_name: str, sample_rate: int = cv.UNDEFINED +): + """Validates that the microphone source can provide audio in the correct format. In particular it validates the sample rate and the enabled channels. + + Note that: + - MicrophoneSource class automatically handles converting bits per sample, so no need to validate + - microphone_source_schema already validates that channels are unique and specifies the max number of channels the component supports + + Args: + component_name (str): The name of the component requesting mic audio + sample_rate (int, optional): The sample rate the component requesting mic audio requires + """ + + def _validate_audio_compatability(config): + if sample_rate is not cv.UNDEFINED: + # Issues require changing the microphone configuration + # - Verifies sample rates match + audio.final_validate_audio_schema( + component_name, + audio_device=CONF_MICROPHONE, + sample_rate=sample_rate, + audio_device_issue=True, + )(config) + + # Issues require changing the MicrophoneSource configuration + # - Verifies that each of the enabled channels are available + audio.final_validate_audio_schema( + component_name, + audio_device=CONF_MICROPHONE, + enabled_channels=config[CONF_CHANNELS], + audio_device_issue=False, + )(config) + + return config + + return _validate_audio_compatability + + +async def microphone_source_to_code(config, passive=False): + """Creates a MicrophoneSource variable for codegen. + + Setting passive to true makes the MicrophoneSource never start/stop the microphone, but only receives audio when another component has actively started the Microphone. If false, then the microphone needs to be explicitly started/stopped. + + Args: + config (Schema): Created with `microphone_source_schema` specifying bits per sample, channels, and gain factor + passive (bool): Enable passive mode for the MicrophoneSource + """ + mic = await cg.get_variable(config[CONF_MICROPHONE]) + mic_source = cg.new_Pvariable( + config[CONF_ID], + mic, + config[CONF_BITS_PER_SAMPLE], + config[CONF_GAIN_FACTOR], + passive, + ) + for channel in config[CONF_CHANNELS]: + cg.add(mic_source.add_channel(channel)) + return mic_source + + +async def microphone_action(config, action_id, template_arg, args): var = cg.new_Pvariable(action_id, template_arg) await cg.register_parented(var, config[CONF_ID]) return var @@ -72,15 +192,25 @@ async def media_player_action(config, action_id, template_arg, args): automation.register_action( "microphone.capture", CaptureAction, MICROPHONE_ACTION_SCHEMA -)(media_player_action) +)(microphone_action) automation.register_action( "microphone.stop_capture", StopCaptureAction, MICROPHONE_ACTION_SCHEMA -)(media_player_action) +)(microphone_action) + +automation.register_action("microphone.mute", MuteAction, MICROPHONE_ACTION_SCHEMA)( + microphone_action +) +automation.register_action("microphone.unmute", UnmuteAction, MICROPHONE_ACTION_SCHEMA)( + microphone_action +) automation.register_condition( "microphone.is_capturing", IsCapturingCondition, MICROPHONE_ACTION_SCHEMA -)(media_player_action) +)(microphone_action) +automation.register_condition( + "microphone.is_muted", IsMutedCondition, MICROPHONE_ACTION_SCHEMA +)(microphone_action) @coroutine_with_priority(100.0) diff --git a/esphome/components/microphone/automation.h b/esphome/components/microphone/automation.h index 29c0ec5df2..5745909c46 100644 --- a/esphome/components/microphone/automation.h +++ b/esphome/components/microphone/automation.h @@ -16,10 +16,17 @@ template class StopCaptureAction : public Action, public void play(Ts... x) override { this->parent_->stop(); } }; -class DataTrigger : public Trigger &> { +template class MuteAction : public Action, public Parented { + void play(Ts... x) override { this->parent_->set_mute_state(true); } +}; +template class UnmuteAction : public Action, public Parented { + void play(Ts... x) override { this->parent_->set_mute_state(false); } +}; + +class DataTrigger : public Trigger &> { public: explicit DataTrigger(Microphone *mic) { - mic->add_data_callback([this](const std::vector &data) { this->trigger(data); }); + mic->add_data_callback([this](const std::vector &data) { this->trigger(data); }); } }; @@ -28,5 +35,10 @@ template class IsCapturingCondition : public Condition, p bool check(Ts... x) override { return this->parent_->is_running(); } }; +template class IsMutedCondition : public Condition, public Parented { + public: + bool check(Ts... x) override { return this->parent_->get_mute_state(); } +}; + } // namespace microphone } // namespace esphome diff --git a/esphome/components/microphone/microphone.cpp b/esphome/components/microphone/microphone.cpp new file mode 100644 index 0000000000..b1289f3791 --- /dev/null +++ b/esphome/components/microphone/microphone.cpp @@ -0,0 +1,21 @@ +#include "microphone.h" + +namespace esphome { +namespace microphone { + +void Microphone::add_data_callback(std::function &)> &&data_callback) { + std::function &)> mute_handled_callback = + [this, data_callback](const std::vector &data) { data_callback(this->silence_audio_(data)); }; + this->data_callbacks_.add(std::move(mute_handled_callback)); +} + +std::vector Microphone::silence_audio_(std::vector data) { + if (this->mute_state_) { + std::memset((void *) data.data(), 0, data.size()); + } + + return data; +} + +} // namespace microphone +} // namespace esphome diff --git a/esphome/components/microphone/microphone.h b/esphome/components/microphone/microphone.h index 914ad80bea..ea4e979e20 100644 --- a/esphome/components/microphone/microphone.h +++ b/esphome/components/microphone/microphone.h @@ -1,5 +1,7 @@ #pragma once +#include "esphome/components/audio/audio.h" + #include #include #include @@ -20,18 +22,25 @@ class Microphone { public: virtual void start() = 0; virtual void stop() = 0; - void add_data_callback(std::function &)> &&data_callback) { - this->data_callbacks_.add(std::move(data_callback)); - } - virtual size_t read(int16_t *buf, size_t len) = 0; + void add_data_callback(std::function &)> &&data_callback); bool is_running() const { return this->state_ == STATE_RUNNING; } bool is_stopped() const { return this->state_ == STATE_STOPPED; } - protected: - State state_{STATE_STOPPED}; + void set_mute_state(bool is_muted) { this->mute_state_ = is_muted; } + bool get_mute_state() { return this->mute_state_; } - CallbackManager &)> data_callbacks_{}; + audio::AudioStreamInfo get_audio_stream_info() { return this->audio_stream_info_; } + + protected: + std::vector silence_audio_(std::vector data); + + State state_{STATE_STOPPED}; + bool mute_state_{false}; + + audio::AudioStreamInfo audio_stream_info_; + + CallbackManager &)> data_callbacks_{}; }; } // namespace microphone diff --git a/esphome/components/microphone/microphone_source.cpp b/esphome/components/microphone/microphone_source.cpp new file mode 100644 index 0000000000..00efcf22a1 --- /dev/null +++ b/esphome/components/microphone/microphone_source.cpp @@ -0,0 +1,95 @@ +#include "microphone_source.h" + +namespace esphome { +namespace microphone { + +static const int32_t Q25_MAX_VALUE = (1 << 25) - 1; +static const int32_t Q25_MIN_VALUE = ~Q25_MAX_VALUE; + +void MicrophoneSource::add_data_callback(std::function &)> &&data_callback) { + std::function &)> filtered_callback = + [this, data_callback](const std::vector &data) { + if (this->enabled_ || this->passive_) { + if (this->processed_samples_.use_count() == 0) { + // Create vector if its unused + this->processed_samples_ = std::make_shared>(); + } + + // Take temporary ownership of samples vector to avoid deallaction before the callback finishes + std::shared_ptr> output_samples = this->processed_samples_; + this->process_audio_(data, *output_samples); + data_callback(*output_samples); + } + }; + this->mic_->add_data_callback(std::move(filtered_callback)); +} + +audio::AudioStreamInfo MicrophoneSource::get_audio_stream_info() { + return audio::AudioStreamInfo(this->bits_per_sample_, this->channels_.count(), + this->mic_->get_audio_stream_info().get_sample_rate()); +} + +void MicrophoneSource::start() { + if (!this->enabled_ && !this->passive_) { + this->enabled_ = true; + this->mic_->start(); + } +} + +void MicrophoneSource::stop() { + if (this->enabled_ && !this->passive_) { + this->enabled_ = false; + this->mic_->stop(); + this->processed_samples_.reset(); + } +} + +void MicrophoneSource::process_audio_(const std::vector &data, std::vector &filtered_data) { + // - Bit depth conversions are obtained by truncating bits or padding with zeros - no dithering is applied. + // - In the comments, Qxx refers to a fixed point number with xx bits of precision for representing fractional values. + // For example, audio with a bit depth of 16 can store a sample in a int16, which can be considered a Q15 number. + // - All samples are converted to Q25 before applying the gain factor - this results in a small precision loss for + // data with 32 bits per sample. Since the maximum gain factor is 64 = (1<<6), this ensures that applying the gain + // will never overflow a 32 bit signed integer. This still retains more bit depth than what is audibly noticeable. + // - Loops for reading/writing data buffers are unrolled, assuming little endian, for a small performance increase. + + const size_t source_bytes_per_sample = this->mic_->get_audio_stream_info().samples_to_bytes(1); + const uint32_t source_channels = this->mic_->get_audio_stream_info().get_channels(); + + const size_t source_bytes_per_frame = this->mic_->get_audio_stream_info().frames_to_bytes(1); + + const uint32_t total_frames = this->mic_->get_audio_stream_info().bytes_to_frames(data.size()); + const size_t target_bytes_per_sample = (this->bits_per_sample_ + 7) / 8; + const size_t target_bytes_per_frame = target_bytes_per_sample * this->channels_.count(); + + filtered_data.resize(target_bytes_per_frame * total_frames); + + uint8_t *current_data = filtered_data.data(); + + for (uint32_t frame_index = 0; frame_index < total_frames; ++frame_index) { + for (uint32_t channel_index = 0; channel_index < source_channels; ++channel_index) { + if (this->channels_.test(channel_index)) { + // Channel's current sample is included in the target mask. Convert bits per sample, if necessary. + + const uint32_t sample_index = frame_index * source_bytes_per_frame + channel_index * source_bytes_per_sample; + + int32_t sample = audio::unpack_audio_sample_to_q31(&data[sample_index], source_bytes_per_sample); // Q31 + sample >>= 6; // Q31 -> Q25 + + // Apply gain using multiplication + sample *= this->gain_factor_; // Q25 + + // Clamp ``sample`` in case gain multiplication overflows 25 bits + sample = clamp(sample, Q25_MIN_VALUE, Q25_MAX_VALUE); // Q25 + + sample *= (1 << 6); // Q25 -> Q31 + + audio::pack_q31_as_audio_sample(sample, current_data, target_bytes_per_sample); + current_data = current_data + target_bytes_per_sample; + } + } + } +} + +} // namespace microphone +} // namespace esphome diff --git a/esphome/components/microphone/microphone_source.h b/esphome/components/microphone/microphone_source.h new file mode 100644 index 0000000000..1e81a284b6 --- /dev/null +++ b/esphome/components/microphone/microphone_source.h @@ -0,0 +1,80 @@ +#pragma once + +#include "microphone.h" + +#include "esphome/components/audio/audio.h" + +#include +#include +#include +#include +#include + +namespace esphome { +namespace microphone { + +static const int32_t MAX_GAIN_FACTOR = 64; + +class MicrophoneSource { + /* + * @brief Helper class that handles converting raw microphone data to a requested format. + * Components requesting microphone audio should register a callback through this class instead of registering a + * callback directly with the microphone if a particular format is required. + * + * Raw microphone data may have a different number of bits per sample and number of channels than the requesting + * component needs. This class handles the conversion by: + * - Internally adds a callback to receive the raw microphone data + * - The ``process_audio_`` handles the raw data + * - Only the channels set in the ``channels_`` bitset are passed through + * - Passed through samples have the bits per sample converted + * - A gain factor is optionally applied to increase the volume - audio may clip! + * - The processed audio is passed to the callback of the component requesting microphone data + * - It tracks an internal enabled state, so it ignores raw microphone data when the component requesting + * microphone data is not actively requesting audio. + * + * Note that this class cannot convert sample rates! + */ + public: + MicrophoneSource(Microphone *mic, uint8_t bits_per_sample, int32_t gain_factor, bool passive) + : mic_(mic), bits_per_sample_(bits_per_sample), gain_factor_(gain_factor), passive_(passive) {} + + /// @brief Enables a channel to be processed through the callback. + /// + /// If the microphone component only has reads from one channel, it is always in channel number 0, regardless if it + /// represents left or right. If the microphone reads from both left and right, channel number 0 and 1 represent the + /// left and right channels respectively. + /// + /// @param channel 0-indexed channel number to enable + void add_channel(uint8_t channel) { this->channels_.set(channel); } + + void add_data_callback(std::function &)> &&data_callback); + + void set_gain_factor(int32_t gain_factor) { this->gain_factor_ = clamp(gain_factor, 1, MAX_GAIN_FACTOR); } + int32_t get_gain_factor() { return this->gain_factor_; } + + /// @brief Gets the AudioStreamInfo of the data after processing + /// @return audio::AudioStreamInfo with the configured bits per sample, configured channel count, and source + /// microphone's sample rate + audio::AudioStreamInfo get_audio_stream_info(); + + void start(); + void stop(); + bool is_passive() const { return this->passive_; } + bool is_running() const { return (this->mic_->is_running() && (this->enabled_ || this->passive_)); } + bool is_stopped() const { return !this->is_running(); }; + + protected: + void process_audio_(const std::vector &data, std::vector &filtered_data); + + std::shared_ptr> processed_samples_; + + Microphone *mic_; + uint8_t bits_per_sample_; + std::bitset<8> channels_; + int32_t gain_factor_; + bool enabled_{false}; + bool passive_; // Only pass audio if ``mic_`` is already running +}; + +} // namespace microphone +} // namespace esphome diff --git a/esphome/components/mics_4514/sensor.py b/esphome/components/mics_4514/sensor.py index 59ccba235a..09329ebfcf 100644 --- a/esphome/components/mics_4514/sensor.py +++ b/esphome/components/mics_4514/sensor.py @@ -9,6 +9,8 @@ from esphome.const import ( CONF_ID, CONF_METHANE, CONF_NITROGEN_DIOXIDE, + DEVICE_CLASS_CARBON_MONOXIDE, + DEVICE_CLASS_EMPTY, STATE_CLASS_MEASUREMENT, UNIT_PARTS_PER_MILLION, ) @@ -22,24 +24,33 @@ MICS4514Component = mics_4514_ns.class_( "MICS4514Component", cg.PollingComponent, i2c.I2CDevice ) -SENSORS = [ - CONF_CARBON_MONOXIDE, - CONF_METHANE, - CONF_ETHANOL, - CONF_HYDROGEN, - CONF_AMMONIA, - CONF_NITROGEN_DIOXIDE, -] +SENSORS = { + CONF_CARBON_MONOXIDE: DEVICE_CLASS_CARBON_MONOXIDE, + CONF_METHANE: DEVICE_CLASS_EMPTY, + CONF_ETHANOL: DEVICE_CLASS_EMPTY, + CONF_HYDROGEN: DEVICE_CLASS_EMPTY, + CONF_AMMONIA: DEVICE_CLASS_EMPTY, + CONF_NITROGEN_DIOXIDE: DEVICE_CLASS_EMPTY, +} + + +def common_sensor_schema(*, device_class: str) -> cv.Schema: + return sensor.sensor_schema( + accuracy_decimals=2, + device_class=device_class, + state_class=STATE_CLASS_MEASUREMENT, + unit_of_measurement=UNIT_PARTS_PER_MILLION, + ) -common_sensor_schema = sensor.sensor_schema( - unit_of_measurement=UNIT_PARTS_PER_MILLION, - state_class=STATE_CLASS_MEASUREMENT, - accuracy_decimals=2, -) CONFIG_SCHEMA = ( cv.Schema({cv.GenerateID(): cv.declare_id(MICS4514Component)}) - .extend({cv.Optional(sensor_type): common_sensor_schema for sensor_type in SENSORS}) + .extend( + { + cv.Optional(sensor_type): common_sensor_schema(device_class=device_class) + for sensor_type, device_class in SENSORS.items() + } + ) .extend(i2c.i2c_device_schema(0x75)) .extend(cv.polling_component_schema("60s")) ) diff --git a/esphome/components/mipi_spi/__init__.py b/esphome/components/mipi_spi/__init__.py new file mode 100644 index 0000000000..46b0206a1f --- /dev/null +++ b/esphome/components/mipi_spi/__init__.py @@ -0,0 +1,15 @@ +CODEOWNERS = ["@clydebarrow"] + +DOMAIN = "mipi_spi" + +CONF_DRAW_FROM_ORIGIN = "draw_from_origin" +CONF_SPI_16 = "spi_16" +CONF_PIXEL_MODE = "pixel_mode" +CONF_COLOR_DEPTH = "color_depth" +CONF_BUS_MODE = "bus_mode" +CONF_USE_AXIS_FLIPS = "use_axis_flips" +CONF_NATIVE_WIDTH = "native_width" +CONF_NATIVE_HEIGHT = "native_height" + +MODE_RGB = "RGB" +MODE_BGR = "BGR" diff --git a/esphome/components/mipi_spi/display.py b/esphome/components/mipi_spi/display.py new file mode 100644 index 0000000000..e9ed97a2a2 --- /dev/null +++ b/esphome/components/mipi_spi/display.py @@ -0,0 +1,474 @@ +import logging + +from esphome import pins +import esphome.codegen as cg +from esphome.components import display, spi +from esphome.components.spi import TYPE_OCTAL, TYPE_QUAD, TYPE_SINGLE +import esphome.config_validation as cv +from esphome.config_validation import ALLOW_EXTRA +from esphome.const import ( + CONF_BRIGHTNESS, + CONF_COLOR_ORDER, + CONF_CS_PIN, + CONF_DATA_RATE, + CONF_DC_PIN, + CONF_DIMENSIONS, + CONF_ENABLE_PIN, + CONF_HEIGHT, + CONF_ID, + CONF_INIT_SEQUENCE, + CONF_INVERT_COLORS, + CONF_LAMBDA, + CONF_MIRROR_X, + CONF_MIRROR_Y, + CONF_MODEL, + CONF_OFFSET_HEIGHT, + CONF_OFFSET_WIDTH, + CONF_RESET_PIN, + CONF_ROTATION, + CONF_SWAP_XY, + CONF_TRANSFORM, + CONF_WIDTH, +) +from esphome.core import TimePeriod + +from ..const import CONF_DRAW_ROUNDING +from ..lvgl.defines import CONF_COLOR_DEPTH +from . import ( + CONF_BUS_MODE, + CONF_DRAW_FROM_ORIGIN, + CONF_NATIVE_HEIGHT, + CONF_NATIVE_WIDTH, + CONF_PIXEL_MODE, + CONF_SPI_16, + CONF_USE_AXIS_FLIPS, + DOMAIN, + MODE_BGR, + MODE_RGB, +) +from .models import ( + DELAY_FLAG, + MADCTL_BGR, + MADCTL_MV, + MADCTL_MX, + MADCTL_MY, + MADCTL_XFLIP, + MADCTL_YFLIP, + DriverChip, + amoled, + cyd, + ili, + jc, + lanbon, + lilygo, + waveshare, +) +from .models.commands import BRIGHTNESS, DISPON, INVOFF, INVON, MADCTL, PIXFMT, SLPOUT + +DEPENDENCIES = ["spi"] + +LOGGER = logging.getLogger(DOMAIN) +mipi_spi_ns = cg.esphome_ns.namespace("mipi_spi") +MipiSpi = mipi_spi_ns.class_( + "MipiSpi", display.Display, display.DisplayBuffer, cg.Component, spi.SPIDevice +) +ColorOrder = display.display_ns.enum("ColorMode") +ColorBitness = display.display_ns.enum("ColorBitness") +Model = mipi_spi_ns.enum("Model") + +COLOR_ORDERS = { + MODE_RGB: ColorOrder.COLOR_ORDER_RGB, + MODE_BGR: ColorOrder.COLOR_ORDER_BGR, +} + +COLOR_DEPTHS = { + 8: ColorBitness.COLOR_BITNESS_332, + 16: ColorBitness.COLOR_BITNESS_565, +} +DATA_PIN_SCHEMA = pins.internal_gpio_output_pin_schema + + +DriverChip("CUSTOM", initsequence={}) + +MODELS = DriverChip.models +# These statements are noops, but serve to suppress linting of side-effect-only imports +for _ in (ili, jc, amoled, lilygo, lanbon, cyd, waveshare): + pass + +PixelMode = mipi_spi_ns.enum("PixelMode") + +PIXEL_MODE_18BIT = "18bit" +PIXEL_MODE_16BIT = "16bit" + +PIXEL_MODES = { + PIXEL_MODE_16BIT: 0x55, + PIXEL_MODE_18BIT: 0x66, +} + + +def validate_dimension(rounding): + def validator(value): + value = cv.positive_int(value) + if value % rounding != 0: + raise cv.Invalid(f"Dimensions and offsets must be divisible by {rounding}") + return value + + return validator + + +def map_sequence(value): + """ + The format is a repeated sequence of [CMD, ] where is s a sequence of bytes. The length is inferred + from the length of the sequence and should not be explicit. + A delay can be inserted by specifying "- delay N" where N is in ms + """ + if isinstance(value, str) and value.lower().startswith("delay "): + value = value.lower()[6:] + delay = cv.All( + cv.positive_time_period_milliseconds, + cv.Range(TimePeriod(milliseconds=1), TimePeriod(milliseconds=255)), + )(value) + return DELAY_FLAG, delay.total_milliseconds + if isinstance(value, int): + return (value,) + value = cv.All(cv.ensure_list(cv.int_range(0, 255)), cv.Length(1, 254))(value) + return tuple(value) + + +def power_of_two(value): + value = cv.int_range(1, 128)(value) + if value & (value - 1) != 0: + raise cv.Invalid("value must be a power of two") + return value + + +def dimension_schema(rounding): + return cv.Any( + cv.dimensions, + cv.Schema( + { + cv.Required(CONF_WIDTH): validate_dimension(rounding), + cv.Required(CONF_HEIGHT): validate_dimension(rounding), + cv.Optional(CONF_OFFSET_HEIGHT, default=0): validate_dimension( + rounding + ), + cv.Optional(CONF_OFFSET_WIDTH, default=0): validate_dimension(rounding), + } + ), + ) + + +def model_schema(bus_mode, model: DriverChip, swapsies: bool): + transform = cv.Schema( + { + cv.Required(CONF_MIRROR_X): cv.boolean, + cv.Required(CONF_MIRROR_Y): cv.boolean, + } + ) + if model.get_default(CONF_SWAP_XY, False) == cv.UNDEFINED: + transform = transform.extend( + { + cv.Optional(CONF_SWAP_XY): cv.invalid( + "Axis swapping not supported by this model" + ) + } + ) + else: + transform = transform.extend( + { + cv.Required(CONF_SWAP_XY): cv.boolean, + } + ) + # CUSTOM model will need to provide a custom init sequence + iseqconf = ( + cv.Required(CONF_INIT_SEQUENCE) + if model.initsequence is None + else cv.Optional(CONF_INIT_SEQUENCE) + ) + # Dimensions are optional if the model has a default width and the transform is not overridden + cv_dimensions = ( + cv.Optional if model.get_default(CONF_WIDTH) and not swapsies else cv.Required + ) + pixel_modes = PIXEL_MODES if bus_mode == TYPE_SINGLE else (PIXEL_MODE_16BIT,) + color_depth = ( + ("16", "8", "16bit", "8bit") if bus_mode == TYPE_SINGLE else ("16", "16bit") + ) + schema = ( + display.FULL_DISPLAY_SCHEMA.extend( + spi.spi_device_schema( + cs_pin_required=False, + default_mode="MODE3" if bus_mode == TYPE_OCTAL else "MODE0", + default_data_rate=model.get_default(CONF_DATA_RATE, 10_000_000), + mode=bus_mode, + ) + ) + .extend( + { + model.option(pin, cv.UNDEFINED): pins.gpio_output_pin_schema + for pin in (CONF_RESET_PIN, CONF_CS_PIN, CONF_DC_PIN) + } + ) + .extend( + { + cv.GenerateID(): cv.declare_id(MipiSpi), + cv_dimensions(CONF_DIMENSIONS): dimension_schema( + model.get_default(CONF_DRAW_ROUNDING, 1) + ), + model.option(CONF_ENABLE_PIN, cv.UNDEFINED): cv.ensure_list( + pins.gpio_output_pin_schema + ), + model.option(CONF_COLOR_ORDER, MODE_BGR): cv.enum( + COLOR_ORDERS, upper=True + ), + model.option(CONF_COLOR_DEPTH, 16): cv.one_of(*color_depth, lower=True), + model.option(CONF_DRAW_ROUNDING, 2): power_of_two, + model.option(CONF_PIXEL_MODE, PIXEL_MODE_16BIT): cv.Any( + cv.one_of(*pixel_modes, lower=True), + cv.int_range(0, 255, min_included=True, max_included=True), + ), + cv.Optional(CONF_TRANSFORM): transform, + cv.Optional(CONF_BUS_MODE, default=bus_mode): cv.one_of( + bus_mode, lower=True + ), + cv.Required(CONF_MODEL): cv.one_of(model.name, upper=True), + iseqconf: cv.ensure_list(map_sequence), + } + ) + .extend( + { + model.option(x): cv.boolean + for x in [ + CONF_DRAW_FROM_ORIGIN, + CONF_SPI_16, + CONF_INVERT_COLORS, + CONF_USE_AXIS_FLIPS, + ] + } + ) + ) + if brightness := model.get_default(CONF_BRIGHTNESS): + schema = schema.extend( + { + cv.Optional(CONF_BRIGHTNESS, default=brightness): cv.int_range( + 0, 0xFF, min_included=True, max_included=True + ), + } + ) + if bus_mode != TYPE_SINGLE: + return cv.All(schema, cv.only_with_esp_idf) + return schema + + +def rotation_as_transform(model, config): + """ + Check if a rotation can be implemented in hardware using the MADCTL register. + A rotation of 180 is always possible, 90 and 270 are possible if the model supports swapping X and Y. + """ + rotation = config.get(CONF_ROTATION, 0) + return rotation and ( + model.get_default(CONF_SWAP_XY) != cv.UNDEFINED or rotation == 180 + ) + + +def config_schema(config): + # First get the model and bus mode + config = cv.Schema( + { + cv.Required(CONF_MODEL): cv.one_of(*MODELS, upper=True), + }, + extra=ALLOW_EXTRA, + )(config) + model = MODELS[config[CONF_MODEL]] + bus_modes = model.modes + config = cv.Schema( + { + model.option(CONF_BUS_MODE, TYPE_SINGLE): cv.one_of(*bus_modes, lower=True), + cv.Required(CONF_MODEL): cv.one_of(*MODELS, upper=True), + }, + extra=ALLOW_EXTRA, + )(config) + bus_mode = config.get(CONF_BUS_MODE, model.modes[0]) + swapsies = config.get(CONF_TRANSFORM, {}).get(CONF_SWAP_XY) is True + config = model_schema(bus_mode, model, swapsies)(config) + # Check for invalid combinations of MADCTL config + if init_sequence := config.get(CONF_INIT_SEQUENCE): + if MADCTL in [x[0] for x in init_sequence] and CONF_TRANSFORM in config: + raise cv.Invalid( + f"transform is not supported when MADCTL ({MADCTL:#X}) is in the init sequence" + ) + + if bus_mode == TYPE_QUAD and CONF_DC_PIN in config: + raise cv.Invalid("DC pin is not supported in quad mode") + if config[CONF_PIXEL_MODE] == PIXEL_MODE_18BIT and bus_mode != TYPE_SINGLE: + raise cv.Invalid("18-bit pixel mode is not supported on a quad or octal bus") + if bus_mode != TYPE_QUAD and CONF_DC_PIN not in config: + raise cv.Invalid(f"DC pin is required in {bus_mode} mode") + return config + + +CONFIG_SCHEMA = config_schema + + +def get_transform(model, config): + can_transform = rotation_as_transform(model, config) + transform = config.get( + CONF_TRANSFORM, + { + CONF_MIRROR_X: model.get_default(CONF_MIRROR_X, False), + CONF_MIRROR_Y: model.get_default(CONF_MIRROR_Y, False), + CONF_SWAP_XY: model.get_default(CONF_SWAP_XY, False), + }, + ) + + # Can we use the MADCTL register to set the rotation? + if can_transform and CONF_TRANSFORM not in config: + rotation = config[CONF_ROTATION] + if rotation == 180: + transform[CONF_MIRROR_X] = not transform[CONF_MIRROR_X] + transform[CONF_MIRROR_Y] = not transform[CONF_MIRROR_Y] + elif rotation == 90: + transform[CONF_SWAP_XY] = not transform[CONF_SWAP_XY] + transform[CONF_MIRROR_X] = not transform[CONF_MIRROR_X] + else: + transform[CONF_SWAP_XY] = not transform[CONF_SWAP_XY] + transform[CONF_MIRROR_Y] = not transform[CONF_MIRROR_Y] + transform[CONF_TRANSFORM] = True + return transform + + +def get_sequence(model, config): + """ + Create the init sequence for the display. + Use the default sequence from the model, if any, and append any custom sequence provided in the config. + Append SLPOUT (if not already in the sequence) and DISPON to the end of the sequence + Pixel format, color order, and orientation will be set. + """ + sequence = list(model.initsequence) + custom_sequence = config.get(CONF_INIT_SEQUENCE, []) + sequence.extend(custom_sequence) + # Ensure each command is a tuple + sequence = [x if isinstance(x, tuple) else (x,) for x in sequence] + commands = [x[0] for x in sequence] + # Set pixel format if not already in the custom sequence + if PIXFMT not in commands: + pixel_mode = config[CONF_PIXEL_MODE] + if not isinstance(pixel_mode, int): + pixel_mode = PIXEL_MODES[pixel_mode] + sequence.append((PIXFMT, pixel_mode)) + # Does the chip use the flipping bits for mirroring rather than the reverse order bits? + use_flip = config[CONF_USE_AXIS_FLIPS] + if MADCTL not in commands: + madctl = 0 + transform = get_transform(model, config) + if transform.get(CONF_TRANSFORM): + LOGGER.info("Using hardware transform to implement rotation") + if transform.get(CONF_MIRROR_X): + madctl |= MADCTL_XFLIP if use_flip else MADCTL_MX + if transform.get(CONF_MIRROR_Y): + madctl |= MADCTL_YFLIP if use_flip else MADCTL_MY + if transform.get(CONF_SWAP_XY) is True: # Exclude Undefined + madctl |= MADCTL_MV + if config[CONF_COLOR_ORDER] == MODE_BGR: + madctl |= MADCTL_BGR + sequence.append((MADCTL, madctl)) + if INVON not in commands and INVOFF not in commands: + if config[CONF_INVERT_COLORS]: + sequence.append((INVON,)) + else: + sequence.append((INVOFF,)) + if BRIGHTNESS not in commands: + if brightness := config.get( + CONF_BRIGHTNESS, model.get_default(CONF_BRIGHTNESS) + ): + sequence.append((BRIGHTNESS, brightness)) + if SLPOUT not in commands: + sequence.append((SLPOUT,)) + sequence.append((DISPON,)) + + # Flatten the sequence into a list of bytes, with the length of each command + # or the delay flag inserted where needed + return sum( + tuple( + (x[1], 0xFF) if x[0] == DELAY_FLAG else (x[0], len(x) - 1) + x[1:] + for x in sequence + ), + (), + ) + + +async def to_code(config): + model = MODELS[config[CONF_MODEL]] + transform = get_transform(model, config) + if CONF_DIMENSIONS in config: + # Explicit dimensions, just use as is + dimensions = config[CONF_DIMENSIONS] + if isinstance(dimensions, dict): + width = dimensions[CONF_WIDTH] + height = dimensions[CONF_HEIGHT] + offset_width = dimensions[CONF_OFFSET_WIDTH] + offset_height = dimensions[CONF_OFFSET_HEIGHT] + else: + (width, height) = dimensions + offset_width = 0 + offset_height = 0 + else: + # Default dimensions, use model defaults and transform if needed + width = model.get_default(CONF_WIDTH) + height = model.get_default(CONF_HEIGHT) + offset_width = model.get_default(CONF_OFFSET_WIDTH, 0) + offset_height = model.get_default(CONF_OFFSET_HEIGHT, 0) + + # if mirroring axes and there are offsets, also mirror the offsets to cater for situations where + # the offset is asymmetric + if transform[CONF_MIRROR_X]: + native_width = model.get_default( + CONF_NATIVE_WIDTH, width + offset_width * 2 + ) + offset_width = native_width - width - offset_width + if transform[CONF_MIRROR_Y]: + native_height = model.get_default( + CONF_NATIVE_HEIGHT, height + offset_height * 2 + ) + offset_height = native_height - height - offset_height + # Swap default dimensions if swap_xy is set + if transform[CONF_SWAP_XY] is True: + width, height = height, width + offset_height, offset_width = offset_width, offset_height + + color_depth = config[CONF_COLOR_DEPTH] + if color_depth.endswith("bit"): + color_depth = color_depth[:-3] + color_depth = COLOR_DEPTHS[int(color_depth)] + + var = cg.new_Pvariable( + config[CONF_ID], width, height, offset_width, offset_height, color_depth + ) + cg.add(var.set_init_sequence(get_sequence(model, config))) + if rotation_as_transform(model, config): + if CONF_TRANSFORM in config: + LOGGER.warning("Use of 'transform' with 'rotation' is not recommended") + else: + config[CONF_ROTATION] = 0 + cg.add(var.set_model(config[CONF_MODEL])) + cg.add(var.set_draw_from_origin(config[CONF_DRAW_FROM_ORIGIN])) + cg.add(var.set_draw_rounding(config[CONF_DRAW_ROUNDING])) + cg.add(var.set_spi_16(config[CONF_SPI_16])) + if enable_pin := config.get(CONF_ENABLE_PIN): + enable = [await cg.gpio_pin_expression(pin) for pin in enable_pin] + cg.add(var.set_enable_pins(enable)) + + if reset_pin := config.get(CONF_RESET_PIN): + reset = await cg.gpio_pin_expression(reset_pin) + cg.add(var.set_reset_pin(reset)) + + if dc_pin := config.get(CONF_DC_PIN): + dc_pin = await cg.gpio_pin_expression(dc_pin) + cg.add(var.set_dc_pin(dc_pin)) + + if lamb := config.get(CONF_LAMBDA): + lambda_ = await cg.process_lambda( + lamb, [(display.DisplayRef, "it")], return_type=cg.void + ) + cg.add(var.set_writer(lambda_)) + await display.register_display(var, config) + await spi.register_spi_device(var, config) diff --git a/esphome/components/mipi_spi/mipi_spi.cpp b/esphome/components/mipi_spi/mipi_spi.cpp new file mode 100644 index 0000000000..2d393ac349 --- /dev/null +++ b/esphome/components/mipi_spi/mipi_spi.cpp @@ -0,0 +1,481 @@ +#include "mipi_spi.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace mipi_spi { + +void MipiSpi::setup() { + ESP_LOGCONFIG(TAG, "Setting up MIPI SPI"); + this->spi_setup(); + if (this->dc_pin_ != nullptr) { + this->dc_pin_->setup(); + this->dc_pin_->digital_write(false); + } + for (auto *pin : this->enable_pins_) { + pin->setup(); + pin->digital_write(true); + } + if (this->reset_pin_ != nullptr) { + this->reset_pin_->setup(); + this->reset_pin_->digital_write(true); + delay(5); + this->reset_pin_->digital_write(false); + delay(5); + this->reset_pin_->digital_write(true); + } + this->bus_width_ = this->parent_->get_bus_width(); + + // need to know when the display is ready for SLPOUT command - will be 120ms after reset + auto when = millis() + 120; + delay(10); + size_t index = 0; + auto &vec = this->init_sequence_; + while (index != vec.size()) { + if (vec.size() - index < 2) { + ESP_LOGE(TAG, "Malformed init sequence"); + this->mark_failed(); + return; + } + uint8_t cmd = vec[index++]; + uint8_t x = vec[index++]; + if (x == DELAY_FLAG) { + ESP_LOGD(TAG, "Delay %dms", cmd); + delay(cmd); + } else { + uint8_t num_args = x & 0x7F; + if (vec.size() - index < num_args) { + ESP_LOGE(TAG, "Malformed init sequence"); + this->mark_failed(); + return; + } + auto arg_byte = vec[index]; + switch (cmd) { + case SLEEP_OUT: { + // are we ready, boots? + int duration = when - millis(); + if (duration > 0) { + ESP_LOGD(TAG, "Sleep %dms", duration); + delay(duration); + } + } break; + + case INVERT_ON: + this->invert_colors_ = true; + break; + case MADCTL_CMD: + this->madctl_ = arg_byte; + break; + case PIXFMT: + this->pixel_mode_ = arg_byte & 0x11 ? PIXEL_MODE_16 : PIXEL_MODE_18; + break; + case BRIGHTNESS: + this->brightness_ = arg_byte; + break; + + default: + break; + } + const auto *ptr = vec.data() + index; + ESP_LOGD(TAG, "Command %02X, length %d, byte %02X", cmd, num_args, arg_byte); + this->write_command_(cmd, ptr, num_args); + index += num_args; + if (cmd == SLEEP_OUT) + delay(10); + } + } + this->setup_complete_ = true; + if (this->draw_from_origin_) + check_buffer_(); + ESP_LOGCONFIG(TAG, "MIPI SPI setup complete"); +} + +void MipiSpi::update() { + if (!this->setup_complete_ || this->is_failed()) { + return; + } + this->do_update_(); + if (this->buffer_ == nullptr || this->x_low_ > this->x_high_ || this->y_low_ > this->y_high_) + return; + ESP_LOGV(TAG, "x_low %d, y_low %d, x_high %d, y_high %d", this->x_low_, this->y_low_, this->x_high_, this->y_high_); + // Some chips require that the drawing window be aligned on certain boundaries + auto dr = this->draw_rounding_; + this->x_low_ = this->x_low_ / dr * dr; + this->y_low_ = this->y_low_ / dr * dr; + this->x_high_ = (this->x_high_ + dr) / dr * dr - 1; + this->y_high_ = (this->y_high_ + dr) / dr * dr - 1; + if (this->draw_from_origin_) { + this->x_low_ = 0; + this->y_low_ = 0; + this->x_high_ = this->width_ - 1; + } + int w = this->x_high_ - this->x_low_ + 1; + int h = this->y_high_ - this->y_low_ + 1; + this->write_to_display_(this->x_low_, this->y_low_, w, h, this->buffer_, this->x_low_, this->y_low_, + this->width_ - w - this->x_low_); + // invalidate watermarks + this->x_low_ = this->width_; + this->y_low_ = this->height_; + this->x_high_ = 0; + this->y_high_ = 0; +} + +void MipiSpi::fill(Color color) { + if (!this->check_buffer_()) + return; + this->x_low_ = 0; + this->y_low_ = 0; + this->x_high_ = this->get_width_internal() - 1; + this->y_high_ = this->get_height_internal() - 1; + switch (this->color_depth_) { + case display::COLOR_BITNESS_332: { + auto new_color = display::ColorUtil::color_to_332(color, display::ColorOrder::COLOR_ORDER_RGB); + memset(this->buffer_, (uint8_t) new_color, this->buffer_bytes_); + break; + } + default: { + auto new_color = display::ColorUtil::color_to_565(color); + if (((uint8_t) (new_color >> 8)) == ((uint8_t) new_color)) { + // Upper and lower is equal can use quicker memset operation. Takes ~20ms. + memset(this->buffer_, (uint8_t) new_color, this->buffer_bytes_); + } else { + auto *ptr_16 = reinterpret_cast(this->buffer_); + auto len = this->buffer_bytes_ / 2; + while (len--) { + *ptr_16++ = new_color; + } + } + } + } +} + +void MipiSpi::draw_absolute_pixel_internal(int x, int y, Color color) { + if (x >= this->get_width_internal() || x < 0 || y >= this->get_height_internal() || y < 0) { + return; + } + if (!this->check_buffer_()) + return; + size_t pos = (y * this->width_) + x; + switch (this->color_depth_) { + case display::COLOR_BITNESS_332: { + uint8_t new_color = display::ColorUtil::color_to_332(color); + if (this->buffer_[pos] == new_color) + return; + this->buffer_[pos] = new_color; + break; + } + + case display::COLOR_BITNESS_565: { + auto *ptr_16 = reinterpret_cast(this->buffer_); + uint8_t hi_byte = static_cast(color.r & 0xF8) | (color.g >> 5); + uint8_t lo_byte = static_cast((color.g & 0x1C) << 3) | (color.b >> 3); + uint16_t new_color = hi_byte | (lo_byte << 8); // big endian + if (ptr_16[pos] == new_color) + return; + ptr_16[pos] = new_color; + break; + } + default: + return; + } + // low and high watermark may speed up drawing from buffer + if (x < this->x_low_) + this->x_low_ = x; + if (y < this->y_low_) + this->y_low_ = y; + if (x > this->x_high_) + this->x_high_ = x; + if (y > this->y_high_) + this->y_high_ = y; +} + +void MipiSpi::reset_params_() { + if (!this->is_ready()) + return; + this->write_command_(this->invert_colors_ ? INVERT_ON : INVERT_OFF); + if (this->brightness_.has_value()) + this->write_command_(BRIGHTNESS, this->brightness_.value()); +} + +void MipiSpi::write_init_sequence_() { + size_t index = 0; + auto &vec = this->init_sequence_; + while (index != vec.size()) { + if (vec.size() - index < 2) { + ESP_LOGE(TAG, "Malformed init sequence"); + this->mark_failed(); + return; + } + uint8_t cmd = vec[index++]; + uint8_t x = vec[index++]; + if (x == DELAY_FLAG) { + ESP_LOGV(TAG, "Delay %dms", cmd); + delay(cmd); + } else { + uint8_t num_args = x & 0x7F; + if (vec.size() - index < num_args) { + ESP_LOGE(TAG, "Malformed init sequence"); + this->mark_failed(); + return; + } + const auto *ptr = vec.data() + index; + this->write_command_(cmd, ptr, num_args); + index += num_args; + } + } + this->setup_complete_ = true; + ESP_LOGCONFIG(TAG, "MIPI SPI setup complete"); +} + +void MipiSpi::set_addr_window_(uint16_t x1, uint16_t y1, uint16_t x2, uint16_t y2) { + ESP_LOGVV(TAG, "Set addr %d/%d, %d/%d", x1, y1, x2, y2); + uint8_t buf[4]; + x1 += this->offset_width_; + x2 += this->offset_width_; + y1 += this->offset_height_; + y2 += this->offset_height_; + put16_be(buf, y1); + put16_be(buf + 2, y2); + this->write_command_(RASET, buf, sizeof buf); + put16_be(buf, x1); + put16_be(buf + 2, x2); + this->write_command_(CASET, buf, sizeof buf); +} + +void MipiSpi::draw_pixels_at(int x_start, int y_start, int w, int h, const uint8_t *ptr, display::ColorOrder order, + display::ColorBitness bitness, bool big_endian, int x_offset, int y_offset, int x_pad) { + if (!this->setup_complete_ || this->is_failed()) + return; + if (w <= 0 || h <= 0) + return; + if (bitness != this->color_depth_ || big_endian != (this->bit_order_ == spi::BIT_ORDER_MSB_FIRST)) { + Display::draw_pixels_at(x_start, y_start, w, h, ptr, order, bitness, big_endian, x_offset, y_offset, x_pad); + return; + } + if (this->draw_from_origin_) { + auto stride = x_offset + w + x_pad; + for (int y = 0; y != h; y++) { + memcpy(this->buffer_ + ((y + y_start) * this->width_ + x_start) * 2, + ptr + ((y + y_offset) * stride + x_offset) * 2, w * 2); + } + ptr = this->buffer_; + w = this->width_; + h += y_start; + x_start = 0; + y_start = 0; + x_offset = 0; + y_offset = 0; + } + this->write_to_display_(x_start, y_start, w, h, ptr, x_offset, y_offset, x_pad); +} + +void MipiSpi::write_18_from_16_bit_(const uint16_t *ptr, size_t w, size_t h, size_t stride) { + stride -= w; + uint8_t transfer_buffer[6 * 256]; + size_t idx = 0; // index into transfer_buffer + while (h-- != 0) { + for (auto x = w; x-- != 0;) { + auto color_val = *ptr++; + // deal with byte swapping + transfer_buffer[idx++] = (color_val & 0xF8); // Blue + transfer_buffer[idx++] = ((color_val & 0x7) << 5) | ((color_val & 0xE000) >> 11); // Green + transfer_buffer[idx++] = (color_val >> 5) & 0xF8; // Red + if (idx == sizeof(transfer_buffer)) { + this->write_array(transfer_buffer, idx); + idx = 0; + } + } + ptr += stride; + } + if (idx != 0) + this->write_array(transfer_buffer, idx); +} + +void MipiSpi::write_18_from_8_bit_(const uint8_t *ptr, size_t w, size_t h, size_t stride) { + stride -= w; + uint8_t transfer_buffer[6 * 256]; + size_t idx = 0; // index into transfer_buffer + while (h-- != 0) { + for (auto x = w; x-- != 0;) { + auto color_val = *ptr++; + transfer_buffer[idx++] = color_val & 0xE0; // Red + transfer_buffer[idx++] = (color_val << 3) & 0xE0; // Green + transfer_buffer[idx++] = color_val << 6; // Blue + if (idx == sizeof(transfer_buffer)) { + this->write_array(transfer_buffer, idx); + idx = 0; + } + } + ptr += stride; + } + if (idx != 0) + this->write_array(transfer_buffer, idx); +} + +void MipiSpi::write_16_from_8_bit_(const uint8_t *ptr, size_t w, size_t h, size_t stride) { + stride -= w; + uint8_t transfer_buffer[6 * 256]; + size_t idx = 0; // index into transfer_buffer + while (h-- != 0) { + for (auto x = w; x-- != 0;) { + auto color_val = *ptr++; + transfer_buffer[idx++] = (color_val & 0xE0) | ((color_val & 0x1C) >> 2); + transfer_buffer[idx++] = (color_val & 0x3) << 3; + if (idx == sizeof(transfer_buffer)) { + this->write_array(transfer_buffer, idx); + idx = 0; + } + } + ptr += stride; + } + if (idx != 0) + this->write_array(transfer_buffer, idx); +} + +void MipiSpi::write_to_display_(int x_start, int y_start, int w, int h, const uint8_t *ptr, int x_offset, int y_offset, + int x_pad) { + this->set_addr_window_(x_start, y_start, x_start + w - 1, y_start + h - 1); + auto stride = x_offset + w + x_pad; + const auto *offset_ptr = ptr; + if (this->color_depth_ == display::COLOR_BITNESS_332) { + offset_ptr += y_offset * stride + x_offset; + } else { + stride *= 2; + offset_ptr += y_offset * stride + x_offset * 2; + } + + switch (this->bus_width_) { + case 4: + this->enable(); + if (x_offset == 0 && x_pad == 0 && y_offset == 0) { + // we could deal here with a non-zero y_offset, but if x_offset is zero, y_offset probably will be so don't + // bother + this->write_cmd_addr_data(8, 0x32, 24, WDATA << 8, ptr, w * h * 2, 4); + } else { + this->write_cmd_addr_data(8, 0x32, 24, WDATA << 8, nullptr, 0, 4); + for (int y = 0; y != h; y++) { + this->write_cmd_addr_data(0, 0, 0, 0, offset_ptr, w * 2, 4); + offset_ptr += stride; + } + } + break; + + case 8: + this->write_command_(WDATA); + this->enable(); + if (x_offset == 0 && x_pad == 0 && y_offset == 0) { + this->write_cmd_addr_data(0, 0, 0, 0, ptr, w * h * 2, 8); + } else { + for (int y = 0; y != h; y++) { + this->write_cmd_addr_data(0, 0, 0, 0, offset_ptr, w * 2, 8); + offset_ptr += stride; + } + } + break; + + default: + this->write_command_(WDATA); + this->enable(); + + if (this->color_depth_ == display::COLOR_BITNESS_565) { + // Source buffer is 16-bit RGB565 + if (this->pixel_mode_ == PIXEL_MODE_18) { + // Convert RGB565 to RGB666 + this->write_18_from_16_bit_(reinterpret_cast(offset_ptr), w, h, stride / 2); + } else { + // Direct RGB565 output + if (x_offset == 0 && x_pad == 0 && y_offset == 0) { + this->write_array(ptr, w * h * 2); + } else { + for (int y = 0; y != h; y++) { + this->write_array(offset_ptr, w * 2); + offset_ptr += stride; + } + } + } + } else { + // Source buffer is 8-bit RGB332 + if (this->pixel_mode_ == PIXEL_MODE_18) { + // Convert RGB332 to RGB666 + this->write_18_from_8_bit_(offset_ptr, w, h, stride); + } else { + this->write_16_from_8_bit_(offset_ptr, w, h, stride); + } + break; + } + } + this->disable(); +} + +void MipiSpi::write_command_(uint8_t cmd, const uint8_t *bytes, size_t len) { + ESP_LOGV(TAG, "Command %02X, length %d, bytes %s", cmd, len, format_hex_pretty(bytes, len).c_str()); + if (this->bus_width_ == 4) { + this->enable(); + this->write_cmd_addr_data(8, 0x02, 24, cmd << 8, bytes, len); + this->disable(); + } else if (this->bus_width_ == 8) { + this->dc_pin_->digital_write(false); + this->enable(); + this->write_cmd_addr_data(0, 0, 0, 0, &cmd, 1, 8); + this->disable(); + this->dc_pin_->digital_write(true); + if (len != 0) { + this->enable(); + this->write_cmd_addr_data(0, 0, 0, 0, bytes, len, 8); + this->disable(); + } + } else { + this->dc_pin_->digital_write(false); + this->enable(); + this->write_byte(cmd); + this->disable(); + this->dc_pin_->digital_write(true); + if (len != 0) { + if (this->spi_16_) { + for (size_t i = 0; i != len; i++) { + this->enable(); + this->write_byte(0); + this->write_byte(bytes[i]); + this->disable(); + } + } else { + this->enable(); + this->write_array(bytes, len); + this->disable(); + } + } + } +} + +void MipiSpi::dump_config() { + ESP_LOGCONFIG(TAG, "MIPI_SPI Display"); + ESP_LOGCONFIG(TAG, " Model: %s", this->model_); + ESP_LOGCONFIG(TAG, " Width: %u", this->width_); + ESP_LOGCONFIG(TAG, " Height: %u", this->height_); + if (this->offset_width_ != 0) + ESP_LOGCONFIG(TAG, " Offset width: %u", this->offset_width_); + if (this->offset_height_ != 0) + ESP_LOGCONFIG(TAG, " Offset height: %u", this->offset_height_); + ESP_LOGCONFIG(TAG, " Swap X/Y: %s", YESNO(this->madctl_ & MADCTL_MV)); + ESP_LOGCONFIG(TAG, " Mirror X: %s", YESNO(this->madctl_ & (MADCTL_MX | MADCTL_XFLIP))); + ESP_LOGCONFIG(TAG, " Mirror Y: %s", YESNO(this->madctl_ & (MADCTL_MY | MADCTL_YFLIP))); + ESP_LOGCONFIG(TAG, " Color depth: %d bits", this->color_depth_ == display::COLOR_BITNESS_565 ? 16 : 8); + ESP_LOGCONFIG(TAG, " Invert colors: %s", YESNO(this->invert_colors_)); + ESP_LOGCONFIG(TAG, " Color order: %s", this->madctl_ & MADCTL_BGR ? "BGR" : "RGB"); + ESP_LOGCONFIG(TAG, " Pixel mode: %s", this->pixel_mode_ == PIXEL_MODE_18 ? "18bit" : "16bit"); + if (this->brightness_.has_value()) + ESP_LOGCONFIG(TAG, " Brightness: %u", this->brightness_.value()); + if (this->spi_16_) + ESP_LOGCONFIG(TAG, " SPI 16bit: YES"); + ESP_LOGCONFIG(TAG, " Draw rounding: %u", this->draw_rounding_); + if (this->draw_from_origin_) + ESP_LOGCONFIG(TAG, " Draw from origin: YES"); + LOG_PIN(" CS Pin: ", this->cs_); + LOG_PIN(" Reset Pin: ", this->reset_pin_); + LOG_PIN(" DC Pin: ", this->dc_pin_); + ESP_LOGCONFIG(TAG, " SPI Mode: %d", this->mode_); + ESP_LOGCONFIG(TAG, " SPI Data rate: %dMHz", static_cast(this->data_rate_ / 1000000)); + ESP_LOGCONFIG(TAG, " SPI Bus width: %d", this->bus_width_); +} + +} // namespace mipi_spi +} // namespace esphome diff --git a/esphome/components/mipi_spi/mipi_spi.h b/esphome/components/mipi_spi/mipi_spi.h new file mode 100644 index 0000000000..052ebe3a6b --- /dev/null +++ b/esphome/components/mipi_spi/mipi_spi.h @@ -0,0 +1,171 @@ +#pragma once + +#include + +#include "esphome/components/spi/spi.h" +#include "esphome/components/display/display.h" +#include "esphome/components/display/display_buffer.h" +#include "esphome/components/display/display_color_utils.h" + +namespace esphome { +namespace mipi_spi { + +constexpr static const char *const TAG = "display.mipi_spi"; +static const uint8_t SW_RESET_CMD = 0x01; +static const uint8_t SLEEP_OUT = 0x11; +static const uint8_t NORON = 0x13; +static const uint8_t INVERT_OFF = 0x20; +static const uint8_t INVERT_ON = 0x21; +static const uint8_t ALL_ON = 0x23; +static const uint8_t WRAM = 0x24; +static const uint8_t MIPI = 0x26; +static const uint8_t DISPLAY_ON = 0x29; +static const uint8_t RASET = 0x2B; +static const uint8_t CASET = 0x2A; +static const uint8_t WDATA = 0x2C; +static const uint8_t TEON = 0x35; +static const uint8_t MADCTL_CMD = 0x36; +static const uint8_t PIXFMT = 0x3A; +static const uint8_t BRIGHTNESS = 0x51; +static const uint8_t SWIRE1 = 0x5A; +static const uint8_t SWIRE2 = 0x5B; +static const uint8_t PAGESEL = 0xFE; + +static const uint8_t MADCTL_MY = 0x80; // Bit 7 Bottom to top +static const uint8_t MADCTL_MX = 0x40; // Bit 6 Right to left +static const uint8_t MADCTL_MV = 0x20; // Bit 5 Swap axes +static const uint8_t MADCTL_RGB = 0x00; // Bit 3 Red-Green-Blue pixel order +static const uint8_t MADCTL_BGR = 0x08; // Bit 3 Blue-Green-Red pixel order +static const uint8_t MADCTL_XFLIP = 0x02; // Mirror the display horizontally +static const uint8_t MADCTL_YFLIP = 0x01; // Mirror the display vertically + +static const uint8_t DELAY_FLAG = 0xFF; +// store a 16 bit value in a buffer, big endian. +static inline void put16_be(uint8_t *buf, uint16_t value) { + buf[0] = value >> 8; + buf[1] = value; +} + +enum PixelMode { + PIXEL_MODE_16, + PIXEL_MODE_18, +}; + +class MipiSpi : public display::DisplayBuffer, + public spi::SPIDevice { + public: + MipiSpi(size_t width, size_t height, int16_t offset_width, int16_t offset_height, display::ColorBitness color_depth) + : width_(width), + height_(height), + offset_width_(offset_width), + offset_height_(offset_height), + color_depth_(color_depth) {} + void set_model(const char *model) { this->model_ = model; } + void update() override; + void setup() override; + display::ColorOrder get_color_mode() { + return this->madctl_ & MADCTL_BGR ? display::COLOR_ORDER_BGR : display::COLOR_ORDER_RGB; + } + + void set_reset_pin(GPIOPin *reset_pin) { this->reset_pin_ = reset_pin; } + void set_enable_pins(std::vector enable_pins) { this->enable_pins_ = std::move(enable_pins); } + void set_dc_pin(GPIOPin *dc_pin) { this->dc_pin_ = dc_pin; } + void set_invert_colors(bool invert_colors) { + this->invert_colors_ = invert_colors; + this->reset_params_(); + } + void set_brightness(uint8_t brightness) { + this->brightness_ = brightness; + this->reset_params_(); + } + + void set_draw_from_origin(bool draw_from_origin) { this->draw_from_origin_ = draw_from_origin; } + display::DisplayType get_display_type() override { return display::DisplayType::DISPLAY_TYPE_COLOR; } + void dump_config() override; + + int get_width_internal() override { return this->width_; } + int get_height_internal() override { return this->height_; } + bool can_proceed() override { return this->setup_complete_; } + void set_init_sequence(const std::vector &sequence) { this->init_sequence_ = sequence; } + void set_draw_rounding(unsigned rounding) { this->draw_rounding_ = rounding; } + void set_spi_16(bool spi_16) { this->spi_16_ = spi_16; } + + protected: + bool check_buffer_() { + if (this->is_failed()) + return false; + if (this->buffer_ != nullptr) + return true; + auto bytes_per_pixel = this->color_depth_ == display::COLOR_BITNESS_565 ? 2 : 1; + this->init_internal_(this->width_ * this->height_ * bytes_per_pixel); + if (this->buffer_ == nullptr) { + this->mark_failed(); + return false; + } + this->buffer_bytes_ = this->width_ * this->height_ * bytes_per_pixel; + return true; + } + void fill(Color color) override; + void draw_absolute_pixel_internal(int x, int y, Color color) override; + void draw_pixels_at(int x_start, int y_start, int w, int h, const uint8_t *ptr, display::ColorOrder order, + display::ColorBitness bitness, bool big_endian, int x_offset, int y_offset, int x_pad) override; + void write_18_from_16_bit_(const uint16_t *ptr, size_t w, size_t h, size_t stride); + void write_18_from_8_bit_(const uint8_t *ptr, size_t w, size_t h, size_t stride); + void write_16_from_8_bit_(const uint8_t *ptr, size_t w, size_t h, size_t stride); + void write_to_display_(int x_start, int y_start, int w, int h, const uint8_t *ptr, int x_offset, int y_offset, + int x_pad); + /** + * the RM67162 in quad SPI mode seems to work like this (not in the datasheet, this is deduced from the + * sample code.) + * + * Immediately after enabling /CS send 4 bytes in single-dataline SPI mode: + * 0: either 0x2 or 0x32. The first indicates that any subsequent data bytes after the initial 4 will be + * sent in 1-dataline SPI. The second indicates quad mode. + * 1: 0x00 + * 2: The command (register address) byte. + * 3: 0x00 + * + * This is followed by zero or more data bytes in either 1-wire or 4-wire mode, depending on the first byte. + * At the conclusion of the write, de-assert /CS. + * + * @param cmd + * @param bytes + * @param len + */ + void write_command_(uint8_t cmd, const uint8_t *bytes, size_t len); + + void write_command_(uint8_t cmd, uint8_t data) { this->write_command_(cmd, &data, 1); } + void write_command_(uint8_t cmd) { this->write_command_(cmd, &cmd, 0); } + void reset_params_(); + void write_init_sequence_(); + void set_addr_window_(uint16_t x1, uint16_t y1, uint16_t x2, uint16_t y2); + + GPIOPin *reset_pin_{nullptr}; + std::vector enable_pins_{}; + GPIOPin *dc_pin_{nullptr}; + uint16_t x_low_{1}; + uint16_t y_low_{1}; + uint16_t x_high_{0}; + uint16_t y_high_{0}; + bool setup_complete_{}; + + bool invert_colors_{}; + size_t width_; + size_t height_; + int16_t offset_width_; + int16_t offset_height_; + size_t buffer_bytes_{0}; + display::ColorBitness color_depth_; + PixelMode pixel_mode_{PIXEL_MODE_16}; + uint8_t bus_width_{}; + bool spi_16_{}; + uint8_t madctl_{}; + bool draw_from_origin_{false}; + unsigned draw_rounding_{2}; + optional brightness_{}; + const char *model_{"Unknown"}; + std::vector init_sequence_{}; +}; +} // namespace mipi_spi +} // namespace esphome diff --git a/esphome/components/mipi_spi/models/__init__.py b/esphome/components/mipi_spi/models/__init__.py new file mode 100644 index 0000000000..e9726032d4 --- /dev/null +++ b/esphome/components/mipi_spi/models/__init__.py @@ -0,0 +1,65 @@ +from esphome.components.spi import TYPE_OCTAL, TYPE_QUAD, TYPE_SINGLE +import esphome.config_validation as cv +from esphome.const import CONF_HEIGHT, CONF_OFFSET_HEIGHT, CONF_OFFSET_WIDTH, CONF_WIDTH + +from .. import CONF_NATIVE_HEIGHT, CONF_NATIVE_WIDTH + +MADCTL_MY = 0x80 # Bit 7 Bottom to top +MADCTL_MX = 0x40 # Bit 6 Right to left +MADCTL_MV = 0x20 # Bit 5 Reverse Mode +MADCTL_ML = 0x10 # Bit 4 LCD refresh Bottom to top +MADCTL_RGB = 0x00 # Bit 3 Red-Green-Blue pixel order +MADCTL_BGR = 0x08 # Bit 3 Blue-Green-Red pixel order +MADCTL_MH = 0x04 # Bit 2 LCD refresh right to left + +# These bits are used instead of the above bits on some chips, where using MX and MY results in incorrect +# partial updates. +MADCTL_XFLIP = 0x02 # Mirror the display horizontally +MADCTL_YFLIP = 0x01 # Mirror the display vertically + +DELAY_FLAG = 0xFFF # Special flag to indicate a delay + + +def delay(ms): + return DELAY_FLAG, ms + + +class DriverChip: + models = {} + + def __init__( + self, + name: str, + modes=(TYPE_SINGLE, TYPE_QUAD, TYPE_OCTAL), + initsequence=None, + **defaults, + ): + name = name.upper() + self.name = name + self.modes = modes + self.initsequence = initsequence + self.defaults = defaults + DriverChip.models[name] = self + + def extend(self, name, **kwargs): + defaults = self.defaults.copy() + if ( + CONF_WIDTH in defaults + and CONF_OFFSET_WIDTH in kwargs + and CONF_NATIVE_WIDTH not in defaults + ): + defaults[CONF_NATIVE_WIDTH] = defaults[CONF_WIDTH] + if ( + CONF_HEIGHT in defaults + and CONF_OFFSET_HEIGHT in kwargs + and CONF_NATIVE_HEIGHT not in defaults + ): + defaults[CONF_NATIVE_HEIGHT] = defaults[CONF_HEIGHT] + defaults.update(kwargs) + return DriverChip(name, self.modes, initsequence=self.initsequence, **defaults) + + def get_default(self, key, fallback=False): + return self.defaults.get(key, fallback) + + def option(self, name, fallback=False): + return cv.Optional(name, default=self.get_default(name, fallback)) diff --git a/esphome/components/mipi_spi/models/amoled.py b/esphome/components/mipi_spi/models/amoled.py new file mode 100644 index 0000000000..14277b243f --- /dev/null +++ b/esphome/components/mipi_spi/models/amoled.py @@ -0,0 +1,72 @@ +from esphome.components.spi import TYPE_QUAD + +from .. import MODE_RGB +from . import DriverChip, delay +from .commands import MIPI, NORON, PAGESEL, PIXFMT, SLPOUT, SWIRE1, SWIRE2, TEON, WRAM + +DriverChip( + "T-DISPLAY-S3-AMOLED", + width=240, + height=536, + cs_pin=6, + reset_pin=17, + enable_pin=38, + bus_mode=TYPE_QUAD, + brightness=0xD0, + color_order=MODE_RGB, + initsequence=(SLPOUT,), # Requires early SLPOUT +) + +DriverChip( + name="T-DISPLAY-S3-AMOLED-PLUS", + width=240, + height=536, + cs_pin=6, + reset_pin=17, + dc_pin=7, + enable_pin=38, + data_rate="40MHz", + brightness=0xD0, + color_order=MODE_RGB, + initsequence=( + (PAGESEL, 4), + (0x6A, 0x00), + (PAGESEL, 0x05), + (PAGESEL, 0x07), + (0x07, 0x4F), + (PAGESEL, 0x01), + (0x2A, 0x02), + (0x2B, 0x73), + (PAGESEL, 0x0A), + (0x29, 0x10), + (PAGESEL, 0x00), + (0x53, 0x20), + (TEON, 0x00), + (PIXFMT, 0x75), + (0xC4, 0x80), + ), +) + +RM690B0 = DriverChip( + "RM690B0", + brightness=0xD0, + color_order=MODE_RGB, + width=480, + height=600, + initsequence=( + (PAGESEL, 0x20), + (MIPI, 0x0A), + (WRAM, 0x80), + (SWIRE1, 0x51), + (SWIRE2, 0x2E), + (PAGESEL, 0x00), + (0xC2, 0x00), + delay(10), + (TEON, 0x00), + (NORON,), + ), +) + +T4_S3_AMOLED = RM690B0.extend("T4-S3", width=450, offset_width=16, bus_mode=TYPE_QUAD) + +models = {} diff --git a/esphome/components/mipi_spi/models/commands.py b/esphome/components/mipi_spi/models/commands.py new file mode 100644 index 0000000000..032a6e6b2b --- /dev/null +++ b/esphome/components/mipi_spi/models/commands.py @@ -0,0 +1,82 @@ +# MIPI DBI commands + +NOP = 0x00 +SWRESET = 0x01 +RDDID = 0x04 +RDDST = 0x09 +RDMODE = 0x0A +RDMADCTL = 0x0B +RDPIXFMT = 0x0C +RDIMGFMT = 0x0D +RDSELFDIAG = 0x0F +SLEEP_IN = 0x10 +SLPIN = 0x10 +SLEEP_OUT = 0x11 +SLPOUT = 0x11 +PTLON = 0x12 +NORON = 0x13 +INVERT_OFF = 0x20 +INVOFF = 0x20 +INVERT_ON = 0x21 +INVON = 0x21 +ALL_ON = 0x23 +WRAM = 0x24 +GAMMASET = 0x26 +MIPI = 0x26 +DISPOFF = 0x28 +DISPON = 0x29 +CASET = 0x2A +PASET = 0x2B +RASET = 0x2B +RAMWR = 0x2C +WDATA = 0x2C +RAMRD = 0x2E +PTLAR = 0x30 +VSCRDEF = 0x33 +TEON = 0x35 +MADCTL = 0x36 +MADCTL_CMD = 0x36 +VSCRSADD = 0x37 +IDMOFF = 0x38 +IDMON = 0x39 +COLMOD = 0x3A +PIXFMT = 0x3A +GETSCANLINE = 0x45 +BRIGHTNESS = 0x51 +WRDISBV = 0x51 +RDDISBV = 0x52 +WRCTRLD = 0x53 +SWIRE1 = 0x5A +SWIRE2 = 0x5B +IFMODE = 0xB0 +FRMCTR1 = 0xB1 +FRMCTR2 = 0xB2 +FRMCTR3 = 0xB3 +INVCTR = 0xB4 +DFUNCTR = 0xB6 +ETMOD = 0xB7 +PWCTR1 = 0xC0 +PWCTR2 = 0xC1 +PWCTR3 = 0xC2 +PWCTR4 = 0xC3 +PWCTR5 = 0xC4 +VMCTR1 = 0xC5 +IFCTR = 0xC6 +VMCTR2 = 0xC7 +GMCTR = 0xC8 +SETEXTC = 0xC8 +PWSET = 0xD0 +VMCTR = 0xD1 +PWSETN = 0xD2 +RDID4 = 0xD3 +RDINDEX = 0xD9 +RDID1 = 0xDA +RDID2 = 0xDB +RDID3 = 0xDC +RDIDX = 0xDD +GMCTRP1 = 0xE0 +GMCTRN1 = 0xE1 +CSCON = 0xF0 +PWCTR6 = 0xF6 +ADJCTL3 = 0xF7 +PAGESEL = 0xFE diff --git a/esphome/components/mipi_spi/models/cyd.py b/esphome/components/mipi_spi/models/cyd.py new file mode 100644 index 0000000000..a25ecf33a8 --- /dev/null +++ b/esphome/components/mipi_spi/models/cyd.py @@ -0,0 +1,10 @@ +from .ili import ILI9341 + +ILI9341.extend( + "ESP32-2432S028", + data_rate="40MHz", + cs_pin=15, + dc_pin=2, +) + +models = {} diff --git a/esphome/components/mipi_spi/models/ili.py b/esphome/components/mipi_spi/models/ili.py new file mode 100644 index 0000000000..cc12b38f5d --- /dev/null +++ b/esphome/components/mipi_spi/models/ili.py @@ -0,0 +1,749 @@ +from esphome.components.spi import TYPE_OCTAL + +from .. import MODE_RGB +from . import DriverChip, delay +from .commands import ( + ADJCTL3, + CSCON, + DFUNCTR, + ETMOD, + FRMCTR1, + FRMCTR2, + FRMCTR3, + GAMMASET, + GMCTR, + GMCTRN1, + GMCTRP1, + IDMOFF, + IFCTR, + IFMODE, + INVCTR, + NORON, + PWCTR1, + PWCTR2, + PWCTR3, + PWCTR4, + PWCTR5, + PWSET, + PWSETN, + SETEXTC, + SWRESET, + VMCTR, + VMCTR1, + VMCTR2, + VSCRSADD, +) + +DriverChip( + "M5CORE", + width=320, + height=240, + cs_pin=14, + dc_pin=27, + reset_pin=33, + initsequence=( + (SETEXTC, 0xFF, 0x93, 0x42), + (PWCTR1, 0x12, 0x12), + (PWCTR2, 0x03), + (VMCTR1, 0xF2), + (IFMODE, 0xE0), + (0xF6, 0x01, 0x00, 0x00), + ( + GMCTRP1, + 0x00, + 0x0C, + 0x11, + 0x04, + 0x11, + 0x08, + 0x37, + 0x89, + 0x4C, + 0x06, + 0x0C, + 0x0A, + 0x2E, + 0x34, + 0x0F, + ), + ( + GMCTRN1, + 0x00, + 0x0B, + 0x11, + 0x05, + 0x13, + 0x09, + 0x33, + 0x67, + 0x48, + 0x07, + 0x0E, + 0x0B, + 0x2E, + 0x33, + 0x0F, + ), + (DFUNCTR, 0x08, 0x82, 0x1D, 0x04), + (IDMOFF,), + ), +) +ILI9341 = DriverChip( + "ILI9341", + mirror_x=True, + width=240, + height=320, + initsequence=( + (0xEF, 0x03, 0x80, 0x02), + (0xCF, 0x00, 0xC1, 0x30), + (0xED, 0x64, 0x03, 0x12, 0x81), + (0xE8, 0x85, 0x00, 0x78), + (0xCB, 0x39, 0x2C, 0x00, 0x34, 0x02), + (0xF7, 0x20), + (0xEA, 0x00, 0x00), + (PWCTR1, 0x23), + (PWCTR2, 0x10), + (VMCTR1, 0x3E, 0x28), + (VMCTR2, 0x86), + (VSCRSADD, 0x00), + (FRMCTR1, 0x00, 0x18), + (DFUNCTR, 0x08, 0x82, 0x27), + (0xF2, 0x00), + (GAMMASET, 0x01), + ( + GMCTRP1, + 0x0F, + 0x31, + 0x2B, + 0x0C, + 0x0E, + 0x08, + 0x4E, + 0xF1, + 0x37, + 0x07, + 0x10, + 0x03, + 0x0E, + 0x09, + 0x00, + ), + ( + GMCTRN1, + 0x00, + 0x0E, + 0x14, + 0x03, + 0x11, + 0x07, + 0x31, + 0xC1, + 0x48, + 0x08, + 0x0F, + 0x0C, + 0x31, + 0x36, + 0x0F, + ), + ), +) +DriverChip( + "ILI9481", + mirror_x=True, + width=320, + height=480, + use_axis_flips=True, + initsequence=( + (PWSET, 0x07, 0x42, 0x18), + (VMCTR, 0x00, 0x07, 0x10), + (PWSETN, 0x01, 0x02), + (PWCTR1, 0x10, 0x3B, 0x00, 0x02, 0x11), + (VMCTR1, 0x03), + (IFCTR, 0x83), + (GMCTR, 0x32, 0x36, 0x45, 0x06, 0x16, 0x37, 0x75, 0x77, 0x54, 0x0C, 0x00), + ), +) +DriverChip( + "ILI9486", + mirror_x=True, + width=320, + height=480, + initsequence=( + (PWCTR3, 0x44), + (VMCTR1, 0x00, 0x00, 0x00, 0x00), + ( + GMCTRP1, + 0x0F, + 0x1F, + 0x1C, + 0x0C, + 0x0F, + 0x08, + 0x48, + 0x98, + 0x37, + 0x0A, + 0x13, + 0x04, + 0x11, + 0x0D, + 0x00, + ), + ( + GMCTRN1, + 0x0F, + 0x32, + 0x2E, + 0x0B, + 0x0D, + 0x05, + 0x47, + 0x75, + 0x37, + 0x06, + 0x10, + 0x03, + 0x24, + 0x20, + 0x00, + ), + ), +) +DriverChip( + "ILI9488", + width=320, + height=480, + pixel_mode="18bit", + initsequence=( + ( + GMCTRP1, + 0x0F, + 0x24, + 0x1C, + 0x0A, + 0x0F, + 0x08, + 0x43, + 0x88, + 0x32, + 0x0F, + 0x10, + 0x06, + 0x0F, + 0x07, + 0x00, + ), + ( + GMCTRN1, + 0x0F, + 0x38, + 0x30, + 0x09, + 0x0F, + 0x0F, + 0x4E, + 0x77, + 0x3C, + 0x07, + 0x10, + 0x05, + 0x23, + 0x1B, + 0x00, + ), + (PWCTR1, 0x17, 0x15), + (PWCTR2, 0x41), + (VMCTR1, 0x00, 0x12, 0x80), + (IFMODE, 0x00), + (FRMCTR1, 0xA0), + (INVCTR, 0x02), + (0xE9, 0x00), + (ADJCTL3, 0xA9, 0x51, 0x2C, 0x82), + ), +) +ILI9488_A = DriverChip( + "ILI9488_A", + width=320, + height=480, + invert_colors=False, + pixel_mode="18bit", + mirror_x=True, + initsequence=( + ( + GMCTRP1, + 0x00, + 0x03, + 0x09, + 0x08, + 0x16, + 0x0A, + 0x3F, + 0x78, + 0x4C, + 0x09, + 0x0A, + 0x08, + 0x16, + 0x1A, + 0x0F, + ), + ( + GMCTRN1, + 0x00, + 0x16, + 0x19, + 0x03, + 0x0F, + 0x05, + 0x32, + 0x45, + 0x46, + 0x04, + 0x0E, + 0x0D, + 0x35, + 0x37, + 0x0F, + ), + (PWCTR1, 0x17, 0x15), + (PWCTR2, 0x41), + (VMCTR1, 0x00, 0x12, 0x80), + (IFMODE, 0x00), + (FRMCTR1, 0xA0), + (INVCTR, 0x02), + (DFUNCTR, 0x02, 0x02), + (0xE9, 0x00), + (ADJCTL3, 0xA9, 0x51, 0x2C, 0x82), + ), +) +ST7796 = DriverChip( + "ST7796", + mirror_x=True, + width=320, + height=480, + initsequence=( + (SWRESET,), + (CSCON, 0xC3), + (CSCON, 0x96), + (VMCTR1, 0x1C), + (IFMODE, 0x80), + (INVCTR, 0x01), + (DFUNCTR, 0x80, 0x02, 0x3B), + (ETMOD, 0xC6), + (CSCON, 0x69), + (CSCON, 0x3C), + ), +) +DriverChip( + "S3BOX", + width=320, + height=240, + mirror_x=True, + mirror_y=True, + invert_colors=False, + data_rate="40MHz", + dc_pin=4, + cs_pin=5, + # reset_pin={CONF_INVERTED: True, CONF_NUMBER: 48}, + initsequence=( + (0xEF, 0x03, 0x80, 0x02), + (0xCF, 0x00, 0xC1, 0x30), + (0xED, 0x64, 0x03, 0x12, 0x81), + (0xE8, 0x85, 0x00, 0x78), + (0xCB, 0x39, 0x2C, 0x00, 0x34, 0x02), + (0xF7, 0x20), + (0xEA, 0x00, 0x00), + (PWCTR1, 0x23), + (PWCTR2, 0x10), + (VMCTR1, 0x3E, 0x28), + (VMCTR2, 0x86), + (VSCRSADD, 0x00), + (FRMCTR1, 0x00, 0x18), + (DFUNCTR, 0x08, 0x82, 0x27), + (0xF2, 0x00), + (GAMMASET, 0x01), + ( + GMCTRP1, + 0x0F, + 0x31, + 0x2B, + 0x0C, + 0x0E, + 0x08, + 0x4E, + 0xF1, + 0x37, + 0x07, + 0x10, + 0x03, + 0x0E, + 0x09, + 0x00, + ), + ( + GMCTRN1, + 0x00, + 0x0E, + 0x14, + 0x03, + 0x11, + 0x07, + 0x31, + 0xC1, + 0x48, + 0x08, + 0x0F, + 0x0C, + 0x31, + 0x36, + 0x0F, + ), + ), +) +DriverChip( + "S3BOXLITE", + mirror_x=True, + color_order=MODE_RGB, + width=320, + height=240, + cs_pin=5, + dc_pin=4, + reset_pin=48, + initsequence=( + (0xEF, 0x03, 0x80, 0x02), + (0xCF, 0x00, 0xC1, 0x30), + (0xED, 0x64, 0x03, 0x12, 0x81), + (0xE8, 0x85, 0x00, 0x78), + (0xCB, 0x39, 0x2C, 0x00, 0x34, 0x02), + (0xF7, 0x20), + (0xEA, 0x00, 0x00), + (PWCTR1, 0x23), + (PWCTR2, 0x10), + (VMCTR1, 0x3E, 0x28), + (VMCTR2, 0x86), + (VSCRSADD, 0x00), + (FRMCTR1, 0x00, 0x18), + (DFUNCTR, 0x08, 0x82, 0x27), + (0xF2, 0x00), + (GAMMASET, 0x01), + ( + GMCTRP1, + 0xF0, + 0x09, + 0x0B, + 0x06, + 0x04, + 0x15, + 0x2F, + 0x54, + 0x42, + 0x3C, + 0x17, + 0x14, + 0x18, + 0x1B, + ), + ( + GMCTRN1, + 0xE0, + 0x09, + 0x0B, + 0x06, + 0x04, + 0x03, + 0x2B, + 0x43, + 0x42, + 0x3B, + 0x16, + 0x14, + 0x17, + 0x1B, + ), + ), +) +ST7789V = DriverChip( + "ST7789V", + width=240, + height=320, + initsequence=( + (DFUNCTR, 0x0A, 0x82), + (FRMCTR2, 0x0C, 0x0C, 0x00, 0x33, 0x33), + (ETMOD, 0x35), + (0xBB, 0x28), + (PWCTR1, 0x0C), + (PWCTR3, 0x01, 0xFF), + (PWCTR4, 0x10), + (PWCTR5, 0x20), + (IFCTR, 0x0F), + (PWSET, 0xA4, 0xA1), + ( + GMCTRP1, + 0xD0, + 0x00, + 0x02, + 0x07, + 0x0A, + 0x28, + 0x32, + 0x44, + 0x42, + 0x06, + 0x0E, + 0x12, + 0x14, + 0x17, + ), + ( + GMCTRN1, + 0xD0, + 0x00, + 0x02, + 0x07, + 0x0A, + 0x28, + 0x31, + 0x54, + 0x47, + 0x0E, + 0x1C, + 0x17, + 0x1B, + 0x1E, + ), + ), +) +DriverChip( + "GC9A01A", + mirror_x=True, + width=240, + height=240, + initsequence=( + (0xEF,), + (0xEB, 0x14), + (0xFE,), + (0xEF,), + (0xEB, 0x14), + (0x84, 0x40), + (0x85, 0xFF), + (0x86, 0xFF), + (0x87, 0xFF), + (0x88, 0x0A), + (0x89, 0x21), + (0x8A, 0x00), + (0x8B, 0x80), + (0x8C, 0x01), + (0x8D, 0x01), + (0x8E, 0xFF), + (0x8F, 0xFF), + (0xB6, 0x00, 0x00), + (0x90, 0x08, 0x08, 0x08, 0x08), + (0xBD, 0x06), + (0xBC, 0x00), + (0xFF, 0x60, 0x01, 0x04), + (0xC3, 0x13), + (0xC4, 0x13), + (0xF9, 0x22), + (0xBE, 0x11), + (0xE1, 0x10, 0x0E), + (0xDF, 0x21, 0x0C, 0x02), + (0xF0, 0x45, 0x09, 0x08, 0x08, 0x26, 0x2A), + (0xF1, 0x43, 0x70, 0x72, 0x36, 0x37, 0x6F), + (0xF2, 0x45, 0x09, 0x08, 0x08, 0x26, 0x2A), + (0xF3, 0x43, 0x70, 0x72, 0x36, 0x37, 0x6F), + (0xED, 0x1B, 0x0B), + (0xAE, 0x77), + (0xCD, 0x63), + (0xE8, 0x34), + ( + 0x62, + 0x18, + 0x0D, + 0x71, + 0xED, + 0x70, + 0x70, + 0x18, + 0x0F, + 0x71, + 0xEF, + 0x70, + 0x70, + ), + ( + 0x63, + 0x18, + 0x11, + 0x71, + 0xF1, + 0x70, + 0x70, + 0x18, + 0x13, + 0x71, + 0xF3, + 0x70, + 0x70, + ), + (0x64, 0x28, 0x29, 0xF1, 0x01, 0xF1, 0x00, 0x07), + (0x66, 0x3C, 0x00, 0xCD, 0x67, 0x45, 0x45, 0x10, 0x00, 0x00, 0x00), + (0x67, 0x00, 0x3C, 0x00, 0x00, 0x00, 0x01, 0x54, 0x10, 0x32, 0x98), + (0x74, 0x10, 0x85, 0x80, 0x00, 0x00, 0x4E, 0x00), + (0x98, 0x3E, 0x07), + (0x35,), + ), +) +DriverChip( + "GC9D01N", + width=160, + height=160, + initsequence=( + (0xFE,), + (0xEF,), + (0x80, 0xFF), + (0x81, 0xFF), + (0x82, 0xFF), + (0x83, 0xFF), + (0x84, 0xFF), + (0x85, 0xFF), + (0x86, 0xFF), + (0x87, 0xFF), + (0x88, 0xFF), + (0x89, 0xFF), + (0x8A, 0xFF), + (0x8B, 0xFF), + (0x8C, 0xFF), + (0x8D, 0xFF), + (0x8E, 0xFF), + (0x8F, 0xFF), + (0x3A, 0x05), + (0xEC, 0x01), + (0x74, 0x02, 0x0E, 0x00, 0x00, 0x00, 0x00, 0x00), + (0x98, 0x3E), + (0x99, 0x3E), + (0xB5, 0x0D, 0x0D), + (0x60, 0x38, 0x0F, 0x79, 0x67), + (0x61, 0x38, 0x11, 0x79, 0x67), + (0x64, 0x38, 0x17, 0x71, 0x5F, 0x79, 0x67), + (0x65, 0x38, 0x13, 0x71, 0x5B, 0x79, 0x67), + (0x6A, 0x00, 0x00), + (0x6C, 0x22, 0x02, 0x22, 0x02, 0x22, 0x22, 0x50), + ( + 0x6E, + 0x03, + 0x03, + 0x01, + 0x01, + 0x00, + 0x00, + 0x0F, + 0x0F, + 0x0D, + 0x0D, + 0x0B, + 0x0B, + 0x09, + 0x09, + 0x00, + 0x00, + 0x00, + 0x00, + 0x0A, + 0x0A, + 0x0C, + 0x0C, + 0x0E, + 0x0E, + 0x10, + 0x10, + 0x00, + 0x00, + 0x02, + 0x02, + 0x04, + 0x04, + ), + (0xBF, 0x01), + (0xF9, 0x40), + (0x9B, 0x3B, 0x93, 0x33, 0x7F, 0x00), + (0x7E, 0x30), + (0x70, 0x0D, 0x02, 0x08, 0x0D, 0x02, 0x08), + (0x71, 0x0D, 0x02, 0x08), + (0x91, 0x0E, 0x09), + (0xC3, 0x19, 0xC4, 0x19, 0xC9, 0x3C), + (0xF0, 0x53, 0x15, 0x0A, 0x04, 0x00, 0x3E), + (0xF1, 0x56, 0xA8, 0x7F, 0x33, 0x34, 0x5F), + (0xF2, 0x53, 0x15, 0x0A, 0x04, 0x00, 0x3A), + (0xF3, 0x52, 0xA4, 0x7F, 0x33, 0x34, 0xDF), + ), +) +DriverChip( + "ST7735", + color_order=MODE_RGB, + width=128, + height=160, + initsequence=( + SWRESET, + delay(10), + (FRMCTR1, 0x01, 0x2C, 0x2D), + (FRMCTR2, 0x01, 0x2C, 0x2D), + (FRMCTR3, 0x01, 0x2C, 0x2D, 0x01, 0x2C, 0x2D), + (INVCTR, 0x07), + (PWCTR1, 0xA2, 0x02, 0x84), + (PWCTR2, 0xC5), + (PWCTR3, 0x0A, 0x00), + (PWCTR4, 0x8A, 0x2A), + (PWCTR5, 0x8A, 0xEE), + (VMCTR1, 0x0E), + ( + GMCTRP1, + 0x02, + 0x1C, + 0x07, + 0x12, + 0x37, + 0x32, + 0x29, + 0x2D, + 0x29, + 0x25, + 0x2B, + 0x39, + 0x00, + 0x01, + 0x03, + 0x10, + ), + ( + GMCTRN1, + 0x03, + 0x1D, + 0x07, + 0x06, + 0x2E, + 0x2C, + 0x29, + 0x2D, + 0x2E, + 0x2E, + 0x37, + 0x3F, + 0x00, + 0x00, + 0x02, + 0x10, + ), + NORON, + ), +) +ST7796.extend( + "WT32-SC01-PLUS", + bus_mode=TYPE_OCTAL, + mirror_x=True, + reset_pin=4, + dc_pin=0, + invert_colors=True, +) + +models = {} diff --git a/esphome/components/mipi_spi/models/jc.py b/esphome/components/mipi_spi/models/jc.py new file mode 100644 index 0000000000..449c5b87ae --- /dev/null +++ b/esphome/components/mipi_spi/models/jc.py @@ -0,0 +1,260 @@ +from esphome.components.spi import TYPE_QUAD +import esphome.config_validation as cv +from esphome.const import CONF_IGNORE_STRAPPING_WARNING, CONF_NUMBER + +from .. import MODE_RGB +from . import DriverChip + +AXS15231 = DriverChip( + "AXS15231", + draw_rounding=8, + swap_xy=cv.UNDEFINED, + color_order=MODE_RGB, + bus_mode=TYPE_QUAD, + initsequence=( + (0xBB, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5A, 0xA5), + (0xC1, 0x33), + (0xBB, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00), + ), +) + +AXS15231.extend( + "JC3248W535", + width=320, + height=480, + cs_pin={CONF_NUMBER: 45, CONF_IGNORE_STRAPPING_WARNING: True}, + data_rate="40MHz", +) + +DriverChip( + "JC3636W518", + height=360, + width=360, + offset_height=1, + draw_rounding=1, + cs_pin=10, + reset_pin=47, + invert_colors=True, + color_order=MODE_RGB, + bus_mode=TYPE_QUAD, + data_rate="40MHz", + initsequence=( + (0xF0, 0x08), + (0xF2, 0x08), + (0x9B, 0x51), + (0x86, 0x53), + (0xF2, 0x80), + (0xF0, 0x00), + (0xF0, 0x01), + (0xF1, 0x01), + (0xB0, 0x54), + (0xB1, 0x3F), + (0xB2, 0x2A), + (0xB4, 0x46), + (0xB5, 0x34), + (0xB6, 0xD5), + (0xB7, 0x30), + (0xBA, 0x00), + (0xBB, 0x08), + (0xBC, 0x08), + (0xBD, 0x00), + (0xC0, 0x80), + (0xC1, 0x10), + (0xC2, 0x37), + (0xC3, 0x80), + (0xC4, 0x10), + (0xC5, 0x37), + (0xC6, 0xA9), + (0xC7, 0x41), + (0xC8, 0x51), + (0xC9, 0xA9), + (0xCA, 0x41), + (0xCB, 0x51), + (0xD0, 0x91), + (0xD1, 0x68), + (0xD2, 0x69), + (0xF5, 0x00, 0xA5), + (0xDD, 0x3F), + (0xDE, 0x3F), + (0xF1, 0x10), + (0xF0, 0x00), + (0xF0, 0x02), + ( + 0xE0, + 0x70, + 0x09, + 0x12, + 0x0C, + 0x0B, + 0x27, + 0x38, + 0x54, + 0x4E, + 0x19, + 0x15, + 0x15, + 0x2C, + 0x2F, + ), + ( + 0xE1, + 0x70, + 0x08, + 0x11, + 0x0C, + 0x0B, + 0x27, + 0x38, + 0x43, + 0x4C, + 0x18, + 0x14, + 0x14, + 0x2B, + 0x2D, + ), + (0xF0, 0x10), + (0xF3, 0x10), + (0xE0, 0x08), + (0xE1, 0x00), + (0xE2, 0x00), + (0xE3, 0x00), + (0xE4, 0xE0), + (0xE5, 0x06), + (0xE6, 0x21), + (0xE7, 0x00), + (0xE8, 0x05), + (0xE9, 0x82), + (0xEA, 0xDF), + (0xEB, 0x89), + (0xEC, 0x20), + (0xED, 0x14), + (0xEE, 0xFF), + (0xEF, 0x00), + (0xF8, 0xFF), + (0xF9, 0x00), + (0xFA, 0x00), + (0xFB, 0x30), + (0xFC, 0x00), + (0xFD, 0x00), + (0xFE, 0x00), + (0xFF, 0x00), + (0x60, 0x42), + (0x61, 0xE0), + (0x62, 0x40), + (0x63, 0x40), + (0x64, 0x02), + (0x65, 0x00), + (0x66, 0x40), + (0x67, 0x03), + (0x68, 0x00), + (0x69, 0x00), + (0x6A, 0x00), + (0x6B, 0x00), + (0x70, 0x42), + (0x71, 0xE0), + (0x72, 0x40), + (0x73, 0x40), + (0x74, 0x02), + (0x75, 0x00), + (0x76, 0x40), + (0x77, 0x03), + (0x78, 0x00), + (0x79, 0x00), + (0x7A, 0x00), + (0x7B, 0x00), + (0x80, 0x48), + (0x81, 0x00), + (0x82, 0x05), + (0x83, 0x02), + (0x84, 0xDD), + (0x85, 0x00), + (0x86, 0x00), + (0x87, 0x00), + (0x88, 0x48), + (0x89, 0x00), + (0x8A, 0x07), + (0x8B, 0x02), + (0x8C, 0xDF), + (0x8D, 0x00), + (0x8E, 0x00), + (0x8F, 0x00), + (0x90, 0x48), + (0x91, 0x00), + (0x92, 0x09), + (0x93, 0x02), + (0x94, 0xE1), + (0x95, 0x00), + (0x96, 0x00), + (0x97, 0x00), + (0x98, 0x48), + (0x99, 0x00), + (0x9A, 0x0B), + (0x9B, 0x02), + (0x9C, 0xE3), + (0x9D, 0x00), + (0x9E, 0x00), + (0x9F, 0x00), + (0xA0, 0x48), + (0xA1, 0x00), + (0xA2, 0x04), + (0xA3, 0x02), + (0xA4, 0xDC), + (0xA5, 0x00), + (0xA6, 0x00), + (0xA7, 0x00), + (0xA8, 0x48), + (0xA9, 0x00), + (0xAA, 0x06), + (0xAB, 0x02), + (0xAC, 0xDE), + (0xAD, 0x00), + (0xAE, 0x00), + (0xAF, 0x00), + (0xB0, 0x48), + (0xB1, 0x00), + (0xB2, 0x08), + (0xB3, 0x02), + (0xB4, 0xE0), + (0xB5, 0x00), + (0xB6, 0x00), + (0xB7, 0x00), + (0xB8, 0x48), + (0xB9, 0x00), + (0xBA, 0x0A), + (0xBB, 0x02), + (0xBC, 0xE2), + (0xBD, 0x00), + (0xBE, 0x00), + (0xBF, 0x00), + (0xC0, 0x12), + (0xC1, 0xAA), + (0xC2, 0x65), + (0xC3, 0x74), + (0xC4, 0x47), + (0xC5, 0x56), + (0xC6, 0x00), + (0xC7, 0x88), + (0xC8, 0x99), + (0xC9, 0x33), + (0xD0, 0x21), + (0xD1, 0xAA), + (0xD2, 0x65), + (0xD3, 0x74), + (0xD4, 0x47), + (0xD5, 0x56), + (0xD6, 0x00), + (0xD7, 0x88), + (0xD8, 0x99), + (0xD9, 0x33), + (0xF3, 0x01), + (0xF0, 0x00), + (0xF0, 0x01), + (0xF1, 0x01), + (0xA0, 0x0B), + (0xA3, 0x2A), + (0xA5, 0xC3), + ), +) + +models = {} diff --git a/esphome/components/mipi_spi/models/lanbon.py b/esphome/components/mipi_spi/models/lanbon.py new file mode 100644 index 0000000000..6f9aa58674 --- /dev/null +++ b/esphome/components/mipi_spi/models/lanbon.py @@ -0,0 +1,15 @@ +from .ili import ST7789V + +ST7789V.extend( + "LANBON-L8", + width=240, + height=320, + mirror_x=True, + mirror_y=True, + data_rate="80MHz", + cs_pin=22, + dc_pin=21, + reset_pin=18, +) + +models = {} diff --git a/esphome/components/mipi_spi/models/lilygo.py b/esphome/components/mipi_spi/models/lilygo.py new file mode 100644 index 0000000000..dd6f9c02f7 --- /dev/null +++ b/esphome/components/mipi_spi/models/lilygo.py @@ -0,0 +1,60 @@ +from esphome.components.spi import TYPE_OCTAL + +from .. import MODE_BGR +from .ili import ST7789V, ST7796 + +ST7789V.extend( + "T-EMBED", + width=170, + height=320, + offset_width=35, + color_order=MODE_BGR, + invert_colors=True, + draw_rounding=1, + cs_pin=10, + dc_pin=13, + reset_pin=9, + data_rate="80MHz", +) + +ST7789V.extend( + "T-DISPLAY", + height=240, + width=135, + offset_width=52, + offset_height=40, + draw_rounding=1, + cs_pin=5, + dc_pin=16, + invert_colors=True, +) +ST7789V.extend( + "T-DISPLAY-S3", + height=320, + width=170, + offset_width=35, + color_order=MODE_BGR, + invert_colors=True, + draw_rounding=1, + dc_pin=7, + cs_pin=6, + reset_pin=5, + enable_pin=[9, 15], + data_rate="10MHz", + bus_mode=TYPE_OCTAL, +) + +ST7796.extend( + "T-DISPLAY-S3-PRO", + width=222, + height=480, + offset_width=49, + draw_rounding=1, + cs_pin=39, + reset_pin=47, + dc_pin=9, + backlight_pin=48, + invert_colors=True, +) + +models = {} diff --git a/esphome/components/mipi_spi/models/waveshare.py b/esphome/components/mipi_spi/models/waveshare.py new file mode 100644 index 0000000000..6d14f56fc6 --- /dev/null +++ b/esphome/components/mipi_spi/models/waveshare.py @@ -0,0 +1,139 @@ +from . import DriverChip +from .ili import ILI9488_A + +DriverChip( + "WAVESHARE-4-TFT", + width=320, + height=480, + invert_colors=True, + spi_16=True, + initsequence=( + ( + 0xF9, + 0x00, + 0x08, + ), + ( + 0xC0, + 0x19, + 0x1A, + ), + ( + 0xC1, + 0x45, + 0x00, + ), + ( + 0xC2, + 0x33, + ), + ( + 0xC5, + 0x00, + 0x28, + ), + ( + 0xB1, + 0xA0, + 0x11, + ), + ( + 0xB4, + 0x02, + ), + ( + 0xB6, + 0x00, + 0x42, + 0x3B, + ), + ( + 0xB7, + 0x07, + ), + ( + 0xE0, + 0x1F, + 0x25, + 0x22, + 0x0B, + 0x06, + 0x0A, + 0x4E, + 0xC6, + 0x39, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + ), + ( + 0xE1, + 0x1F, + 0x3F, + 0x3F, + 0x0F, + 0x1F, + 0x0F, + 0x46, + 0x49, + 0x31, + 0x05, + 0x09, + 0x03, + 0x1C, + 0x1A, + 0x00, + ), + ( + 0xF1, + 0x36, + 0x04, + 0x00, + 0x3C, + 0x0F, + 0x0F, + 0xA4, + 0x02, + ), + ( + 0xF2, + 0x18, + 0xA3, + 0x12, + 0x02, + 0x32, + 0x12, + 0xFF, + 0x32, + 0x00, + ), + ( + 0xF4, + 0x40, + 0x00, + 0x08, + 0x91, + 0x04, + ), + ( + 0xF8, + 0x21, + 0x04, + ), + ), +) + +ILI9488_A.extend( + "PICO-RESTOUCH-LCD-3.5", + spi_16=True, + pixel_mode="16bit", + mirror_x=True, + dc_pin=33, + cs_pin=34, + reset_pin=40, + data_rate="20MHz", + invert_colors=True, +) diff --git a/esphome/components/mitsubishi/mitsubishi.cpp b/esphome/components/mitsubishi/mitsubishi.cpp index 449c8fc712..3d9207dd96 100644 --- a/esphome/components/mitsubishi/mitsubishi.cpp +++ b/esphome/components/mitsubishi/mitsubishi.cpp @@ -25,8 +25,8 @@ const uint8_t MITSUBISHI_FAN_AUTO = 0x00; const uint8_t MITSUBISHI_VERTICAL_VANE_SWING = 0x38; -// const uint8_t MITSUBISHI_AUTO = 0X80; -const uint8_t MITSUBISHI_OTHERWISE = 0X40; +// const uint8_t MITSUBISHI_AUTO = 0x80; +const uint8_t MITSUBISHI_OTHERWISE = 0x40; const uint8_t MITSUBISHI_POWERFUL = 0x08; // Optional presets used to enable some model features @@ -42,13 +42,13 @@ const uint16_t MITSUBISHI_HEADER_SPACE = 1700; const uint16_t MITSUBISHI_MIN_GAP = 17500; // Marker bytes -const uint8_t MITSUBISHI_BYTE00 = 0X23; -const uint8_t MITSUBISHI_BYTE01 = 0XCB; -const uint8_t MITSUBISHI_BYTE02 = 0X26; -const uint8_t MITSUBISHI_BYTE03 = 0X01; -const uint8_t MITSUBISHI_BYTE04 = 0X00; -const uint8_t MITSUBISHI_BYTE13 = 0X00; -const uint8_t MITSUBISHI_BYTE16 = 0X00; +const uint8_t MITSUBISHI_BYTE00 = 0x23; +const uint8_t MITSUBISHI_BYTE01 = 0xCB; +const uint8_t MITSUBISHI_BYTE02 = 0x26; +const uint8_t MITSUBISHI_BYTE03 = 0x01; +const uint8_t MITSUBISHI_BYTE04 = 0x00; +const uint8_t MITSUBISHI_BYTE13 = 0x00; +const uint8_t MITSUBISHI_BYTE16 = 0x00; climate::ClimateTraits MitsubishiClimate::traits() { auto traits = climate::ClimateTraits(); diff --git a/esphome/components/mixer/speaker/mixer_speaker.cpp b/esphome/components/mixer/speaker/mixer_speaker.cpp index 121a62392c..8e480dd49b 100644 --- a/esphome/components/mixer/speaker/mixer_speaker.cpp +++ b/esphome/components/mixer/speaker/mixer_speaker.cpp @@ -53,14 +53,15 @@ void SourceSpeaker::dump_config() { } void SourceSpeaker::setup() { - this->parent_->get_output_speaker()->add_audio_output_callback( - [this](uint32_t new_playback_ms, uint32_t remainder_us, uint32_t pending_ms, uint32_t write_timestamp) { - uint32_t personal_playback_ms = std::min(new_playback_ms, this->pending_playback_ms_); - if (personal_playback_ms > 0) { - this->pending_playback_ms_ -= personal_playback_ms; - this->audio_output_callback_(personal_playback_ms, remainder_us, this->pending_playback_ms_, write_timestamp); - } - }); + this->parent_->get_output_speaker()->add_audio_output_callback([this](uint32_t new_frames, int64_t write_timestamp) { + // The SourceSpeaker may not have included any audio in the mixed output, so verify there were pending frames + uint32_t speakers_playback_frames = std::min(new_frames, this->pending_playback_frames_); + this->pending_playback_frames_ -= speakers_playback_frames; + + if (speakers_playback_frames > 0) { + this->audio_output_callback_(speakers_playback_frames, write_timestamp); + } + }); } void SourceSpeaker::loop() { @@ -153,6 +154,7 @@ esp_err_t SourceSpeaker::start_() { } } + this->pending_playback_frames_ = 0; // reset return this->parent_->start(this->audio_stream_info_); } @@ -542,11 +544,7 @@ void MixerSpeaker::audio_mixer_task(void *params) { // Update source speaker buffer length transfer_buffers_with_data[0]->decrease_buffer_length(active_stream_info.frames_to_bytes(frames_to_mix)); - speakers_with_data[0]->accumulated_frames_read_ += frames_to_mix; - - // Add new audio duration to the source speaker pending playback - speakers_with_data[0]->pending_playback_ms_ += - active_stream_info.frames_to_milliseconds_with_remainder(&speakers_with_data[0]->accumulated_frames_read_); + speakers_with_data[0]->pending_playback_frames_ += frames_to_mix; // Update output transfer buffer length output_transfer_buffer->increase_buffer_length( @@ -586,10 +584,6 @@ void MixerSpeaker::audio_mixer_task(void *params) { reinterpret_cast(output_transfer_buffer->get_buffer_end()), this_mixer->audio_stream_info_.value(), frames_to_mix); - speakers_with_data[i]->pending_playback_ms_ += - speakers_with_data[i]->get_audio_stream_info().frames_to_milliseconds_with_remainder( - &speakers_with_data[i]->accumulated_frames_read_); - if (i != transfer_buffers_with_data.size() - 1) { // Need to mix more streams together, point primary buffer and stream info to the already mixed output primary_buffer = reinterpret_cast(output_transfer_buffer->get_buffer_end()); @@ -601,11 +595,7 @@ void MixerSpeaker::audio_mixer_task(void *params) { for (int i = 0; i < transfer_buffers_with_data.size(); ++i) { transfer_buffers_with_data[i]->decrease_buffer_length( speakers_with_data[i]->get_audio_stream_info().frames_to_bytes(frames_to_mix)); - speakers_with_data[i]->accumulated_frames_read_ += frames_to_mix; - - speakers_with_data[i]->pending_playback_ms_ += - speakers_with_data[i]->get_audio_stream_info().frames_to_milliseconds_with_remainder( - &speakers_with_data[i]->accumulated_frames_read_); + speakers_with_data[i]->pending_playback_frames_ += frames_to_mix; } // Update output transfer buffer length diff --git a/esphome/components/mixer/speaker/mixer_speaker.h b/esphome/components/mixer/speaker/mixer_speaker.h index 0bd6b5f4c8..48bacc4471 100644 --- a/esphome/components/mixer/speaker/mixer_speaker.h +++ b/esphome/components/mixer/speaker/mixer_speaker.h @@ -114,9 +114,7 @@ class SourceSpeaker : public speaker::Speaker, public Component { uint32_t ducking_transition_samples_remaining_{0}; uint32_t samples_per_ducking_step_{0}; - uint32_t accumulated_frames_read_{0}; - - uint32_t pending_playback_ms_{0}; + uint32_t pending_playback_frames_{0}; }; class MixerSpeaker : public Component { diff --git a/esphome/components/mlx90393/sensor.py b/esphome/components/mlx90393/sensor.py index cb9cb84aae..372bb05bda 100644 --- a/esphome/components/mlx90393/sensor.py +++ b/esphome/components/mlx90393/sensor.py @@ -63,6 +63,11 @@ def _validate(config): raise cv.Invalid( f"{axis}: {CONF_RESOLUTION} cannot be {res} with {CONF_TEMPERATURE_COMPENSATION} enabled" ) + if config[CONF_HALLCONF] == 0xC: + if (config[CONF_OVERSAMPLING], config[CONF_FILTER]) in [(0, 0), (1, 0), (0, 1)]: + raise cv.Invalid( + f"{CONF_OVERSAMPLING}=={config[CONF_OVERSAMPLING]} and {CONF_FILTER}=={config[CONF_FILTER]} not allowed with {CONF_HALLCONF}=={config[CONF_HALLCONF]:#02x}" + ) return config diff --git a/esphome/components/mlx90393/sensor_mlx90393.cpp b/esphome/components/mlx90393/sensor_mlx90393.cpp index e86080fe9c..46fe68fab0 100644 --- a/esphome/components/mlx90393/sensor_mlx90393.cpp +++ b/esphome/components/mlx90393/sensor_mlx90393.cpp @@ -6,13 +6,41 @@ namespace mlx90393 { static const char *const TAG = "mlx90393"; +const LogString *settings_to_string(MLX90393Setting setting) { + switch (setting) { + case MLX90393_GAIN_SEL: + return LOG_STR("gain"); + case MLX90393_RESOLUTION: + return LOG_STR("resolution"); + case MLX90393_OVER_SAMPLING: + return LOG_STR("oversampling"); + case MLX90393_DIGITAL_FILTERING: + return LOG_STR("digital filtering"); + case MLX90393_TEMPERATURE_OVER_SAMPLING: + return LOG_STR("temperature oversampling"); + case MLX90393_TEMPERATURE_COMPENSATION: + return LOG_STR("temperature compensation"); + case MLX90393_HALLCONF: + return LOG_STR("hallconf"); + case MLX90393_LAST: + return LOG_STR("error"); + default: + return LOG_STR("unknown"); + } +}; + bool MLX90393Cls::transceive(const uint8_t *request, size_t request_size, uint8_t *response, size_t response_size) { i2c::ErrorCode e = this->write(request, request_size); if (e != i2c::ErrorCode::ERROR_OK) { + ESP_LOGV(TAG, "i2c failed to write %u", e); return false; } e = this->read(response, response_size); - return e == i2c::ErrorCode::ERROR_OK; + if (e != i2c::ErrorCode::ERROR_OK) { + ESP_LOGV(TAG, "i2c failed to read %u", e); + return false; + } + return true; } bool MLX90393Cls::has_drdy_pin() { return this->drdy_pin_ != nullptr; } @@ -27,6 +55,53 @@ bool MLX90393Cls::read_drdy_pin() { void MLX90393Cls::sleep_millis(uint32_t millis) { delay(millis); } void MLX90393Cls::sleep_micros(uint32_t micros) { delayMicroseconds(micros); } +uint8_t MLX90393Cls::apply_setting_(MLX90393Setting which) { + uint8_t ret = -1; + switch (which) { + case MLX90393_GAIN_SEL: + ret = this->mlx_.setGainSel(this->gain_); + break; + case MLX90393_RESOLUTION: + ret = this->mlx_.setResolution(this->resolutions_[0], this->resolutions_[1], this->resolutions_[2]); + break; + case MLX90393_OVER_SAMPLING: + ret = this->mlx_.setOverSampling(this->oversampling_); + break; + case MLX90393_DIGITAL_FILTERING: + ret = this->mlx_.setDigitalFiltering(this->filter_); + break; + case MLX90393_TEMPERATURE_OVER_SAMPLING: + ret = this->mlx_.setTemperatureOverSampling(this->temperature_oversampling_); + break; + case MLX90393_TEMPERATURE_COMPENSATION: + ret = this->mlx_.setTemperatureCompensation(this->temperature_compensation_); + break; + case MLX90393_HALLCONF: + ret = this->mlx_.setHallConf(this->hallconf_); + break; + default: + break; + } + if (ret != MLX90393::STATUS_OK) { + ESP_LOGE(TAG, "failed to apply %s", LOG_STR_ARG(settings_to_string(which))); + } + return ret; +} + +bool MLX90393Cls::apply_all_settings_() { + // perform dummy read after reset + // first one always gets NAK even tough everything is fine + uint8_t ignore = 0; + this->mlx_.getGainSel(ignore); + + uint8_t result = MLX90393::STATUS_OK; + for (int i = MLX90393_GAIN_SEL; i != MLX90393_LAST; i++) { + MLX90393Setting stage = static_cast(i); + result |= this->apply_setting_(stage); + } + return result == MLX90393::STATUS_OK; +} + void MLX90393Cls::setup() { ESP_LOGCONFIG(TAG, "Setting up MLX90393..."); // note the two arguments A0 and A1 which are used to construct an i2c address @@ -34,19 +109,12 @@ void MLX90393Cls::setup() { // see the transceive function above, which uses the address from I2CComponent this->mlx_.begin_with_hal(this, 0, 0); - this->mlx_.setGainSel(this->gain_); + if (!this->apply_all_settings_()) { + this->mark_failed(); + } - this->mlx_.setResolution(this->resolutions_[0], this->resolutions_[1], this->resolutions_[2]); - - this->mlx_.setOverSampling(this->oversampling_); - - this->mlx_.setDigitalFiltering(this->filter_); - - this->mlx_.setTemperatureOverSampling(this->temperature_oversampling_); - - this->mlx_.setTemperatureCompensation(this->temperature_compensation_); - - this->mlx_.setHallConf(this->hallconf_); + // start verify settings process + this->set_timeout("verify settings", 3000, [this]() { this->verify_settings_timeout_(MLX90393_GAIN_SEL); }); } void MLX90393Cls::dump_config() { @@ -91,5 +159,119 @@ void MLX90393Cls::update() { } } +bool MLX90393Cls::verify_setting_(MLX90393Setting which) { + uint8_t read_value = 0xFF; + uint8_t expected_value = 0xFF; + uint8_t read_status = -1; + char read_back_str[25] = {0}; + + switch (which) { + case MLX90393_GAIN_SEL: { + read_status = this->mlx_.getGainSel(read_value); + expected_value = this->gain_; + break; + } + + case MLX90393_RESOLUTION: { + uint8_t read_resolutions[3] = {0xFF}; + read_status = this->mlx_.getResolution(read_resolutions[0], read_resolutions[1], read_resolutions[2]); + snprintf(read_back_str, sizeof(read_back_str), "%u %u %u expected %u %u %u", read_resolutions[0], + read_resolutions[1], read_resolutions[2], this->resolutions_[0], this->resolutions_[1], + this->resolutions_[2]); + bool is_correct = true; + for (int i = 0; i < 3; i++) { + is_correct &= read_resolutions[i] == this->resolutions_[i]; + } + if (is_correct) { + // set read_value and expected_value to same number, so the code blow recognizes it is correct + read_value = 0; + expected_value = 0; + } else { + // set to different numbers, to show incorrect + read_value = 1; + expected_value = 0; + } + break; + } + case MLX90393_OVER_SAMPLING: { + read_status = this->mlx_.getOverSampling(read_value); + expected_value = this->oversampling_; + break; + } + case MLX90393_DIGITAL_FILTERING: { + read_status = this->mlx_.getDigitalFiltering(read_value); + expected_value = this->filter_; + break; + } + case MLX90393_TEMPERATURE_OVER_SAMPLING: { + read_status = this->mlx_.getTemperatureOverSampling(read_value); + expected_value = this->temperature_oversampling_; + break; + } + case MLX90393_TEMPERATURE_COMPENSATION: { + read_status = this->mlx_.getTemperatureCompensation(read_value); + expected_value = (bool) this->temperature_compensation_; + break; + } + case MLX90393_HALLCONF: { + read_status = this->mlx_.getHallConf(read_value); + expected_value = this->hallconf_; + break; + } + default: { + return false; + } + } + if (read_status != MLX90393::STATUS_OK) { + ESP_LOGE(TAG, "verify error: failed to read %s", LOG_STR_ARG(settings_to_string(which))); + return false; + } + if (read_back_str[0] == 0x0) { + snprintf(read_back_str, sizeof(read_back_str), "%u expected %u", read_value, expected_value); + } + bool is_correct = read_value == expected_value; + if (!is_correct) { + ESP_LOGW(TAG, "verify failed: read back wrong %s: got %s", LOG_STR_ARG(settings_to_string(which)), read_back_str); + return false; + } + ESP_LOGD(TAG, "verify succeeded for %s. got %s", LOG_STR_ARG(settings_to_string(which)), read_back_str); + return true; +} + +/** + * Regularly checks that our settings are still applied. + * Used to catch spurious chip resets. + * + * returns true if everything is fine. + * false if not + */ +void MLX90393Cls::verify_settings_timeout_(MLX90393Setting stage) { + bool is_setting_ok = this->verify_setting_(stage); + + if (!is_setting_ok) { + if (this->mlx_.checkStatus(this->mlx_.reset()) != MLX90393::STATUS_OK) { + ESP_LOGE(TAG, "failed to reset device"); + this->status_set_error(); + this->mark_failed(); + return; + } + + if (!this->apply_all_settings_()) { + ESP_LOGE(TAG, "failed to re-apply settings"); + this->status_set_error(); + this->mark_failed(); + } else { + ESP_LOGI(TAG, "reset and re-apply settings completed"); + } + } + + MLX90393Setting next_stage = static_cast(static_cast(stage) + 1); + if (next_stage == MLX90393_LAST) { + next_stage = static_cast(0); + } + + this->set_timeout("verify settings", 3000, [this, next_stage]() { this->verify_settings_timeout_(next_stage); }); +} + } // namespace mlx90393 } // namespace esphome diff --git a/esphome/components/mlx90393/sensor_mlx90393.h b/esphome/components/mlx90393/sensor_mlx90393.h index 479891a76c..8a6f3321f9 100644 --- a/esphome/components/mlx90393/sensor_mlx90393.h +++ b/esphome/components/mlx90393/sensor_mlx90393.h @@ -1,15 +1,26 @@ #pragma once -#include "esphome/core/component.h" -#include "esphome/components/sensor/sensor.h" -#include "esphome/components/i2c/i2c.h" -#include "esphome/core/hal.h" #include #include +#include "esphome/components/i2c/i2c.h" +#include "esphome/components/sensor/sensor.h" +#include "esphome/core/component.h" +#include "esphome/core/hal.h" namespace esphome { namespace mlx90393 { +enum MLX90393Setting { + MLX90393_GAIN_SEL = 0, + MLX90393_RESOLUTION, + MLX90393_OVER_SAMPLING, + MLX90393_DIGITAL_FILTERING, + MLX90393_TEMPERATURE_OVER_SAMPLING, + MLX90393_TEMPERATURE_COMPENSATION, + MLX90393_HALLCONF, + MLX90393_LAST, +}; + class MLX90393Cls : public PollingComponent, public i2c::I2CDevice, public MLX90393Hal { public: void setup() override; @@ -58,6 +69,12 @@ class MLX90393Cls : public PollingComponent, public i2c::I2CDevice, public MLX90 bool temperature_compensation_{false}; uint8_t hallconf_{0xC}; GPIOPin *drdy_pin_{nullptr}; + + bool apply_all_settings_(); + uint8_t apply_setting_(MLX90393Setting which); + + bool verify_setting_(MLX90393Setting which); + void verify_settings_timeout_(MLX90393Setting stage); }; } // namespace mlx90393 diff --git a/esphome/components/mqtt/__init__.py b/esphome/components/mqtt/__init__.py index 99f8ad76d8..63d8da5788 100644 --- a/esphome/components/mqtt/__init__.py +++ b/esphome/components/mqtt/__init__.py @@ -41,6 +41,7 @@ from esphome.const import ( CONF_REBOOT_TIMEOUT, CONF_RETAIN, CONF_SHUTDOWN_MESSAGE, + CONF_SKIP_CERT_CN_CHECK, CONF_SSL_FINGERPRINTS, CONF_STATE_TOPIC, CONF_SUBSCRIBE_QOS, @@ -67,7 +68,6 @@ def AUTO_LOAD(): CONF_DISCOVER_IP = "discover_ip" CONF_IDF_SEND_ASYNC = "idf_send_async" -CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" def validate_message_just_topic(value): diff --git a/esphome/components/mqtt/mqtt_client.cpp b/esphome/components/mqtt/mqtt_client.cpp index 9afa3a588d..1fcef3293c 100644 --- a/esphome/components/mqtt/mqtt_client.cpp +++ b/esphome/components/mqtt/mqtt_client.cpp @@ -138,7 +138,11 @@ void MQTTClientComponent::send_device_info_() { #endif #ifdef USE_API_NOISE - root["api_encryption"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; + if (api::global_api_server->get_noise_ctx()->has_psk()) { + root["api_encryption"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; + } else { + root["api_encryption_supported"] = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; + } #endif }, 2, this->discovery_info_.retain); diff --git a/esphome/components/mqtt/mqtt_const.h b/esphome/components/mqtt/mqtt_const.h index 445457a27f..3ddd8fc5cc 100644 --- a/esphome/components/mqtt/mqtt_const.h +++ b/esphome/components/mqtt/mqtt_const.h @@ -64,6 +64,8 @@ constexpr const char *const MQTT_DEVICE_NAME = "name"; constexpr const char *const MQTT_DEVICE_SUGGESTED_AREA = "sa"; constexpr const char *const MQTT_DEVICE_SW_VERSION = "sw"; constexpr const char *const MQTT_DEVICE_HW_VERSION = "hw"; +constexpr const char *const MQTT_DIRECTION_COMMAND_TOPIC = "dir_cmd_t"; +constexpr const char *const MQTT_DIRECTION_STATE_TOPIC = "dir_stat_t"; constexpr const char *const MQTT_DOCKED_TEMPLATE = "dock_tpl"; constexpr const char *const MQTT_DOCKED_TOPIC = "dock_t"; constexpr const char *const MQTT_EFFECT_COMMAND_TOPIC = "fx_cmd_t"; @@ -328,6 +330,8 @@ constexpr const char *const MQTT_DEVICE_NAME = "name"; constexpr const char *const MQTT_DEVICE_SUGGESTED_AREA = "suggested_area"; constexpr const char *const MQTT_DEVICE_SW_VERSION = "sw_version"; constexpr const char *const MQTT_DEVICE_HW_VERSION = "hw_version"; +constexpr const char *const MQTT_DIRECTION_COMMAND_TOPIC = "direction_command_topic"; +constexpr const char *const MQTT_DIRECTION_STATE_TOPIC = "direction_state_topic"; constexpr const char *const MQTT_DOCKED_TEMPLATE = "docked_template"; constexpr const char *const MQTT_DOCKED_TOPIC = "docked_topic"; constexpr const char *const MQTT_EFFECT_COMMAND_TOPIC = "effect_command_topic"; diff --git a/esphome/components/mqtt/mqtt_fan.cpp b/esphome/components/mqtt/mqtt_fan.cpp index 32892199fe..9e5ea54bee 100644 --- a/esphome/components/mqtt/mqtt_fan.cpp +++ b/esphome/components/mqtt/mqtt_fan.cpp @@ -43,6 +43,32 @@ void MQTTFanComponent::setup() { } }); + if (this->state_->get_traits().supports_direction()) { + this->subscribe(this->get_direction_command_topic(), [this](const std::string &topic, const std::string &payload) { + auto val = parse_on_off(payload.c_str(), "forward", "reverse"); + switch (val) { + case PARSE_ON: + ESP_LOGD(TAG, "'%s': Setting direction FORWARD", this->friendly_name().c_str()); + this->state_->make_call().set_direction(fan::FanDirection::FORWARD).perform(); + break; + case PARSE_OFF: + ESP_LOGD(TAG, "'%s': Setting direction REVERSE", this->friendly_name().c_str()); + this->state_->make_call().set_direction(fan::FanDirection::REVERSE).perform(); + break; + case PARSE_TOGGLE: + this->state_->make_call() + .set_direction(this->state_->direction == fan::FanDirection::FORWARD ? fan::FanDirection::REVERSE + : fan::FanDirection::FORWARD) + .perform(); + break; + case PARSE_NONE: + ESP_LOGW(TAG, "Unknown direction Payload %s", payload.c_str()); + this->status_momentary_warning("direction", 5000); + break; + } + }); + } + if (this->state_->get_traits().supports_oscillation()) { this->subscribe(this->get_oscillation_command_topic(), [this](const std::string &topic, const std::string &payload) { @@ -94,6 +120,10 @@ void MQTTFanComponent::setup() { void MQTTFanComponent::dump_config() { ESP_LOGCONFIG(TAG, "MQTT Fan '%s': ", this->state_->get_name().c_str()); LOG_MQTT_COMPONENT(true, true); + if (this->state_->get_traits().supports_direction()) { + ESP_LOGCONFIG(TAG, " Direction State Topic: '%s'", this->get_direction_state_topic().c_str()); + ESP_LOGCONFIG(TAG, " Direction Command Topic: '%s'", this->get_direction_command_topic().c_str()); + } if (this->state_->get_traits().supports_oscillation()) { ESP_LOGCONFIG(TAG, " Oscillation State Topic: '%s'", this->get_oscillation_state_topic().c_str()); ESP_LOGCONFIG(TAG, " Oscillation Command Topic: '%s'", this->get_oscillation_command_topic().c_str()); @@ -107,6 +137,10 @@ void MQTTFanComponent::dump_config() { bool MQTTFanComponent::send_initial_state() { return this->publish_state(); } void MQTTFanComponent::send_discovery(JsonObject root, mqtt::SendDiscoveryConfig &config) { + if (this->state_->get_traits().supports_direction()) { + root[MQTT_DIRECTION_COMMAND_TOPIC] = this->get_direction_command_topic(); + root[MQTT_DIRECTION_STATE_TOPIC] = this->get_direction_state_topic(); + } if (this->state_->get_traits().supports_oscillation()) { root[MQTT_OSCILLATION_COMMAND_TOPIC] = this->get_oscillation_command_topic(); root[MQTT_OSCILLATION_STATE_TOPIC] = this->get_oscillation_state_topic(); @@ -122,6 +156,11 @@ bool MQTTFanComponent::publish_state() { ESP_LOGD(TAG, "'%s' Sending state %s.", this->state_->get_name().c_str(), state_s); this->publish(this->get_state_topic_(), state_s); bool failed = false; + if (this->state_->get_traits().supports_direction()) { + bool success = this->publish(this->get_direction_state_topic(), + this->state_->direction == fan::FanDirection::FORWARD ? "forward" : "reverse"); + failed = failed || !success; + } if (this->state_->get_traits().supports_oscillation()) { bool success = this->publish(this->get_oscillation_state_topic(), this->state_->oscillating ? "oscillate_on" : "oscillate_off"); diff --git a/esphome/components/mqtt/mqtt_fan.h b/esphome/components/mqtt/mqtt_fan.h index 12286b9f01..fdcec0782d 100644 --- a/esphome/components/mqtt/mqtt_fan.h +++ b/esphome/components/mqtt/mqtt_fan.h @@ -15,6 +15,8 @@ class MQTTFanComponent : public mqtt::MQTTComponent { public: explicit MQTTFanComponent(fan::Fan *state); + MQTT_COMPONENT_CUSTOM_TOPIC(direction, command) + MQTT_COMPONENT_CUSTOM_TOPIC(direction, state) MQTT_COMPONENT_CUSTOM_TOPIC(oscillation, command) MQTT_COMPONENT_CUSTOM_TOPIC(oscillation, state) MQTT_COMPONENT_CUSTOM_TOPIC(speed_level, command) diff --git a/esphome/components/network/__init__.py b/esphome/components/network/__init__.py index be4e102930..129b1ced06 100644 --- a/esphome/components/network/__init__.py +++ b/esphome/components/network/__init__.py @@ -26,7 +26,7 @@ CONFIG_SCHEMA = cv.Schema( esp32_arduino=cv.Version(0, 0, 0), esp8266_arduino=cv.Version(0, 0, 0), rp2040_arduino=cv.Version(0, 0, 0), - bk72xx_libretiny=cv.Version(1, 7, 0), + bk72xx_arduino=cv.Version(1, 7, 0), ), cv.boolean_false, ), diff --git a/esphome/components/nextion/base_component.py b/esphome/components/nextion/base_component.py index 9708379861..0058d957dc 100644 --- a/esphome/components/nextion/base_component.py +++ b/esphome/components/nextion/base_component.py @@ -7,28 +7,29 @@ from esphome.const import CONF_BACKGROUND_COLOR, CONF_FOREGROUND_COLOR, CONF_VIS from . import CONF_NEXTION_ID, Nextion -CONF_VARIABLE_NAME = "variable_name" +CONF_AUTO_WAKE_ON_TOUCH = "auto_wake_on_touch" +CONF_BACKGROUND_PRESSED_COLOR = "background_pressed_color" +CONF_COMMAND_SPACING = "command_spacing" CONF_COMPONENT_NAME = "component_name" -CONF_WAVE_CHANNEL_ID = "wave_channel_id" -CONF_WAVE_MAX_VALUE = "wave_max_value" -CONF_PRECISION = "precision" -CONF_WAVEFORM_SEND_LAST_VALUE = "waveform_send_last_value" -CONF_TFT_URL = "tft_url" +CONF_EXIT_REPARSE_ON_START = "exit_reparse_on_start" +CONF_FONT_ID = "font_id" +CONF_FOREGROUND_PRESSED_COLOR = "foreground_pressed_color" +CONF_ON_BUFFER_OVERFLOW = "on_buffer_overflow" +CONF_ON_PAGE = "on_page" +CONF_ON_SETUP = "on_setup" CONF_ON_SLEEP = "on_sleep" CONF_ON_WAKE = "on_wake" -CONF_ON_SETUP = "on_setup" -CONF_ON_PAGE = "on_page" -CONF_ON_BUFFER_OVERFLOW = "on_buffer_overflow" -CONF_TOUCH_SLEEP_TIMEOUT = "touch_sleep_timeout" -CONF_WAKE_UP_PAGE = "wake_up_page" -CONF_START_UP_PAGE = "start_up_page" -CONF_AUTO_WAKE_ON_TOUCH = "auto_wake_on_touch" -CONF_WAVE_MAX_LENGTH = "wave_max_length" -CONF_BACKGROUND_PRESSED_COLOR = "background_pressed_color" -CONF_FOREGROUND_PRESSED_COLOR = "foreground_pressed_color" -CONF_FONT_ID = "font_id" -CONF_EXIT_REPARSE_ON_START = "exit_reparse_on_start" +CONF_PRECISION = "precision" CONF_SKIP_CONNECTION_HANDSHAKE = "skip_connection_handshake" +CONF_START_UP_PAGE = "start_up_page" +CONF_TFT_URL = "tft_url" +CONF_TOUCH_SLEEP_TIMEOUT = "touch_sleep_timeout" +CONF_VARIABLE_NAME = "variable_name" +CONF_WAKE_UP_PAGE = "wake_up_page" +CONF_WAVE_CHANNEL_ID = "wave_channel_id" +CONF_WAVE_MAX_LENGTH = "wave_max_length" +CONF_WAVE_MAX_VALUE = "wave_max_value" +CONF_WAVEFORM_SEND_LAST_VALUE = "waveform_send_last_value" def NextionName(value): diff --git a/esphome/components/nextion/display.py b/esphome/components/nextion/display.py index 60f26e5234..2e7c1c2825 100644 --- a/esphome/components/nextion/display.py +++ b/esphome/components/nextion/display.py @@ -9,16 +9,17 @@ from esphome.const import ( CONF_ON_TOUCH, CONF_TRIGGER_ID, ) -from esphome.core import CORE +from esphome.core import CORE, TimePeriod from . import Nextion, nextion_ns, nextion_ref from .base_component import ( CONF_AUTO_WAKE_ON_TOUCH, + CONF_COMMAND_SPACING, CONF_EXIT_REPARSE_ON_START, CONF_ON_BUFFER_OVERFLOW, - CONF_ON_PAGE, CONF_ON_SETUP, CONF_ON_SLEEP, + CONF_ON_PAGE, CONF_ON_WAKE, CONF_SKIP_CONNECTION_HANDSHAKE, CONF_START_UP_PAGE, @@ -88,6 +89,10 @@ CONFIG_SCHEMA = ( cv.Optional(CONF_AUTO_WAKE_ON_TOUCH, default=True): cv.boolean, cv.Optional(CONF_EXIT_REPARSE_ON_START, default=False): cv.boolean, cv.Optional(CONF_SKIP_CONNECTION_HANDSHAKE, default=False): cv.boolean, + cv.Optional(CONF_COMMAND_SPACING): cv.All( + cv.positive_time_period_milliseconds, + cv.Range(max=TimePeriod(milliseconds=255)), + ), } ) .extend(cv.polling_component_schema("5s")) @@ -120,6 +125,10 @@ async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await uart.register_uart_device(var, config) + if command_spacing := config.get(CONF_COMMAND_SPACING): + cg.add_define("USE_NEXTION_COMMAND_SPACING") + cg.add(var.set_command_spacing(command_spacing.total_milliseconds)) + if CONF_BRIGHTNESS in config: cg.add(var.set_brightness(config[CONF_BRIGHTNESS])) diff --git a/esphome/components/nextion/nextion.cpp b/esphome/components/nextion/nextion.cpp index 67f08f68f8..38e37300af 100644 --- a/esphome/components/nextion/nextion.cpp +++ b/esphome/components/nextion/nextion.cpp @@ -31,11 +31,22 @@ bool Nextion::send_command_(const std::string &command) { return false; } +#ifdef USE_NEXTION_COMMAND_SPACING + if (!this->ignore_is_setup_ && !this->command_pacer_.can_send()) { + return false; + } +#endif // USE_NEXTION_COMMAND_SPACING + ESP_LOGN(TAG, "send_command %s", command.c_str()); this->write_str(command.c_str()); const uint8_t to_send[3] = {0xFF, 0xFF, 0xFF}; this->write_array(to_send, sizeof(to_send)); + +#ifdef USE_NEXTION_COMMAND_SPACING + this->command_pacer_.mark_sent(); +#endif // USE_NEXTION_COMMAND_SPACING + return true; } @@ -158,6 +169,10 @@ void Nextion::dump_config() { if (this->start_up_page_ != -1) { ESP_LOGCONFIG(TAG, " Start Up Page: %" PRId16, this->start_up_page_); } + +#ifdef USE_NEXTION_COMMAND_SPACING + ESP_LOGCONFIG(TAG, " Command spacing: %" PRIu8 "ms", this->command_pacer_.get_spacing()); +#endif // USE_NEXTION_COMMAND_SPACING } float Nextion::get_setup_priority() const { return setup_priority::DATA; } @@ -312,6 +327,11 @@ bool Nextion::remove_from_q_(bool report_empty) { } NextionQueue *nb = this->nextion_queue_.front(); + if (!nb || !nb->component) { + ESP_LOGE(TAG, "Invalid queue entry!"); + this->nextion_queue_.pop_front(); + return false; + } NextionComponentBase *component = nb->component; ESP_LOGN(TAG, "Removing %s from the queue", component->get_variable_name().c_str()); @@ -341,6 +361,12 @@ void Nextion::process_nextion_commands_() { return; } +#ifdef USE_NEXTION_COMMAND_SPACING + if (!this->command_pacer_.can_send()) { + return; // Will try again in next loop iteration + } +#endif + size_t to_process_length = 0; std::string to_process; @@ -380,7 +406,9 @@ void Nextion::process_nextion_commands_() { this->setup_callback_.call(); } } - +#ifdef USE_NEXTION_COMMAND_SPACING + this->command_pacer_.mark_sent(); // Here is where we should mark the command as sent +#endif break; case 0x02: // invalid Component ID or name was used ESP_LOGW(TAG, "Nextion reported component ID or name invalid!"); @@ -524,6 +552,11 @@ void Nextion::process_nextion_commands_() { } NextionQueue *nb = this->nextion_queue_.front(); + if (!nb || !nb->component) { + ESP_LOGE(TAG, "Invalid queue entry!"); + this->nextion_queue_.pop_front(); + return; + } NextionComponentBase *component = nb->component; if (component->get_queue_type() != NextionQueueType::TEXT_SENSOR) { @@ -564,6 +597,11 @@ void Nextion::process_nextion_commands_() { } NextionQueue *nb = this->nextion_queue_.front(); + if (!nb || !nb->component) { + ESP_LOGE(TAG, "Invalid queue entry!"); + this->nextion_queue_.pop_front(); + return; + } NextionComponentBase *component = nb->component; if (component->get_queue_type() != NextionQueueType::SENSOR && diff --git a/esphome/components/nextion/nextion.h b/esphome/components/nextion/nextion.h index b2404e1f0d..4bc5305923 100644 --- a/esphome/components/nextion/nextion.h +++ b/esphome/components/nextion/nextion.h @@ -35,8 +35,54 @@ using nextion_writer_t = std::function; static const std::string COMMAND_DELIMITER{static_cast(255), static_cast(255), static_cast(255)}; +#ifdef USE_NEXTION_COMMAND_SPACING +class NextionCommandPacer { + public: + /** + * @brief Creates command pacer with initial spacing + * @param initial_spacing Initial time between commands in milliseconds + */ + explicit NextionCommandPacer(uint8_t initial_spacing = 0) : spacing_ms_(initial_spacing) {} + + /** + * @brief Set the minimum time between commands + * @param spacing_ms Spacing in milliseconds + */ + void set_spacing(uint8_t spacing_ms) { spacing_ms_ = spacing_ms; } + + /** + * @brief Get current command spacing + * @return Current spacing in milliseconds + */ + uint8_t get_spacing() const { return spacing_ms_; } + + /** + * @brief Check if enough time has passed to send next command + * @return true if enough time has passed since last command + */ + bool can_send() const { return (millis() - last_command_time_) >= spacing_ms_; } + + /** + * @brief Mark a command as sent, updating the timing + */ + void mark_sent() { last_command_time_ = millis(); } + + private: + uint8_t spacing_ms_; + uint32_t last_command_time_{0}; +}; +#endif // USE_NEXTION_COMMAND_SPACING + class Nextion : public NextionBase, public PollingComponent, public uart::UARTDevice { public: +#ifdef USE_NEXTION_COMMAND_SPACING + /** + * @brief Set the command spacing for the display + * @param spacing_ms Time in milliseconds between commands + */ + void set_command_spacing(uint32_t spacing_ms) { this->command_pacer_.set_spacing(spacing_ms); } +#endif // USE_NEXTION_COMMAND_SPACING + /** * Set the text of a component to a static string. * @param component The component name. @@ -1227,6 +1273,9 @@ class Nextion : public NextionBase, public PollingComponent, public uart::UARTDe bool is_connected() { return this->is_connected_; } protected: +#ifdef USE_NEXTION_COMMAND_SPACING + NextionCommandPacer command_pacer_{0}; +#endif // USE_NEXTION_COMMAND_SPACING std::deque nextion_queue_; std::deque waveform_queue_; uint16_t recv_ret_string_(std::string &response, uint32_t timeout, bool recv_flag); @@ -1360,5 +1409,6 @@ class Nextion : public NextionBase, public PollingComponent, public uart::UARTDe uint32_t started_ms_ = 0; bool sent_setup_commands_ = false; }; + } // namespace nextion } // namespace esphome diff --git a/esphome/components/number/__init__.py b/esphome/components/number/__init__.py index f45cfd54f2..7aa103e9d9 100644 --- a/esphome/components/number/__init__.py +++ b/esphome/components/number/__init__.py @@ -170,7 +170,7 @@ NUMBER_OPERATION_OPTIONS = { validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_") validate_unit_of_measurement = cv.string_strict -NUMBER_SCHEMA = ( +_NUMBER_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( @@ -196,16 +196,14 @@ NUMBER_SCHEMA = ( ) ) -_UNDEF = object() - def number_schema( class_: MockObjClass, *, - icon: str = _UNDEF, - entity_category: str = _UNDEF, - device_class: str = _UNDEF, - unit_of_measurement: str = _UNDEF, + icon: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, + unit_of_measurement: str = cv.UNDEFINED, ) -> cv.Schema: schema = {cv.GenerateID(): cv.declare_id(class_)} @@ -215,10 +213,15 @@ def number_schema( (CONF_DEVICE_CLASS, device_class, validate_device_class), (CONF_UNIT_OF_MEASUREMENT, unit_of_measurement, validate_unit_of_measurement), ]: - if default is not _UNDEF: + if default is not cv.UNDEFINED: schema[cv.Optional(key, default=default)] = validator - return NUMBER_SCHEMA.extend(schema) + return _NUMBER_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +NUMBER_SCHEMA = number_schema(Number) +NUMBER_SCHEMA.add_extra(cv.deprecated_schema_constant("number")) async def setup_number_core_( diff --git a/esphome/components/one_wire/one_wire_bus.cpp b/esphome/components/one_wire/one_wire_bus.cpp index a8d29428d3..c2542177cf 100644 --- a/esphome/components/one_wire/one_wire_bus.cpp +++ b/esphome/components/one_wire/one_wire_bus.cpp @@ -17,8 +17,15 @@ const uint8_t ONE_WIRE_ROM_SEARCH = 0xF0; const std::vector &OneWireBus::get_devices() { return this->devices_; } +bool OneWireBus::reset_() { + int res = this->reset_int(); + if (res == -1) + ESP_LOGE(TAG, "1-wire bus is held low"); + return res == 1; +} + bool IRAM_ATTR OneWireBus::select(uint64_t address) { - if (!this->reset()) + if (!this->reset_()) return false; this->write8(ONE_WIRE_ROM_SELECT); this->write64(address); @@ -31,16 +38,13 @@ void OneWireBus::search() { this->reset_search(); uint64_t address; while (true) { - { - InterruptLock lock; - if (!this->reset()) { - // Reset failed or no devices present - return; - } - - this->write8(ONE_WIRE_ROM_SEARCH); - address = this->search_int(); + if (!this->reset_()) { + // Reset failed or no devices present + return; } + + this->write8(ONE_WIRE_ROM_SEARCH); + address = this->search_int(); if (address == 0) break; auto *address8 = reinterpret_cast(&address); diff --git a/esphome/components/one_wire/one_wire_bus.h b/esphome/components/one_wire/one_wire_bus.h index 6818b17499..c88532046f 100644 --- a/esphome/components/one_wire/one_wire_bus.h +++ b/esphome/components/one_wire/one_wire_bus.h @@ -9,14 +9,6 @@ namespace one_wire { class OneWireBus { public: - /** Reset the bus, should be done before all write operations. - * - * Takes approximately 1ms. - * - * @return Whether the operation was successful. - */ - virtual bool reset() = 0; - /// Write a word to the bus. LSB first. virtual void write8(uint8_t val) = 0; @@ -50,6 +42,20 @@ class OneWireBus { /// log the found devices void dump_devices_(const char *tag); + /** Reset the bus, should be done before all write operations. + * + * Takes approximately 1ms. + * + * @return Whether the operation was successful. + */ + bool reset_(); + + /** + * Bus Reset + * @return -1: signal fail, 0: no device detected, 1: device detected + */ + virtual int reset_int() = 0; + /// Reset the device search. virtual void reset_search() = 0; diff --git a/esphome/components/online_image/bmp_image.cpp b/esphome/components/online_image/bmp_image.cpp index af9019a4d2..f55c9f1813 100644 --- a/esphome/components/online_image/bmp_image.cpp +++ b/esphome/components/online_image/bmp_image.cpp @@ -62,6 +62,13 @@ int HOT BmpDecoder::decode(uint8_t *buffer, size_t size) { case 1: this->width_bytes_ = (this->width_ % 8 == 0) ? (this->width_ / 8) : (this->width_ / 8 + 1); break; + case 24: + this->width_bytes_ = this->width_ * 3; + if (this->width_bytes_ % 4 != 0) { + this->padding_bytes_ = 4 - (this->width_bytes_ % 4); + this->width_bytes_ += this->padding_bytes_; + } + break; default: ESP_LOGE(TAG, "Unsupported bits per pixel: %d", this->bits_per_pixel_); return DECODE_ERROR_UNSUPPORTED_FORMAT; @@ -78,18 +85,48 @@ int HOT BmpDecoder::decode(uint8_t *buffer, size_t size) { this->current_index_ = this->data_offset_; index = this->data_offset_; } - while (index < size) { - size_t paint_index = this->current_index_ - this->data_offset_; - - uint8_t current_byte = buffer[index]; - for (uint8_t i = 0; i < 8; i++) { - size_t x = (paint_index * 8) % this->width_ + i; - size_t y = (this->height_ - 1) - (paint_index / this->width_bytes_); - Color c = (current_byte & (1 << (7 - i))) ? display::COLOR_ON : display::COLOR_OFF; - this->draw(x, y, 1, 1, c); + switch (this->bits_per_pixel_) { + case 1: { + while (index < size) { + uint8_t current_byte = buffer[index]; + for (uint8_t i = 0; i < 8; i++) { + size_t x = (this->paint_index_ % this->width_) + i; + size_t y = (this->height_ - 1) - (this->paint_index_ / this->width_); + Color c = (current_byte & (1 << (7 - i))) ? display::COLOR_ON : display::COLOR_OFF; + this->draw(x, y, 1, 1, c); + } + this->paint_index_ += 8; + this->current_index_++; + index++; + } + break; } - this->current_index_++; - index++; + case 24: { + while (index < size) { + if (index + 2 >= size) { + this->decoded_bytes_ += index; + return index; + } + uint8_t b = buffer[index]; + uint8_t g = buffer[index + 1]; + uint8_t r = buffer[index + 2]; + size_t x = this->paint_index_ % this->width_; + size_t y = (this->height_ - 1) - (this->paint_index_ / this->width_); + Color c = Color(r, g, b); + this->draw(x, y, 1, 1, c); + this->paint_index_++; + this->current_index_ += 3; + index += 3; + if (x == this->width_ - 1 && this->padding_bytes_ > 0) { + index += this->padding_bytes_; + this->current_index_ += this->padding_bytes_; + } + } + break; + } + default: + ESP_LOGE(TAG, "Unsupported bits per pixel: %d", this->bits_per_pixel_); + return DECODE_ERROR_UNSUPPORTED_FORMAT; } this->decoded_bytes_ += size; return size; diff --git a/esphome/components/online_image/bmp_image.h b/esphome/components/online_image/bmp_image.h index 61192f6a46..916ffea1ad 100644 --- a/esphome/components/online_image/bmp_image.h +++ b/esphome/components/online_image/bmp_image.h @@ -24,6 +24,7 @@ class BmpDecoder : public ImageDecoder { protected: size_t current_index_{0}; + size_t paint_index_{0}; ssize_t width_{0}; ssize_t height_{0}; uint16_t bits_per_pixel_{0}; @@ -32,6 +33,7 @@ class BmpDecoder : public ImageDecoder { uint32_t color_table_entries_{0}; size_t width_bytes_{0}; size_t data_offset_{0}; + uint8_t padding_bytes_{0}; }; } // namespace online_image diff --git a/esphome/components/opentherm/binary_sensor/__init__.py b/esphome/components/opentherm/binary_sensor/__init__.py index d4c7861a1d..35228228ea 100644 --- a/esphome/components/opentherm/binary_sensor/__init__.py +++ b/esphome/components/opentherm/binary_sensor/__init__.py @@ -11,10 +11,8 @@ COMPONENT_TYPE = const.BINARY_SENSOR def get_entity_validation_schema(entity: schema.BinarySensorSchema) -> cv.Schema: return binary_sensor.binary_sensor_schema( - device_class=( - entity.device_class or binary_sensor._UNDEF # pylint: disable=protected-access - ), - icon=(entity.icon or binary_sensor._UNDEF), # pylint: disable=protected-access + device_class=(entity.device_class or cv.UNDEFINED), + icon=(entity.icon or cv.UNDEFINED), ) diff --git a/esphome/components/opentherm/number/__init__.py b/esphome/components/opentherm/number/__init__.py index 00aa62483c..a65864647a 100644 --- a/esphome/components/opentherm/number/__init__.py +++ b/esphome/components/opentherm/number/__init__.py @@ -3,13 +3,7 @@ from typing import Any import esphome.codegen as cg from esphome.components import number import esphome.config_validation as cv -from esphome.const import ( - CONF_ID, - CONF_INITIAL_VALUE, - CONF_RESTORE_VALUE, - CONF_STEP, - CONF_UNIT_OF_MEASUREMENT, -) +from esphome.const import CONF_INITIAL_VALUE, CONF_RESTORE_VALUE, CONF_STEP from .. import const, generate, input, schema, validate @@ -22,33 +16,30 @@ OpenthermNumber = generate.opentherm_ns.class_( async def new_openthermnumber(config: dict[str, Any]) -> cg.Pvariable: - var = cg.new_Pvariable(config[CONF_ID]) - await cg.register_component(var, config) - await number.register_number( - var, + var = await number.new_number( config, min_value=config[input.CONF_min_value], max_value=config[input.CONF_max_value], step=config[input.CONF_step], ) + await cg.register_component(var, config) input.generate_setters(var, config) - if CONF_INITIAL_VALUE in config: - cg.add(var.set_initial_value(config[CONF_INITIAL_VALUE])) - if CONF_RESTORE_VALUE in config: - cg.add(var.set_restore_value(config[CONF_RESTORE_VALUE])) + if (initial_value := config.get(CONF_INITIAL_VALUE, None)) is not None: + cg.add(var.set_initial_value(initial_value)) + if (restore_value := config.get(CONF_RESTORE_VALUE, None)) is not None: + cg.add(var.set_restore_value(restore_value)) return var def get_entity_validation_schema(entity: schema.InputSchema) -> cv.Schema: return ( - number.NUMBER_SCHEMA.extend( + number.number_schema( + OpenthermNumber, unit_of_measurement=entity.unit_of_measurement + ) + .extend( { - cv.GenerateID(): cv.declare_id(OpenthermNumber), - cv.Optional( - CONF_UNIT_OF_MEASUREMENT, entity.unit_of_measurement - ): cv.string_strict, cv.Optional(CONF_STEP, entity.step): cv.float_, cv.Optional(CONF_INITIAL_VALUE): cv.float_, cv.Optional(CONF_RESTORE_VALUE): cv.boolean, diff --git a/esphome/components/opentherm/sensor/__init__.py b/esphome/components/opentherm/sensor/__init__.py index 86c842b299..9aa33f457d 100644 --- a/esphome/components/opentherm/sensor/__init__.py +++ b/esphome/components/opentherm/sensor/__init__.py @@ -23,10 +23,10 @@ MSG_DATA_TYPES = { def get_entity_validation_schema(entity: schema.SensorSchema) -> cv.Schema: return sensor.sensor_schema( - unit_of_measurement=entity.unit_of_measurement or sensor._UNDEF, # pylint: disable=protected-access + unit_of_measurement=entity.unit_of_measurement or cv.UNDEFINED, # pylint: disable=protected-access accuracy_decimals=entity.accuracy_decimals, - device_class=entity.device_class or sensor._UNDEF, # pylint: disable=protected-access - icon=entity.icon or sensor._UNDEF, # pylint: disable=protected-access + device_class=entity.device_class or cv.UNDEFINED, # pylint: disable=protected-access + icon=entity.icon or cv.UNDEFINED, # pylint: disable=protected-access state_class=entity.state_class, ).extend( { diff --git a/esphome/components/opentherm/switch/__init__.py b/esphome/components/opentherm/switch/__init__.py index ead086d24b..f8f09b3967 100644 --- a/esphome/components/opentherm/switch/__init__.py +++ b/esphome/components/opentherm/switch/__init__.py @@ -3,7 +3,6 @@ from typing import Any import esphome.codegen as cg from esphome.components import switch import esphome.config_validation as cv -from esphome.const import CONF_ID from .. import const, generate, schema, validate @@ -16,15 +15,14 @@ OpenthermSwitch = generate.opentherm_ns.class_( async def new_openthermswitch(config: dict[str, Any]) -> cg.Pvariable: - var = cg.new_Pvariable(config[CONF_ID]) + var = await switch.new_switch(config) await cg.register_component(var, config) - await switch.register_switch(var, config) return var def get_entity_validation_schema(entity: schema.SwitchSchema) -> cv.Schema: - return switch.SWITCH_SCHEMA.extend( - {cv.GenerateID(): cv.declare_id(OpenthermSwitch)} + return switch.switch_schema( + OpenthermSwitch, default_restore_mode=entity.default_mode ).extend(cv.COMPONENT_SCHEMA) diff --git a/esphome/components/output/lock/__init__.py b/esphome/components/output/lock/__init__.py index c9bdba0f75..553114b689 100644 --- a/esphome/components/output/lock/__init__.py +++ b/esphome/components/output/lock/__init__.py @@ -1,24 +1,26 @@ import esphome.codegen as cg from esphome.components import lock, output import esphome.config_validation as cv -from esphome.const import CONF_ID, CONF_OUTPUT +from esphome.const import CONF_OUTPUT from .. import output_ns OutputLock = output_ns.class_("OutputLock", lock.Lock, cg.Component) -CONFIG_SCHEMA = lock.LOCK_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(OutputLock), - cv.Required(CONF_OUTPUT): cv.use_id(output.BinaryOutput), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + lock.lock_schema(OutputLock) + .extend( + { + cv.Required(CONF_OUTPUT): cv.use_id(output.BinaryOutput), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await lock.new_lock(config) await cg.register_component(var, config) - await lock.register_lock(var, config) output_ = await cg.get_variable(config[CONF_OUTPUT]) cg.add(var.set_output(output_)) diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index f4d11e7bd0..08ae798282 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -24,22 +24,13 @@ DOMAIN = CONF_PACKAGES def validate_git_package(config: dict): + if CONF_URL not in config: + return config + config = BASE_SCHEMA(config) new_config = config - for key, conf in config.items(): - if CONF_URL in conf: - try: - conf = BASE_SCHEMA(conf) - if CONF_FILE in conf: - new_config[key][CONF_FILES] = [conf[CONF_FILE]] - del new_config[key][CONF_FILE] - except cv.MultipleInvalid as e: - with cv.prepend_path([key]): - raise e - except cv.Invalid as e: - raise cv.Invalid( - "Extra keys not allowed in git based package", - path=[key] + e.path, - ) from e + if CONF_FILE in config: + new_config[CONF_FILES] = [config[CONF_FILE]] + del new_config[CONF_FILE] return new_config @@ -74,8 +65,8 @@ BASE_SCHEMA = cv.All( cv.Required(CONF_URL): cv.url, cv.Optional(CONF_USERNAME): cv.string, cv.Optional(CONF_PASSWORD): cv.string, - cv.Exclusive(CONF_FILE, "files"): validate_yaml_filename, - cv.Exclusive(CONF_FILES, "files"): cv.All( + cv.Exclusive(CONF_FILE, CONF_FILES): validate_yaml_filename, + cv.Exclusive(CONF_FILES, CONF_FILES): cv.All( cv.ensure_list( cv.Any( validate_yaml_filename, @@ -100,14 +91,17 @@ BASE_SCHEMA = cv.All( cv.has_at_least_one_key(CONF_FILE, CONF_FILES), ) +PACKAGE_SCHEMA = cv.All( + cv.Any(validate_source_shorthand, BASE_SCHEMA, dict), validate_git_package +) -CONFIG_SCHEMA = cv.All( +CONFIG_SCHEMA = cv.Any( cv.Schema( { - str: cv.Any(validate_source_shorthand, BASE_SCHEMA, dict), + str: PACKAGE_SCHEMA, } ), - validate_git_package, + cv.ensure_list(PACKAGE_SCHEMA), ) @@ -183,25 +177,33 @@ def _process_base_package(config: dict) -> dict: return {"packages": packages} +def _process_package(package_config, config): + recursive_package = package_config + if CONF_URL in package_config: + package_config = _process_base_package(package_config) + if isinstance(package_config, dict): + recursive_package = do_packages_pass(package_config) + config = merge_config(recursive_package, config) + return config + + def do_packages_pass(config: dict): if CONF_PACKAGES not in config: return config packages = config[CONF_PACKAGES] with cv.prepend_path(CONF_PACKAGES): packages = CONFIG_SCHEMA(packages) - if not isinstance(packages, dict): + if isinstance(packages, dict): + for package_name, package_config in reversed(packages.items()): + with cv.prepend_path(package_name): + config = _process_package(package_config, config) + elif isinstance(packages, list): + for package_config in reversed(packages): + config = _process_package(package_config, config) + else: raise cv.Invalid( - f"Packages must be a key to value mapping, got {type(packages)} instead" + f"Packages must be a key to value mapping or list, got {type(packages)} instead" ) - for package_name, package_config in reversed(packages.items()): - with cv.prepend_path(package_name): - recursive_package = package_config - if CONF_URL in package_config: - package_config = _process_base_package(package_config) - if isinstance(package_config, dict): - recursive_package = do_packages_pass(package_config) - config = merge_config(recursive_package, config) - del config[CONF_PACKAGES] return config diff --git a/esphome/components/packet_transport/__init__.py b/esphome/components/packet_transport/__init__.py new file mode 100644 index 0000000000..99c1d824ca --- /dev/null +++ b/esphome/components/packet_transport/__init__.py @@ -0,0 +1,201 @@ +"""ESPHome packet transport component.""" + +import hashlib +import logging + +import esphome.codegen as cg +from esphome.components.api import CONF_ENCRYPTION +from esphome.components.binary_sensor import BinarySensor +from esphome.components.sensor import Sensor +import esphome.config_validation as cv +from esphome.const import ( + CONF_BINARY_SENSORS, + CONF_ID, + CONF_INTERNAL, + CONF_KEY, + CONF_NAME, + CONF_PLATFORM, + CONF_SENSORS, +) +from esphome.core import CORE +from esphome.cpp_generator import MockObjClass + +CODEOWNERS = ["@clydebarrow"] +AUTO_LOAD = ["xxtea"] + +packet_transport_ns = cg.esphome_ns.namespace("packet_transport") +PacketTransport = packet_transport_ns.class_("PacketTransport", cg.PollingComponent) + +IS_PLATFORM_COMPONENT = True + +DOMAIN = "packet_transport" +CONF_BROADCAST = "broadcast" +CONF_BROADCAST_ID = "broadcast_id" +CONF_PROVIDER = "provider" +CONF_PROVIDERS = "providers" +CONF_REMOTE_ID = "remote_id" +CONF_PING_PONG_ENABLE = "ping_pong_enable" +CONF_PING_PONG_RECYCLE_TIME = "ping_pong_recycle_time" +CONF_ROLLING_CODE_ENABLE = "rolling_code_enable" +CONF_TRANSPORT_ID = "transport_id" + + +_LOGGER = logging.getLogger(__name__) + + +def sensor_validation(cls: MockObjClass): + return cv.maybe_simple_value( + cv.Schema( + { + cv.Required(CONF_ID): cv.use_id(cls), + cv.Optional(CONF_BROADCAST_ID): cv.validate_id_name, + } + ), + key=CONF_ID, + ) + + +def provider_name_validate(value): + value = cv.valid_name(value) + if "_" in value: + _LOGGER.warning( + "Device names typically do not contain underscores - did you mean to use a hyphen in '%s'?", + value, + ) + return value + + +ENCRYPTION_SCHEMA = { + cv.Optional(CONF_ENCRYPTION): cv.maybe_simple_value( + cv.Schema( + { + cv.Required(CONF_KEY): cv.string, + } + ), + key=CONF_KEY, + ) +} + +PROVIDER_SCHEMA = cv.Schema( + { + cv.Required(CONF_NAME): provider_name_validate, + } +).extend(ENCRYPTION_SCHEMA) + + +def validate_(config): + if CONF_ENCRYPTION in config: + if CONF_SENSORS not in config and CONF_BINARY_SENSORS not in config: + raise cv.Invalid("No sensors or binary sensors to encrypt") + elif config[CONF_ROLLING_CODE_ENABLE]: + raise cv.Invalid("Rolling code requires an encryption key") + if config[CONF_PING_PONG_ENABLE]: + if not any(CONF_ENCRYPTION in p for p in config.get(CONF_PROVIDERS) or ()): + raise cv.Invalid("Ping-pong requires at least one encrypted provider") + return config + + +TRANSPORT_SCHEMA = ( + cv.polling_component_schema("15s") + .extend( + { + cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean, + cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean, + cv.Optional( + CONF_PING_PONG_RECYCLE_TIME, default="600s" + ): cv.positive_time_period_seconds, + cv.Optional(CONF_SENSORS): cv.ensure_list(sensor_validation(Sensor)), + cv.Optional(CONF_BINARY_SENSORS): cv.ensure_list( + sensor_validation(BinarySensor) + ), + cv.Optional(CONF_PROVIDERS, default=[]): cv.ensure_list(PROVIDER_SCHEMA), + }, + ) + .extend(ENCRYPTION_SCHEMA) + .add_extra(validate_) +) + + +def transport_schema(cls): + return TRANSPORT_SCHEMA.extend({cv.GenerateID(): cv.declare_id(cls)}) + + +# Build a list of sensors for this platform +CORE.data[DOMAIN] = {CONF_SENSORS: []} + + +def get_sensors(transport_id): + """Return the list of sensors for this platform.""" + return ( + sensor + for sensor in CORE.data[DOMAIN][CONF_SENSORS] + if sensor[CONF_TRANSPORT_ID] == transport_id + ) + + +def validate_packet_transport_sensor(config): + if CONF_NAME in config and CONF_INTERNAL not in config: + raise cv.Invalid("Must provide internal: config when using name:") + CORE.data[DOMAIN][CONF_SENSORS].append(config) + return config + + +def packet_transport_sensor_schema(base_schema): + return cv.All( + base_schema.extend( + { + cv.GenerateID(CONF_TRANSPORT_ID): cv.use_id(PacketTransport), + cv.Optional(CONF_REMOTE_ID): cv.string_strict, + cv.Required(CONF_PROVIDER): provider_name_validate, + } + ), + cv.has_at_least_one_key(CONF_ID, CONF_REMOTE_ID), + validate_packet_transport_sensor, + ) + + +def hash_encryption_key(config: dict): + return list(hashlib.sha256(config[CONF_KEY].encode()).digest()) + + +async def register_packet_transport(var, config): + var = await cg.register_component(var, config) + cg.add(var.set_rolling_code_enable(config[CONF_ROLLING_CODE_ENABLE])) + cg.add(var.set_ping_pong_enable(config[CONF_PING_PONG_ENABLE])) + cg.add( + var.set_ping_pong_recycle_time( + config[CONF_PING_PONG_RECYCLE_TIME].total_seconds + ) + ) + # Get directly configured providers, plus those from sensors and binary sensors + providers = { + sensor[CONF_PROVIDER] for sensor in get_sensors(config[CONF_ID]) + }.union(x[CONF_NAME] for x in config[CONF_PROVIDERS]) + for provider in providers: + cg.add(var.add_provider(provider)) + for provider in config[CONF_PROVIDERS]: + name = provider[CONF_NAME] + if encryption := provider.get(CONF_ENCRYPTION): + cg.add(var.set_provider_encryption(name, hash_encryption_key(encryption))) + + for sens_conf in config.get(CONF_SENSORS, ()): + sens_id = sens_conf[CONF_ID] + sensor = await cg.get_variable(sens_id) + bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) + cg.add(var.add_sensor(bcst_id, sensor)) + for sens_conf in config.get(CONF_BINARY_SENSORS, ()): + sens_id = sens_conf[CONF_ID] + sensor = await cg.get_variable(sens_id) + bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) + cg.add(var.add_binary_sensor(bcst_id, sensor)) + + if encryption := config.get(CONF_ENCRYPTION): + cg.add(var.set_encryption_key(hash_encryption_key(encryption))) + return providers + + +async def new_packet_transport(config): + var = cg.new_Pvariable(config[CONF_ID]) + cg.add(var.set_platform_name(config[CONF_PLATFORM])) + providers = await register_packet_transport(var, config) + return var, providers diff --git a/esphome/components/packet_transport/binary_sensor.py b/esphome/components/packet_transport/binary_sensor.py new file mode 100644 index 0000000000..076e37e6bb --- /dev/null +++ b/esphome/components/packet_transport/binary_sensor.py @@ -0,0 +1,19 @@ +import esphome.codegen as cg +from esphome.components import binary_sensor +from esphome.const import CONF_ID + +from . import ( + CONF_PROVIDER, + CONF_REMOTE_ID, + CONF_TRANSPORT_ID, + packet_transport_sensor_schema, +) + +CONFIG_SCHEMA = packet_transport_sensor_schema(binary_sensor.binary_sensor_schema()) + + +async def to_code(config): + var = await binary_sensor.new_binary_sensor(config) + comp = await cg.get_variable(config[CONF_TRANSPORT_ID]) + remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) + cg.add(comp.add_remote_binary_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/packet_transport/packet_transport.cpp b/esphome/components/packet_transport/packet_transport.cpp new file mode 100644 index 0000000000..be2f77e379 --- /dev/null +++ b/esphome/components/packet_transport/packet_transport.cpp @@ -0,0 +1,534 @@ +#include "esphome/core/log.h" +#include "esphome/core/application.h" +#include "packet_transport.h" + +#include "esphome/components/xxtea/xxtea.h" + +namespace esphome { +namespace packet_transport { +/** + * Structure of a data packet; everything is little-endian + * + * --- In clear text --- + * MAGIC_NUMBER: 16 bits + * host name length: 1 byte + * host name: (length) bytes + * padding: 0 or more null bytes to a 4 byte boundary + * + * --- Encrypted (if key set) ---- + * DATA_KEY: 1 byte: OR ROLLING_CODE_KEY: + * Rolling code (if enabled): 8 bytes + * Ping keys: if any + * repeat: + * PING_KEY: 1 byte + * ping code: 4 bytes + * Sensors: + * repeat: + * SENSOR_KEY: 1 byte + * float value: 4 bytes + * name length: 1 byte + * name + * Binary Sensors: + * repeat: + * BINARY_SENSOR_KEY: 1 byte + * bool value: 1 bytes + * name length: 1 byte + * name + * + * Padded to a 4 byte boundary with nulls + * + * Structure of a ping request packet: + * --- In clear text --- + * MAGIC_PING: 16 bits + * host name length: 1 byte + * host name: (length) bytes + * Ping key (4 bytes) + * + */ +static const char *const TAG = "packet_transport"; + +static size_t round4(size_t value) { return (value + 3) & ~3; } + +union FuData { + uint32_t u32; + float f32; +}; + +static const uint16_t MAGIC_NUMBER = 0x4553; +static const uint16_t MAGIC_PING = 0x5048; +static const uint32_t PREF_HASH = 0x45535043; +enum DataKey { + ZERO_FILL_KEY, + DATA_KEY, + SENSOR_KEY, + BINARY_SENSOR_KEY, + PING_KEY, + ROLLING_CODE_KEY, +}; + +enum DecodeResult { + DECODE_OK, + DECODE_UNMATCHED, + DECODE_ERROR, + DECODE_EMPTY, +}; + +static const size_t MAX_PING_KEYS = 4; + +static inline void add(std::vector &vec, uint32_t data) { + vec.push_back(data & 0xFF); + vec.push_back((data >> 8) & 0xFF); + vec.push_back((data >> 16) & 0xFF); + vec.push_back((data >> 24) & 0xFF); +} + +class PacketDecoder { + public: + PacketDecoder(const uint8_t *buffer, size_t len) : buffer_(buffer), len_(len) {} + + DecodeResult decode_string(char *data, size_t maxlen) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + auto len = this->buffer_[this->position_]; + if (len == 0 || this->position_ + 1 + len > this->len_ || len >= maxlen) + return DECODE_ERROR; + this->position_++; + memcpy(data, this->buffer_ + this->position_, len); + data[len] = 0; + this->position_ += len; + return DECODE_OK; + } + + template DecodeResult get(T &data) { + if (this->position_ + sizeof(T) > this->len_) + return DECODE_ERROR; + T value = 0; + for (size_t i = 0; i != sizeof(T); ++i) { + value += this->buffer_[this->position_++] << (i * 8); + } + data = value; + return DECODE_OK; + } + + template DecodeResult decode(uint8_t key, T &data) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + if (this->buffer_[this->position_] != key) + return DECODE_UNMATCHED; + if (this->position_ + 1 + sizeof(T) > this->len_) + return DECODE_ERROR; + this->position_++; + T value = 0; + for (size_t i = 0; i != sizeof(T); ++i) { + value += this->buffer_[this->position_++] << (i * 8); + } + data = value; + return DECODE_OK; + } + + template DecodeResult decode(uint8_t key, char *buf, size_t buflen, T &data) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + if (this->buffer_[this->position_] != key) + return DECODE_UNMATCHED; + this->position_++; + T value = 0; + for (size_t i = 0; i != sizeof(T); ++i) { + value += this->buffer_[this->position_++] << (i * 8); + } + data = value; + return this->decode_string(buf, buflen); + } + + DecodeResult decode(uint8_t key) { + if (this->position_ == this->len_) + return DECODE_EMPTY; + if (this->buffer_[this->position_] != key) + return DECODE_UNMATCHED; + this->position_++; + return DECODE_OK; + } + + size_t get_remaining_size() const { return this->len_ - this->position_; } + + // align the pointer to the given byte boundary + bool bump_to(size_t boundary) { + auto newpos = this->position_; + auto offset = this->position_ % boundary; + if (offset != 0) { + newpos += boundary - offset; + } + if (newpos >= this->len_) + return false; + this->position_ = newpos; + return true; + } + + bool decrypt(const uint32_t *key) { + if (this->get_remaining_size() % 4 != 0) { + return false; + } + xxtea::decrypt((uint32_t *) (this->buffer_ + this->position_), this->get_remaining_size() / 4, key); + return true; + } + + protected: + const uint8_t *buffer_; + size_t len_; + size_t position_{}; +}; + +static inline void add(std::vector &vec, uint8_t data) { vec.push_back(data); } +static inline void add(std::vector &vec, uint16_t data) { + vec.push_back((uint8_t) data); + vec.push_back((uint8_t) (data >> 8)); +} +static inline void add(std::vector &vec, DataKey data) { vec.push_back(data); } +static void add(std::vector &vec, const char *str) { + auto len = strlen(str); + vec.push_back(len); + for (size_t i = 0; i != len; i++) { + vec.push_back(*str++); + } +} + +void PacketTransport::setup() { + this->name_ = App.get_name().c_str(); + if (strlen(this->name_) > 255) { + this->mark_failed(); + this->status_set_error("Device name exceeds 255 chars"); + return; + } + this->resend_ping_key_ = this->ping_pong_enable_; + this->pref_ = global_preferences->make_preference(PREF_HASH, true); + if (this->rolling_code_enable_) { + // restore the upper 32 bits of the rolling code, increment and save. + this->pref_.load(&this->rolling_code_[1]); + this->rolling_code_[1]++; + this->pref_.save(&this->rolling_code_[1]); + // must make sure it's saved immediately + global_preferences->sync(); + this->ping_key_ = random_uint32(); + ESP_LOGV(TAG, "Rolling code incremented, upper part now %u", (unsigned) this->rolling_code_[1]); + } +#ifdef USE_SENSOR + for (auto &sensor : this->sensors_) { + sensor.sensor->add_on_state_callback([this, &sensor](float x) { + this->updated_ = true; + sensor.updated = true; + }); + } +#endif +#ifdef USE_BINARY_SENSOR + for (auto &sensor : this->binary_sensors_) { + sensor.sensor->add_on_state_callback([this, &sensor](bool value) { + this->updated_ = true; + sensor.updated = true; + }); + } +#endif + // initialise the header. This is invariant. + add(this->header_, MAGIC_NUMBER); + add(this->header_, this->name_); + // pad to a multiple of 4 bytes + while (this->header_.size() & 0x3) + this->header_.push_back(0); +} + +void PacketTransport::init_data_() { + this->data_.clear(); + if (this->rolling_code_enable_) { + add(this->data_, ROLLING_CODE_KEY); + add(this->data_, this->rolling_code_[0]); + add(this->data_, this->rolling_code_[1]); + this->increment_code_(); + } else { + add(this->data_, DATA_KEY); + } + for (auto pkey : this->ping_keys_) { + add(this->data_, PING_KEY); + add(this->data_, pkey.second); + } +} + +void PacketTransport::flush_() { + if (!this->should_send() || this->data_.empty()) + return; + auto header_len = round4(this->header_.size()); + auto len = round4(data_.size()); + auto encode_buffer = std::vector(round4(header_len + len)); + memcpy(encode_buffer.data(), this->header_.data(), this->header_.size()); + memcpy(encode_buffer.data() + header_len, this->data_.data(), this->data_.size()); + if (this->is_encrypted_()) { + xxtea::encrypt((uint32_t *) (encode_buffer.data() + header_len), len / 4, + (uint32_t *) this->encryption_key_.data()); + } + this->send_packet(encode_buffer); +} + +void PacketTransport::add_binary_data_(uint8_t key, const char *id, bool data) { + auto len = 1 + 1 + 1 + strlen(id); + if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { + this->flush_(); + } + add(this->data_, key); + add(this->data_, (uint8_t) data); + add(this->data_, id); +} +void PacketTransport::add_data_(uint8_t key, const char *id, float data) { + FuData udata{.f32 = data}; + this->add_data_(key, id, udata.u32); +} + +void PacketTransport::add_data_(uint8_t key, const char *id, uint32_t data) { + auto len = 4 + 1 + 1 + strlen(id); + if (len + this->header_.size() + this->data_.size() > this->get_max_packet_size()) { + this->flush_(); + } + add(this->data_, key); + add(this->data_, data); + add(this->data_, id); +} +void PacketTransport::send_data_(bool all) { + if (!this->should_send()) + return; + this->init_data_(); +#ifdef USE_SENSOR + for (auto &sensor : this->sensors_) { + if (all || sensor.updated) { + sensor.updated = false; + this->add_data_(SENSOR_KEY, sensor.id, sensor.sensor->get_state()); + } + } +#endif +#ifdef USE_BINARY_SENSOR + for (auto &sensor : this->binary_sensors_) { + if (all || sensor.updated) { + sensor.updated = false; + this->add_binary_data_(BINARY_SENSOR_KEY, sensor.id, sensor.sensor->state); + } + } +#endif + this->flush_(); + this->updated_ = false; +} + +void PacketTransport::update() { + auto now = millis() / 1000; + if (this->last_key_time_ + this->ping_pong_recyle_time_ < now) { + this->resend_ping_key_ = this->ping_pong_enable_; + this->last_key_time_ = now; + } +} + +void PacketTransport::add_key_(const char *name, uint32_t key) { + if (!this->is_encrypted_()) + return; + if (this->ping_keys_.count(name) == 0 && this->ping_keys_.size() == MAX_PING_KEYS) { + ESP_LOGW(TAG, "Ping key from %s discarded", name); + return; + } + this->ping_keys_[name] = key; + this->updated_ = true; + ESP_LOGV(TAG, "Ping key from %s now %X", name, (unsigned) key); +} + +static bool process_rolling_code(Provider &provider, PacketDecoder &decoder) { + uint32_t code0, code1; + if (decoder.get(code0) != DECODE_OK || decoder.get(code1) != DECODE_OK) { + ESP_LOGW(TAG, "Rolling code requires 8 bytes"); + return false; + } + if (code1 < provider.last_code[1] || (code1 == provider.last_code[1] && code0 <= provider.last_code[0])) { + ESP_LOGW(TAG, "Rolling code for %s %08lX:%08lX is old", provider.name, (unsigned long) code1, + (unsigned long) code0); + return false; + } + provider.last_code[0] = code0; + provider.last_code[1] = code1; + ESP_LOGV(TAG, "Saw new rolling code for %s %08lX:%08lX", provider.name, (unsigned long) code1, (unsigned long) code0); + return true; +} + +/** + * Process a received packet + */ +void PacketTransport::process_(const std::vector &data) { + auto ping_key_seen = !this->ping_pong_enable_; + PacketDecoder decoder((data.data()), data.size()); + char namebuf[256]{}; + uint8_t byte; + FuData rdata{}; + uint16_t magic; + if (decoder.get(magic) != DECODE_OK) { + ESP_LOGD(TAG, "Short buffer"); + return; + } + if (magic != MAGIC_NUMBER && magic != MAGIC_PING) { + ESP_LOGV(TAG, "Bad magic %X", magic); + return; + } + + if (decoder.decode_string(namebuf, sizeof namebuf) != DECODE_OK) { + ESP_LOGV(TAG, "Bad hostname length"); + return; + } + if (strcmp(this->name_, namebuf) == 0) { + ESP_LOGVV(TAG, "Ignoring our own data"); + return; + } + if (magic == MAGIC_PING) { + uint32_t key; + if (decoder.get(key) != DECODE_OK) { + ESP_LOGW(TAG, "Bad ping request"); + return; + } + this->add_key_(namebuf, key); + ESP_LOGV(TAG, "Updated ping key for %s to %08X", namebuf, (unsigned) key); + return; + } + + if (this->providers_.count(namebuf) == 0) { + ESP_LOGVV(TAG, "Unknown hostname %s", namebuf); + return; + } + ESP_LOGV(TAG, "Found hostname %s", namebuf); + +#ifdef USE_SENSOR + auto &sensors = this->remote_sensors_[namebuf]; +#endif +#ifdef USE_BINARY_SENSOR + auto &binary_sensors = this->remote_binary_sensors_[namebuf]; +#endif + + if (!decoder.bump_to(4)) { + ESP_LOGW(TAG, "Bad packet length %zu", data.size()); + } + auto len = decoder.get_remaining_size(); + if (round4(len) != len) { + ESP_LOGW(TAG, "Bad payload length %zu", len); + return; + } + + auto &provider = this->providers_[namebuf]; + // if encryption not used with this host, ping check is pointless since it would be easily spoofed. + if (provider.encryption_key.empty()) + ping_key_seen = true; + + if (!provider.encryption_key.empty()) { + decoder.decrypt((const uint32_t *) provider.encryption_key.data()); + } + if (decoder.get(byte) != DECODE_OK) { + ESP_LOGV(TAG, "No key byte"); + return; + } + + if (byte == ROLLING_CODE_KEY) { + if (!process_rolling_code(provider, decoder)) + return; + } else if (byte != DATA_KEY) { + ESP_LOGV(TAG, "Expected rolling_key or data_key, got %X", byte); + return; + } + uint32_t key; + while (decoder.get_remaining_size() != 0) { + if (decoder.decode(ZERO_FILL_KEY) == DECODE_OK) + continue; + if (decoder.decode(PING_KEY, key) == DECODE_OK) { + if (key == this->ping_key_) { + ping_key_seen = true; + ESP_LOGV(TAG, "Found good ping key %X", (unsigned) key); + } else { + ESP_LOGV(TAG, "Unknown ping key %X", (unsigned) key); + } + continue; + } + if (!ping_key_seen) { + ESP_LOGW(TAG, "Ping key not seen"); + this->resend_ping_key_ = true; + break; + } + if (decoder.decode(BINARY_SENSOR_KEY, namebuf, sizeof(namebuf), byte) == DECODE_OK) { + ESP_LOGV(TAG, "Got binary sensor %s %d", namebuf, byte); +#ifdef USE_BINARY_SENSOR + if (binary_sensors.count(namebuf) != 0) + binary_sensors[namebuf]->publish_state(byte != 0); +#endif + continue; + } + if (decoder.decode(SENSOR_KEY, namebuf, sizeof(namebuf), rdata.u32) == DECODE_OK) { + ESP_LOGV(TAG, "Got sensor %s %f", namebuf, rdata.f32); +#ifdef USE_SENSOR + if (sensors.count(namebuf) != 0) + sensors[namebuf]->publish_state(rdata.f32); +#endif + continue; + } + if (decoder.get(byte) == DECODE_OK) { + ESP_LOGW(TAG, "Unknown key %X", byte); + ESP_LOGD(TAG, "Buffer pos: %zu contents: %s", data.size() - decoder.get_remaining_size(), + format_hex_pretty(data).c_str()); + } + break; + } +} + +void PacketTransport::dump_config() { + ESP_LOGCONFIG(TAG, "Packet Transport:"); + ESP_LOGCONFIG(TAG, " Platform: %s", this->platform_name_); + ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(this->is_encrypted_())); + ESP_LOGCONFIG(TAG, " Ping-pong: %s", YESNO(this->ping_pong_enable_)); +#ifdef USE_SENSOR + for (auto sensor : this->sensors_) + ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.id); +#endif +#ifdef USE_BINARY_SENSOR + for (auto sensor : this->binary_sensors_) + ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.id); +#endif + for (const auto &host : this->providers_) { + ESP_LOGCONFIG(TAG, " Remote host: %s", host.first.c_str()); + ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(!host.second.encryption_key.empty())); +#ifdef USE_SENSOR + for (const auto &sensor : this->remote_sensors_[host.first.c_str()]) + ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.first.c_str()); +#endif +#ifdef USE_BINARY_SENSOR + for (const auto &sensor : this->remote_binary_sensors_[host.first.c_str()]) + ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.first.c_str()); +#endif + } +} +void PacketTransport::increment_code_() { + if (this->rolling_code_enable_) { + if (++this->rolling_code_[0] == 0) { + this->rolling_code_[1]++; + this->pref_.save(&this->rolling_code_[1]); + // must make sure it's saved immediately + global_preferences->sync(); + } + } +} + +void PacketTransport::loop() { + if (this->resend_ping_key_) + this->send_ping_pong_request_(); + if (this->updated_) { + this->send_data_(this->resend_data_); + } +} + +void PacketTransport::send_ping_pong_request_() { + if (!this->ping_pong_enable_ || !this->should_send()) + return; + this->ping_key_ = random_uint32(); + this->ping_header_.clear(); + add(this->ping_header_, MAGIC_PING); + add(this->ping_header_, this->name_); + add(this->ping_header_, this->ping_key_); + this->send_packet(this->ping_header_); + this->resend_ping_key_ = false; + ESP_LOGV(TAG, "Sent new ping request %08X", (unsigned) this->ping_key_); +} +} // namespace packet_transport +} // namespace esphome diff --git a/esphome/components/packet_transport/packet_transport.h b/esphome/components/packet_transport/packet_transport.h new file mode 100644 index 0000000000..34edb82963 --- /dev/null +++ b/esphome/components/packet_transport/packet_transport.h @@ -0,0 +1,154 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/core/preferences.h" +#ifdef USE_SENSOR +#include "esphome/components/sensor/sensor.h" +#endif +#ifdef USE_BINARY_SENSOR +#include "esphome/components/binary_sensor/binary_sensor.h" +#endif +# +#include +#include + +/** + * Providing packet encoding functions for exchanging data with a remote host. + * + * A transport is required to send the data; this is provided by a child class. + * The child class should implement the virtual functions send_packet_ and get_max_packet_size_. + * On receipt of a data packet, it should call `this->process_()` with the data. + */ + +namespace esphome { +namespace packet_transport { + +struct Provider { + std::vector encryption_key; + const char *name; + uint32_t last_code[2]; +}; + +#ifdef USE_SENSOR +struct Sensor { + sensor::Sensor *sensor; + const char *id; + bool updated; +}; +#endif +#ifdef USE_BINARY_SENSOR +struct BinarySensor { + binary_sensor::BinarySensor *sensor; + const char *id; + bool updated; +}; +#endif + +class PacketTransport : public PollingComponent { + public: + void setup() override; + void loop() override; + void update() override; + void dump_config() override; + +#ifdef USE_SENSOR + void add_sensor(const char *id, sensor::Sensor *sensor) { + Sensor st{sensor, id, true}; + this->sensors_.push_back(st); + } + void add_remote_sensor(const char *hostname, const char *remote_id, sensor::Sensor *sensor) { + this->add_provider(hostname); + this->remote_sensors_[hostname][remote_id] = sensor; + } +#endif +#ifdef USE_BINARY_SENSOR + void add_binary_sensor(const char *id, binary_sensor::BinarySensor *sensor) { + BinarySensor st{sensor, id, true}; + this->binary_sensors_.push_back(st); + } + + void add_remote_binary_sensor(const char *hostname, const char *remote_id, binary_sensor::BinarySensor *sensor) { + this->add_provider(hostname); + this->remote_binary_sensors_[hostname][remote_id] = sensor; + } +#endif + + void add_provider(const char *hostname) { + if (this->providers_.count(hostname) == 0) { + Provider provider; + provider.encryption_key = std::vector{}; + provider.last_code[0] = 0; + provider.last_code[1] = 0; + provider.name = hostname; + this->providers_[hostname] = provider; +#ifdef USE_SENSOR + this->remote_sensors_[hostname] = std::map(); +#endif +#ifdef USE_BINARY_SENSOR + this->remote_binary_sensors_[hostname] = std::map(); +#endif + } + } + + void set_encryption_key(std::vector key) { this->encryption_key_ = std::move(key); } + void set_rolling_code_enable(bool enable) { this->rolling_code_enable_ = enable; } + void set_ping_pong_enable(bool enable) { this->ping_pong_enable_ = enable; } + void set_ping_pong_recycle_time(uint32_t recycle_time) { this->ping_pong_recyle_time_ = recycle_time; } + void set_provider_encryption(const char *name, std::vector key) { + this->providers_[name].encryption_key = std::move(key); + } + void set_platform_name(const char *name) { this->platform_name_ = name; } + + protected: + // child classes must implement this + virtual void send_packet(const std::vector &buf) const = 0; + virtual size_t get_max_packet_size() = 0; + virtual bool should_send() { return true; } + + // to be called by child classes when a data packet is received. + void process_(const std::vector &data); + void send_data_(bool all); + void flush_(); + void add_data_(uint8_t key, const char *id, float data); + void add_data_(uint8_t key, const char *id, uint32_t data); + void increment_code_(); + void add_binary_data_(uint8_t key, const char *id, bool data); + void init_data_(); + + bool updated_{}; + uint32_t ping_key_{}; + uint32_t rolling_code_[2]{}; + bool rolling_code_enable_{}; + bool ping_pong_enable_{}; + uint32_t ping_pong_recyle_time_{}; + uint32_t last_key_time_{}; + bool resend_ping_key_{}; + bool resend_data_{}; + const char *name_{}; + ESPPreferenceObject pref_{}; + + std::vector encryption_key_{}; + +#ifdef USE_SENSOR + std::vector sensors_{}; + std::map> remote_sensors_{}; +#endif +#ifdef USE_BINARY_SENSOR + std::vector binary_sensors_{}; + std::map> remote_binary_sensors_{}; +#endif + + std::map providers_{}; + std::vector ping_header_{}; + std::vector header_{}; + std::vector data_{}; + std::map ping_keys_{}; + const char *platform_name_{""}; + void add_key_(const char *name, uint32_t key); + void send_ping_pong_request_(); + + inline bool is_encrypted_() { return !this->encryption_key_.empty(); } +}; + +} // namespace packet_transport +} // namespace esphome diff --git a/esphome/components/packet_transport/sensor.py b/esphome/components/packet_transport/sensor.py new file mode 100644 index 0000000000..15c0e33b30 --- /dev/null +++ b/esphome/components/packet_transport/sensor.py @@ -0,0 +1,19 @@ +import esphome.codegen as cg +from esphome.components.sensor import new_sensor, sensor_schema +from esphome.const import CONF_ID + +from . import ( + CONF_PROVIDER, + CONF_REMOTE_ID, + CONF_TRANSPORT_ID, + packet_transport_sensor_schema, +) + +CONFIG_SCHEMA = packet_transport_sensor_schema(sensor_schema()) + + +async def to_code(config): + var = await new_sensor(config) + comp = await cg.get_variable(config[CONF_TRANSPORT_ID]) + remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) + cg.add(comp.add_remote_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/pca9685/pca9685_output.cpp b/esphome/components/pca9685/pca9685_output.cpp index d92312355a..1998f8d12f 100644 --- a/esphome/components/pca9685/pca9685_output.cpp +++ b/esphome/components/pca9685/pca9685_output.cpp @@ -101,8 +101,9 @@ void PCA9685Output::loop() { return; const uint16_t num_channels = this->max_channel_ - this->min_channel_ + 1; + const uint16_t phase_delta_begin = 4096 / num_channels; for (uint8_t channel = this->min_channel_; channel <= this->max_channel_; channel++) { - uint16_t phase_begin = uint16_t(channel - this->min_channel_) / num_channels * 4096; + uint16_t phase_begin = (channel - this->min_channel_) * phase_delta_begin; uint16_t phase_end; uint16_t amount = this->pwm_amounts_[channel]; if (amount == 0) { diff --git a/esphome/components/pm2005/__init__.py b/esphome/components/pm2005/__init__.py new file mode 100644 index 0000000000..3716dd7b5e --- /dev/null +++ b/esphome/components/pm2005/__init__.py @@ -0,0 +1 @@ +"""PM2005/2105 component for ESPHome.""" diff --git a/esphome/components/pm2005/pm2005.cpp b/esphome/components/pm2005/pm2005.cpp new file mode 100644 index 0000000000..38847210fd --- /dev/null +++ b/esphome/components/pm2005/pm2005.cpp @@ -0,0 +1,123 @@ +#include "esphome/core/log.h" +#include "pm2005.h" + +namespace esphome { +namespace pm2005 { + +static const char *const TAG = "pm2005"; + +// Converts a sensor situation to a human readable string +static const LogString *pm2005_get_situation_string(int status) { + switch (status) { + case 1: + return LOG_STR("Close"); + case 2: + return LOG_STR("Malfunction"); + case 3: + return LOG_STR("Under detecting"); + case 0x80: + return LOG_STR("Detecting completed"); + default: + return LOG_STR("Invalid"); + } +} + +// Converts a sensor measuring mode to a human readable string +static const LogString *pm2005_get_measuring_mode_string(int status) { + switch (status) { + case 2: + return LOG_STR("Single"); + case 3: + return LOG_STR("Continuous"); + case 5: + return LOG_STR("Dynamic"); + default: + return LOG_STR("Timing"); + } +} + +static inline uint16_t get_sensor_value(const uint8_t *data, uint8_t i) { return data[i] * 0x100 + data[i + 1]; } + +void PM2005Component::setup() { + if (this->sensor_type_ == PM2005) { + ESP_LOGCONFIG(TAG, "Setting up PM2005..."); + + this->situation_value_index_ = 3; + this->pm_1_0_value_index_ = 4; + this->pm_2_5_value_index_ = 6; + this->pm_10_0_value_index_ = 8; + this->measuring_value_index_ = 10; + } else { + ESP_LOGCONFIG(TAG, "Setting up PM2105..."); + + this->situation_value_index_ = 2; + this->pm_1_0_value_index_ = 3; + this->pm_2_5_value_index_ = 5; + this->pm_10_0_value_index_ = 7; + this->measuring_value_index_ = 9; + } + + if (this->read(this->data_buffer_, 12) != i2c::ERROR_OK) { + ESP_LOGE(TAG, "Communication failed!"); + this->mark_failed(); + return; + } +} + +void PM2005Component::update() { + if (this->read(this->data_buffer_, 12) != i2c::ERROR_OK) { + ESP_LOGW(TAG, "Read result failed."); + this->status_set_warning(); + return; + } + + if (this->sensor_situation_ == this->data_buffer_[this->situation_value_index_]) { + return; + } + + this->sensor_situation_ = this->data_buffer_[this->situation_value_index_]; + ESP_LOGD(TAG, "Sensor situation: %s.", LOG_STR_ARG(pm2005_get_situation_string(this->sensor_situation_))); + if (this->sensor_situation_ == 2) { + this->status_set_warning(); + return; + } + if (this->sensor_situation_ != 0x80) { + return; + } + + uint16_t pm1 = get_sensor_value(this->data_buffer_, this->pm_1_0_value_index_); + uint16_t pm25 = get_sensor_value(this->data_buffer_, this->pm_2_5_value_index_); + uint16_t pm10 = get_sensor_value(this->data_buffer_, this->pm_10_0_value_index_); + uint16_t sensor_measuring_mode = get_sensor_value(this->data_buffer_, this->measuring_value_index_); + ESP_LOGD(TAG, "PM1.0: %d, PM2.5: %d, PM10: %d, Measuring mode: %s.", pm1, pm25, pm10, + LOG_STR_ARG(pm2005_get_measuring_mode_string(sensor_measuring_mode))); + + if (this->pm_1_0_sensor_ != nullptr) { + this->pm_1_0_sensor_->publish_state(pm1); + } + if (this->pm_2_5_sensor_ != nullptr) { + this->pm_2_5_sensor_->publish_state(pm25); + } + if (this->pm_10_0_sensor_ != nullptr) { + this->pm_10_0_sensor_->publish_state(pm10); + } + + this->status_clear_warning(); +} + +void PM2005Component::dump_config() { + ESP_LOGCONFIG(TAG, "PM2005:"); + ESP_LOGCONFIG(TAG, " Type: PM2%u05", this->sensor_type_ == PM2105); + + LOG_I2C_DEVICE(this); + if (this->is_failed()) { + ESP_LOGE(TAG, "Communication with PM2%u05 failed!", this->sensor_type_ == PM2105); + } + + LOG_SENSOR(" ", "PM1.0", this->pm_1_0_sensor_); + LOG_SENSOR(" ", "PM2.5", this->pm_2_5_sensor_); + LOG_SENSOR(" ", "PM10 ", this->pm_10_0_sensor_); +} + +} // namespace pm2005 +} // namespace esphome diff --git a/esphome/components/pm2005/pm2005.h b/esphome/components/pm2005/pm2005.h new file mode 100644 index 0000000000..219fbae5cb --- /dev/null +++ b/esphome/components/pm2005/pm2005.h @@ -0,0 +1,46 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/components/sensor/sensor.h" +#include "esphome/components/i2c/i2c.h" + +namespace esphome { +namespace pm2005 { + +enum SensorType { + PM2005, + PM2105, +}; + +class PM2005Component : public PollingComponent, public i2c::I2CDevice { + public: + float get_setup_priority() const override { return esphome::setup_priority::DATA; } + + void set_sensor_type(SensorType sensor_type) { this->sensor_type_ = sensor_type; } + + void set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor) { this->pm_1_0_sensor_ = pm_1_0_sensor; } + void set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor) { this->pm_2_5_sensor_ = pm_2_5_sensor; } + void set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor) { this->pm_10_0_sensor_ = pm_10_0_sensor; } + + void setup() override; + void dump_config() override; + void update() override; + + protected: + uint8_t sensor_situation_{0}; + uint8_t data_buffer_[12]; + SensorType sensor_type_{PM2005}; + + sensor::Sensor *pm_1_0_sensor_{nullptr}; + sensor::Sensor *pm_2_5_sensor_{nullptr}; + sensor::Sensor *pm_10_0_sensor_{nullptr}; + + uint8_t situation_value_index_{3}; + uint8_t pm_1_0_value_index_{4}; + uint8_t pm_2_5_value_index_{6}; + uint8_t pm_10_0_value_index_{8}; + uint8_t measuring_value_index_{10}; +}; + +} // namespace pm2005 +} // namespace esphome diff --git a/esphome/components/pm2005/sensor.py b/esphome/components/pm2005/sensor.py new file mode 100644 index 0000000000..66f630f8ff --- /dev/null +++ b/esphome/components/pm2005/sensor.py @@ -0,0 +1,86 @@ +"""PM2005/2105 Sensor component for ESPHome.""" + +import esphome.codegen as cg +import esphome.config_validation as cv +from esphome.components import i2c, sensor +from esphome.const import ( + CONF_ID, + CONF_PM_1_0, + CONF_PM_2_5, + CONF_PM_10_0, + CONF_TYPE, + DEVICE_CLASS_PM1, + DEVICE_CLASS_PM10, + DEVICE_CLASS_PM25, + ICON_CHEMICAL_WEAPON, + STATE_CLASS_MEASUREMENT, + UNIT_MICROGRAMS_PER_CUBIC_METER, +) + +DEPENDENCIES = ["i2c"] +CODEOWNERS = ["@andrewjswan"] + +pm2005_ns = cg.esphome_ns.namespace("pm2005") +PM2005Component = pm2005_ns.class_( + "PM2005Component", cg.PollingComponent, i2c.I2CDevice +) + +SensorType = pm2005_ns.enum("SensorType") +SENSOR_TYPE = { + "PM2005": SensorType.PM2005, + "PM2105": SensorType.PM2105, +} + + +CONFIG_SCHEMA = cv.All( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(PM2005Component), + cv.Optional(CONF_TYPE, default="PM2005"): cv.enum(SENSOR_TYPE, upper=True), + cv.Optional(CONF_PM_1_0): sensor.sensor_schema( + unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, + icon=ICON_CHEMICAL_WEAPON, + accuracy_decimals=0, + device_class=DEVICE_CLASS_PM1, + state_class=STATE_CLASS_MEASUREMENT, + ), + cv.Optional(CONF_PM_2_5): sensor.sensor_schema( + unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, + icon=ICON_CHEMICAL_WEAPON, + accuracy_decimals=0, + device_class=DEVICE_CLASS_PM25, + state_class=STATE_CLASS_MEASUREMENT, + ), + cv.Optional(CONF_PM_10_0): sensor.sensor_schema( + unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, + icon=ICON_CHEMICAL_WEAPON, + accuracy_decimals=0, + device_class=DEVICE_CLASS_PM10, + state_class=STATE_CLASS_MEASUREMENT, + ), + }, + ) + .extend(cv.polling_component_schema("60s")) + .extend(i2c.i2c_device_schema(0x28)), +) + + +async def to_code(config) -> None: + """Code generation entry point.""" + var = cg.new_Pvariable(config[CONF_ID]) + await cg.register_component(var, config) + await i2c.register_i2c_device(var, config) + + cg.add(var.set_sensor_type(config[CONF_TYPE])) + + if pm_1_0_config := config.get(CONF_PM_1_0): + sens = await sensor.new_sensor(pm_1_0_config) + cg.add(var.set_pm_1_0_sensor(sens)) + + if pm_2_5_config := config.get(CONF_PM_2_5): + sens = await sensor.new_sensor(pm_2_5_config) + cg.add(var.set_pm_2_5_sensor(sens)) + + if pm_10_0_config := config.get(CONF_PM_10_0): + sens = await sensor.new_sensor(pm_10_0_config) + cg.add(var.set_pm_10_0_sensor(sens)) diff --git a/esphome/components/pmsa003i/pmsa003i.cpp b/esphome/components/pmsa003i/pmsa003i.cpp index a9665c6a5a..36f9c9a132 100644 --- a/esphome/components/pmsa003i/pmsa003i.cpp +++ b/esphome/components/pmsa003i/pmsa003i.cpp @@ -1,5 +1,6 @@ #include "pmsa003i.h" #include "esphome/core/log.h" +#include "esphome/core/helpers.h" #include namespace esphome { @@ -7,6 +8,16 @@ namespace pmsa003i { static const char *const TAG = "pmsa003i"; +static const uint8_t COUNT_PAYLOAD_BYTES = 28; +static const uint8_t COUNT_PAYLOAD_LENGTH_BYTES = 2; +static const uint8_t COUNT_START_CHARACTER_BYTES = 2; +static const uint8_t COUNT_DATA_BYTES = COUNT_START_CHARACTER_BYTES + COUNT_PAYLOAD_LENGTH_BYTES + COUNT_PAYLOAD_BYTES; +static const uint8_t CHECKSUM_START_INDEX = COUNT_DATA_BYTES - 2; +static const uint8_t COUNT_16_BIT_VALUES = (COUNT_PAYLOAD_LENGTH_BYTES + COUNT_PAYLOAD_BYTES) / 2; +static const uint8_t START_CHARACTER_1 = 0x42; +static const uint8_t START_CHARACTER_2 = 0x4D; +static const uint8_t READ_DATA_RETRY_COUNT = 3; + void PMSA003IComponent::setup() { ESP_LOGCONFIG(TAG, "Setting up pmsa003i..."); @@ -14,7 +25,7 @@ void PMSA003IComponent::setup() { bool successful_read = this->read_data_(&data); if (!successful_read) { - for (int i = 0; i < 3; i++) { + for (uint8_t i = 0; i < READ_DATA_RETRY_COUNT; i++) { successful_read = this->read_data_(&data); if (successful_read) { break; @@ -28,7 +39,10 @@ void PMSA003IComponent::setup() { } } -void PMSA003IComponent::dump_config() { LOG_I2C_DEVICE(this); } +void PMSA003IComponent::dump_config() { + ESP_LOGCONFIG(TAG, "PMSA003I:"); + LOG_I2C_DEVICE(this); +} void PMSA003IComponent::update() { PM25AQIData data; @@ -75,35 +89,48 @@ void PMSA003IComponent::update() { } bool PMSA003IComponent::read_data_(PM25AQIData *data) { - const uint8_t num_bytes = 32; - uint8_t buffer[num_bytes]; + uint8_t buffer[COUNT_DATA_BYTES]; - this->read_bytes_raw(buffer, num_bytes); + this->read_bytes_raw(buffer, COUNT_DATA_BYTES); // https://github.com/adafruit/Adafruit_PM25AQI // Check that start byte is correct! - if (buffer[0] != 0x42) { + if (buffer[0] != START_CHARACTER_1 || buffer[1] != START_CHARACTER_2) { + ESP_LOGW(TAG, "Start character mismatch: %02X %02X != %02X %02X", buffer[0], buffer[1], START_CHARACTER_1, + START_CHARACTER_2); return false; } - // get checksum ready - int16_t sum = 0; - for (uint8_t i = 0; i < 30; i++) { - sum += buffer[i]; + const uint16_t payload_length = encode_uint16(buffer[2], buffer[3]); + if (payload_length != COUNT_PAYLOAD_BYTES) { + ESP_LOGW(TAG, "Payload length mismatch: %u != %u", payload_length, COUNT_PAYLOAD_BYTES); + return false; + } + + // Calculate checksum + uint16_t checksum = 0; + for (uint8_t i = 0; i < CHECKSUM_START_INDEX; i++) { + checksum += buffer[i]; + } + + const uint16_t check = encode_uint16(buffer[CHECKSUM_START_INDEX], buffer[CHECKSUM_START_INDEX + 1]); + if (checksum != check) { + ESP_LOGW(TAG, "Checksum mismatch: %u != %u", checksum, check); + return false; } // The data comes in endian'd, this solves it so it works on all platforms - uint16_t buffer_u16[15]; - for (uint8_t i = 0; i < 15; i++) { - buffer_u16[i] = buffer[2 + i * 2 + 1]; - buffer_u16[i] += (buffer[2 + i * 2] << 8); + uint16_t buffer_u16[COUNT_16_BIT_VALUES]; + for (uint8_t i = 0; i < COUNT_16_BIT_VALUES; i++) { + const uint8_t buffer_index = COUNT_START_CHARACTER_BYTES + i * 2; + buffer_u16[i] = encode_uint16(buffer[buffer_index], buffer[buffer_index + 1]); } // put it into a nice struct :) - memcpy((void *) data, (void *) buffer_u16, 30); + memcpy((void *) data, (void *) buffer_u16, COUNT_16_BIT_VALUES * 2); - return (sum == data->checksum); + return true; } } // namespace pmsa003i diff --git a/esphome/components/pmsa003i/pmsa003i.h b/esphome/components/pmsa003i/pmsa003i.h index 1fe4139951..59f39a7314 100644 --- a/esphome/components/pmsa003i/pmsa003i.h +++ b/esphome/components/pmsa003i/pmsa003i.h @@ -10,21 +10,21 @@ namespace pmsa003i { /**! Structure holding Plantower's standard packet **/ // From https://github.com/adafruit/Adafruit_PM25AQI struct PM25AQIData { - uint16_t framelen; ///< How long this data chunk is - uint16_t pm10_standard, ///< Standard PM1.0 - pm25_standard, ///< Standard PM2.5 - pm100_standard; ///< Standard PM10.0 - uint16_t pm10_env, ///< Environmental PM1.0 - pm25_env, ///< Environmental PM2.5 - pm100_env; ///< Environmental PM10.0 - uint16_t particles_03um, ///> 0.3um Particle Count - particles_05um, ///> 0.5um Particle Count - particles_10um, ///> 1.0um Particle Count - particles_25um, ///> 2.5um Particle Count - particles_50um, ///> 5.0um Particle Count - particles_100um; ///> 10.0um Particle Count - uint16_t unused; ///< Unused - uint16_t checksum; ///< Packet checksum + uint16_t framelen; ///< How long this data chunk is + uint16_t pm10_standard; ///< Standard PM1.0 + uint16_t pm25_standard; ///< Standard PM2.5 + uint16_t pm100_standard; ///< Standard PM10.0 + uint16_t pm10_env; ///< Environmental PM1.0 + uint16_t pm25_env; ///< Environmental PM2.5 + uint16_t pm100_env; ///< Environmental PM10.0 + uint16_t particles_03um; ///< 0.3um Particle Count + uint16_t particles_05um; ///< 0.5um Particle Count + uint16_t particles_10um; ///< 1.0um Particle Count + uint16_t particles_25um; ///< 2.5um Particle Count + uint16_t particles_50um; ///< 5.0um Particle Count + uint16_t particles_100um; ///< 10.0um Particle Count + uint16_t unused; ///< Unused + uint16_t checksum; ///< Packet checksum }; class PMSA003IComponent : public PollingComponent, public i2c::I2CDevice { @@ -34,18 +34,18 @@ class PMSA003IComponent : public PollingComponent, public i2c::I2CDevice { void update() override; float get_setup_priority() const override { return setup_priority::DATA; } - void set_standard_units(bool standard_units) { standard_units_ = standard_units; } + void set_standard_units(bool standard_units) { this->standard_units_ = standard_units; } - void set_pm_1_0_sensor(sensor::Sensor *pm_1_0) { pm_1_0_sensor_ = pm_1_0; } - void set_pm_2_5_sensor(sensor::Sensor *pm_2_5) { pm_2_5_sensor_ = pm_2_5; } - void set_pm_10_0_sensor(sensor::Sensor *pm_10_0) { pm_10_0_sensor_ = pm_10_0; } + void set_pm_1_0_sensor(sensor::Sensor *pm_1_0) { this->pm_1_0_sensor_ = pm_1_0; } + void set_pm_2_5_sensor(sensor::Sensor *pm_2_5) { this->pm_2_5_sensor_ = pm_2_5; } + void set_pm_10_0_sensor(sensor::Sensor *pm_10_0) { this->pm_10_0_sensor_ = pm_10_0; } - void set_pmc_0_3_sensor(sensor::Sensor *pmc_0_3) { pmc_0_3_sensor_ = pmc_0_3; } - void set_pmc_0_5_sensor(sensor::Sensor *pmc_0_5) { pmc_0_5_sensor_ = pmc_0_5; } - void set_pmc_1_0_sensor(sensor::Sensor *pmc_1_0) { pmc_1_0_sensor_ = pmc_1_0; } - void set_pmc_2_5_sensor(sensor::Sensor *pmc_2_5) { pmc_2_5_sensor_ = pmc_2_5; } - void set_pmc_5_0_sensor(sensor::Sensor *pmc_5_0) { pmc_5_0_sensor_ = pmc_5_0; } - void set_pmc_10_0_sensor(sensor::Sensor *pmc_10_0) { pmc_10_0_sensor_ = pmc_10_0; } + void set_pmc_0_3_sensor(sensor::Sensor *pmc_0_3) { this->pmc_0_3_sensor_ = pmc_0_3; } + void set_pmc_0_5_sensor(sensor::Sensor *pmc_0_5) { this->pmc_0_5_sensor_ = pmc_0_5; } + void set_pmc_1_0_sensor(sensor::Sensor *pmc_1_0) { this->pmc_1_0_sensor_ = pmc_1_0; } + void set_pmc_2_5_sensor(sensor::Sensor *pmc_2_5) { this->pmc_2_5_sensor_ = pmc_2_5; } + void set_pmc_5_0_sensor(sensor::Sensor *pmc_5_0) { this->pmc_5_0_sensor_ = pmc_5_0; } + void set_pmc_10_0_sensor(sensor::Sensor *pmc_10_0) { this->pmc_10_0_sensor_ = pmc_10_0; } protected: bool read_data_(PM25AQIData *data); diff --git a/esphome/components/pmsx003/pmsx003.cpp b/esphome/components/pmsx003/pmsx003.cpp index de2b23b8eb..11626768d8 100644 --- a/esphome/components/pmsx003/pmsx003.cpp +++ b/esphome/components/pmsx003/pmsx003.cpp @@ -6,45 +6,39 @@ namespace pmsx003 { static const char *const TAG = "pmsx003"; -void PMSX003Component::set_pm_1_0_std_sensor(sensor::Sensor *pm_1_0_std_sensor) { - pm_1_0_std_sensor_ = pm_1_0_std_sensor; -} -void PMSX003Component::set_pm_2_5_std_sensor(sensor::Sensor *pm_2_5_std_sensor) { - pm_2_5_std_sensor_ = pm_2_5_std_sensor; -} -void PMSX003Component::set_pm_10_0_std_sensor(sensor::Sensor *pm_10_0_std_sensor) { - pm_10_0_std_sensor_ = pm_10_0_std_sensor; -} +static const uint8_t START_CHARACTER_1 = 0x42; +static const uint8_t START_CHARACTER_2 = 0x4D; -void PMSX003Component::set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor) { pm_1_0_sensor_ = pm_1_0_sensor; } -void PMSX003Component::set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor) { pm_2_5_sensor_ = pm_2_5_sensor; } -void PMSX003Component::set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor) { pm_10_0_sensor_ = pm_10_0_sensor; } +static const uint16_t PMS_STABILISING_MS = 30000; // time taken for the sensor to become stable after power on in ms -void PMSX003Component::set_pm_particles_03um_sensor(sensor::Sensor *pm_particles_03um_sensor) { - pm_particles_03um_sensor_ = pm_particles_03um_sensor; -} -void PMSX003Component::set_pm_particles_05um_sensor(sensor::Sensor *pm_particles_05um_sensor) { - pm_particles_05um_sensor_ = pm_particles_05um_sensor; -} -void PMSX003Component::set_pm_particles_10um_sensor(sensor::Sensor *pm_particles_10um_sensor) { - pm_particles_10um_sensor_ = pm_particles_10um_sensor; -} -void PMSX003Component::set_pm_particles_25um_sensor(sensor::Sensor *pm_particles_25um_sensor) { - pm_particles_25um_sensor_ = pm_particles_25um_sensor; -} -void PMSX003Component::set_pm_particles_50um_sensor(sensor::Sensor *pm_particles_50um_sensor) { - pm_particles_50um_sensor_ = pm_particles_50um_sensor; -} -void PMSX003Component::set_pm_particles_100um_sensor(sensor::Sensor *pm_particles_100um_sensor) { - pm_particles_100um_sensor_ = pm_particles_100um_sensor; -} +static const uint16_t PMS_CMD_MEASUREMENT_MODE_PASSIVE = + 0x0000; // use `PMS_CMD_MANUAL_MEASUREMENT` to trigger a measurement +static const uint16_t PMS_CMD_MEASUREMENT_MODE_ACTIVE = 0x0001; // automatically perform measurements +static const uint16_t PMS_CMD_SLEEP_MODE_SLEEP = 0x0000; // go to sleep mode +static const uint16_t PMS_CMD_SLEEP_MODE_WAKEUP = 0x0001; // wake up from sleep mode -void PMSX003Component::set_temperature_sensor(sensor::Sensor *temperature_sensor) { - temperature_sensor_ = temperature_sensor; -} -void PMSX003Component::set_humidity_sensor(sensor::Sensor *humidity_sensor) { humidity_sensor_ = humidity_sensor; } -void PMSX003Component::set_formaldehyde_sensor(sensor::Sensor *formaldehyde_sensor) { - formaldehyde_sensor_ = formaldehyde_sensor; +void PMSX003Component::dump_config() { + ESP_LOGCONFIG(TAG, "PMSX003:"); + LOG_SENSOR(" ", "PM1.0STD", this->pm_1_0_std_sensor_); + LOG_SENSOR(" ", "PM2.5STD", this->pm_2_5_std_sensor_); + LOG_SENSOR(" ", "PM10.0STD", this->pm_10_0_std_sensor_); + + LOG_SENSOR(" ", "PM1.0", this->pm_1_0_sensor_); + LOG_SENSOR(" ", "PM2.5", this->pm_2_5_sensor_); + LOG_SENSOR(" ", "PM10.0", this->pm_10_0_sensor_); + + LOG_SENSOR(" ", "PM0.3um", this->pm_particles_03um_sensor_); + LOG_SENSOR(" ", "PM0.5um", this->pm_particles_05um_sensor_); + LOG_SENSOR(" ", "PM1.0um", this->pm_particles_10um_sensor_); + LOG_SENSOR(" ", "PM2.5um", this->pm_particles_25um_sensor_); + LOG_SENSOR(" ", "PM5.0um", this->pm_particles_50um_sensor_); + LOG_SENSOR(" ", "PM10.0um", this->pm_particles_100um_sensor_); + + LOG_SENSOR(" ", "Formaldehyde", this->formaldehyde_sensor_); + + LOG_SENSOR(" ", "Temperature", this->temperature_sensor_); + LOG_SENSOR(" ", "Humidity", this->humidity_sensor_); + this->check_uart_settings(9600); } void PMSX003Component::loop() { @@ -55,8 +49,8 @@ void PMSX003Component::loop() { // need to keep track of what state we're in. if (this->update_interval_ > PMS_STABILISING_MS) { if (this->initialised_ == 0) { - this->send_command_(PMS_CMD_AUTO_MANUAL, 0); - this->send_command_(PMS_CMD_ON_STANDBY, 1); + this->send_command_(PMS_CMD_MEASUREMENT_MODE, PMS_CMD_MEASUREMENT_MODE_PASSIVE); + this->send_command_(PMS_CMD_SLEEP_MODE, PMS_CMD_SLEEP_MODE_WAKEUP); this->initialised_ = 1; } switch (this->state_) { @@ -66,7 +60,7 @@ void PMSX003Component::loop() { return; this->state_ = PMSX003_STATE_STABILISING; - this->send_command_(PMS_CMD_ON_STANDBY, 1); + this->send_command_(PMS_CMD_SLEEP_MODE, PMS_CMD_SLEEP_MODE_WAKEUP); this->fan_on_time_ = now; return; case PMSX003_STATE_STABILISING: @@ -77,7 +71,7 @@ void PMSX003Component::loop() { while (this->available()) this->read_byte(&this->data_[0]); // Trigger a new read - this->send_command_(PMS_CMD_TRIG_MANUAL, 0); + this->send_command_(PMS_CMD_MANUAL_MEASUREMENT, 0); this->state_ = PMSX003_STATE_WAITING; break; case PMSX003_STATE_WAITING: @@ -116,242 +110,212 @@ void PMSX003Component::loop() { } } } -float PMSX003Component::get_setup_priority() const { return setup_priority::DATA; } + optional PMSX003Component::check_byte_() { - uint8_t index = this->data_index_; - uint8_t byte = this->data_[index]; + const uint8_t index = this->data_index_; + const uint8_t byte = this->data_[index]; - if (index == 0) - return byte == 0x42; - - if (index == 1) - return byte == 0x4D; - - if (index == 2) - return true; - - uint16_t payload_length = this->get_16_bit_uint_(2); - if (index == 3) { - bool length_matches = false; - switch (this->type_) { - case PMSX003_TYPE_X003: - length_matches = payload_length == 28 || payload_length == 20; - break; - case PMSX003_TYPE_5003T: - case PMSX003_TYPE_5003S: - length_matches = payload_length == 28; - break; - case PMSX003_TYPE_5003ST: - length_matches = payload_length == 36; - break; + if (index == 0 || index == 1) { + const uint8_t start_char = index == 0 ? START_CHARACTER_1 : START_CHARACTER_2; + if (byte == start_char) { + return true; } - if (!length_matches) { - ESP_LOGW(TAG, "PMSX003 length %u doesn't match. Are you using the correct PMSX003 type?", payload_length); - return false; - } + ESP_LOGW(TAG, "Start character %u mismatch: 0x%02X != 0x%02X", index + 1, byte, START_CHARACTER_1); + return false; + } + + if (index == 2) { return true; } - // start (16bit) + length (16bit) + DATA (payload_length-2 bytes) + checksum (16bit) - uint8_t total_size = 4 + payload_length; + const uint16_t payload_length = this->get_16_bit_uint_(2); + if (index == 3) { + if (this->check_payload_length_(payload_length)) { + return true; + } else { + ESP_LOGW(TAG, "Payload length %u doesn't match. Are you using the correct PMSX003 type?", payload_length); + return false; + } + } - if (index < total_size - 1) + // start (16bit) + length (16bit) + DATA (payload_length - 16bit) + checksum (16bit) + const uint16_t total_size = 4 + payload_length; + + if (index < total_size - 1) { return true; + } // checksum is without checksum bytes uint16_t checksum = 0; - for (uint8_t i = 0; i < total_size - 2; i++) + for (uint16_t i = 0; i < total_size - 2; i++) { checksum += this->data_[i]; + } - uint16_t check = this->get_16_bit_uint_(total_size - 2); + const uint16_t check = this->get_16_bit_uint_(total_size - 2); if (checksum != check) { - ESP_LOGW(TAG, "PMSX003 checksum mismatch! 0x%02X!=0x%02X", checksum, check); + ESP_LOGW(TAG, "PMSX003 checksum mismatch! 0x%02X != 0x%02X", checksum, check); return false; } return {}; } -void PMSX003Component::send_command_(uint8_t cmd, uint16_t data) { - this->data_index_ = 0; - this->data_[data_index_++] = 0x42; - this->data_[data_index_++] = 0x4D; - this->data_[data_index_++] = cmd; - this->data_[data_index_++] = (data >> 8) & 0xFF; - this->data_[data_index_++] = (data >> 0) & 0xFF; - int sum = 0; - for (int i = 0; i < data_index_; i++) { - sum += this->data_[i]; +bool PMSX003Component::check_payload_length_(uint16_t payload_length) { + switch (this->type_) { + case PMSX003_TYPE_X003: + // The expected payload length is typically 28 bytes. + // However, a 20-byte payload check was already present in the code. + // No official documentation was found confirming this. + // Retaining this check to avoid breaking existing behavior. + return payload_length == 28 || payload_length == 20; // 2*13+2 + case PMSX003_TYPE_5003T: + case PMSX003_TYPE_5003S: + return payload_length == 28; // 2*13+2 (Data 13 not set/reserved) + case PMSX003_TYPE_5003ST: + return payload_length == 36; // 2*17+2 (Data 16 not set/reserved) } - this->data_[data_index_++] = (sum >> 8) & 0xFF; - this->data_[data_index_++] = (sum >> 0) & 0xFF; - for (int i = 0; i < data_index_; i++) { - this->write_byte(this->data_[i]); + return false; +} + +void PMSX003Component::send_command_(PMSX0003Command cmd, uint16_t data) { + uint8_t send_data[7] = { + START_CHARACTER_1, // Start Byte 1 + START_CHARACTER_2, // Start Byte 2 + cmd, // Command + uint8_t((data >> 8) & 0xFF), // Data 1 + uint8_t((data >> 0) & 0xFF), // Data 2 + 0, // Verify Byte 1 + 0, // Verify Byte 2 + }; + + // Calculate checksum + uint16_t checksum = 0; + for (uint8_t i = 0; i < 5; i++) { + checksum += send_data[i]; + } + send_data[5] = (checksum >> 8) & 0xFF; // Verify Byte 1 + send_data[6] = (checksum >> 0) & 0xFF; // Verify Byte 2 + + for (auto send_byte : send_data) { + this->write_byte(send_byte); } - this->data_index_ = 0; } void PMSX003Component::parse_data_() { - switch (this->type_) { - case PMSX003_TYPE_5003ST: { - float temperature = (int16_t) this->get_16_bit_uint_(30) / 10.0f; - float humidity = this->get_16_bit_uint_(32) / 10.0f; + // Particle Matter + const uint16_t pm_1_0_std_concentration = this->get_16_bit_uint_(4); + const uint16_t pm_2_5_std_concentration = this->get_16_bit_uint_(6); + const uint16_t pm_10_0_std_concentration = this->get_16_bit_uint_(8); - ESP_LOGD(TAG, "Got Temperature: %.1f°C, Humidity: %.1f%%", temperature, humidity); + const uint16_t pm_1_0_concentration = this->get_16_bit_uint_(10); + const uint16_t pm_2_5_concentration = this->get_16_bit_uint_(12); + const uint16_t pm_10_0_concentration = this->get_16_bit_uint_(14); - if (this->temperature_sensor_ != nullptr) - this->temperature_sensor_->publish_state(temperature); - if (this->humidity_sensor_ != nullptr) - this->humidity_sensor_->publish_state(humidity); - // The rest of the PMS5003ST matches the PMS5003S, continue on - } - case PMSX003_TYPE_5003S: { - uint16_t formaldehyde = this->get_16_bit_uint_(28); + const uint16_t pm_particles_03um = this->get_16_bit_uint_(16); + const uint16_t pm_particles_05um = this->get_16_bit_uint_(18); + const uint16_t pm_particles_10um = this->get_16_bit_uint_(20); + const uint16_t pm_particles_25um = this->get_16_bit_uint_(22); - ESP_LOGD(TAG, "Got Formaldehyde: %u µg/m^3", formaldehyde); + ESP_LOGD(TAG, + "Got PM1.0 Standard Concentration: %u µg/m³, PM2.5 Standard Concentration %u µg/m³, PM10.0 Standard " + "Concentration: %u µg/m³, PM1.0 Concentration: %u µg/m³, PM2.5 Concentration %u µg/m³, PM10.0 " + "Concentration: %u µg/m³", + pm_1_0_std_concentration, pm_2_5_std_concentration, pm_10_0_std_concentration, pm_1_0_concentration, + pm_2_5_concentration, pm_10_0_concentration); - if (this->formaldehyde_sensor_ != nullptr) - this->formaldehyde_sensor_->publish_state(formaldehyde); - // The rest of the PMS5003S matches the PMS5003, continue on - } - case PMSX003_TYPE_X003: { - uint16_t pm_1_0_std_concentration = this->get_16_bit_uint_(4); - uint16_t pm_2_5_std_concentration = this->get_16_bit_uint_(6); - uint16_t pm_10_0_std_concentration = this->get_16_bit_uint_(8); + if (this->pm_1_0_std_sensor_ != nullptr) + this->pm_1_0_std_sensor_->publish_state(pm_1_0_std_concentration); + if (this->pm_2_5_std_sensor_ != nullptr) + this->pm_2_5_std_sensor_->publish_state(pm_2_5_std_concentration); + if (this->pm_10_0_std_sensor_ != nullptr) + this->pm_10_0_std_sensor_->publish_state(pm_10_0_std_concentration); - uint16_t pm_1_0_concentration = this->get_16_bit_uint_(10); - uint16_t pm_2_5_concentration = this->get_16_bit_uint_(12); - uint16_t pm_10_0_concentration = this->get_16_bit_uint_(14); + if (this->pm_1_0_sensor_ != nullptr) + this->pm_1_0_sensor_->publish_state(pm_1_0_concentration); + if (this->pm_2_5_sensor_ != nullptr) + this->pm_2_5_sensor_->publish_state(pm_2_5_concentration); + if (this->pm_10_0_sensor_ != nullptr) + this->pm_10_0_sensor_->publish_state(pm_10_0_concentration); - uint16_t pm_particles_03um = this->get_16_bit_uint_(16); - uint16_t pm_particles_05um = this->get_16_bit_uint_(18); - uint16_t pm_particles_10um = this->get_16_bit_uint_(20); - uint16_t pm_particles_25um = this->get_16_bit_uint_(22); - uint16_t pm_particles_50um = this->get_16_bit_uint_(24); - uint16_t pm_particles_100um = this->get_16_bit_uint_(26); + if (this->pm_particles_03um_sensor_ != nullptr) + this->pm_particles_03um_sensor_->publish_state(pm_particles_03um); + if (this->pm_particles_05um_sensor_ != nullptr) + this->pm_particles_05um_sensor_->publish_state(pm_particles_05um); + if (this->pm_particles_10um_sensor_ != nullptr) + this->pm_particles_10um_sensor_->publish_state(pm_particles_10um); + if (this->pm_particles_25um_sensor_ != nullptr) + this->pm_particles_25um_sensor_->publish_state(pm_particles_25um); - ESP_LOGD(TAG, - "Got PM1.0 Concentration: %u µg/m^3, PM2.5 Concentration %u µg/m^3, PM10.0 Concentration: %u µg/m^3", - pm_1_0_concentration, pm_2_5_concentration, pm_10_0_concentration); + if (this->type_ == PMSX003_TYPE_5003T) { + ESP_LOGD(TAG, + "Got PM0.3 Particles: %u Count/0.1L, PM0.5 Particles: %u Count/0.1L, PM1.0 Particles: %u Count/0.1L, " + "PM2.5 Particles %u Count/0.1L", + pm_particles_03um, pm_particles_05um, pm_particles_10um, pm_particles_25um); + } else { + // Note the pm particles 50um & 100um are not returned, + // as PMS5003T uses those data values for temperature and humidity. + const uint16_t pm_particles_50um = this->get_16_bit_uint_(24); + const uint16_t pm_particles_100um = this->get_16_bit_uint_(26); - if (this->pm_1_0_std_sensor_ != nullptr) - this->pm_1_0_std_sensor_->publish_state(pm_1_0_std_concentration); - if (this->pm_2_5_std_sensor_ != nullptr) - this->pm_2_5_std_sensor_->publish_state(pm_2_5_std_concentration); - if (this->pm_10_0_std_sensor_ != nullptr) - this->pm_10_0_std_sensor_->publish_state(pm_10_0_std_concentration); + ESP_LOGD(TAG, + "Got PM0.3 Particles: %u Count/0.1L, PM0.5 Particles: %u Count/0.1L, PM1.0 Particles: %u Count/0.1L, " + "PM2.5 Particles %u Count/0.1L, PM5.0 Particles: %u Count/0.1L, PM10.0 Particles %u Count/0.1L", + pm_particles_03um, pm_particles_05um, pm_particles_10um, pm_particles_25um, pm_particles_50um, + pm_particles_100um); - if (this->pm_1_0_sensor_ != nullptr) - this->pm_1_0_sensor_->publish_state(pm_1_0_concentration); - if (this->pm_2_5_sensor_ != nullptr) - this->pm_2_5_sensor_->publish_state(pm_2_5_concentration); - if (this->pm_10_0_sensor_ != nullptr) - this->pm_10_0_sensor_->publish_state(pm_10_0_concentration); + if (this->pm_particles_50um_sensor_ != nullptr) + this->pm_particles_50um_sensor_->publish_state(pm_particles_50um); + if (this->pm_particles_100um_sensor_ != nullptr) + this->pm_particles_100um_sensor_->publish_state(pm_particles_100um); + } - if (this->pm_particles_03um_sensor_ != nullptr) - this->pm_particles_03um_sensor_->publish_state(pm_particles_03um); - if (this->pm_particles_05um_sensor_ != nullptr) - this->pm_particles_05um_sensor_->publish_state(pm_particles_05um); - if (this->pm_particles_10um_sensor_ != nullptr) - this->pm_particles_10um_sensor_->publish_state(pm_particles_10um); - if (this->pm_particles_25um_sensor_ != nullptr) - this->pm_particles_25um_sensor_->publish_state(pm_particles_25um); - if (this->pm_particles_50um_sensor_ != nullptr) - this->pm_particles_50um_sensor_->publish_state(pm_particles_50um); - if (this->pm_particles_100um_sensor_ != nullptr) - this->pm_particles_100um_sensor_->publish_state(pm_particles_100um); - break; - } - case PMSX003_TYPE_5003T: { - uint16_t pm_1_0_std_concentration = this->get_16_bit_uint_(4); - uint16_t pm_2_5_std_concentration = this->get_16_bit_uint_(6); - uint16_t pm_10_0_std_concentration = this->get_16_bit_uint_(8); + // Formaldehyde + if (this->type_ == PMSX003_TYPE_5003ST || this->type_ == PMSX003_TYPE_5003S) { + const uint16_t formaldehyde = this->get_16_bit_uint_(28); - uint16_t pm_1_0_concentration = this->get_16_bit_uint_(10); - uint16_t pm_2_5_concentration = this->get_16_bit_uint_(12); - uint16_t pm_10_0_concentration = this->get_16_bit_uint_(14); + ESP_LOGD(TAG, "Got Formaldehyde: %u µg/m^3", formaldehyde); - uint16_t pm_particles_03um = this->get_16_bit_uint_(16); - uint16_t pm_particles_05um = this->get_16_bit_uint_(18); - uint16_t pm_particles_10um = this->get_16_bit_uint_(20); - uint16_t pm_particles_25um = this->get_16_bit_uint_(22); - // Note the pm particles 50um & 100um are not returned, - // as PMS5003T uses those data values for temperature and humidity. + if (this->formaldehyde_sensor_ != nullptr) + this->formaldehyde_sensor_->publish_state(formaldehyde); + } - float temperature = (int16_t) this->get_16_bit_uint_(24) / 10.0f; - float humidity = this->get_16_bit_uint_(26) / 10.0f; + // Temperature and Humidity + if (this->type_ == PMSX003_TYPE_5003ST || this->type_ == PMSX003_TYPE_5003T) { + const uint8_t temperature_offset = (this->type_ == PMSX003_TYPE_5003T) ? 24 : 30; - ESP_LOGD(TAG, - "Got PM1.0 Concentration: %u µg/m^3, PM2.5 Concentration %u µg/m^3, PM10.0 Concentration: %u µg/m^3, " - "Temperature: %.1f°C, Humidity: %.1f%%", - pm_1_0_concentration, pm_2_5_concentration, pm_10_0_concentration, temperature, humidity); + const float temperature = static_cast(this->get_16_bit_uint_(temperature_offset)) / 10.0f; + const float humidity = this->get_16_bit_uint_(temperature_offset + 2) / 10.0f; - if (this->pm_1_0_std_sensor_ != nullptr) - this->pm_1_0_std_sensor_->publish_state(pm_1_0_std_concentration); - if (this->pm_2_5_std_sensor_ != nullptr) - this->pm_2_5_std_sensor_->publish_state(pm_2_5_std_concentration); - if (this->pm_10_0_std_sensor_ != nullptr) - this->pm_10_0_std_sensor_->publish_state(pm_10_0_std_concentration); + ESP_LOGD(TAG, "Got Temperature: %.1f°C, Humidity: %.1f%%", temperature, humidity); - if (this->pm_1_0_sensor_ != nullptr) - this->pm_1_0_sensor_->publish_state(pm_1_0_concentration); - if (this->pm_2_5_sensor_ != nullptr) - this->pm_2_5_sensor_->publish_state(pm_2_5_concentration); - if (this->pm_10_0_sensor_ != nullptr) - this->pm_10_0_sensor_->publish_state(pm_10_0_concentration); + if (this->temperature_sensor_ != nullptr) + this->temperature_sensor_->publish_state(temperature); + if (this->humidity_sensor_ != nullptr) + this->humidity_sensor_->publish_state(humidity); + } - if (this->pm_particles_03um_sensor_ != nullptr) - this->pm_particles_03um_sensor_->publish_state(pm_particles_03um); - if (this->pm_particles_05um_sensor_ != nullptr) - this->pm_particles_05um_sensor_->publish_state(pm_particles_05um); - if (this->pm_particles_10um_sensor_ != nullptr) - this->pm_particles_10um_sensor_->publish_state(pm_particles_10um); - if (this->pm_particles_25um_sensor_ != nullptr) - this->pm_particles_25um_sensor_->publish_state(pm_particles_25um); + // Firmware Version and Error Code + if (this->type_ == PMSX003_TYPE_5003ST) { + const uint8_t firmware_version = this->data_[36]; + const uint8_t error_code = this->data_[37]; - if (this->temperature_sensor_ != nullptr) - this->temperature_sensor_->publish_state(temperature); - if (this->humidity_sensor_ != nullptr) - this->humidity_sensor_->publish_state(humidity); - break; - } + ESP_LOGD(TAG, "Got Firmware Version: 0x%02X, Error Code: 0x%02X", firmware_version, error_code); } // Spin down the sensor again if we aren't going to need it until more time has // passed than it takes to stabilise if (this->update_interval_ > PMS_STABILISING_MS) { - this->send_command_(PMS_CMD_ON_STANDBY, 0); + this->send_command_(PMS_CMD_SLEEP_MODE, PMS_CMD_SLEEP_MODE_SLEEP); this->state_ = PMSX003_STATE_IDLE; } this->status_clear_warning(); } + uint16_t PMSX003Component::get_16_bit_uint_(uint8_t start_index) { return (uint16_t(this->data_[start_index]) << 8) | uint16_t(this->data_[start_index + 1]); } -void PMSX003Component::dump_config() { - ESP_LOGCONFIG(TAG, "PMSX003:"); - LOG_SENSOR(" ", "PM1.0STD", this->pm_1_0_std_sensor_); - LOG_SENSOR(" ", "PM2.5STD", this->pm_2_5_std_sensor_); - LOG_SENSOR(" ", "PM10.0STD", this->pm_10_0_std_sensor_); - - LOG_SENSOR(" ", "PM1.0", this->pm_1_0_sensor_); - LOG_SENSOR(" ", "PM2.5", this->pm_2_5_sensor_); - LOG_SENSOR(" ", "PM10.0", this->pm_10_0_sensor_); - - LOG_SENSOR(" ", "PM0.3um", this->pm_particles_03um_sensor_); - LOG_SENSOR(" ", "PM0.5um", this->pm_particles_05um_sensor_); - LOG_SENSOR(" ", "PM1.0um", this->pm_particles_10um_sensor_); - LOG_SENSOR(" ", "PM2.5um", this->pm_particles_25um_sensor_); - LOG_SENSOR(" ", "PM5.0um", this->pm_particles_50um_sensor_); - LOG_SENSOR(" ", "PM10.0um", this->pm_particles_100um_sensor_); - - LOG_SENSOR(" ", "Temperature", this->temperature_sensor_); - LOG_SENSOR(" ", "Humidity", this->humidity_sensor_); - LOG_SENSOR(" ", "Formaldehyde", this->formaldehyde_sensor_); - this->check_uart_settings(9600); -} } // namespace pmsx003 } // namespace esphome diff --git a/esphome/components/pmsx003/pmsx003.h b/esphome/components/pmsx003/pmsx003.h index cb5c16aecf..85bb1ff9f3 100644 --- a/esphome/components/pmsx003/pmsx003.h +++ b/esphome/components/pmsx003/pmsx003.h @@ -7,13 +7,12 @@ namespace esphome { namespace pmsx003 { -// known command bytes -static const uint8_t PMS_CMD_AUTO_MANUAL = - 0xE1; // data=0: perform measurement manually, data=1: perform measurement automatically -static const uint8_t PMS_CMD_TRIG_MANUAL = 0xE2; // trigger a manual measurement -static const uint8_t PMS_CMD_ON_STANDBY = 0xE4; // data=0: go to standby mode, data=1: go to normal mode - -static const uint16_t PMS_STABILISING_MS = 30000; // time taken for the sensor to become stable after power on +enum PMSX0003Command : uint8_t { + PMS_CMD_MEASUREMENT_MODE = + 0xE1, // Data Options: `PMS_CMD_MEASUREMENT_MODE_PASSIVE`, `PMS_CMD_MEASUREMENT_MODE_ACTIVE` + PMS_CMD_MANUAL_MEASUREMENT = 0xE2, + PMS_CMD_SLEEP_MODE = 0xE4, // Data Options: `PMS_CMD_SLEEP_MODE_SLEEP`, `PMS_CMD_SLEEP_MODE_WAKEUP` +}; enum PMSX003Type { PMSX003_TYPE_X003 = 0, @@ -31,37 +30,53 @@ enum PMSX003State { class PMSX003Component : public uart::UARTDevice, public Component { public: PMSX003Component() = default; - void loop() override; - float get_setup_priority() const override; + float get_setup_priority() const override { return setup_priority::DATA; } void dump_config() override; + void loop() override; - void set_type(PMSX003Type type) { type_ = type; } + void set_update_interval(uint32_t update_interval) { this->update_interval_ = update_interval; } - void set_update_interval(uint32_t val) { update_interval_ = val; }; + void set_type(PMSX003Type type) { this->type_ = type; } - void set_pm_1_0_std_sensor(sensor::Sensor *pm_1_0_std_sensor); - void set_pm_2_5_std_sensor(sensor::Sensor *pm_2_5_std_sensor); - void set_pm_10_0_std_sensor(sensor::Sensor *pm_10_0_std_sensor); + void set_pm_1_0_std_sensor(sensor::Sensor *pm_1_0_std_sensor) { this->pm_1_0_std_sensor_ = pm_1_0_std_sensor; } + void set_pm_2_5_std_sensor(sensor::Sensor *pm_2_5_std_sensor) { this->pm_2_5_std_sensor_ = pm_2_5_std_sensor; } + void set_pm_10_0_std_sensor(sensor::Sensor *pm_10_0_std_sensor) { this->pm_10_0_std_sensor_ = pm_10_0_std_sensor; } - void set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor); - void set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor); - void set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor); + void set_pm_1_0_sensor(sensor::Sensor *pm_1_0_sensor) { this->pm_1_0_sensor_ = pm_1_0_sensor; } + void set_pm_2_5_sensor(sensor::Sensor *pm_2_5_sensor) { this->pm_2_5_sensor_ = pm_2_5_sensor; } + void set_pm_10_0_sensor(sensor::Sensor *pm_10_0_sensor) { this->pm_10_0_sensor_ = pm_10_0_sensor; } - void set_pm_particles_03um_sensor(sensor::Sensor *pm_particles_03um_sensor); - void set_pm_particles_05um_sensor(sensor::Sensor *pm_particles_05um_sensor); - void set_pm_particles_10um_sensor(sensor::Sensor *pm_particles_10um_sensor); - void set_pm_particles_25um_sensor(sensor::Sensor *pm_particles_25um_sensor); - void set_pm_particles_50um_sensor(sensor::Sensor *pm_particles_50um_sensor); - void set_pm_particles_100um_sensor(sensor::Sensor *pm_particles_100um_sensor); + void set_pm_particles_03um_sensor(sensor::Sensor *pm_particles_03um_sensor) { + this->pm_particles_03um_sensor_ = pm_particles_03um_sensor; + } + void set_pm_particles_05um_sensor(sensor::Sensor *pm_particles_05um_sensor) { + this->pm_particles_05um_sensor_ = pm_particles_05um_sensor; + } + void set_pm_particles_10um_sensor(sensor::Sensor *pm_particles_10um_sensor) { + this->pm_particles_10um_sensor_ = pm_particles_10um_sensor; + } + void set_pm_particles_25um_sensor(sensor::Sensor *pm_particles_25um_sensor) { + this->pm_particles_25um_sensor_ = pm_particles_25um_sensor; + } + void set_pm_particles_50um_sensor(sensor::Sensor *pm_particles_50um_sensor) { + this->pm_particles_50um_sensor_ = pm_particles_50um_sensor; + } + void set_pm_particles_100um_sensor(sensor::Sensor *pm_particles_100um_sensor) { + this->pm_particles_100um_sensor_ = pm_particles_100um_sensor; + } - void set_temperature_sensor(sensor::Sensor *temperature_sensor); - void set_humidity_sensor(sensor::Sensor *humidity_sensor); - void set_formaldehyde_sensor(sensor::Sensor *formaldehyde_sensor); + void set_formaldehyde_sensor(sensor::Sensor *formaldehyde_sensor) { + this->formaldehyde_sensor_ = formaldehyde_sensor; + } + + void set_temperature_sensor(sensor::Sensor *temperature_sensor) { this->temperature_sensor_ = temperature_sensor; } + void set_humidity_sensor(sensor::Sensor *humidity_sensor) { this->humidity_sensor_ = humidity_sensor; } protected: optional check_byte_(); void parse_data_(); - void send_command_(uint8_t cmd, uint16_t data); + bool check_payload_length_(uint16_t payload_length); + void send_command_(PMSX0003Command cmd, uint16_t data); uint16_t get_16_bit_uint_(uint8_t start_index); uint8_t data_[64]; @@ -92,9 +107,12 @@ class PMSX003Component : public uart::UARTDevice, public Component { sensor::Sensor *pm_particles_50um_sensor_{nullptr}; sensor::Sensor *pm_particles_100um_sensor_{nullptr}; + // Formaldehyde + sensor::Sensor *formaldehyde_sensor_{nullptr}; + + // Temperature and Humidity sensor::Sensor *temperature_sensor_{nullptr}; sensor::Sensor *humidity_sensor_{nullptr}; - sensor::Sensor *formaldehyde_sensor_{nullptr}; }; } // namespace pmsx003 diff --git a/esphome/components/pmsx003/sensor.py b/esphome/components/pmsx003/sensor.py index 1556b3c983..bebd3a01ee 100644 --- a/esphome/components/pmsx003/sensor.py +++ b/esphome/components/pmsx003/sensor.py @@ -33,6 +33,7 @@ from esphome.const import ( UNIT_PERCENT, ) +CODEOWNERS = ["@ximex"] DEPENDENCIES = ["uart"] pmsx003_ns = cg.esphome_ns.namespace("pmsx003") @@ -57,9 +58,18 @@ SENSORS_TO_TYPE = { CONF_PM_1_0: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], CONF_PM_2_5: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], CONF_PM_10_0: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_1_0_STD: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_2_5_STD: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_10_0_STD: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_0_3UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_0_5UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_1_0UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_2_5UM: [TYPE_PMSX003, TYPE_PMS5003T, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_5_0UM: [TYPE_PMSX003, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_PM_10_0UM: [TYPE_PMSX003, TYPE_PMS5003ST, TYPE_PMS5003S], + CONF_FORMALDEHYDE: [TYPE_PMS5003ST, TYPE_PMS5003S], CONF_TEMPERATURE: [TYPE_PMS5003T, TYPE_PMS5003ST], CONF_HUMIDITY: [TYPE_PMS5003T, TYPE_PMS5003ST], - CONF_FORMALDEHYDE: [TYPE_PMS5003ST, TYPE_PMS5003S], } @@ -164,6 +174,12 @@ CONFIG_SCHEMA = ( accuracy_decimals=0, state_class=STATE_CLASS_MEASUREMENT, ), + cv.Optional(CONF_FORMALDEHYDE): sensor.sensor_schema( + unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, + icon=ICON_CHEMICAL_WEAPON, + accuracy_decimals=0, + state_class=STATE_CLASS_MEASUREMENT, + ), cv.Optional(CONF_TEMPERATURE): sensor.sensor_schema( unit_of_measurement=UNIT_CELSIUS, accuracy_decimals=1, @@ -176,12 +192,6 @@ CONFIG_SCHEMA = ( device_class=DEVICE_CLASS_HUMIDITY, state_class=STATE_CLASS_MEASUREMENT, ), - cv.Optional(CONF_FORMALDEHYDE): sensor.sensor_schema( - unit_of_measurement=UNIT_MICROGRAMS_PER_CUBIC_METER, - icon=ICON_CHEMICAL_WEAPON, - accuracy_decimals=0, - state_class=STATE_CLASS_MEASUREMENT, - ), cv.Optional(CONF_UPDATE_INTERVAL, default="0s"): validate_update_interval, } ) @@ -256,6 +266,10 @@ async def to_code(config): sens = await sensor.new_sensor(config[CONF_PM_10_0UM]) cg.add(var.set_pm_particles_100um_sensor(sens)) + if CONF_FORMALDEHYDE in config: + sens = await sensor.new_sensor(config[CONF_FORMALDEHYDE]) + cg.add(var.set_formaldehyde_sensor(sens)) + if CONF_TEMPERATURE in config: sens = await sensor.new_sensor(config[CONF_TEMPERATURE]) cg.add(var.set_temperature_sensor(sens)) @@ -264,8 +278,4 @@ async def to_code(config): sens = await sensor.new_sensor(config[CONF_HUMIDITY]) cg.add(var.set_humidity_sensor(sens)) - if CONF_FORMALDEHYDE in config: - sens = await sensor.new_sensor(config[CONF_FORMALDEHYDE]) - cg.add(var.set_formaldehyde_sensor(sens)) - cg.add(var.set_update_interval(config[CONF_UPDATE_INTERVAL])) diff --git a/esphome/components/pn7150/pn7150.h b/esphome/components/pn7150/pn7150.h index 54038f5085..87af7d629b 100644 --- a/esphome/components/pn7150/pn7150.h +++ b/esphome/components/pn7150/pn7150.h @@ -123,8 +123,8 @@ enum class NCIState : uint8_t { RFST_POLL_ACTIVE, EP_DEACTIVATING, EP_SELECTING, - TEST = 0XFE, - FAILED = 0XFF, + TEST = 0xFE, + FAILED = 0xFF, }; enum class TestMode : uint8_t { diff --git a/esphome/components/pn7160/pn7160.h b/esphome/components/pn7160/pn7160.h index f2e05ea1d0..ff8a492b7b 100644 --- a/esphome/components/pn7160/pn7160.h +++ b/esphome/components/pn7160/pn7160.h @@ -138,8 +138,8 @@ enum class NCIState : uint8_t { RFST_POLL_ACTIVE, EP_DEACTIVATING, EP_SELECTING, - TEST = 0XFE, - FAILED = 0XFF, + TEST = 0xFE, + FAILED = 0xFF, }; enum class TestMode : uint8_t { diff --git a/esphome/components/prometheus/prometheus_handler.cpp b/esphome/components/prometheus/prometheus_handler.cpp index 794df299a1..2677860c7c 100644 --- a/esphome/components/prometheus/prometheus_handler.cpp +++ b/esphome/components/prometheus/prometheus_handler.cpp @@ -89,6 +89,12 @@ void PrometheusHandler::handleRequest(AsyncWebServerRequest *req) { this->valve_row_(stream, obj, area, node, friendly_name); #endif +#ifdef USE_CLIMATE + this->climate_type_(stream); + for (auto *obj : App.get_climates()) + this->climate_row_(stream, obj, area, node, friendly_name); +#endif + req->send(stream); } @@ -824,6 +830,174 @@ void PrometheusHandler::valve_row_(AsyncResponseStream *stream, valve::Valve *ob } #endif +#ifdef USE_CLIMATE +void PrometheusHandler::climate_type_(AsyncResponseStream *stream) { + stream->print(F("#TYPE esphome_climate_setting gauge\n")); + stream->print(F("#TYPE esphome_climate_value gauge\n")); + stream->print(F("#TYPE esphome_climate_failed gauge\n")); +} + +void PrometheusHandler::climate_setting_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, + std::string &node, std::string &friendly_name, std::string &setting, + const LogString *setting_value) { + stream->print(F("esphome_climate_setting{id=\"")); + stream->print(relabel_id_(obj).c_str()); + add_area_label_(stream, area); + add_node_label_(stream, node); + add_friendly_name_label_(stream, friendly_name); + stream->print(F("\",name=\"")); + stream->print(relabel_name_(obj).c_str()); + stream->print(F("\",category=\"")); + stream->print(setting.c_str()); + stream->print(F("\",setting_value=\"")); + stream->print(LOG_STR_ARG(setting_value)); + stream->print(F("\"} ")); + stream->print(F("1.0")); + stream->print(F("\n")); +} + +void PrometheusHandler::climate_value_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, + std::string &node, std::string &friendly_name, std::string &category, + std::string &climate_value) { + stream->print(F("esphome_climate_value{id=\"")); + stream->print(relabel_id_(obj).c_str()); + add_area_label_(stream, area); + add_node_label_(stream, node); + add_friendly_name_label_(stream, friendly_name); + stream->print(F("\",name=\"")); + stream->print(relabel_name_(obj).c_str()); + stream->print(F("\",category=\"")); + stream->print(category.c_str()); + stream->print(F("\"} ")); + stream->print(climate_value.c_str()); + stream->print(F("\n")); +} + +void PrometheusHandler::climate_failed_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, + std::string &node, std::string &friendly_name, std::string &category, + bool is_failed_value) { + stream->print(F("esphome_climate_failed{id=\"")); + stream->print(relabel_id_(obj).c_str()); + add_area_label_(stream, area); + add_node_label_(stream, node); + add_friendly_name_label_(stream, friendly_name); + stream->print(F("\",name=\"")); + stream->print(relabel_name_(obj).c_str()); + stream->print(F("\",category=\"")); + stream->print(category.c_str()); + stream->print(F("\"} ")); + if (is_failed_value) { + stream->print(F("1.0")); + } else { + stream->print(F("0.0")); + } + stream->print(F("\n")); +} + +void PrometheusHandler::climate_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, + std::string &node, std::string &friendly_name) { + if (obj->is_internal() && !this->include_internal_) + return; + // Data itself + bool any_failures = false; + std::string climate_mode_category = "mode"; + const auto *climate_mode_value = climate::climate_mode_to_string(obj->mode); + climate_setting_row_(stream, obj, area, node, friendly_name, climate_mode_category, climate_mode_value); + const auto traits = obj->get_traits(); + // Now see if traits is supported + int8_t target_accuracy = traits.get_target_temperature_accuracy_decimals(); + int8_t current_accuracy = traits.get_current_temperature_accuracy_decimals(); + // max temp + std::string max_temp = "maximum_temperature"; + auto max_temp_value = value_accuracy_to_string(traits.get_visual_max_temperature(), target_accuracy); + climate_value_row_(stream, obj, area, node, friendly_name, max_temp, max_temp_value); + // max temp + std::string min_temp = "mininum_temperature"; + auto min_temp_value = value_accuracy_to_string(traits.get_visual_min_temperature(), target_accuracy); + climate_value_row_(stream, obj, area, node, friendly_name, min_temp, min_temp_value); + // now check optional traits + if (traits.get_supports_current_temperature()) { + std::string current_temp = "current_temperature"; + if (std::isnan(obj->current_temperature)) { + climate_failed_row_(stream, obj, area, node, friendly_name, current_temp, true); + any_failures = true; + } else { + auto current_temp_value = value_accuracy_to_string(obj->current_temperature, current_accuracy); + climate_value_row_(stream, obj, area, node, friendly_name, current_temp, current_temp_value); + climate_failed_row_(stream, obj, area, node, friendly_name, current_temp, false); + } + } + if (traits.get_supports_current_humidity()) { + std::string current_humidity = "current_humidity"; + if (std::isnan(obj->current_humidity)) { + climate_failed_row_(stream, obj, area, node, friendly_name, current_humidity, true); + any_failures = true; + } else { + auto current_humidity_value = value_accuracy_to_string(obj->current_humidity, 0); + climate_value_row_(stream, obj, area, node, friendly_name, current_humidity, current_humidity_value); + climate_failed_row_(stream, obj, area, node, friendly_name, current_humidity, false); + } + } + if (traits.get_supports_target_humidity()) { + std::string target_humidity = "target_humidity"; + if (std::isnan(obj->target_humidity)) { + climate_failed_row_(stream, obj, area, node, friendly_name, target_humidity, true); + any_failures = true; + } else { + auto target_humidity_value = value_accuracy_to_string(obj->target_humidity, 0); + climate_value_row_(stream, obj, area, node, friendly_name, target_humidity, target_humidity_value); + climate_failed_row_(stream, obj, area, node, friendly_name, target_humidity, false); + } + } + if (traits.get_supports_two_point_target_temperature()) { + std::string target_temp_low = "target_temperature_low"; + auto target_temp_low_value = value_accuracy_to_string(obj->target_temperature_low, target_accuracy); + climate_value_row_(stream, obj, area, node, friendly_name, target_temp_low, target_temp_low_value); + std::string target_temp_high = "target_temperature_high"; + auto target_temp_high_value = value_accuracy_to_string(obj->target_temperature_high, target_accuracy); + climate_value_row_(stream, obj, area, node, friendly_name, target_temp_high, target_temp_high_value); + } else { + std::string target_temp = "target_temperature"; + auto target_temp_value = value_accuracy_to_string(obj->target_temperature, target_accuracy); + climate_value_row_(stream, obj, area, node, friendly_name, target_temp, target_temp_value); + } + if (traits.get_supports_action()) { + std::string climate_trait_category = "action"; + const auto *climate_trait_value = climate::climate_action_to_string(obj->action); + climate_setting_row_(stream, obj, area, node, friendly_name, climate_trait_category, climate_trait_value); + } + if (traits.get_supports_fan_modes()) { + std::string climate_trait_category = "fan_mode"; + if (obj->fan_mode.has_value()) { + const auto *climate_trait_value = climate::climate_fan_mode_to_string(obj->fan_mode.value()); + climate_setting_row_(stream, obj, area, node, friendly_name, climate_trait_category, climate_trait_value); + climate_failed_row_(stream, obj, area, node, friendly_name, climate_trait_category, false); + } else { + climate_failed_row_(stream, obj, area, node, friendly_name, climate_trait_category, true); + any_failures = true; + } + } + if (traits.get_supports_presets()) { + std::string climate_trait_category = "preset"; + if (obj->preset.has_value()) { + const auto *climate_trait_value = climate::climate_preset_to_string(obj->preset.value()); + climate_setting_row_(stream, obj, area, node, friendly_name, climate_trait_category, climate_trait_value); + climate_failed_row_(stream, obj, area, node, friendly_name, climate_trait_category, false); + } else { + climate_failed_row_(stream, obj, area, node, friendly_name, climate_trait_category, true); + any_failures = true; + } + } + if (traits.get_supports_swing_modes()) { + std::string climate_trait_category = "swing_mode"; + const auto *climate_trait_value = climate::climate_swing_mode_to_string(obj->swing_mode); + climate_setting_row_(stream, obj, area, node, friendly_name, climate_trait_category, climate_trait_value); + } + std::string all_climate_category = "all"; + climate_failed_row_(stream, obj, area, node, friendly_name, all_climate_category, any_failures); +} +#endif + } // namespace prometheus } // namespace esphome #endif diff --git a/esphome/components/prometheus/prometheus_handler.h b/esphome/components/prometheus/prometheus_handler.h index b77dbc462b..bdc3d971ce 100644 --- a/esphome/components/prometheus/prometheus_handler.h +++ b/esphome/components/prometheus/prometheus_handler.h @@ -8,6 +8,9 @@ #include "esphome/core/component.h" #include "esphome/core/controller.h" #include "esphome/core/entity_base.h" +#ifdef USE_CLIMATE +#include "esphome/core/log.h" +#endif namespace esphome { namespace prometheus { @@ -169,6 +172,20 @@ class PrometheusHandler : public AsyncWebHandler, public Component { std::string &friendly_name); #endif +#ifdef USE_CLIMATE + /// Return the type for prometheus + void climate_type_(AsyncResponseStream *stream); + /// Return the climate state as prometheus data point + void climate_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, + std::string &friendly_name); + void climate_failed_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, + std::string &friendly_name, std::string &category, bool is_failed_value); + void climate_setting_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, + std::string &friendly_name, std::string &setting, const LogString *setting_value); + void climate_value_row_(AsyncResponseStream *stream, climate::Climate *obj, std::string &area, std::string &node, + std::string &friendly_name, std::string &category, std::string &climate_value); +#endif + web_server_base::WebServerBase *base_; bool include_internal_{false}; std::map relabel_map_id_; diff --git a/esphome/components/psram/__init__.py b/esphome/components/psram/__init__.py index f268d5747f..53ba54dd28 100644 --- a/esphome/components/psram/__init__.py +++ b/esphome/components/psram/__init__.py @@ -16,6 +16,8 @@ from esphome.const import ( CONF_ID, CONF_MODE, CONF_SPEED, + KEY_CORE, + KEY_FRAMEWORK_VERSION, PLATFORM_ESP32, ) from esphome.core import CORE @@ -110,11 +112,11 @@ async def to_code(config): add_idf_sdkconfig_option(f"{SPIRAM_MODES[config[CONF_MODE]]}", True) add_idf_sdkconfig_option(f"{SPIRAM_SPEEDS[config[CONF_SPEED]]}", True) if config[CONF_MODE] == TYPE_OCTAL and config[CONF_SPEED] == 120e6: - add_idf_sdkconfig_option("CONFIG_ESP32S3_DEFAULT_CPU_FREQ_240", True) - # This works only on IDF 5.4.x but does no harm on earlier versions - add_idf_sdkconfig_option( - "CONFIG_SPIRAM_TIMING_TUNING_POINT_VIA_TEMPERATURE_SENSOR", True - ) + add_idf_sdkconfig_option("CONFIG_ESP_DEFAULT_CPU_FREQ_MHZ_240", True) + if CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] >= cv.Version(5, 4, 0): + add_idf_sdkconfig_option( + "CONFIG_SPIRAM_TIMING_TUNING_POINT_VIA_TEMPERATURE_SENSOR", True + ) if config[CONF_ENABLE_ECC]: add_idf_sdkconfig_option("CONFIG_SPIRAM_ECC_ENABLE", True) diff --git a/esphome/components/pulse_meter/pulse_meter_sensor.cpp b/esphome/components/pulse_meter/pulse_meter_sensor.cpp index 836a84b391..b82cb7a15c 100644 --- a/esphome/components/pulse_meter/pulse_meter_sensor.cpp +++ b/esphome/components/pulse_meter/pulse_meter_sensor.cpp @@ -18,6 +18,9 @@ void PulseMeterSensor::setup() { this->pin_->setup(); this->isr_pin_ = pin_->to_isr(); + // Set the pin value to the current value to avoid a false edge + this->last_pin_val_ = this->pin_->digital_read(); + // Set the last processed edge to now for the first timeout this->last_processed_edge_us_ = micros(); @@ -25,23 +28,37 @@ void PulseMeterSensor::setup() { this->pin_->attach_interrupt(PulseMeterSensor::edge_intr, this, gpio::INTERRUPT_RISING_EDGE); } else if (this->filter_mode_ == FILTER_PULSE) { // Set the pin value to the current value to avoid a false edge - this->pulse_state_.last_pin_val_ = this->isr_pin_.digital_read(); - this->pulse_state_.latched_ = this->pulse_state_.last_pin_val_; + this->pulse_state_.latched_ = this->last_pin_val_; this->pin_->attach_interrupt(PulseMeterSensor::pulse_intr, this, gpio::INTERRUPT_ANY_EDGE); } } void PulseMeterSensor::loop() { - const uint32_t now = micros(); - // Reset the count in get before we pass it back to the ISR as set this->get_->count_ = 0; - // Swap out set and get to get the latest state from the ISR - // The ISR could interrupt on any of these lines and the results would be consistent - auto *temp = this->set_; - this->set_ = this->get_; - this->get_ = temp; + { + // Lock the interrupt so the interrupt code doesn't interfere with itself + InterruptLock lock; + + // Sometimes ESP devices miss interrupts if the edge rises or falls too slowly. + // See https://github.com/espressif/arduino-esp32/issues/4172 + // If the edges are rising too slowly it also implies that the pulse rate is slow. + // Therefore the update rate of the loop is likely fast enough to detect the edges. + // When the main loop detects an edge that the ISR didn't it will run the ISR functions directly. + bool current = this->pin_->digital_read(); + if (this->filter_mode_ == FILTER_EDGE && current && !this->last_pin_val_) { + PulseMeterSensor::edge_intr(this); + } else if (this->filter_mode_ == FILTER_PULSE && current != this->last_pin_val_) { + PulseMeterSensor::pulse_intr(this); + } + this->last_pin_val_ = current; + + // Swap out set and get to get the latest state from the ISR + std::swap(this->set_, this->get_); + } + + const uint32_t now = micros(); // If an edge was peeked, repay the debt if (this->peeked_edge_ && this->get_->count_ > 0) { @@ -131,6 +148,9 @@ void IRAM_ATTR PulseMeterSensor::edge_intr(PulseMeterSensor *sensor) { set.last_rising_edge_us_ = now; set.count_++; } + + // This ISR is bound to rising edges, so the pin is high + sensor->last_pin_val_ = true; } void IRAM_ATTR PulseMeterSensor::pulse_intr(PulseMeterSensor *sensor) { @@ -144,9 +164,9 @@ void IRAM_ATTR PulseMeterSensor::pulse_intr(PulseMeterSensor *sensor) { // Filter length has passed since the last interrupt const bool length = now - state.last_intr_ >= sensor->filter_us_; - if (length && state.latched_ && !state.last_pin_val_) { // Long enough low edge + if (length && state.latched_ && !sensor->last_pin_val_) { // Long enough low edge state.latched_ = false; - } else if (length && !state.latched_ && state.last_pin_val_) { // Long enough high edge + } else if (length && !state.latched_ && sensor->last_pin_val_) { // Long enough high edge state.latched_ = true; set.last_detected_edge_us_ = state.last_intr_; set.count_++; @@ -158,7 +178,7 @@ void IRAM_ATTR PulseMeterSensor::pulse_intr(PulseMeterSensor *sensor) { set.last_rising_edge_us_ = !state.latched_ && pin_val ? now : set.last_detected_edge_us_; state.last_intr_ = now; - state.last_pin_val_ = pin_val; + sensor->last_pin_val_ = pin_val; } } // namespace pulse_meter diff --git a/esphome/components/pulse_meter/pulse_meter_sensor.h b/esphome/components/pulse_meter/pulse_meter_sensor.h index 76c4a35f03..748bab29ac 100644 --- a/esphome/components/pulse_meter/pulse_meter_sensor.h +++ b/esphome/components/pulse_meter/pulse_meter_sensor.h @@ -49,9 +49,7 @@ class PulseMeterSensor : public sensor::Sensor, public Component { // This struct (and the two pointers) are used to pass data between the ISR and loop. // These two pointers are exchanged each loop. - // Therefore you can't use data in the pointer to loop receives to set values in the pointer to loop sends. - // As a result it's easiest if you only use these pointers to send data from the ISR to the loop. - // (except for resetting the values) + // Use these to send data from the ISR to the loop not the other way around (except for resetting the values). struct State { uint32_t last_detected_edge_us_ = 0; uint32_t last_rising_edge_us_ = 0; @@ -61,9 +59,12 @@ class PulseMeterSensor : public sensor::Sensor, public Component { volatile State *set_ = state_; volatile State *get_ = state_ + 1; - // Only use these variables in the ISR + // Only use the following variables in the ISR or while guarded by an InterruptLock ISRInternalGPIOPin isr_pin_; + /// The last pin value seen + bool last_pin_val_ = false; + /// Filter state for edge mode struct EdgeState { uint32_t last_sent_edge_us_ = 0; @@ -74,7 +75,6 @@ class PulseMeterSensor : public sensor::Sensor, public Component { struct PulseState { uint32_t last_intr_ = 0; bool latched_ = false; - bool last_pin_val_ = false; }; PulseState pulse_state_{}; }; diff --git a/esphome/components/qspi_dbi/__init__.py b/esphome/components/qspi_dbi/__init__.py index a4b833f6d7..290a864335 100644 --- a/esphome/components/qspi_dbi/__init__.py +++ b/esphome/components/qspi_dbi/__init__.py @@ -1,4 +1,3 @@ CODEOWNERS = ["@clydebarrow"] CONF_DRAW_FROM_ORIGIN = "draw_from_origin" -CONF_DRAW_ROUNDING = "draw_rounding" diff --git a/esphome/components/qspi_dbi/display.py b/esphome/components/qspi_dbi/display.py index 8c29991f37..5b01bcc6ca 100644 --- a/esphome/components/qspi_dbi/display.py +++ b/esphome/components/qspi_dbi/display.py @@ -1,6 +1,7 @@ from esphome import pins import esphome.codegen as cg from esphome.components import display, spi +from esphome.components.const import CONF_DRAW_ROUNDING import esphome.config_validation as cv from esphome.const import ( CONF_BRIGHTNESS, @@ -24,7 +25,7 @@ from esphome.const import ( ) from esphome.core import TimePeriod -from . import CONF_DRAW_FROM_ORIGIN, CONF_DRAW_ROUNDING +from . import CONF_DRAW_FROM_ORIGIN from .models import DriverChip DEPENDENCIES = ["spi"] diff --git a/esphome/components/qspi_dbi/models.py b/esphome/components/qspi_dbi/models.py index 7ae1a10ec0..8ce592e0cf 100644 --- a/esphome/components/qspi_dbi/models.py +++ b/esphome/components/qspi_dbi/models.py @@ -1,8 +1,7 @@ # Commands +from esphome.components.const import CONF_DRAW_ROUNDING from esphome.const import CONF_INVERT_COLORS, CONF_SWAP_XY -from . import CONF_DRAW_ROUNDING - SW_RESET_CMD = 0x01 SLEEP_IN = 0x10 SLEEP_OUT = 0x11 diff --git a/esphome/components/remote_base/__init__.py b/esphome/components/remote_base/__init__.py index daea4e5c11..836b98104b 100644 --- a/esphome/components/remote_base/__init__.py +++ b/esphome/components/remote_base/__init__.py @@ -28,6 +28,7 @@ from esphome.const import ( CONF_RC_CODE_2, CONF_REPEAT, CONF_SECOND, + CONF_SOURCE, CONF_STATE, CONF_SYNC, CONF_TIMES, @@ -265,6 +266,53 @@ async def build_dumpers(config): return dumpers +# Beo4 +Beo4Data, Beo4BinarySensor, Beo4Trigger, Beo4Action, Beo4Dumper = declare_protocol( + "Beo4" +) +BEO4_SCHEMA = cv.Schema( + { + cv.Required(CONF_SOURCE): cv.hex_uint8_t, + cv.Required(CONF_COMMAND): cv.hex_uint8_t, + cv.Optional(CONF_COMMAND_REPEATS, default=1): cv.uint8_t, + } +) + + +@register_binary_sensor("beo4", Beo4BinarySensor, BEO4_SCHEMA) +def beo4_binary_sensor(var, config): + cg.add( + var.set_data( + cg.StructInitializer( + Beo4Data, + ("source", config[CONF_SOURCE]), + ("command", config[CONF_COMMAND]), + ("repeats", config[CONF_COMMAND_REPEATS]), + ) + ) + ) + + +@register_trigger("beo4", Beo4Trigger, Beo4Data) +def beo4_trigger(var, config): + pass + + +@register_dumper("beo4", Beo4Dumper) +def beo4_dumper(var, config): + pass + + +@register_action("beo4", Beo4Action, BEO4_SCHEMA) +async def beo4_action(var, config, args): + template_ = await cg.templatable(config[CONF_SOURCE], args, cg.uint8) + cg.add(var.set_source(template_)) + template_ = await cg.templatable(config[CONF_COMMAND], args, cg.uint8) + cg.add(var.set_command(template_)) + template_ = await cg.templatable(config[CONF_COMMAND_REPEATS], args, cg.uint8) + cg.add(var.set_repeats(template_)) + + # ByronSX ( ByronSXData, @@ -881,6 +929,49 @@ async def pronto_action(var, config, args): cg.add(var.set_data(template_)) +# Gobox +( + GoboxData, + GoboxBinarySensor, + GoboxTrigger, + GoboxAction, + GoboxDumper, +) = declare_protocol("Gobox") +GOBOX_SCHEMA = cv.Schema( + { + cv.Required(CONF_CODE): cv.int_, + } +) + + +@register_binary_sensor("gobox", GoboxBinarySensor, GOBOX_SCHEMA) +def gobox_binary_sensor(var, config): + cg.add( + var.set_data( + cg.StructInitializer( + GoboxData, + ("code", config[CONF_CODE]), + ) + ) + ) + + +@register_trigger("gobox", GoboxTrigger, GoboxData) +def gobox_trigger(var, config): + pass + + +@register_dumper("gobox", GoboxDumper) +def gobox_dumper(var, config): + pass + + +@register_action("gobox", GoboxAction, GOBOX_SCHEMA) +async def gobox_action(var, config, args): + template_ = await cg.templatable(config[CONF_CODE], args, cg.int_) + cg.add(var.set_code(template_)) + + # Roomba ( RoombaData, diff --git a/esphome/components/remote_base/beo4_protocol.cpp b/esphome/components/remote_base/beo4_protocol.cpp new file mode 100644 index 0000000000..8f5a642401 --- /dev/null +++ b/esphome/components/remote_base/beo4_protocol.cpp @@ -0,0 +1,153 @@ +#include "beo4_protocol.h" +#include "esphome/core/log.h" + +#include + +namespace esphome { +namespace remote_base { + +static const char *const TAG = "remote.beo4"; + +// beo4 pulse width, high=carrier_pulse low=data_pulse +constexpr uint16_t PW_CARR_US = 200; // carrier pulse length +constexpr uint16_t PW_ZERO_US = 2925; // + 200 = 3125 µs +constexpr uint16_t PW_SAME_US = 6050; // + 200 = 6250 µs +constexpr uint16_t PW_ONE_US = 9175; // + 200 = 9375 µs +constexpr uint16_t PW_STOP_US = 12300; // + 200 = 12500 µs +constexpr uint16_t PW_START_US = 15425; // + 200 = 15625 µs + +// beo4 pulse codes +constexpr uint8_t PC_ZERO = (PW_CARR_US + PW_ZERO_US) / 3125; // =1 +constexpr uint8_t PC_SAME = (PW_CARR_US + PW_SAME_US) / 3125; // =2 +constexpr uint8_t PC_ONE = (PW_CARR_US + PW_ONE_US) / 3125; // =3 +constexpr uint8_t PC_STOP = (PW_CARR_US + PW_STOP_US) / 3125; // =4 +constexpr uint8_t PC_START = (PW_CARR_US + PW_START_US) / 3125; // =5 + +// beo4 number of data bits = beoLink+beoSrc+beoCmd = 1+8+8 = 17 +constexpr uint32_t N_BITS = 1 + 8 + 8; + +// required symbols = 2*(start_sequence + n_bits + stop) = 2*(3+17+1) = 42 +constexpr uint32_t N_SYM = 2 + ((3 + 17 + 1) * 2u); // + 2 = 44 + +// states finite-state-machine decoder +enum class RxSt { RX_IDLE, RX_DATA, RX_STOP }; + +void Beo4Protocol::encode(RemoteTransmitData *dst, const Beo4Data &data) { + uint32_t beo_code = ((uint32_t) data.source << 8) + (uint32_t) data.command; + uint32_t jc = 0, ic = 0; + uint32_t cur_bit = 0; + uint32_t pre_bit = 0; + dst->set_carrier_frequency(455000); + dst->reserve(N_SYM); + + // start sequence=zero,zero,start + dst->item(PW_CARR_US, PW_ZERO_US); + dst->item(PW_CARR_US, PW_ZERO_US); + dst->item(PW_CARR_US, PW_START_US); + + // the data-bit BeoLink is always 0 + dst->item(PW_CARR_US, PW_ZERO_US); + + // The B&O trick to avoid extra long and extra short + // code-frames by extracting the data-bits from left + // to right, then comparing current with previous bit + // and set pulse to "same" "one" or "zero" + for (jc = 15, ic = 0; ic < 16; ic++, jc--) { + cur_bit = ((beo_code) >> jc) & 1; + if (cur_bit == pre_bit) { + dst->item(PW_CARR_US, PW_SAME_US); + } else if (1 == cur_bit) { + dst->item(PW_CARR_US, PW_ONE_US); + } else { + dst->item(PW_CARR_US, PW_ZERO_US); + } + pre_bit = cur_bit; + } + // complete the frame with stop-symbol and final carrier pulse + dst->item(PW_CARR_US, PW_STOP_US); + dst->mark(PW_CARR_US); +} + +optional Beo4Protocol::decode(RemoteReceiveData src) { + int32_t n_sym = src.size(); + Beo4Data data{ + .source = 0, + .command = 0, + .repeats = 0, + }; + // suppress dummy codes (TSO7000 hiccups) + if (n_sym > 42) { + static uint32_t beo_code = 0; + RxSt fsm = RxSt::RX_IDLE; + int32_t ic = 0; + int32_t jc = 0; + uint32_t pre_bit = 0; + uint32_t cnt_bit = 0; + ESP_LOGD(TAG, "Beo4: n_sym=%" PRId32, n_sym); + for (jc = 0, ic = 0; ic < (n_sym - 1); ic += 2, jc++) { + int32_t pulse_width = src[ic] - src[ic + 1]; + // suppress TSOP7000 (dummy pulses) + if (pulse_width > 1500) { + int32_t pulse_code = (pulse_width + 1560) / 3125; + switch (fsm) { + case RxSt::RX_IDLE: { + beo_code = 0; + cnt_bit = 0; + pre_bit = 0; + if (PC_START == pulse_code) { + fsm = RxSt::RX_DATA; + } + break; + } + case RxSt::RX_DATA: { + uint32_t cur_bit = 0; + switch (pulse_code) { + case PC_ZERO: { + cur_bit = pre_bit = 0; + break; + } + case PC_SAME: { + cur_bit = pre_bit; + break; + } + case PC_ONE: { + cur_bit = pre_bit = 1; + break; + } + default: { + fsm = RxSt::RX_IDLE; + break; + } + } + beo_code = (beo_code << 1) + cur_bit; + if (++cnt_bit == N_BITS) { + fsm = RxSt::RX_STOP; + } + break; + } + case RxSt::RX_STOP: { + if (PC_STOP == pulse_code) { + data.source = (uint8_t) ((beo_code >> 8) & 0xff); + data.command = (uint8_t) ((beo_code) &0xff); + data.repeats++; + } + if ((n_sym - ic) < 42) { + return data; + } else { + fsm = RxSt::RX_IDLE; + } + break; + } + } + } + } + } + return {}; // decoding failed +} + +void Beo4Protocol::dump(const Beo4Data &data) { + ESP_LOGI(TAG, "Beo4: source=0x%02x command=0x%02x repeats=%d ", data.source, data.command, data.repeats); +} + +} // namespace remote_base +} // namespace esphome diff --git a/esphome/components/remote_base/beo4_protocol.h b/esphome/components/remote_base/beo4_protocol.h new file mode 100644 index 0000000000..445e792cbc --- /dev/null +++ b/esphome/components/remote_base/beo4_protocol.h @@ -0,0 +1,43 @@ +#pragma once + +#include "remote_base.h" + +#include + +namespace esphome { +namespace remote_base { + +struct Beo4Data { + uint8_t source; // beoSource, e.g. video, audio, light... + uint8_t command; // beoCommend, e.g. volume+, mute,... + uint8_t repeats; // beoRepeat for repeat commands, e.g. up, down... + + bool operator==(const Beo4Data &rhs) const { return source == rhs.source && command == rhs.command; } +}; + +class Beo4Protocol : public RemoteProtocol { + public: + void encode(RemoteTransmitData *dst, const Beo4Data &data) override; + optional decode(RemoteReceiveData src) override; + void dump(const Beo4Data &data) override; +}; + +DECLARE_REMOTE_PROTOCOL(Beo4) + +template class Beo4Action : public RemoteTransmitterActionBase { + public: + TEMPLATABLE_VALUE(uint8_t, source) + TEMPLATABLE_VALUE(uint8_t, command) + TEMPLATABLE_VALUE(uint8_t, repeats) + + void encode(RemoteTransmitData *dst, Ts... x) override { + Beo4Data data{}; + data.source = this->source_.value(x...); + data.command = this->command_.value(x...); + data.repeats = this->repeats_.value(x...); + Beo4Protocol().encode(dst, data); + } +}; + +} // namespace remote_base +} // namespace esphome diff --git a/esphome/components/remote_base/gobox_protocol.cpp b/esphome/components/remote_base/gobox_protocol.cpp new file mode 100644 index 0000000000..54e0dff663 --- /dev/null +++ b/esphome/components/remote_base/gobox_protocol.cpp @@ -0,0 +1,131 @@ +#include "gobox_protocol.h" +#include "esphome/core/log.h" + +namespace esphome { +namespace remote_base { + +static const char *const TAG = "remote.gobox"; + +constexpr uint32_t BIT_MARK_US = 580; // 70us seems like a safe time delta for the receiver... +constexpr uint32_t BIT_ONE_SPACE_US = 1640; +constexpr uint32_t BIT_ZERO_SPACE_US = 545; +constexpr uint64_t HEADER = 0b011001001100010uL; // 15 bits +constexpr uint64_t HEADER_SIZE = 15; +constexpr uint64_t CODE_SIZE = 17; + +void GoboxProtocol::dump_timings_(const RawTimings &timings) const { + ESP_LOGD(TAG, "Gobox: size=%u", timings.size()); + for (int32_t timing : timings) { + ESP_LOGD(TAG, "Gobox: timing=%ld", (long) timing); + } +} + +void GoboxProtocol::encode(RemoteTransmitData *dst, const GoboxData &data) { + ESP_LOGI(TAG, "Send Gobox: code=0x%x", data.code); + dst->set_carrier_frequency(38000); + dst->reserve((HEADER_SIZE + CODE_SIZE + 1) * 2); + uint64_t code = (HEADER << CODE_SIZE) | (data.code & ((1UL << CODE_SIZE) - 1)); + ESP_LOGI(TAG, "Send Gobox: code=0x%Lx", code); + for (int16_t i = (HEADER_SIZE + CODE_SIZE - 1); i >= 0; i--) { + if (code & ((uint64_t) 1 << i)) { + dst->item(BIT_MARK_US, BIT_ONE_SPACE_US); + } else { + dst->item(BIT_MARK_US, BIT_ZERO_SPACE_US); + } + } + dst->item(BIT_MARK_US, 2000); + + dump_timings_(dst->get_data()); +} + +optional GoboxProtocol::decode(RemoteReceiveData src) { + if (src.size() < ((HEADER_SIZE + CODE_SIZE) * 2 + 1)) { + return {}; + } + + // First check for the header + uint64_t code = HEADER; + for (int16_t i = HEADER_SIZE - 1; i >= 0; i--) { + if (code & ((uint64_t) 1 << i)) { + if (!src.expect_item(BIT_MARK_US, BIT_ONE_SPACE_US)) { + return {}; + } + } else { + if (!src.expect_item(BIT_MARK_US, BIT_ZERO_SPACE_US)) { + return {}; + } + } + } + + // Next, build up the code + code = 0UL; + for (int16_t i = CODE_SIZE - 1; i >= 0; i--) { + if (!src.expect_mark(BIT_MARK_US)) { + return {}; + } + if (src.expect_space(BIT_ONE_SPACE_US)) { + code |= (1UL << i); + } else if (!src.expect_space(BIT_ZERO_SPACE_US)) { + return {}; + } + } + + if (!src.expect_mark(BIT_MARK_US)) { + return {}; + } + + dump_timings_(src.get_raw_data()); + + GoboxData out; + out.code = code; + + return out; +} + +void GoboxProtocol::dump(const GoboxData &data) { + ESP_LOGI(TAG, "Received Gobox: code=0x%x", data.code); + switch (data.code) { + case GOBOX_MENU: + ESP_LOGI(TAG, "Received Gobox: key=MENU"); + break; + case GOBOX_RETURN: + ESP_LOGI(TAG, "Received Gobox: key=RETURN"); + break; + case GOBOX_UP: + ESP_LOGI(TAG, "Received Gobox: key=UP"); + break; + case GOBOX_LEFT: + ESP_LOGI(TAG, "Received Gobox: key=LEFT"); + break; + case GOBOX_RIGHT: + ESP_LOGI(TAG, "Received Gobox: key=RIGHT"); + break; + case GOBOX_DOWN: + ESP_LOGI(TAG, "Received Gobox: key=DOWN"); + break; + case GOBOX_OK: + ESP_LOGI(TAG, "Received Gobox: key=OK"); + break; + case GOBOX_TOGGLE: + ESP_LOGI(TAG, "Received Gobox: key=TOGGLE"); + break; + case GOBOX_PROFILE: + ESP_LOGI(TAG, "Received Gobox: key=PROFILE"); + break; + case GOBOX_FASTER: + ESP_LOGI(TAG, "Received Gobox: key=FASTER"); + break; + case GOBOX_SLOWER: + ESP_LOGI(TAG, "Received Gobox: key=SLOWER"); + break; + case GOBOX_LOUDER: + ESP_LOGI(TAG, "Received Gobox: key=LOUDER"); + break; + case GOBOX_SOFTER: + ESP_LOGI(TAG, "Received Gobox: key=SOFTER"); + break; + } +} + +} // namespace remote_base +} // namespace esphome diff --git a/esphome/components/remote_base/gobox_protocol.h b/esphome/components/remote_base/gobox_protocol.h new file mode 100644 index 0000000000..7e18b61458 --- /dev/null +++ b/esphome/components/remote_base/gobox_protocol.h @@ -0,0 +1,54 @@ +#pragma once + +#include "esphome/core/component.h" +#include "remote_base.h" + +namespace esphome { +namespace remote_base { + +struct GoboxData { + int code; + bool operator==(const GoboxData &rhs) const { return code == rhs.code; } +}; + +enum { + GOBOX_MENU = 0xaa55, + GOBOX_RETURN = 0x22dd, + GOBOX_UP = 0x0af5, + GOBOX_LEFT = 0x8a75, + GOBOX_RIGHT = 0x48b7, + GOBOX_DOWN = 0xa25d, + GOBOX_OK = 0xc837, + GOBOX_TOGGLE = 0xb847, + GOBOX_PROFILE = 0xfa05, + GOBOX_FASTER = 0xf00f, + GOBOX_SLOWER = 0xd02f, + GOBOX_LOUDER = 0xb04f, + GOBOX_SOFTER = 0xf807, +}; + +class GoboxProtocol : public RemoteProtocol { + private: + void dump_timings_(const RawTimings &timings) const; + + public: + void encode(RemoteTransmitData *dst, const GoboxData &data) override; + optional decode(RemoteReceiveData src) override; + void dump(const GoboxData &data) override; +}; + +DECLARE_REMOTE_PROTOCOL(Gobox) + +template class GoboxAction : public RemoteTransmitterActionBase { + public: + TEMPLATABLE_VALUE(uint64_t, code); + + void encode(RemoteTransmitData *dst, Ts... x) override { + GoboxData data{}; + data.code = this->code_.value(x...); + GoboxProtocol().encode(dst, data); + } +}; + +} // namespace remote_base +} // namespace esphome diff --git a/esphome/components/remote_receiver/remote_receiver_esp32.cpp b/esphome/components/remote_receiver/remote_receiver_esp32.cpp index 2b6032cdf2..a8ee186d70 100644 --- a/esphome/components/remote_receiver/remote_receiver_esp32.cpp +++ b/esphome/components/remote_receiver/remote_receiver_esp32.cpp @@ -208,7 +208,6 @@ void RemoteReceiverComponent::loop() { this->store_.buffer_read = next_read; if (!this->temp_.empty()) { - this->temp_.push_back(-this->idle_us_); this->call_listeners_dumpers_(); } } @@ -219,11 +218,9 @@ void RemoteReceiverComponent::loop() { this->decode_rmt_(item, len / sizeof(rmt_item32_t)); vRingbufferReturnItem(this->ringbuf_, item); - if (this->temp_.empty()) - return; - - this->temp_.push_back(-this->idle_us_); - this->call_listeners_dumpers_(); + if (!this->temp_.empty()) { + this->call_listeners_dumpers_(); + } } #endif } @@ -234,6 +231,7 @@ void RemoteReceiverComponent::decode_rmt_(rmt_symbol_word_t *item, size_t item_c void RemoteReceiverComponent::decode_rmt_(rmt_item32_t *item, size_t item_count) { #endif bool prev_level = false; + bool idle_level = false; uint32_t prev_length = 0; this->temp_.clear(); int32_t multiplier = this->pin_->is_inverted() ? -1 : 1; @@ -266,7 +264,7 @@ void RemoteReceiverComponent::decode_rmt_(rmt_item32_t *item, size_t item_count) } else if ((bool(item[i].level0) == prev_level) || (item[i].duration0 < filter_ticks)) { prev_length += item[i].duration0; } else { - if (prev_length > 0) { + if (prev_length >= filter_ticks) { if (prev_level) { this->temp_.push_back(this->to_microseconds_(prev_length) * multiplier); } else { @@ -276,6 +274,7 @@ void RemoteReceiverComponent::decode_rmt_(rmt_item32_t *item, size_t item_count) prev_level = bool(item[i].level0); prev_length = item[i].duration0; } + idle_level = !bool(item[i].level0); if (item[i].duration1 == 0u) { // EOF, sometimes garbage follows, break early @@ -283,7 +282,7 @@ void RemoteReceiverComponent::decode_rmt_(rmt_item32_t *item, size_t item_count) } else if ((bool(item[i].level1) == prev_level) || (item[i].duration1 < filter_ticks)) { prev_length += item[i].duration1; } else { - if (prev_length > 0) { + if (prev_length >= filter_ticks) { if (prev_level) { this->temp_.push_back(this->to_microseconds_(prev_length) * multiplier); } else { @@ -293,14 +292,22 @@ void RemoteReceiverComponent::decode_rmt_(rmt_item32_t *item, size_t item_count) prev_level = bool(item[i].level1); prev_length = item[i].duration1; } + idle_level = !bool(item[i].level1); } - if (prev_length > 0) { + if (prev_length >= filter_ticks && prev_level != idle_level) { if (prev_level) { this->temp_.push_back(this->to_microseconds_(prev_length) * multiplier); } else { this->temp_.push_back(-int32_t(this->to_microseconds_(prev_length)) * multiplier); } } + if (!this->temp_.empty()) { + if (idle_level) { + this->temp_.push_back(this->idle_us_ * multiplier); + } else { + this->temp_.push_back(-int32_t(this->idle_us_) * multiplier); + } + } } } // namespace remote_receiver diff --git a/esphome/components/resampler/speaker/resampler_speaker.cpp b/esphome/components/resampler/speaker/resampler_speaker.cpp index 9bb46ad78c..5e5615cbb9 100644 --- a/esphome/components/resampler/speaker/resampler_speaker.cpp +++ b/esphome/components/resampler/speaker/resampler_speaker.cpp @@ -43,13 +43,18 @@ void ResamplerSpeaker::setup() { return; } - this->output_speaker_->add_audio_output_callback( - [this](uint32_t new_playback_ms, uint32_t remainder_us, uint32_t pending_ms, uint32_t write_timestamp) { - int32_t adjustment = this->playback_differential_ms_; - this->playback_differential_ms_ -= adjustment; - int32_t adjusted_playback_ms = static_cast(new_playback_ms) + adjustment; - this->audio_output_callback_(adjusted_playback_ms, remainder_us, pending_ms, write_timestamp); - }); + this->output_speaker_->add_audio_output_callback([this](uint32_t new_frames, int64_t write_timestamp) { + if (this->audio_stream_info_.get_sample_rate() != this->target_stream_info_.get_sample_rate()) { + // Convert the number of frames from the target sample rate to the source sample rate. Track the remainder to + // avoid losing frames from integer division truncation. + const uint64_t numerator = new_frames * this->audio_stream_info_.get_sample_rate() + this->callback_remainder_; + const uint64_t denominator = this->target_stream_info_.get_sample_rate(); + this->callback_remainder_ = numerator % denominator; + this->audio_output_callback_(numerator / denominator, write_timestamp); + } else { + this->audio_output_callback_(new_frames, write_timestamp); + } + }); } void ResamplerSpeaker::loop() { @@ -283,7 +288,6 @@ void ResamplerSpeaker::resample_task(void *params) { xEventGroupSetBits(this_resampler->event_group_, ResamplingEventGroupBits::ERR_ESP_NOT_SUPPORTED); } - this_resampler->playback_differential_ms_ = 0; while (err == ESP_OK) { uint32_t event_bits = xEventGroupGetBits(this_resampler->event_group_); @@ -295,8 +299,6 @@ void ResamplerSpeaker::resample_task(void *params) { int32_t ms_differential = 0; audio::AudioResamplerState resampler_state = resampler->resample(false, &ms_differential); - this_resampler->playback_differential_ms_ += ms_differential; - if (resampler_state == audio::AudioResamplerState::FINISHED) { break; } else if (resampler_state == audio::AudioResamplerState::FAILED) { diff --git a/esphome/components/resampler/speaker/resampler_speaker.h b/esphome/components/resampler/speaker/resampler_speaker.h index d5e3f2b6d6..51790069d2 100644 --- a/esphome/components/resampler/speaker/resampler_speaker.h +++ b/esphome/components/resampler/speaker/resampler_speaker.h @@ -100,7 +100,7 @@ class ResamplerSpeaker : public Component, public speaker::Speaker { uint32_t buffer_duration_ms_; - int32_t playback_differential_ms_{0}; + uint64_t callback_remainder_{0}; }; } // namespace resampler diff --git a/esphome/components/scd30/sensor.py b/esphome/components/scd30/sensor.py index 83fb9738ec..fb3ad713bb 100644 --- a/esphome/components/scd30/sensor.py +++ b/esphome/components/scd30/sensor.py @@ -18,6 +18,8 @@ from esphome.const import ( UNIT_CELSIUS, UNIT_PARTS_PER_MILLION, UNIT_PERCENT, + CONF_AUTOMATIC_SELF_CALIBRATION, + CONF_AMBIENT_PRESSURE_COMPENSATION, ) DEPENDENCIES = ["i2c"] @@ -33,10 +35,7 @@ ForceRecalibrationWithReference = scd30_ns.class_( "ForceRecalibrationWithReference", automation.Action ) -CONF_AUTOMATIC_SELF_CALIBRATION = "automatic_self_calibration" CONF_ALTITUDE_COMPENSATION = "altitude_compensation" -CONF_AMBIENT_PRESSURE_COMPENSATION = "ambient_pressure_compensation" - CONFIG_SCHEMA = ( cv.Schema( diff --git a/esphome/components/scd4x/sensor.py b/esphome/components/scd4x/sensor.py index f050c3ec34..f753f54c3b 100644 --- a/esphome/components/scd4x/sensor.py +++ b/esphome/components/scd4x/sensor.py @@ -20,6 +20,10 @@ from esphome.const import ( UNIT_CELSIUS, UNIT_PARTS_PER_MILLION, UNIT_PERCENT, + CONF_AUTOMATIC_SELF_CALIBRATION, + CONF_AMBIENT_PRESSURE_COMPENSATION, + CONF_AMBIENT_PRESSURE_COMPENSATION_SOURCE, + CONF_MEASUREMENT_MODE, ) CODEOWNERS = ["@sjtrny", "@martgras"] @@ -47,11 +51,6 @@ FactoryResetAction = scd4x_ns.class_("FactoryResetAction", automation.Action) CONF_ALTITUDE_COMPENSATION = "altitude_compensation" -CONF_AMBIENT_PRESSURE_COMPENSATION = "ambient_pressure_compensation" -CONF_AMBIENT_PRESSURE_COMPENSATION_SOURCE = "ambient_pressure_compensation_source" -CONF_AUTOMATIC_SELF_CALIBRATION = "automatic_self_calibration" -CONF_MEASUREMENT_MODE = "measurement_mode" - CONFIG_SCHEMA = ( cv.Schema( diff --git a/esphome/components/sdp3x/sensor.py b/esphome/components/sdp3x/sensor.py index 67f3f9561f..7cda2779ce 100644 --- a/esphome/components/sdp3x/sensor.py +++ b/esphome/components/sdp3x/sensor.py @@ -5,6 +5,7 @@ from esphome.const import ( DEVICE_CLASS_PRESSURE, STATE_CLASS_MEASUREMENT, UNIT_HECTOPASCAL, + CONF_MEASUREMENT_MODE, ) DEPENDENCIES = ["i2c"] @@ -22,7 +23,7 @@ MEASUREMENT_MODE = { "mass_flow": MeasurementMode.MASS_FLOW_AVG, "differential_pressure": MeasurementMode.DP_AVG, } -CONF_MEASUREMENT_MODE = "measurement_mode" + CONFIG_SCHEMA = ( sensor.sensor_schema( diff --git a/esphome/components/select/__init__.py b/esphome/components/select/__init__.py index 5a3271fdfd..ecbba8677b 100644 --- a/esphome/components/select/__init__.py +++ b/esphome/components/select/__init__.py @@ -48,7 +48,7 @@ SELECT_OPERATION_OPTIONS = { } -SELECT_SCHEMA = ( +_SELECT_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( @@ -64,29 +64,28 @@ SELECT_SCHEMA = ( ) ) -_UNDEF = object() - def select_schema( - class_: MockObjClass = _UNDEF, + class_: MockObjClass, *, - entity_category: str = _UNDEF, - icon: str = _UNDEF, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, ): - schema = cv.Schema({}) - if class_ is not _UNDEF: - schema = schema.extend({cv.GenerateID(): cv.declare_id(class_)}) - if entity_category is not _UNDEF: - schema = schema.extend( - { - cv.Optional( - CONF_ENTITY_CATEGORY, default=entity_category - ): cv.entity_category - } - ) - if icon is not _UNDEF: - schema = schema.extend({cv.Optional(CONF_ICON, default=icon): cv.icon}) - return SELECT_SCHEMA.extend(schema) + schema = {cv.GenerateID(): cv.declare_id(class_)} + + for key, default, validator in [ + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _SELECT_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +SELECT_SCHEMA = select_schema(Select) +SELECT_SCHEMA.add_extra(cv.deprecated_schema_constant("select")) async def setup_select_core_(var, config, *, options: list[str]): diff --git a/esphome/components/sen5x/sensor.py b/esphome/components/sen5x/sensor.py index a8a796853e..f52de5fe85 100644 --- a/esphome/components/sen5x/sensor.py +++ b/esphome/components/sen5x/sensor.py @@ -4,6 +4,7 @@ import esphome.codegen as cg from esphome.components import i2c, sensirion_common, sensor import esphome.config_validation as cv from esphome.const import ( + CONF_GAIN_FACTOR, CONF_HUMIDITY, CONF_ID, CONF_OFFSET, @@ -43,7 +44,6 @@ RhtAccelerationMode = sen5x_ns.enum("RhtAccelerationMode") CONF_ACCELERATION_MODE = "acceleration_mode" CONF_ALGORITHM_TUNING = "algorithm_tuning" CONF_AUTO_CLEANING_INTERVAL = "auto_cleaning_interval" -CONF_GAIN_FACTOR = "gain_factor" CONF_GATING_MAX_DURATION_MINUTES = "gating_max_duration_minutes" CONF_INDEX_OFFSET = "index_offset" CONF_LEARNING_TIME_GAIN_HOURS = "learning_time_gain_hours" diff --git a/esphome/components/sensor/__init__.py b/esphome/components/sensor/__init__.py index 9dbad27102..051098f6e4 100644 --- a/esphome/components/sensor/__init__.py +++ b/esphome/components/sensor/__init__.py @@ -1,3 +1,4 @@ +import logging import math from esphome import automation @@ -9,6 +10,7 @@ from esphome.const import ( CONF_ACCURACY_DECIMALS, CONF_ALPHA, CONF_BELOW, + CONF_CALIBRATION, CONF_DEVICE_CLASS, CONF_ENTITY_CATEGORY, CONF_EXPIRE_AFTER, @@ -30,6 +32,7 @@ from esphome.const import ( CONF_SEND_EVERY, CONF_SEND_FIRST_AT, CONF_STATE_CLASS, + CONF_TEMPERATURE, CONF_TIMEOUT, CONF_TO, CONF_TRIGGER_ID, @@ -153,6 +156,8 @@ DEVICE_CLASSES = [ DEVICE_CLASS_WIND_SPEED, ] +_LOGGER = logging.getLogger(__name__) + sensor_ns = cg.esphome_ns.namespace("sensor") StateClasses = sensor_ns.enum("StateClass") STATE_CLASSES = { @@ -246,6 +251,8 @@ HeartbeatFilter = sensor_ns.class_("HeartbeatFilter", Filter, cg.Component) DeltaFilter = sensor_ns.class_("DeltaFilter", Filter) OrFilter = sensor_ns.class_("OrFilter", Filter) CalibrateLinearFilter = sensor_ns.class_("CalibrateLinearFilter", Filter) +ToNTCResistanceFilter = sensor_ns.class_("ToNTCResistanceFilter", Filter) +ToNTCTemperatureFilter = sensor_ns.class_("ToNTCTemperatureFilter", Filter) CalibratePolynomialFilter = sensor_ns.class_("CalibratePolynomialFilter", Filter) SensorInRangeCondition = sensor_ns.class_("SensorInRangeCondition", Filter) ClampFilter = sensor_ns.class_("ClampFilter", Filter) @@ -257,7 +264,7 @@ validate_accuracy_decimals = cv.int_ validate_icon = cv.icon validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_") -SENSOR_SCHEMA = ( +_SENSOR_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMPONENT_SCHEMA) .extend( @@ -302,22 +309,20 @@ SENSOR_SCHEMA = ( ) ) -_UNDEF = object() - def sensor_schema( - class_: MockObjClass = _UNDEF, + class_: MockObjClass = cv.UNDEFINED, *, - unit_of_measurement: str = _UNDEF, - icon: str = _UNDEF, - accuracy_decimals: int = _UNDEF, - device_class: str = _UNDEF, - state_class: str = _UNDEF, - entity_category: str = _UNDEF, + unit_of_measurement: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, + accuracy_decimals: int = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, + state_class: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, ) -> cv.Schema: schema = {} - if class_ is not _UNDEF: + if class_ is not cv.UNDEFINED: # Not optional. schema[cv.GenerateID()] = cv.declare_id(class_) @@ -329,10 +334,15 @@ def sensor_schema( (CONF_STATE_CLASS, state_class, validate_state_class), (CONF_ENTITY_CATEGORY, entity_category, sensor_entity_category), ]: - if default is not _UNDEF: + if default is not cv.UNDEFINED: schema[cv.Optional(key, default=default)] = validator - return SENSOR_SCHEMA.extend(schema) + return _SENSOR_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +SENSOR_SCHEMA = sensor_schema() +SENSOR_SCHEMA.add_extra(cv.deprecated_schema_constant("sensor")) @FILTER_REGISTRY.register("offset", OffsetFilter, cv.templatable(cv.float_)) @@ -804,7 +814,9 @@ async def setup_sensor_core_(var, config): mqtt_ = cg.new_Pvariable(mqtt_id, var) await mqtt.register_mqtt_component(mqtt_, config) - if (expire_after := config.get(CONF_EXPIRE_AFTER, _UNDEF)) is not _UNDEF: + if ( + expire_after := config.get(CONF_EXPIRE_AFTER, cv.UNDEFINED) + ) is not cv.UNDEFINED: if expire_after is None: cg.add(mqtt_.disable_expire_after()) else: @@ -852,6 +864,138 @@ async def sensor_in_range_to_code(config, condition_id, template_arg, args): return var +def validate_ntc_calibration_parameter(value): + if isinstance(value, dict): + return cv.Schema( + { + cv.Required(CONF_TEMPERATURE): cv.temperature, + cv.Required(CONF_VALUE): cv.resistance, + } + )(value) + + value = cv.string(value) + parts = value.split("->") + if len(parts) != 2: + raise cv.Invalid("Calibration parameter must be of form 3000 -> 23°C") + resistance = cv.resistance(parts[0].strip()) + temperature = cv.temperature(parts[1].strip()) + return validate_ntc_calibration_parameter( + { + CONF_TEMPERATURE: temperature, + CONF_VALUE: resistance, + } + ) + + +CONF_A = "a" +CONF_B = "b" +CONF_C = "c" +ZERO_POINT = 273.15 + + +def ntc_calc_steinhart_hart(value): + r1 = value[0][CONF_VALUE] + r2 = value[1][CONF_VALUE] + r3 = value[2][CONF_VALUE] + t1 = value[0][CONF_TEMPERATURE] + ZERO_POINT + t2 = value[1][CONF_TEMPERATURE] + ZERO_POINT + t3 = value[2][CONF_TEMPERATURE] + ZERO_POINT + + l1 = math.log(r1) + l2 = math.log(r2) + l3 = math.log(r3) + + y1 = 1 / t1 + y2 = 1 / t2 + y3 = 1 / t3 + + g2 = (y2 - y1) / (l2 - l1) + g3 = (y3 - y1) / (l3 - l1) + + c = (g3 - g2) / (l3 - l2) * 1 / (l1 + l2 + l3) + b = g2 - c * (l1 * l1 + l1 * l2 + l2 * l2) + a = y1 - (b + l1 * l1 * c) * l1 + return a, b, c + + +def ntc_get_abc(value): + a = value[CONF_A] + b = value[CONF_B] + c = value[CONF_C] + return a, b, c + + +def ntc_process_calibration(value): + if isinstance(value, dict): + value = cv.Schema( + { + cv.Required(CONF_A): cv.float_, + cv.Required(CONF_B): cv.float_, + cv.Required(CONF_C): cv.float_, + } + )(value) + a, b, c = ntc_get_abc(value) + elif isinstance(value, list): + if len(value) != 3: + raise cv.Invalid( + "Steinhart–Hart Calibration must consist of exactly three values" + ) + value = cv.Schema([validate_ntc_calibration_parameter])(value) + a, b, c = ntc_calc_steinhart_hart(value) + else: + raise cv.Invalid( + f"Calibration parameter accepts either a list for steinhart-hart calibration, or mapping for b-constant calibration, not {type(value)}" + ) + _LOGGER.info("Coefficient: a:%s, b:%s, c:%s", a, b, c) + return { + CONF_A: a, + CONF_B: b, + CONF_C: c, + } + + +@FILTER_REGISTRY.register( + "to_ntc_resistance", + ToNTCResistanceFilter, + cv.All( + cv.Schema( + { + cv.Required(CONF_CALIBRATION): ntc_process_calibration, + } + ), + ), +) +async def calibrate_ntc_resistance_filter_to_code(config, filter_id): + calib = config[CONF_CALIBRATION] + return cg.new_Pvariable( + filter_id, + calib[CONF_A], + calib[CONF_B], + calib[CONF_C], + ) + + +@FILTER_REGISTRY.register( + "to_ntc_temperature", + ToNTCTemperatureFilter, + cv.All( + cv.Schema( + { + cv.Required(CONF_CALIBRATION): ntc_process_calibration, + } + ), + ), +) +async def calibrate_ntc_temperature_filter_to_code(config, filter_id): + calib = config[CONF_CALIBRATION] + return cg.new_Pvariable( + filter_id, + calib[CONF_A], + calib[CONF_B], + calib[CONF_C], + ) + + def _mean(xs): return sum(xs) / len(xs) diff --git a/esphome/components/sensor/filter.cpp b/esphome/components/sensor/filter.cpp index 0a8740dd5b..ce23c1f800 100644 --- a/esphome/components/sensor/filter.cpp +++ b/esphome/components/sensor/filter.cpp @@ -481,5 +481,28 @@ optional RoundMultipleFilter::new_value(float value) { return value; } +optional ToNTCResistanceFilter::new_value(float value) { + if (!std::isfinite(value)) { + return NAN; + } + double k = 273.15; + // https://de.wikipedia.org/wiki/Steinhart-Hart-Gleichung#cite_note-stein2_s4-3 + double t = value + k; + double y = (this->a_ - 1 / (t)) / (2 * this->c_); + double x = sqrt(pow(this->b_ / (3 * this->c_), 3) + y * y); + double resistance = exp(pow(x - y, 1 / 3.0) - pow(x + y, 1 / 3.0)); + return resistance; +} + +optional ToNTCTemperatureFilter::new_value(float value) { + if (!std::isfinite(value)) { + return NAN; + } + double lr = log(double(value)); + double v = this->a_ + this->b_ * lr + this->c_ * lr * lr * lr; + double temp = float(1.0 / v - 273.15); + return temp; +} + } // namespace sensor } // namespace esphome diff --git a/esphome/components/sensor/filter.h b/esphome/components/sensor/filter.h index 86586b458d..3cfaebb708 100644 --- a/esphome/components/sensor/filter.h +++ b/esphome/components/sensor/filter.h @@ -439,5 +439,27 @@ class RoundMultipleFilter : public Filter { float multiple_; }; +class ToNTCResistanceFilter : public Filter { + public: + ToNTCResistanceFilter(double a, double b, double c) : a_(a), b_(b), c_(c) {} + optional new_value(float value) override; + + protected: + double a_; + double b_; + double c_; +}; + +class ToNTCTemperatureFilter : public Filter { + public: + ToNTCTemperatureFilter(double a, double b, double c) : a_(a), b_(b), c_(c) {} + optional new_value(float value) override; + + protected: + double a_; + double b_; + double c_; +}; + } // namespace sensor } // namespace esphome diff --git a/esphome/components/sgp4x/sensor.py b/esphome/components/sgp4x/sensor.py index 9317187df3..4f29248881 100644 --- a/esphome/components/sgp4x/sensor.py +++ b/esphome/components/sgp4x/sensor.py @@ -3,6 +3,7 @@ from esphome.components import i2c, sensirion_common, sensor import esphome.config_validation as cv from esphome.const import ( CONF_COMPENSATION, + CONF_GAIN_FACTOR, CONF_ID, CONF_STORE_BASELINE, CONF_TEMPERATURE_SOURCE, @@ -24,7 +25,6 @@ SGP4xComponent = sgp4x_ns.class_( ) CONF_ALGORITHM_TUNING = "algorithm_tuning" -CONF_GAIN_FACTOR = "gain_factor" CONF_GATING_MAX_DURATION_MINUTES = "gating_max_duration_minutes" CONF_HUMIDITY_SOURCE = "humidity_source" CONF_INDEX_OFFSET = "index_offset" diff --git a/esphome/components/sht4x/sht4x.cpp b/esphome/components/sht4x/sht4x.cpp index dea542ea9e..e4fa16d87a 100644 --- a/esphome/components/sht4x/sht4x.cpp +++ b/esphome/components/sht4x/sht4x.cpp @@ -12,14 +12,22 @@ void SHT4XComponent::start_heater_() { uint8_t cmd[] = {MEASURECOMMANDS[this->heater_command_]}; ESP_LOGD(TAG, "Heater turning on"); - this->write(cmd, 1); + if (this->write(cmd, 1) != i2c::ERROR_OK) { + this->status_set_error("Failed to turn on heater"); + } } void SHT4XComponent::setup() { ESP_LOGCONFIG(TAG, "Setting up sht4x..."); - if (this->duty_cycle_ > 0.0) { - uint32_t heater_interval = (uint32_t) (this->heater_time_ / this->duty_cycle_); + auto err = this->write(nullptr, 0); + if (err != i2c::ERROR_OK) { + this->mark_failed(); + return; + } + + if (std::isfinite(this->duty_cycle_) && this->duty_cycle_ > 0.0f) { + uint32_t heater_interval = static_cast(static_cast(this->heater_time_) / this->duty_cycle_); ESP_LOGD(TAG, "Heater interval: %" PRIu32, heater_interval); if (this->heater_power_ == SHT4X_HEATERPOWER_HIGH) { @@ -47,37 +55,50 @@ void SHT4XComponent::setup() { } } -void SHT4XComponent::dump_config() { LOG_I2C_DEVICE(this); } +void SHT4XComponent::dump_config() { + ESP_LOGCONFIG(TAG, "SHT4x:"); + LOG_I2C_DEVICE(this); + if (this->is_failed()) { + ESP_LOGE(TAG, "Communication with SHT4x failed!"); + } +} void SHT4XComponent::update() { // Send command - this->write_command(MEASURECOMMANDS[this->precision_]); + if (!this->write_command(MEASURECOMMANDS[this->precision_])) { + // Warning will be printed only if warning status is not set yet + this->status_set_warning("Failed to send measurement command"); + return; + } this->set_timeout(10, [this]() { uint16_t buffer[2]; // Read measurement - bool read_status = this->read_data(buffer, 2); + if (!this->read_data(buffer, 2)) { + // Using ESP_LOGW to force the warning to be printed + ESP_LOGW(TAG, "Sensor read failed"); + this->status_set_warning(); + return; + } - if (read_status) { - // Evaluate and publish measurements - if (this->temp_sensor_ != nullptr) { - // Temp is contained in the first result word - float sensor_value_temp = buffer[0]; - float temp = -45 + 175 * sensor_value_temp / 65535; + this->status_clear_warning(); - this->temp_sensor_->publish_state(temp); - } + // Evaluate and publish measurements + if (this->temp_sensor_ != nullptr) { + // Temp is contained in the first result word + float sensor_value_temp = buffer[0]; + float temp = -45 + 175 * sensor_value_temp / 65535; - if (this->humidity_sensor_ != nullptr) { - // Relative humidity is in the second result word - float sensor_value_rh = buffer[1]; - float rh = -6 + 125 * sensor_value_rh / 65535; + this->temp_sensor_->publish_state(temp); + } - this->humidity_sensor_->publish_state(rh); - } - } else { - ESP_LOGD(TAG, "Sensor read failed"); + if (this->humidity_sensor_ != nullptr) { + // Relative humidity is in the second result word + float sensor_value_rh = buffer[1]; + float rh = -6 + 125 * sensor_value_rh / 65535; + + this->humidity_sensor_->publish_state(rh); } }); } diff --git a/esphome/components/sht4x/sht4x.h b/esphome/components/sht4x/sht4x.h index 46037bb0e8..98e0629b50 100644 --- a/esphome/components/sht4x/sht4x.h +++ b/esphome/components/sht4x/sht4x.h @@ -13,7 +13,7 @@ enum SHT4XPRECISION { SHT4X_PRECISION_HIGH = 0, SHT4X_PRECISION_MED, SHT4X_PRECI enum SHT4XHEATERPOWER { SHT4X_HEATERPOWER_HIGH, SHT4X_HEATERPOWER_MED, SHT4X_HEATERPOWER_LOW }; -enum SHT4XHEATERTIME { SHT4X_HEATERTIME_LONG = 1100, SHT4X_HEATERTIME_SHORT = 110 }; +enum SHT4XHEATERTIME : uint16_t { SHT4X_HEATERTIME_LONG = 1100, SHT4X_HEATERTIME_SHORT = 110 }; class SHT4XComponent : public PollingComponent, public sensirion_common::SensirionI2CDevice { public: diff --git a/esphome/components/sml/text_sensor/__init__.py b/esphome/components/sml/text_sensor/__init__.py index 401db9c582..9c9da26c3a 100644 --- a/esphome/components/sml/text_sensor/__init__.py +++ b/esphome/components/sml/text_sensor/__init__.py @@ -1,7 +1,7 @@ import esphome.codegen as cg from esphome.components import text_sensor import esphome.config_validation as cv -from esphome.const import CONF_FORMAT, CONF_ID +from esphome.const import CONF_FORMAT from .. import CONF_OBIS_CODE, CONF_SERVER_ID, CONF_SML_ID, Sml, obis_code, sml_ns @@ -19,9 +19,8 @@ SML_TYPES = { SmlTextSensor = sml_ns.class_("SmlTextSensor", text_sensor.TextSensor, cg.Component) -CONFIG_SCHEMA = text_sensor.TEXT_SENSOR_SCHEMA.extend( +CONFIG_SCHEMA = text_sensor.text_sensor_schema(SmlTextSensor).extend( { - cv.GenerateID(): cv.declare_id(SmlTextSensor), cv.GenerateID(CONF_SML_ID): cv.use_id(Sml), cv.Required(CONF_OBIS_CODE): obis_code, cv.Optional(CONF_SERVER_ID, default=""): cv.string, @@ -31,13 +30,12 @@ CONFIG_SCHEMA = text_sensor.TEXT_SENSOR_SCHEMA.extend( async def to_code(config): - var = cg.new_Pvariable( - config[CONF_ID], + var = await text_sensor.new_text_sensor( + config, config[CONF_SERVER_ID], config[CONF_OBIS_CODE], config[CONF_FORMAT], ) await cg.register_component(var, config) - await text_sensor.register_text_sensor(var, config) sml = await cg.get_variable(config[CONF_SML_ID]) cg.add(sml.register_sml_listener(var)) diff --git a/esphome/components/sound_level/__init__.py b/esphome/components/sound_level/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/esphome/components/sound_level/sensor.py b/esphome/components/sound_level/sensor.py new file mode 100644 index 0000000000..292efadab8 --- /dev/null +++ b/esphome/components/sound_level/sensor.py @@ -0,0 +1,97 @@ +from esphome import automation +import esphome.codegen as cg +from esphome.components import microphone, sensor +import esphome.config_validation as cv +from esphome.const import ( + CONF_ID, + CONF_MEASUREMENT_DURATION, + CONF_MICROPHONE, + DEVICE_CLASS_SOUND_PRESSURE, + PLATFORM_ESP32, + STATE_CLASS_MEASUREMENT, + UNIT_DECIBEL, +) + +AUTOLOAD = ["audio"] +CODEOWNERS = ["@kahrendt"] +DEPENDENCIES = ["microphone"] + + +CONF_PASSIVE = "passive" +CONF_PEAK = "peak" +CONF_RMS = "rms" + +sound_level_ns = cg.esphome_ns.namespace("sound_level") +SoundLevelComponent = sound_level_ns.class_("SoundLevelComponent", cg.Component) + +StartAction = sound_level_ns.class_("StartAction", automation.Action) +StopAction = sound_level_ns.class_("StopAction", automation.Action) + +CONFIG_SCHEMA = cv.All( + cv.Schema( + { + cv.GenerateID(): cv.declare_id(SoundLevelComponent), + cv.Optional(CONF_MEASUREMENT_DURATION, default="1000ms"): cv.All( + cv.positive_time_period_milliseconds, + cv.Range( + min=cv.TimePeriod(milliseconds=50), + max=cv.TimePeriod(seconds=60), + ), + ), + cv.Optional( + CONF_MICROPHONE, default={} + ): microphone.microphone_source_schema( + min_bits_per_sample=16, + max_bits_per_sample=16, + ), + cv.Required(CONF_PASSIVE): cv.boolean, + cv.Optional(CONF_PEAK): sensor.sensor_schema( + unit_of_measurement=UNIT_DECIBEL, + accuracy_decimals=1, + device_class=DEVICE_CLASS_SOUND_PRESSURE, + state_class=STATE_CLASS_MEASUREMENT, + ), + cv.Optional(CONF_RMS): sensor.sensor_schema( + unit_of_measurement=UNIT_DECIBEL, + accuracy_decimals=1, + device_class=DEVICE_CLASS_SOUND_PRESSURE, + state_class=STATE_CLASS_MEASUREMENT, + ), + } + ).extend(cv.COMPONENT_SCHEMA), + cv.only_on([PLATFORM_ESP32]), +) + + +async def to_code(config): + var = cg.new_Pvariable(config[CONF_ID]) + await cg.register_component(var, config) + + mic_source = await microphone.microphone_source_to_code( + config[CONF_MICROPHONE], passive=config[CONF_PASSIVE] + ) + cg.add(var.set_microphone_source(mic_source)) + + cg.add(var.set_measurement_duration(config[CONF_MEASUREMENT_DURATION])) + + if peak_config := config.get(CONF_PEAK): + sens = await sensor.new_sensor(peak_config) + cg.add(var.set_peak_sensor(sens)) + if rms_config := config.get(CONF_RMS): + sens = await sensor.new_sensor(rms_config) + cg.add(var.set_rms_sensor(sens)) + + +SOUND_LEVEL_ACTION_SCHEMA = automation.maybe_simple_id( + { + cv.GenerateID(): cv.use_id(SoundLevelComponent), + } +) + + +@automation.register_action("sound_level.start", StartAction, SOUND_LEVEL_ACTION_SCHEMA) +@automation.register_action("sound_level.stop", StopAction, SOUND_LEVEL_ACTION_SCHEMA) +async def sound_level_action_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + return var diff --git a/esphome/components/sound_level/sound_level.cpp b/esphome/components/sound_level/sound_level.cpp new file mode 100644 index 0000000000..f8447ce436 --- /dev/null +++ b/esphome/components/sound_level/sound_level.cpp @@ -0,0 +1,194 @@ +#include "sound_level.h" + +#ifdef USE_ESP32 + +#include "esphome/core/log.h" + +#include +#include + +namespace esphome { +namespace sound_level { + +static const char *const TAG = "sound_level"; + +static const uint32_t AUDIO_BUFFER_DURATION_MS = 30; +static const uint32_t RING_BUFFER_DURATION_MS = 120; + +// Square INT16_MIN since INT16_MIN^2 > INT16_MAX^2 +static const double MAX_SAMPLE_SQUARED_DENOMINATOR = INT16_MIN * INT16_MIN; + +void SoundLevelComponent::dump_config() { + ESP_LOGCONFIG(TAG, "Sound Level Component:"); + ESP_LOGCONFIG(TAG, " Measurement Duration: %" PRIu32 " ms", measurement_duration_ms_); + LOG_SENSOR(" ", "Peak:", this->peak_sensor_); + + LOG_SENSOR(" ", "RMS:", this->rms_sensor_); +} + +void SoundLevelComponent::setup() { + this->microphone_source_->add_data_callback([this](const std::vector &data) { + std::shared_ptr temp_ring_buffer = this->ring_buffer_.lock(); + if (this->ring_buffer_.use_count() == 2) { + // ``audio_buffer_`` and ``temp_ring_buffer`` share ownership of a ring buffer, so its safe/useful to write + temp_ring_buffer->write((void *) data.data(), data.size()); + } + }); + + if (!this->microphone_source_->is_passive()) { + // Automatically start the microphone if not in passive mode + this->microphone_source_->start(); + } +} + +void SoundLevelComponent::loop() { + if ((this->peak_sensor_ == nullptr) && (this->rms_sensor_ == nullptr)) { + // No sensors configured, nothing to do + return; + } + + if (this->microphone_source_->is_running() && !this->status_has_error()) { + // Allocate buffers + if (this->start_()) { + this->status_clear_warning(); + } + } else { + if (!this->status_has_warning()) { + this->status_set_warning("Microphone isn't running, can't compute statistics"); + + // Deallocate buffers, if necessary + this->stop_(); + + // Reset sensor outputs + if (this->peak_sensor_ != nullptr) { + this->peak_sensor_->publish_state(NAN); + } + if (this->rms_sensor_ != nullptr) { + this->rms_sensor_->publish_state(NAN); + } + + // Reset accumulators + this->squared_peak_ = 0; + this->squared_samples_sum_ = 0; + this->sample_count_ = 0; + } + + return; + } + + if (this->status_has_error()) { + return; + } + + // Copy data from ring buffer into the transfer buffer - don't block to avoid slowing the main loop + this->audio_buffer_->transfer_data_from_source(0); + + if (this->audio_buffer_->available() == 0) { + // No new audio available for processing + return; + } + + const uint32_t samples_in_window = + this->microphone_source_->get_audio_stream_info().ms_to_samples(this->measurement_duration_ms_); + const uint32_t samples_available_to_process = + this->microphone_source_->get_audio_stream_info().bytes_to_samples(this->audio_buffer_->available()); + const uint32_t samples_to_process = std::min(samples_in_window - this->sample_count_, samples_available_to_process); + + // MicrophoneSource always provides int16 samples due to Python codegen settings + const int16_t *audio_data = reinterpret_cast(this->audio_buffer_->get_buffer_start()); + + // Process all the new audio samples + for (uint32_t i = 0; i < samples_to_process; ++i) { + // Squaring int16 samples won't overflow an int32 + int32_t squared_sample = static_cast(audio_data[i]) * static_cast(audio_data[i]); + + if (this->peak_sensor_ != nullptr) { + this->squared_peak_ = std::max(this->squared_peak_, squared_sample); + } + + if (this->rms_sensor_ != nullptr) { + // Squared sum is an uint64 type - at max levels, an uint32 type would overflow after ~8 samples + this->squared_samples_sum_ += squared_sample; + } + + ++this->sample_count_; + } + + // Remove the processed samples from ``audio_buffer_`` + this->audio_buffer_->decrease_buffer_length( + this->microphone_source_->get_audio_stream_info().samples_to_bytes(samples_to_process)); + + if (this->sample_count_ == samples_in_window) { + // Processed enough samples for the measurement window, compute and publish the sensor values + if (this->peak_sensor_ != nullptr) { + const float peak_db = 10.0f * log10(static_cast(this->squared_peak_) / MAX_SAMPLE_SQUARED_DENOMINATOR); + this->peak_sensor_->publish_state(peak_db); + + this->squared_peak_ = 0; // reset accumulator + } + + if (this->rms_sensor_ != nullptr) { + // Calculations are done with doubles instead of floats - floats lose precision for even modest window durations + const double rms_db = 10.0 * log10((this->squared_samples_sum_ / MAX_SAMPLE_SQUARED_DENOMINATOR) / + static_cast(samples_in_window)); + this->rms_sensor_->publish_state(rms_db); + + this->squared_samples_sum_ = 0; // reset accumulator + } + + this->sample_count_ = 0; // reset counter + } +} + +void SoundLevelComponent::start() { + if (this->microphone_source_->is_passive()) { + ESP_LOGW(TAG, "Can't start the microphone in passive mode"); + return; + } + this->microphone_source_->start(); +} + +void SoundLevelComponent::stop() { + if (this->microphone_source_->is_passive()) { + ESP_LOGW(TAG, "Can't stop microphone in passive mode"); + return; + } + this->microphone_source_->stop(); +} + +bool SoundLevelComponent::start_() { + if (this->audio_buffer_ != nullptr) { + return true; + } + + // Allocate a transfer buffer + this->audio_buffer_ = audio::AudioSourceTransferBuffer::create( + this->microphone_source_->get_audio_stream_info().ms_to_bytes(AUDIO_BUFFER_DURATION_MS)); + if (this->audio_buffer_ == nullptr) { + this->status_momentary_error("Failed to allocate transfer buffer", 15000); + return false; + } + + // Allocates a new ring buffer, adds it as a source for the transfer buffer, and points ring_buffer_ to it + this->ring_buffer_.reset(); // Reset pointer to any previous ring buffer allocation + std::shared_ptr temp_ring_buffer = + RingBuffer::create(this->microphone_source_->get_audio_stream_info().ms_to_bytes(RING_BUFFER_DURATION_MS)); + if (temp_ring_buffer.use_count() == 0) { + this->status_momentary_error("Failed to allocate ring buffer", 15000); + this->stop_(); + return false; + } else { + this->ring_buffer_ = temp_ring_buffer; + this->audio_buffer_->set_source(temp_ring_buffer); + } + + this->status_clear_error(); + return true; +} + +void SoundLevelComponent::stop_() { this->audio_buffer_.reset(); } + +} // namespace sound_level +} // namespace esphome + +#endif diff --git a/esphome/components/sound_level/sound_level.h b/esphome/components/sound_level/sound_level.h new file mode 100644 index 0000000000..6a80a60ac7 --- /dev/null +++ b/esphome/components/sound_level/sound_level.h @@ -0,0 +1,73 @@ +#pragma once + +#ifdef USE_ESP32 + +#include "esphome/components/audio/audio_transfer_buffer.h" +#include "esphome/components/microphone/microphone_source.h" +#include "esphome/components/sensor/sensor.h" + +#include "esphome/core/component.h" +#include "esphome/core/ring_buffer.h" + +namespace esphome { +namespace sound_level { + +class SoundLevelComponent : public Component { + public: + void dump_config() override; + void setup() override; + void loop() override; + + float get_setup_priority() const override { return setup_priority::AFTER_CONNECTION; } + + void set_measurement_duration(uint32_t measurement_duration_ms) { + this->measurement_duration_ms_ = measurement_duration_ms; + } + void set_microphone_source(microphone::MicrophoneSource *microphone_source) { + this->microphone_source_ = microphone_source; + } + void set_peak_sensor(sensor::Sensor *peak_sensor) { this->peak_sensor_ = peak_sensor; } + void set_rms_sensor(sensor::Sensor *rms_sensor) { this->rms_sensor_ = rms_sensor; } + + /// @brief Starts the MicrophoneSource to start measuring sound levels + void start(); + + /// @brief Stops the MicrophoneSource + void stop(); + + protected: + /// @brief Internal start command that, if necessary, allocates ``audio_buffer_`` and a ring buffer which + /// ``audio_buffer_`` owns and ``ring_buffer_`` points to. Returns true if allocations were successful. + bool start_(); + + /// @brief Internal stop command the deallocates ``audio_buffer_`` (which automatically deallocates its ring buffer) + void stop_(); + + microphone::MicrophoneSource *microphone_source_{nullptr}; + + sensor::Sensor *peak_sensor_{nullptr}; + sensor::Sensor *rms_sensor_{nullptr}; + + std::unique_ptr audio_buffer_; + std::weak_ptr ring_buffer_; + + int32_t squared_peak_{0}; + uint64_t squared_samples_sum_{0}; + uint32_t sample_count_{0}; + + uint32_t measurement_duration_ms_; +}; + +template class StartAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->start(); } +}; + +template class StopAction : public Action, public Parented { + public: + void play(Ts... x) override { this->parent_->stop(); } +}; + +} // namespace sound_level +} // namespace esphome +#endif diff --git a/esphome/components/speaker/media_player/__init__.py b/esphome/components/speaker/media_player/__init__.py index 14b72cacc0..35d763b1f8 100644 --- a/esphome/components/speaker/media_player/__init__.py +++ b/esphome/components/speaker/media_player/__init__.py @@ -332,14 +332,12 @@ async def to_code(config): esp32.add_idf_sdkconfig_option("CONFIG_TCP_MSS", 1436) esp32.add_idf_sdkconfig_option("CONFIG_TCP_MSL", 60000) esp32.add_idf_sdkconfig_option("CONFIG_TCP_SND_BUF_DEFAULT", 65535) - esp32.add_idf_sdkconfig_option( - "CONFIG_TCP_WND_DEFAULT", 65535 - ) # Adjusted from referenced settings to avoid compilation error + esp32.add_idf_sdkconfig_option("CONFIG_TCP_WND_DEFAULT", 512000) esp32.add_idf_sdkconfig_option("CONFIG_TCP_RECVMBOX_SIZE", 512) esp32.add_idf_sdkconfig_option("CONFIG_TCP_QUEUE_OOSEQ", True) esp32.add_idf_sdkconfig_option("CONFIG_TCP_OVERSIZE_MSS", True) esp32.add_idf_sdkconfig_option("CONFIG_LWIP_WND_SCALE", True) - esp32.add_idf_sdkconfig_option("CONFIG_TCP_RCV_SCALE", 3) + esp32.add_idf_sdkconfig_option("CONFIG_LWIP_TCP_RCV_SCALE", 3) esp32.add_idf_sdkconfig_option("CONFIG_LWIP_TCPIP_RECVMBOX_SIZE", 512) # Allocate wifi buffers in PSRAM diff --git a/esphome/components/speaker/media_player/speaker_media_player.cpp b/esphome/components/speaker/media_player/speaker_media_player.cpp index e143920010..fed0207c93 100644 --- a/esphome/components/speaker/media_player/speaker_media_player.cpp +++ b/esphome/components/speaker/media_player/speaker_media_player.cpp @@ -106,16 +106,6 @@ void SpeakerMediaPlayer::setup() { ESP_LOGE(TAG, "Failed to create media pipeline"); this->mark_failed(); } - - // Setup callback to track the duration of audio played by the media pipeline - this->media_speaker_->add_audio_output_callback( - [this](uint32_t new_playback_ms, uint32_t remainder_us, uint32_t pending_ms, uint32_t write_timestamp) { - this->playback_ms_ += new_playback_ms; - this->remainder_us_ = remainder_us; - this->pending_ms_ = pending_ms; - this->last_audio_write_timestamp_ = write_timestamp; - this->playback_us_ = this->playback_ms_ * 1000 + this->remainder_us_; - }); } ESP_LOGI(TAG, "Set up speaker media player"); @@ -321,7 +311,6 @@ void SpeakerMediaPlayer::loop() { AudioPipelineState old_media_pipeline_state = this->media_pipeline_state_; if (this->media_pipeline_ != nullptr) { this->media_pipeline_state_ = this->media_pipeline_->process_state(); - this->decoded_playback_ms_ = this->media_pipeline_->get_playback_ms(); } if (this->media_pipeline_state_ == AudioPipelineState::ERROR_READING) { @@ -379,13 +368,6 @@ void SpeakerMediaPlayer::loop() { } else if (this->media_pipeline_state_ == AudioPipelineState::PLAYING) { this->state = media_player::MEDIA_PLAYER_STATE_PLAYING; } else if (this->media_pipeline_state_ == AudioPipelineState::STOPPED) { - // Reset playback durations - this->decoded_playback_ms_ = 0; - this->playback_us_ = 0; - this->playback_ms_ = 0; - this->remainder_us_ = 0; - this->pending_ms_ = 0; - if (!media_playlist_.empty()) { uint32_t timeout_ms = 0; if (old_media_pipeline_state == AudioPipelineState::PLAYING) { diff --git a/esphome/components/speaker/media_player/speaker_media_player.h b/esphome/components/speaker/media_player/speaker_media_player.h index 81eb72a830..67e9859a13 100644 --- a/esphome/components/speaker/media_player/speaker_media_player.h +++ b/esphome/components/speaker/media_player/speaker_media_player.h @@ -73,10 +73,6 @@ class SpeakerMediaPlayer : public Component, public media_player::MediaPlayer { void play_file(audio::AudioFile *media_file, bool announcement, bool enqueue); - uint32_t get_playback_ms() const { return this->playback_ms_; } - uint32_t get_playback_us() const { return this->playback_us_; } - uint32_t get_decoded_playback_ms() const { return this->decoded_playback_ms_; } - void set_playlist_delay_ms(AudioPipelineType pipeline_type, uint32_t delay_ms); protected: @@ -141,13 +137,6 @@ class SpeakerMediaPlayer : public Component, public media_player::MediaPlayer { Trigger<> *mute_trigger_ = new Trigger<>(); Trigger<> *unmute_trigger_ = new Trigger<>(); Trigger *volume_trigger_ = new Trigger(); - - uint32_t decoded_playback_ms_{0}; - uint32_t playback_us_{0}; - uint32_t playback_ms_{0}; - uint32_t remainder_us_{0}; - uint32_t pending_ms_{0}; - uint32_t last_audio_write_timestamp_{0}; }; } // namespace speaker diff --git a/esphome/components/speaker/speaker.h b/esphome/components/speaker/speaker.h index c4cf912fa6..373d2e3a74 100644 --- a/esphome/components/speaker/speaker.h +++ b/esphome/components/speaker/speaker.h @@ -104,12 +104,9 @@ class Speaker { /// Callback function for sending the duration of the audio written to the speaker since the last callback. /// Parameters: - /// - Duration in milliseconds. Never rounded and should always be less than or equal to the actual duration. - /// - Remainder duration in microseconds. Rounded duration after subtracting the previous parameter from the actual - /// duration. - /// - Duration of remaining, unwritten audio buffered in the speaker in milliseconds. - /// - System time in microseconds when the last write was completed. - void add_audio_output_callback(std::function &&callback) { + /// - Frames played + /// - System time in microseconds when the frames were written to the DAC + void add_audio_output_callback(std::function &&callback) { this->audio_output_callback_.add(std::move(callback)); } @@ -123,7 +120,7 @@ class Speaker { audio_dac::AudioDac *audio_dac_{nullptr}; #endif - CallbackManager audio_output_callback_{}; + CallbackManager audio_output_callback_{}; }; } // namespace speaker diff --git a/esphome/components/spi/spi.h b/esphome/components/spi/spi.h index 378d95e7b9..f96d3da251 100644 --- a/esphome/components/spi/spi.h +++ b/esphome/components/spi/spi.h @@ -1,5 +1,4 @@ #pragma once - #include "esphome/core/application.h" #include "esphome/core/component.h" #include "esphome/core/hal.h" @@ -28,6 +27,11 @@ using SPIInterface = spi_host_device_t; #endif // USE_ESP_IDF +#ifdef USE_ZEPHYR +// TODO supprse clang-tidy. Remove after SPI driver for nrf52 is added. +using SPIInterface = void *; +#endif + /** * Implementation of SPI Controller mode. */ @@ -351,6 +355,12 @@ class SPIComponent : public Component { void setup() override; void dump_config() override; + size_t get_bus_width() const { + if (this->data_pins_.empty()) { + return 1; + } + return this->data_pins_.size(); + } protected: GPIOPin *clk_pin_{nullptr}; diff --git a/esphome/components/sprinkler/__init__.py b/esphome/components/sprinkler/__init__.py index 2c59309b1f..3c94d97739 100644 --- a/esphome/components/sprinkler/__init__.py +++ b/esphome/components/sprinkler/__init__.py @@ -4,7 +4,6 @@ import esphome.codegen as cg from esphome.components import number, switch import esphome.config_validation as cv from esphome.const import ( - CONF_ENTITY_CATEGORY, CONF_ID, CONF_INITIAL_VALUE, CONF_MAX_VALUE, @@ -296,12 +295,11 @@ SPRINKLER_VALVE_SCHEMA = cv.Schema( cv.Optional(CONF_PUMP_SWITCH_ID): cv.use_id(switch.Switch), cv.Optional(CONF_RUN_DURATION): cv.positive_time_period_seconds, cv.Optional(CONF_RUN_DURATION_NUMBER): cv.maybe_simple_value( - number.NUMBER_SCHEMA.extend( + number.number_schema( + SprinklerControllerNumber, entity_category=ENTITY_CATEGORY_CONFIG + ) + .extend( { - cv.GenerateID(): cv.declare_id(SprinklerControllerNumber), - cv.Optional( - CONF_ENTITY_CATEGORY, default=ENTITY_CATEGORY_CONFIG - ): cv.entity_category, cv.Optional(CONF_INITIAL_VALUE, default=900): cv.positive_int, cv.Optional(CONF_MAX_VALUE, default=86400): cv.positive_int, cv.Optional(CONF_MIN_VALUE, default=1): cv.positive_int, @@ -314,7 +312,8 @@ SPRINKLER_VALVE_SCHEMA = cv.Schema( CONF_UNIT_OF_MEASUREMENT, default=UNIT_SECOND ): cv.one_of(UNIT_MINUTE, UNIT_SECOND, lower="True"), } - ).extend(cv.COMPONENT_SCHEMA), + ) + .extend(cv.COMPONENT_SCHEMA), validate_min_max, key=CONF_NAME, ), @@ -371,12 +370,11 @@ SPRINKLER_CONTROLLER_SCHEMA = cv.Schema( cv.Optional(CONF_NEXT_PREV_IGNORE_DISABLED, default=False): cv.boolean, cv.Optional(CONF_MANUAL_SELECTION_DELAY): cv.positive_time_period_seconds, cv.Optional(CONF_MULTIPLIER_NUMBER): cv.maybe_simple_value( - number.NUMBER_SCHEMA.extend( + number.number_schema( + SprinklerControllerNumber, entity_category=ENTITY_CATEGORY_CONFIG + ) + .extend( { - cv.GenerateID(): cv.declare_id(SprinklerControllerNumber), - cv.Optional( - CONF_ENTITY_CATEGORY, default=ENTITY_CATEGORY_CONFIG - ): cv.entity_category, cv.Optional(CONF_INITIAL_VALUE, default=1): cv.positive_float, cv.Optional(CONF_MAX_VALUE, default=10): cv.positive_float, cv.Optional(CONF_MIN_VALUE, default=0): cv.positive_float, @@ -386,18 +384,18 @@ SPRINKLER_CONTROLLER_SCHEMA = cv.Schema( single=True ), } - ).extend(cv.COMPONENT_SCHEMA), + ) + .extend(cv.COMPONENT_SCHEMA), validate_min_max, key=CONF_NAME, ), cv.Optional(CONF_REPEAT): cv.positive_int, cv.Optional(CONF_REPEAT_NUMBER): cv.maybe_simple_value( - number.NUMBER_SCHEMA.extend( + number.number_schema( + SprinklerControllerNumber, entity_category=ENTITY_CATEGORY_CONFIG + ) + .extend( { - cv.GenerateID(): cv.declare_id(SprinklerControllerNumber), - cv.Optional( - CONF_ENTITY_CATEGORY, default=ENTITY_CATEGORY_CONFIG - ): cv.entity_category, cv.Optional(CONF_INITIAL_VALUE, default=0): cv.positive_int, cv.Optional(CONF_MAX_VALUE, default=10): cv.positive_int, cv.Optional(CONF_MIN_VALUE, default=0): cv.positive_int, @@ -407,7 +405,8 @@ SPRINKLER_CONTROLLER_SCHEMA = cv.Schema( single=True ), } - ).extend(cv.COMPONENT_SCHEMA), + ) + .extend(cv.COMPONENT_SCHEMA), validate_min_max, key=CONF_NAME, ), diff --git a/esphome/components/switch/__init__.py b/esphome/components/switch/__init__.py index 0f159f69ec..e7445051e0 100644 --- a/esphome/components/switch/__init__.py +++ b/esphome/components/switch/__init__.py @@ -72,6 +72,9 @@ _SWITCH_SCHEMA = ( { cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTSwitchComponent), cv.Optional(CONF_INVERTED): cv.boolean, + cv.Optional(CONF_RESTORE_MODE, default="ALWAYS_OFF"): cv.enum( + RESTORE_MODES, upper=True, space="_" + ), cv.Optional(CONF_ON_TURN_ON): automation.validate_automation( { cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(SwitchTurnOnTrigger), @@ -87,57 +90,44 @@ _SWITCH_SCHEMA = ( ) ) -_UNDEF = object() - def switch_schema( - class_: MockObjClass = _UNDEF, + class_: MockObjClass, *, - entity_category: str = _UNDEF, - device_class: str = _UNDEF, - icon: str = _UNDEF, block_inverted: bool = False, - default_restore_mode: str = "ALWAYS_OFF", + default_restore_mode: str = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, ): - schema = _SWITCH_SCHEMA.extend( - { - cv.Optional(CONF_RESTORE_MODE, default=default_restore_mode): cv.enum( - RESTORE_MODES, upper=True, space="_" - ), - } - ) - if class_ is not _UNDEF: - schema = schema.extend({cv.GenerateID(): cv.declare_id(class_)}) - if entity_category is not _UNDEF: - schema = schema.extend( - { - cv.Optional( - CONF_ENTITY_CATEGORY, default=entity_category - ): cv.entity_category - } - ) - if device_class is not _UNDEF: - schema = schema.extend( - { - cv.Optional( - CONF_DEVICE_CLASS, default=device_class - ): validate_device_class - } - ) - if icon is not _UNDEF: - schema = schema.extend({cv.Optional(CONF_ICON, default=icon): cv.icon}) + schema = {cv.GenerateID(): cv.declare_id(class_)} + + for key, default, validator in [ + (CONF_DEVICE_CLASS, device_class, validate_device_class), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ( + CONF_RESTORE_MODE, + default_restore_mode, + cv.enum(RESTORE_MODES, upper=True, space="_") + if default_restore_mode is not cv.UNDEFINED + else cv.UNDEFINED, + ), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + if block_inverted: - schema = schema.extend( - { - cv.Optional(CONF_INVERTED): cv.invalid( - "Inverted is not supported for this platform!" - ) - } + schema[cv.Optional(CONF_INVERTED)] = cv.invalid( + "Inverted is not supported for this platform!" ) - return schema + + return _SWITCH_SCHEMA.extend(schema) -SWITCH_SCHEMA = switch_schema() # for compatibility +# Remove before 2025.11.0 +SWITCH_SCHEMA = switch_schema(Switch) +SWITCH_SCHEMA.add_extra(cv.deprecated_schema_constant("switch")) async def setup_switch_core_(var, config): diff --git a/esphome/components/syslog/__init__.py b/esphome/components/syslog/__init__.py new file mode 100644 index 0000000000..80b79d2040 --- /dev/null +++ b/esphome/components/syslog/__init__.py @@ -0,0 +1,41 @@ +import esphome.codegen as cg +from esphome.components import udp +from esphome.components.logger import LOG_LEVELS, is_log_level +from esphome.components.time import RealTimeClock +from esphome.components.udp import CONF_UDP_ID +import esphome.config_validation as cv +from esphome.const import CONF_ID, CONF_LEVEL, CONF_PORT, CONF_TIME_ID +from esphome.cpp_types import Component, Parented + +CODEOWNERS = ["@clydebarrow"] + +DEPENDENCIES = ["udp", "logger", "time"] + +syslog_ns = cg.esphome_ns.namespace("syslog") +Syslog = syslog_ns.class_("Syslog", Component, Parented.template(udp.UDPComponent)) + +CONF_STRIP = "strip" +CONF_FACILITY = "facility" +CONFIG_SCHEMA = udp.UDP_SCHEMA.extend( + { + cv.GenerateID(): cv.declare_id(Syslog), + cv.GenerateID(CONF_TIME_ID): cv.use_id(RealTimeClock), + cv.Optional(CONF_PORT, default=514): cv.port, + cv.Optional(CONF_LEVEL, default="DEBUG"): is_log_level, + cv.Optional(CONF_STRIP, default=True): cv.boolean, + cv.Optional(CONF_FACILITY, default=16): cv.int_range(0, 23), + } +) + + +async def to_code(config): + parent = await cg.get_variable(config[CONF_UDP_ID]) + time = await cg.get_variable(config[CONF_TIME_ID]) + cg.add(parent.set_broadcast_port(config[CONF_PORT])) + cg.add(parent.set_should_broadcast()) + level = LOG_LEVELS[config[CONF_LEVEL]] + var = cg.new_Pvariable(config[CONF_ID], level, time) + await cg.register_component(var, config) + await cg.register_parented(var, parent) + cg.add(var.set_strip(config[CONF_STRIP])) + cg.add(var.set_facility(config[CONF_FACILITY])) diff --git a/esphome/components/syslog/esphome_syslog.cpp b/esphome/components/syslog/esphome_syslog.cpp new file mode 100644 index 0000000000..9d2cda549b --- /dev/null +++ b/esphome/components/syslog/esphome_syslog.cpp @@ -0,0 +1,49 @@ +#include "esphome_syslog.h" + +#include "esphome/components/logger/logger.h" +#include "esphome/core/application.h" +#include "esphome/core/time.h" + +namespace esphome { +namespace syslog { + +// Map log levels to syslog severity using an array, indexed by ESPHome log level (1-7) +constexpr int LOG_LEVEL_TO_SYSLOG_SEVERITY[] = { + 3, // NONE + 3, // ERROR + 4, // WARN + 5, // INFO + 6, // CONFIG + 7, // DEBUG + 7, // VERBOSE + 7 // VERY_VERBOSE +}; + +void Syslog::setup() { + logger::global_logger->add_on_log_callback( + [this](int level, const char *tag, const char *message) { this->log_(level, tag, message); }); +} + +void Syslog::log_(const int level, const char *tag, const char *message) const { + if (level > this->log_level_) + return; + // Syslog PRI calculation: facility * 8 + severity + int severity = 7; + if ((unsigned) level <= 7) { + severity = LOG_LEVEL_TO_SYSLOG_SEVERITY[level]; + } + int pri = this->facility_ * 8 + severity; + auto timestamp = this->time_->now().strftime("%b %d %H:%M:%S"); + unsigned len = strlen(message); + // remove color formatting + if (this->strip_ && message[0] == 0x1B && len > 11) { + message += 7; + len -= 11; + } + + auto data = str_sprintf("<%d>%s %s %s: %.*s", pri, timestamp.c_str(), App.get_name().c_str(), tag, len, message); + this->parent_->send_packet((const uint8_t *) data.data(), data.size()); +} + +} // namespace syslog +} // namespace esphome diff --git a/esphome/components/syslog/esphome_syslog.h b/esphome/components/syslog/esphome_syslog.h new file mode 100644 index 0000000000..421a9bee73 --- /dev/null +++ b/esphome/components/syslog/esphome_syslog.h @@ -0,0 +1,27 @@ +#pragma once +#include "esphome/core/component.h" +#include "esphome/core/helpers.h" +#include "esphome/core/log.h" +#include "esphome/components/udp/udp_component.h" +#include "esphome/components/time/real_time_clock.h" + +#ifdef USE_NETWORK +namespace esphome { +namespace syslog { +class Syslog : public Component, public Parented { + public: + Syslog(int level, time::RealTimeClock *time) : log_level_(level), time_(time) {} + void setup() override; + void set_strip(bool strip) { this->strip_ = strip; } + void set_facility(int facility) { this->facility_ = facility; } + + protected: + int log_level_; + void log_(int level, const char *tag, const char *message) const; + time::RealTimeClock *time_; + bool strip_{true}; + int facility_{16}; +}; +} // namespace syslog +} // namespace esphome +#endif diff --git a/esphome/components/tca9555/tca9555.cpp b/esphome/components/tca9555/tca9555.cpp index cf0894427f..e065398c46 100644 --- a/esphome/components/tca9555/tca9555.cpp +++ b/esphome/components/tca9555/tca9555.cpp @@ -76,15 +76,20 @@ bool TCA9555Component::read_gpio_modes_() { bool TCA9555Component::digital_read_hw(uint8_t pin) { if (this->is_failed()) return false; - bool success; - uint8_t data[2]; - success = this->read_bytes(TCA9555_INPUT_PORT_REGISTER_0, data, 2); - this->input_mask_ = (uint16_t(data[1]) << 8) | (uint16_t(data[0]) << 0); - - if (!success) { + uint8_t data; + uint8_t bank_number = pin < 8 ? 0 : 1; + uint8_t register_to_read = bank_number ? TCA9555_INPUT_PORT_REGISTER_1 : TCA9555_INPUT_PORT_REGISTER_0; + if (!this->read_bytes(register_to_read, &data, 1)) { this->status_set_warning("Failed to read input register"); return false; } + uint8_t second_half = this->input_mask_ >> 8; + uint8_t first_half = this->input_mask_; + if (bank_number) { + this->input_mask_ = (data << 8) | (uint16_t(first_half) << 0); + } else { + this->input_mask_ = (uint16_t(second_half) << 8) | (data << 0); + } this->status_clear_warning(); return true; diff --git a/esphome/components/template/alarm_control_panel/__init__.py b/esphome/components/template/alarm_control_panel/__init__.py index 0f213857dc..a406c626ee 100644 --- a/esphome/components/template/alarm_control_panel/__init__.py +++ b/esphome/components/template/alarm_control_panel/__init__.py @@ -1,7 +1,7 @@ import esphome.codegen as cg from esphome.components import alarm_control_panel, binary_sensor import esphome.config_validation as cv -from esphome.const import CONF_BINARY_SENSORS, CONF_ID, CONF_INPUT, CONF_RESTORE_MODE +from esphome.const import CONF_BINARY_SENSORS, CONF_INPUT, CONF_RESTORE_MODE from .. import template_ns @@ -51,6 +51,7 @@ ALARM_SENSOR_TYPES = { "DELAYED": AlarmSensorType.ALARM_SENSOR_TYPE_DELAYED, "INSTANT": AlarmSensorType.ALARM_SENSOR_TYPE_INSTANT, "DELAYED_FOLLOWER": AlarmSensorType.ALARM_SENSOR_TYPE_DELAYED_FOLLOWER, + "INSTANT_ALWAYS": AlarmSensorType.ALARM_SENSOR_TYPE_INSTANT_ALWAYS, } @@ -76,9 +77,9 @@ TEMPLATE_ALARM_CONTROL_PANEL_BINARY_SENSOR_SCHEMA = cv.maybe_simple_value( ) TEMPLATE_ALARM_CONTROL_PANEL_SCHEMA = ( - alarm_control_panel.ALARM_CONTROL_PANEL_SCHEMA.extend( + alarm_control_panel.alarm_control_panel_schema(TemplateAlarmControlPanel) + .extend( { - cv.GenerateID(): cv.declare_id(TemplateAlarmControlPanel), cv.Optional(CONF_CODES): cv.ensure_list(cv.string_strict), cv.Optional(CONF_REQUIRES_CODE_TO_ARM): cv.boolean, cv.Optional(CONF_ARMING_HOME_TIME): cv.positive_time_period_milliseconds, @@ -99,7 +100,8 @@ TEMPLATE_ALARM_CONTROL_PANEL_SCHEMA = ( RESTORE_MODES, upper=True ), } - ).extend(cv.COMPONENT_SCHEMA) + ) + .extend(cv.COMPONENT_SCHEMA) ) CONFIG_SCHEMA = cv.All( @@ -109,9 +111,8 @@ CONFIG_SCHEMA = cv.All( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await alarm_control_panel.new_alarm_control_panel(config) await cg.register_component(var, config) - await alarm_control_panel.register_alarm_control_panel(var, config) if CONF_CODES in config: for acode in config[CONF_CODES]: cg.add(var.add_code(acode)) diff --git a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp index 99843417fa..bf1338edbe 100644 --- a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp +++ b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.cpp @@ -58,6 +58,9 @@ void TemplateAlarmControlPanel::dump_config() { case ALARM_SENSOR_TYPE_DELAYED_FOLLOWER: sensor_type = "delayed_follower"; break; + case ALARM_SENSOR_TYPE_INSTANT_ALWAYS: + sensor_type = "instant_always"; + break; case ALARM_SENSOR_TYPE_DELAYED: default: sensor_type = "delayed"; @@ -145,24 +148,25 @@ void TemplateAlarmControlPanel::loop() { continue; } - // If sensor type is of type instant - if (sensor_info.second.type == ALARM_SENSOR_TYPE_INSTANT) { - instant_sensor_not_ready = true; - break; - } - // If sensor type is of type interior follower - if (sensor_info.second.type == ALARM_SENSOR_TYPE_DELAYED_FOLLOWER) { - // Look to see if we are in the pending state - if (this->current_state_ == ACP_STATE_PENDING) { - delayed_sensor_not_ready = true; - } else { + switch (sensor_info.second.type) { + case ALARM_SENSOR_TYPE_INSTANT: instant_sensor_not_ready = true; - } - } - // If sensor type is of type delayed - if (sensor_info.second.type == ALARM_SENSOR_TYPE_DELAYED) { - delayed_sensor_not_ready = true; - break; + break; + case ALARM_SENSOR_TYPE_INSTANT_ALWAYS: + instant_sensor_not_ready = true; + future_state = ACP_STATE_TRIGGERED; + break; + case ALARM_SENSOR_TYPE_DELAYED_FOLLOWER: + // Look to see if we are in the pending state + if (this->current_state_ == ACP_STATE_PENDING) { + delayed_sensor_not_ready = true; + } else { + instant_sensor_not_ready = true; + } + break; + case ALARM_SENSOR_TYPE_DELAYED: + default: + delayed_sensor_not_ready = true; } } } diff --git a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h index 9ae69a0422..b29a48dfd7 100644 --- a/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h +++ b/esphome/components/template/alarm_control_panel/template_alarm_control_panel.h @@ -27,7 +27,8 @@ enum BinarySensorFlags : uint16_t { enum AlarmSensorType : uint16_t { ALARM_SENSOR_TYPE_DELAYED = 0, ALARM_SENSOR_TYPE_INSTANT, - ALARM_SENSOR_TYPE_DELAYED_FOLLOWER + ALARM_SENSOR_TYPE_DELAYED_FOLLOWER, + ALARM_SENSOR_TYPE_INSTANT_ALWAYS, }; #endif diff --git a/esphome/components/template/cover/__init__.py b/esphome/components/template/cover/__init__.py index 5129e6b1af..a4fb0b7021 100644 --- a/esphome/components/template/cover/__init__.py +++ b/esphome/components/template/cover/__init__.py @@ -34,31 +34,37 @@ RESTORE_MODES = { CONF_HAS_POSITION = "has_position" CONF_TOGGLE_ACTION = "toggle_action" -CONFIG_SCHEMA = cover.COVER_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(TemplateCover), - cv.Optional(CONF_LAMBDA): cv.returning_lambda, - cv.Optional(CONF_OPTIMISTIC, default=False): cv.boolean, - cv.Optional(CONF_ASSUMED_STATE, default=False): cv.boolean, - cv.Optional(CONF_HAS_POSITION, default=False): cv.boolean, - cv.Optional(CONF_OPEN_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_CLOSE_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_STOP_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_TILT_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_TILT_LAMBDA): cv.returning_lambda, - cv.Optional(CONF_TOGGLE_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_POSITION_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_RESTORE_MODE, default="RESTORE"): cv.enum( - RESTORE_MODES, upper=True - ), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + cover.cover_schema(TemplateCover) + .extend( + { + cv.Optional(CONF_LAMBDA): cv.returning_lambda, + cv.Optional(CONF_OPTIMISTIC, default=False): cv.boolean, + cv.Optional(CONF_ASSUMED_STATE, default=False): cv.boolean, + cv.Optional(CONF_HAS_POSITION, default=False): cv.boolean, + cv.Optional(CONF_OPEN_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_CLOSE_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_STOP_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_TILT_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_TILT_LAMBDA): cv.returning_lambda, + cv.Optional(CONF_TOGGLE_ACTION): automation.validate_automation( + single=True + ), + cv.Optional(CONF_POSITION_ACTION): automation.validate_automation( + single=True + ), + cv.Optional(CONF_RESTORE_MODE, default="RESTORE"): cv.enum( + RESTORE_MODES, upper=True + ), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) if CONF_LAMBDA in config: template_ = await cg.process_lambda( config[CONF_LAMBDA], [], return_type=cg.optional.template(float) diff --git a/esphome/components/template/lock/__init__.py b/esphome/components/template/lock/__init__.py index 2dcb90e038..4c74a521fa 100644 --- a/esphome/components/template/lock/__init__.py +++ b/esphome/components/template/lock/__init__.py @@ -17,17 +17,11 @@ from .. import template_ns TemplateLock = template_ns.class_("TemplateLock", lock.Lock, cg.Component) -LockState = lock.lock_ns.enum("LockState") - -LOCK_STATES = { - "LOCKED": LockState.LOCK_STATE_LOCKED, - "UNLOCKED": LockState.LOCK_STATE_UNLOCKED, - "JAMMED": LockState.LOCK_STATE_JAMMED, - "LOCKING": LockState.LOCK_STATE_LOCKING, - "UNLOCKING": LockState.LOCK_STATE_UNLOCKING, -} - -validate_lock_state = cv.enum(LOCK_STATES, upper=True) +TemplateLockPublishAction = template_ns.class_( + "TemplateLockPublishAction", + automation.Action, + cg.Parented.template(TemplateLock), +) def validate(config): @@ -42,9 +36,9 @@ def validate(config): CONFIG_SCHEMA = cv.All( - lock.LOCK_SCHEMA.extend( + lock.lock_schema(TemplateLock) + .extend( { - cv.GenerateID(): cv.declare_id(TemplateLock), cv.Optional(CONF_LAMBDA): cv.returning_lambda, cv.Optional(CONF_OPTIMISTIC, default=False): cv.boolean, cv.Optional(CONF_ASSUMED_STATE, default=False): cv.boolean, @@ -54,19 +48,19 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_LOCK_ACTION): automation.validate_automation(single=True), cv.Optional(CONF_OPEN_ACTION): automation.validate_automation(single=True), } - ).extend(cv.COMPONENT_SCHEMA), + ) + .extend(cv.COMPONENT_SCHEMA), validate, ) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await lock.new_lock(config) await cg.register_component(var, config) - await lock.register_lock(var, config) if CONF_LAMBDA in config: template_ = await cg.process_lambda( - config[CONF_LAMBDA], [], return_type=cg.optional.template(LockState) + config[CONF_LAMBDA], [], return_type=cg.optional.template(lock.LockState) ) cg.add(var.set_state_lambda(template_)) if CONF_UNLOCK_ACTION in config: @@ -88,17 +82,18 @@ async def to_code(config): @automation.register_action( "lock.template.publish", - lock.LockPublishAction, - cv.Schema( + TemplateLockPublishAction, + cv.maybe_simple_value( { - cv.Required(CONF_ID): cv.use_id(lock.Lock), - cv.Required(CONF_STATE): cv.templatable(validate_lock_state), - } + cv.GenerateID(): cv.use_id(TemplateLock), + cv.Required(CONF_STATE): cv.templatable(lock.validate_lock_state), + }, + key=CONF_STATE, ), ) async def lock_template_publish_to_code(config, action_id, template_arg, args): - paren = await cg.get_variable(config[CONF_ID]) - var = cg.new_Pvariable(action_id, template_arg, paren) - template_ = await cg.templatable(config[CONF_STATE], args, LockState) + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) + template_ = await cg.templatable(config[CONF_STATE], args, lock.LockState) cg.add(var.set_state(template_)) return var diff --git a/esphome/components/template/lock/automation.h b/esphome/components/template/lock/automation.h new file mode 100644 index 0000000000..6124546592 --- /dev/null +++ b/esphome/components/template/lock/automation.h @@ -0,0 +1,18 @@ +#pragma once + +#include "template_lock.h" + +#include "esphome/core/automation.h" + +namespace esphome { +namespace template_ { + +template class TemplateLockPublishAction : public Action, public Parented { + public: + TEMPLATABLE_VALUE(lock::LockState, state) + + void play(Ts... x) override { this->parent_->publish_state(this->state_.value(x...)); } +}; + +} // namespace template_ +} // namespace esphome diff --git a/esphome/components/template/text/__init__.py b/esphome/components/template/text/__init__.py index b0fea38aaf..572b5ba0f4 100644 --- a/esphome/components/template/text/__init__.py +++ b/esphome/components/template/text/__init__.py @@ -46,9 +46,9 @@ def validate(config): CONFIG_SCHEMA = cv.All( - text.TEXT_SCHEMA.extend( + text.text_schema(TemplateText) + .extend( { - cv.GenerateID(): cv.declare_id(TemplateText), cv.Optional(CONF_MIN_LENGTH, default=0): cv.int_range(min=0, max=255), cv.Optional(CONF_MAX_LENGTH, default=255): cv.int_range(min=0, max=255), cv.Optional(CONF_PATTERN): cv.string, @@ -58,7 +58,8 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_INITIAL_VALUE): cv.string_strict, cv.Optional(CONF_RESTORE_VALUE, default=False): cv.boolean, } - ).extend(cv.polling_component_schema("60s")), + ) + .extend(cv.polling_component_schema("60s")), validate, ) diff --git a/esphome/components/template/valve/__init__.py b/esphome/components/template/valve/__init__.py index 12e5174168..526751564d 100644 --- a/esphome/components/template/valve/__init__.py +++ b/esphome/components/template/valve/__init__.py @@ -21,6 +21,10 @@ from .. import template_ns TemplateValve = template_ns.class_("TemplateValve", valve.Valve, cg.Component) +TemplateValvePublishAction = template_ns.class_( + "TemplateValvePublishAction", automation.Action, cg.Parented.template(TemplateValve) +) + TemplateValveRestoreMode = template_ns.enum("TemplateValveRestoreMode") RESTORE_MODES = { "NO_RESTORE": TemplateValveRestoreMode.VALVE_NO_RESTORE, @@ -31,23 +35,30 @@ RESTORE_MODES = { CONF_HAS_POSITION = "has_position" CONF_TOGGLE_ACTION = "toggle_action" -CONFIG_SCHEMA = valve.VALVE_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(TemplateValve), - cv.Optional(CONF_LAMBDA): cv.returning_lambda, - cv.Optional(CONF_OPTIMISTIC, default=False): cv.boolean, - cv.Optional(CONF_ASSUMED_STATE, default=False): cv.boolean, - cv.Optional(CONF_HAS_POSITION, default=False): cv.boolean, - cv.Optional(CONF_OPEN_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_CLOSE_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_STOP_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_TOGGLE_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_POSITION_ACTION): automation.validate_automation(single=True), - cv.Optional(CONF_RESTORE_MODE, default="NO_RESTORE"): cv.enum( - RESTORE_MODES, upper=True - ), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + valve.valve_schema(TemplateValve) + .extend( + { + cv.Optional(CONF_LAMBDA): cv.returning_lambda, + cv.Optional(CONF_OPTIMISTIC, default=False): cv.boolean, + cv.Optional(CONF_ASSUMED_STATE, default=False): cv.boolean, + cv.Optional(CONF_HAS_POSITION, default=False): cv.boolean, + cv.Optional(CONF_OPEN_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_CLOSE_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_STOP_ACTION): automation.validate_automation(single=True), + cv.Optional(CONF_TOGGLE_ACTION): automation.validate_automation( + single=True + ), + cv.Optional(CONF_POSITION_ACTION): automation.validate_automation( + single=True + ), + cv.Optional(CONF_RESTORE_MODE, default="NO_RESTORE"): cv.enum( + RESTORE_MODES, upper=True + ), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): @@ -90,10 +101,10 @@ async def to_code(config): @automation.register_action( "valve.template.publish", - valve.ValvePublishAction, + TemplateValvePublishAction, cv.Schema( { - cv.Required(CONF_ID): cv.use_id(valve.Valve), + cv.GenerateID(): cv.use_id(TemplateValve), cv.Exclusive(CONF_STATE, "pos"): cv.templatable(valve.validate_valve_state), cv.Exclusive(CONF_POSITION, "pos"): cv.templatable(cv.percentage), cv.Optional(CONF_CURRENT_OPERATION): cv.templatable( @@ -103,8 +114,8 @@ async def to_code(config): ), ) async def valve_template_publish_to_code(config, action_id, template_arg, args): - paren = await cg.get_variable(config[CONF_ID]) - var = cg.new_Pvariable(action_id, template_arg, paren) + var = cg.new_Pvariable(action_id, template_arg) + await cg.register_parented(var, config[CONF_ID]) if state_config := config.get(CONF_STATE): template_ = await cg.templatable(state_config, args, float) cg.add(var.set_position(template_)) diff --git a/esphome/components/template/valve/automation.h b/esphome/components/template/valve/automation.h new file mode 100644 index 0000000000..af9b070c60 --- /dev/null +++ b/esphome/components/template/valve/automation.h @@ -0,0 +1,24 @@ +#pragma once + +#include "template_valve.h" + +#include "esphome/core/automation.h" + +namespace esphome { +namespace template_ { + +template class TemplateValvePublishAction : public Action, public Parented { + TEMPLATABLE_VALUE(float, position) + TEMPLATABLE_VALUE(valve::ValveOperation, current_operation) + + void play(Ts... x) override { + if (this->position_.has_value()) + this->parent_->position = this->position_.value(x...); + if (this->current_operation_.has_value()) + this->parent_->current_operation = this->current_operation_.value(x...); + this->parent_->publish_state(); + } +}; + +} // namespace template_ +} // namespace esphome diff --git a/esphome/components/text/__init__.py b/esphome/components/text/__init__.py index 20e5a645d1..1cc9283e45 100644 --- a/esphome/components/text/__init__.py +++ b/esphome/components/text/__init__.py @@ -5,6 +5,8 @@ import esphome.codegen as cg from esphome.components import mqtt, web_server import esphome.config_validation as cv from esphome.const import ( + CONF_ENTITY_CATEGORY, + CONF_ICON, CONF_ID, CONF_MODE, CONF_MQTT_ID, @@ -14,6 +16,7 @@ from esphome.const import ( CONF_WEB_SERVER, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity CODEOWNERS = ["@mauritskorse"] @@ -39,7 +42,7 @@ TEXT_MODES = { "PASSWORD": TextMode.TEXT_MODE_PASSWORD, # to be implemented for keys, passwords, etc. } -TEXT_SCHEMA = ( +_TEXT_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMPONENT_SCHEMA) .extend( @@ -57,6 +60,34 @@ TEXT_SCHEMA = ( ) +def text_schema( + class_: MockObjClass = cv.UNDEFINED, + *, + icon: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + mode: str = cv.UNDEFINED, +) -> cv.Schema: + schema = {} + + if class_ is not cv.UNDEFINED: + schema[cv.GenerateID()] = cv.declare_id(class_) + + for key, default, validator in [ + (CONF_ICON, icon, cv.icon), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_MODE, mode, cv.enum(TEXT_MODES, upper=True)), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _TEXT_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +TEXT_SCHEMA = text_schema() +TEXT_SCHEMA.add_extra(cv.deprecated_schema_constant("text")) + + async def setup_text_core_( var, config, diff --git a/esphome/components/text_sensor/__init__.py b/esphome/components/text_sensor/__init__.py index 12993d9ffc..888b65745f 100644 --- a/esphome/components/text_sensor/__init__.py +++ b/esphome/components/text_sensor/__init__.py @@ -125,7 +125,7 @@ async def map_filter_to_code(config, filter_id): validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_") -TEXT_SENSOR_SCHEMA = ( +_TEXT_SENSOR_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMPONENT_SCHEMA) .extend( @@ -152,38 +152,33 @@ TEXT_SENSOR_SCHEMA = ( ) ) -_UNDEF = object() - def text_sensor_schema( - class_: MockObjClass = _UNDEF, + class_: MockObjClass = cv.UNDEFINED, *, - icon: str = _UNDEF, - entity_category: str = _UNDEF, - device_class: str = _UNDEF, + device_class: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, ) -> cv.Schema: - schema = TEXT_SENSOR_SCHEMA - if class_ is not _UNDEF: - schema = schema.extend({cv.GenerateID(): cv.declare_id(class_)}) - if icon is not _UNDEF: - schema = schema.extend({cv.Optional(CONF_ICON, default=icon): cv.icon}) - if device_class is not _UNDEF: - schema = schema.extend( - { - cv.Optional( - CONF_DEVICE_CLASS, default=device_class - ): validate_device_class - } - ) - if entity_category is not _UNDEF: - schema = schema.extend( - { - cv.Optional( - CONF_ENTITY_CATEGORY, default=entity_category - ): cv.entity_category - } - ) - return schema + schema = {} + + if class_ is not cv.UNDEFINED: + schema[cv.GenerateID()] = cv.declare_id(class_) + + for key, default, validator in [ + (CONF_ICON, icon, cv.icon), + (CONF_DEVICE_CLASS, device_class, validate_device_class), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _TEXT_SENSOR_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +TEXT_SENSOR_SCHEMA = text_sensor_schema() +TEXT_SENSOR_SCHEMA.add_extra(cv.deprecated_schema_constant("text_sensor")) async def build_filters(config): diff --git a/esphome/components/time_based/cover.py b/esphome/components/time_based/cover.py index c723345370..d14332d453 100644 --- a/esphome/components/time_based/cover.py +++ b/esphome/components/time_based/cover.py @@ -6,7 +6,6 @@ from esphome.const import ( CONF_ASSUMED_STATE, CONF_CLOSE_ACTION, CONF_CLOSE_DURATION, - CONF_ID, CONF_OPEN_ACTION, CONF_OPEN_DURATION, CONF_STOP_ACTION, @@ -18,25 +17,27 @@ TimeBasedCover = time_based_ns.class_("TimeBasedCover", cover.Cover, cg.Componen CONF_HAS_BUILT_IN_ENDSTOP = "has_built_in_endstop" CONF_MANUAL_CONTROL = "manual_control" -CONFIG_SCHEMA = cover.COVER_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(TimeBasedCover), - cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, - cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), - cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, - cv.Optional(CONF_HAS_BUILT_IN_ENDSTOP, default=False): cv.boolean, - cv.Optional(CONF_MANUAL_CONTROL, default=False): cv.boolean, - cv.Optional(CONF_ASSUMED_STATE, default=True): cv.boolean, - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + cover.cover_schema(TimeBasedCover) + .extend( + { + cv.Required(CONF_STOP_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_OPEN_DURATION): cv.positive_time_period_milliseconds, + cv.Required(CONF_CLOSE_ACTION): automation.validate_automation(single=True), + cv.Required(CONF_CLOSE_DURATION): cv.positive_time_period_milliseconds, + cv.Optional(CONF_HAS_BUILT_IN_ENDSTOP, default=False): cv.boolean, + cv.Optional(CONF_MANUAL_CONTROL, default=False): cv.boolean, + cv.Optional(CONF_ASSUMED_STATE, default=True): cv.boolean, + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) await automation.build_automation( var.get_stop_trigger(), [], config[CONF_STOP_ACTION] diff --git a/esphome/components/tm1638/switch/__init__.py b/esphome/components/tm1638/switch/__init__.py index 8832cf8b92..90ff87938c 100644 --- a/esphome/components/tm1638/switch/__init__.py +++ b/esphome/components/tm1638/switch/__init__.py @@ -8,13 +8,16 @@ from ..display import CONF_TM1638_ID, TM1638Component, tm1638_ns TM1638SwitchLed = tm1638_ns.class_("TM1638SwitchLed", switch.Switch, cg.Component) -CONFIG_SCHEMA = switch.SWITCH_SCHEMA.extend( - { - cv.GenerateID(): cv.declare_id(TM1638SwitchLed), - cv.GenerateID(CONF_TM1638_ID): cv.use_id(TM1638Component), - cv.Required(CONF_LED): cv.int_range(min=0, max=7), - } -).extend(cv.COMPONENT_SCHEMA) +CONFIG_SCHEMA = ( + switch.switch_schema(TM1638SwitchLed) + .extend( + { + cv.GenerateID(CONF_TM1638_ID): cv.use_id(TM1638Component), + cv.Required(CONF_LED): cv.int_range(min=0, max=7), + } + ) + .extend(cv.COMPONENT_SCHEMA) +) async def to_code(config): diff --git a/esphome/components/tormatic/cover.py b/esphome/components/tormatic/cover.py index 627ae6b63d..447920326b 100644 --- a/esphome/components/tormatic/cover.py +++ b/esphome/components/tormatic/cover.py @@ -1,17 +1,17 @@ import esphome.codegen as cg from esphome.components import cover, uart import esphome.config_validation as cv -from esphome.const import CONF_CLOSE_DURATION, CONF_ID, CONF_OPEN_DURATION +from esphome.const import CONF_CLOSE_DURATION, CONF_OPEN_DURATION tormatic_ns = cg.esphome_ns.namespace("tormatic") Tormatic = tormatic_ns.class_("Tormatic", cover.Cover, cg.PollingComponent) CONFIG_SCHEMA = ( - cover.COVER_SCHEMA.extend(uart.UART_DEVICE_SCHEMA) + cover.cover_schema(Tormatic) + .extend(uart.UART_DEVICE_SCHEMA) .extend(cv.polling_component_schema("300ms")) .extend( { - cv.GenerateID(): cv.declare_id(Tormatic), cv.Optional( CONF_OPEN_DURATION, default="15s" ): cv.positive_time_period_milliseconds, @@ -34,9 +34,8 @@ FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema( async def to_code(config): - var = cg.new_Pvariable(config[CONF_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) await uart.register_uart_device(var, config) cg.add(var.set_close_duration(config[CONF_CLOSE_DURATION])) diff --git a/esphome/components/touchscreen/touchscreen.cpp b/esphome/components/touchscreen/touchscreen.cpp index 11207908fa..dcf3209752 100644 --- a/esphome/components/touchscreen/touchscreen.cpp +++ b/esphome/components/touchscreen/touchscreen.cpp @@ -50,13 +50,15 @@ void Touchscreen::loop() { tp.second.x_prev = tp.second.x; tp.second.y_prev = tp.second.y; } + // The interrupt flag must be reset BEFORE calling update_touches, otherwise we might miss an interrupt that was + // triggered while we were reading touch data. + this->store_.touched = false; this->update_touches(); if (this->skip_update_) { for (auto &tp : this->touches_) { tp.second.state &= ~STATE_RELEASING; } } else { - this->store_.touched = false; this->defer([this]() { this->send_touches_(); }); if (this->touch_timeout_ > 0) { // Simulate a touch after touch_timeout_> ms. This will reset any existing timeout operation. diff --git a/esphome/components/tuya/cover/__init__.py b/esphome/components/tuya/cover/__init__.py index 61029b6daa..8c610c0272 100644 --- a/esphome/components/tuya/cover/__init__.py +++ b/esphome/components/tuya/cover/__init__.py @@ -1,12 +1,7 @@ import esphome.codegen as cg from esphome.components import cover import esphome.config_validation as cv -from esphome.const import ( - CONF_MAX_VALUE, - CONF_MIN_VALUE, - CONF_OUTPUT_ID, - CONF_RESTORE_MODE, -) +from esphome.const import CONF_MAX_VALUE, CONF_MIN_VALUE, CONF_RESTORE_MODE from .. import CONF_TUYA_ID, Tuya, tuya_ns @@ -38,9 +33,9 @@ def validate_range(config): CONFIG_SCHEMA = cv.All( - cover.COVER_SCHEMA.extend( + cover.cover_schema(TuyaCover) + .extend( { - cv.GenerateID(CONF_OUTPUT_ID): cv.declare_id(TuyaCover), cv.GenerateID(CONF_TUYA_ID): cv.use_id(Tuya), cv.Optional(CONF_CONTROL_DATAPOINT): cv.uint8_t, cv.Optional(CONF_DIRECTION_DATAPOINT): cv.uint8_t, @@ -54,15 +49,15 @@ CONFIG_SCHEMA = cv.All( RESTORE_MODES, upper=True ), }, - ).extend(cv.COMPONENT_SCHEMA), + ) + .extend(cv.COMPONENT_SCHEMA), validate_range, ) async def to_code(config): - var = cg.new_Pvariable(config[CONF_OUTPUT_ID]) + var = await cover.new_cover(config) await cg.register_component(var, config) - await cover.register_cover(var, config) if CONF_CONTROL_DATAPOINT in config: cg.add(var.set_control_id(config[CONF_CONTROL_DATAPOINT])) diff --git a/esphome/components/tuya/select/__init__.py b/esphome/components/tuya/select/__init__.py index a34e279746..e5b2e36ce7 100644 --- a/esphome/components/tuya/select/__init__.py +++ b/esphome/components/tuya/select/__init__.py @@ -1,7 +1,12 @@ import esphome.codegen as cg from esphome.components import select import esphome.config_validation as cv -from esphome.const import CONF_ENUM_DATAPOINT, CONF_OPTIMISTIC, CONF_OPTIONS +from esphome.const import ( + CONF_ENUM_DATAPOINT, + CONF_INT_DATAPOINT, + CONF_OPTIMISTIC, + CONF_OPTIONS, +) from .. import CONF_TUYA_ID, Tuya, tuya_ns @@ -26,17 +31,19 @@ def ensure_option_map(value): return value -CONFIG_SCHEMA = ( +CONFIG_SCHEMA = cv.All( select.select_schema(TuyaSelect) .extend( { cv.GenerateID(CONF_TUYA_ID): cv.use_id(Tuya), - cv.Required(CONF_ENUM_DATAPOINT): cv.uint8_t, + cv.Optional(CONF_ENUM_DATAPOINT): cv.uint8_t, + cv.Optional(CONF_INT_DATAPOINT): cv.uint8_t, cv.Required(CONF_OPTIONS): ensure_option_map, cv.Optional(CONF_OPTIMISTIC, default=False): cv.boolean, } ) - .extend(cv.COMPONENT_SCHEMA) + .extend(cv.COMPONENT_SCHEMA), + cv.has_exactly_one_key(CONF_ENUM_DATAPOINT, CONF_INT_DATAPOINT), ) @@ -47,5 +54,8 @@ async def to_code(config): cg.add(var.set_select_mappings(list(options_map.keys()))) parent = await cg.get_variable(config[CONF_TUYA_ID]) cg.add(var.set_tuya_parent(parent)) - cg.add(var.set_select_id(config[CONF_ENUM_DATAPOINT])) + if enum_datapoint := config.get(CONF_ENUM_DATAPOINT, None) is not None: + cg.add(var.set_select_id(enum_datapoint, False)) + if int_datapoint := config.get(CONF_INT_DATAPOINT, None) is not None: + cg.add(var.set_select_id(int_datapoint, True)) cg.add(var.set_optimistic(config[CONF_OPTIMISTIC])) diff --git a/esphome/components/tuya/select/tuya_select.cpp b/esphome/components/tuya/select/tuya_select.cpp index a4df0873b0..02643e97f4 100644 --- a/esphome/components/tuya/select/tuya_select.cpp +++ b/esphome/components/tuya/select/tuya_select.cpp @@ -31,7 +31,11 @@ void TuyaSelect::control(const std::string &value) { if (idx.has_value()) { uint8_t mapping = this->mappings_.at(idx.value()); ESP_LOGV(TAG, "Setting %u datapoint value to %u:%s", this->select_id_, mapping, value.c_str()); - this->parent_->set_enum_datapoint_value(this->select_id_, mapping); + if (this->is_int_) { + this->parent_->set_integer_datapoint_value(this->select_id_, mapping); + } else { + this->parent_->set_enum_datapoint_value(this->select_id_, mapping); + } return; } @@ -41,6 +45,7 @@ void TuyaSelect::control(const std::string &value) { void TuyaSelect::dump_config() { LOG_SELECT("", "Tuya Select", this); ESP_LOGCONFIG(TAG, " Select has datapoint ID %u", this->select_id_); + ESP_LOGCONFIG(TAG, " Data type: %s", this->is_int_ ? "int" : "enum"); ESP_LOGCONFIG(TAG, " Options are:"); auto options = this->traits.get_options(); for (auto i = 0; i < this->mappings_.size(); i++) { diff --git a/esphome/components/tuya/select/tuya_select.h b/esphome/components/tuya/select/tuya_select.h index 6a7e5c7ed0..12d7b507d4 100644 --- a/esphome/components/tuya/select/tuya_select.h +++ b/esphome/components/tuya/select/tuya_select.h @@ -16,7 +16,10 @@ class TuyaSelect : public select::Select, public Component { void set_tuya_parent(Tuya *parent) { this->parent_ = parent; } void set_optimistic(bool optimistic) { this->optimistic_ = optimistic; } - void set_select_id(uint8_t select_id) { this->select_id_ = select_id; } + void set_select_id(uint8_t select_id, bool is_int) { + this->select_id_ = select_id; + this->is_int_ = is_int; + } void set_select_mappings(std::vector mappings) { this->mappings_ = std::move(mappings); } protected: @@ -26,6 +29,7 @@ class TuyaSelect : public select::Select, public Component { bool optimistic_ = false; uint8_t select_id_; std::vector mappings_; + bool is_int_ = false; }; } // namespace tuya diff --git a/esphome/components/uart/packet_transport/__init__.py b/esphome/components/uart/packet_transport/__init__.py new file mode 100644 index 0000000000..58c6296e2f --- /dev/null +++ b/esphome/components/uart/packet_transport/__init__.py @@ -0,0 +1,20 @@ +from esphome.components.packet_transport import ( + PacketTransport, + new_packet_transport, + transport_schema, +) +from esphome.cpp_types import PollingComponent + +from .. import UART_DEVICE_SCHEMA, register_uart_device, uart_ns + +CODEOWNERS = ["@clydebarrow"] +DEPENDENCIES = ["uart"] + +UARTTransport = uart_ns.class_("UARTTransport", PacketTransport, PollingComponent) + +CONFIG_SCHEMA = transport_schema(UARTTransport).extend(UART_DEVICE_SCHEMA) + + +async def to_code(config): + var, _ = await new_packet_transport(config) + await register_uart_device(var, config) diff --git a/esphome/components/uart/packet_transport/uart_transport.cpp b/esphome/components/uart/packet_transport/uart_transport.cpp new file mode 100644 index 0000000000..423b657532 --- /dev/null +++ b/esphome/components/uart/packet_transport/uart_transport.cpp @@ -0,0 +1,88 @@ +#include "esphome/core/log.h" +#include "esphome/core/application.h" +#include "uart_transport.h" + +namespace esphome { +namespace uart { + +static const char *const TAG = "uart_transport"; + +void UARTTransport::loop() { + PacketTransport::loop(); + + while (this->parent_->available()) { + uint8_t byte; + if (!this->parent_->read_byte(&byte)) { + ESP_LOGW(TAG, "Failed to read byte from UART"); + return; + } + if (byte == FLAG_BYTE) { + if (this->rx_started_ && this->receive_buffer_.size() > 6) { + auto len = this->receive_buffer_.size(); + auto crc = crc16(this->receive_buffer_.data(), len - 2); + if (crc != (this->receive_buffer_[len - 2] | (this->receive_buffer_[len - 1] << 8))) { + ESP_LOGD(TAG, "CRC mismatch, discarding packet"); + this->rx_started_ = false; + this->receive_buffer_.clear(); + continue; + } + this->receive_buffer_.resize(len - 2); + this->process_(this->receive_buffer_); + this->rx_started_ = false; + } else { + this->rx_started_ = true; + } + this->receive_buffer_.clear(); + this->rx_control_ = false; + continue; + } + if (!this->rx_started_) + continue; + if (byte == CONTROL_BYTE) { + this->rx_control_ = true; + continue; + } + if (this->rx_control_) { + byte ^= 0x20; + this->rx_control_ = false; + } + if (this->receive_buffer_.size() == MAX_PACKET_SIZE) { + ESP_LOGD(TAG, "Packet too large, discarding"); + this->rx_started_ = false; + this->receive_buffer_.clear(); + continue; + } + this->receive_buffer_.push_back(byte); + } +} + +void UARTTransport::update() { + this->updated_ = true; + this->resend_data_ = true; + PacketTransport::update(); +} + +/** + * Write a byte to the UART bus. If the byte is a flag or control byte, it will be escaped. + * @param byte The byte to write. + */ +void UARTTransport::write_byte_(uint8_t byte) const { + if (byte == FLAG_BYTE || byte == CONTROL_BYTE) { + this->parent_->write_byte(CONTROL_BYTE); + byte ^= 0x20; + } + this->parent_->write_byte(byte); +} + +void UARTTransport::send_packet(const std::vector &buf) const { + this->parent_->write_byte(FLAG_BYTE); + for (uint8_t byte : buf) { + this->write_byte_(byte); + } + auto crc = crc16(buf.data(), buf.size()); + this->write_byte_(crc & 0xFF); + this->write_byte_(crc >> 8); + this->parent_->write_byte(FLAG_BYTE); +} +} // namespace uart +} // namespace esphome diff --git a/esphome/components/uart/packet_transport/uart_transport.h b/esphome/components/uart/packet_transport/uart_transport.h new file mode 100644 index 0000000000..f1431e948c --- /dev/null +++ b/esphome/components/uart/packet_transport/uart_transport.h @@ -0,0 +1,41 @@ +#pragma once + +#include "esphome/core/component.h" +#include "esphome/components/packet_transport/packet_transport.h" +#include +#include "../uart.h" + +namespace esphome { +namespace uart { + +/** + * A transport protocol for sending and receiving packets over a UART connection. + * The protocol is based on Asynchronous HDLC framing. (https://en.wikipedia.org/wiki/High-Level_Data_Link_Control) + * There are two special bytes: FLAG_BYTE and CONTROL_BYTE. + * A 16-bit CRC is appended to the packet, then + * the protocol wraps the resulting data between FLAG_BYTEs. + * Any occurrence of FLAG_BYTE or CONTROL_BYTE in the data is escaped by emitting CONTROL_BYTE followed by the byte + * XORed with 0x20. + */ +static const uint16_t MAX_PACKET_SIZE = 508; +static const uint8_t FLAG_BYTE = 0x7E; +static const uint8_t CONTROL_BYTE = 0x7D; + +class UARTTransport : public packet_transport::PacketTransport, public UARTDevice { + public: + void loop() override; + void update() override; + float get_setup_priority() const override { return setup_priority::PROCESSOR; } + + protected: + void write_byte_(uint8_t byte) const; + void send_packet(const std::vector &buf) const override; + bool should_send() override { return true; }; + size_t get_max_packet_size() override { return MAX_PACKET_SIZE; } + std::vector receive_buffer_{}; + bool rx_started_{}; + bool rx_control_{}; +}; + +} // namespace uart +} // namespace esphome diff --git a/esphome/components/udp/__init__.py b/esphome/components/udp/__init__.py index 140d1e4236..ed405d7c22 100644 --- a/esphome/components/udp/__init__.py +++ b/esphome/components/udp/__init__.py @@ -1,164 +1,162 @@ -import hashlib - +from esphome import automation +from esphome.automation import Trigger import esphome.codegen as cg -from esphome.components.api import CONF_ENCRYPTION -from esphome.components.binary_sensor import BinarySensor -from esphome.components.sensor import Sensor -import esphome.config_validation as cv -from esphome.const import ( +from esphome.components.packet_transport import ( CONF_BINARY_SENSORS, - CONF_ID, - CONF_INTERNAL, - CONF_KEY, - CONF_NAME, - CONF_PORT, + CONF_ENCRYPTION, + CONF_PING_PONG_ENABLE, + CONF_PROVIDERS, + CONF_ROLLING_CODE_ENABLE, CONF_SENSORS, ) -from esphome.cpp_generator import MockObjClass +import esphome.config_validation as cv +from esphome.const import CONF_DATA, CONF_ID, CONF_PORT, CONF_TRIGGER_ID +from esphome.core import Lambda +from esphome.cpp_generator import ExpressionStatement, MockObj CODEOWNERS = ["@clydebarrow"] DEPENDENCIES = ["network"] -AUTO_LOAD = ["socket", "xxtea"] +AUTO_LOAD = ["socket"] + MULTI_CONF = True - udp_ns = cg.esphome_ns.namespace("udp") -UDPComponent = udp_ns.class_("UDPComponent", cg.PollingComponent) +UDPComponent = udp_ns.class_("UDPComponent", cg.Component) +UDPWriteAction = udp_ns.class_("UDPWriteAction", automation.Action) +trigger_args = cg.std_vector.template(cg.uint8) -CONF_BROADCAST = "broadcast" -CONF_BROADCAST_ID = "broadcast_id" CONF_ADDRESSES = "addresses" CONF_LISTEN_ADDRESS = "listen_address" -CONF_PROVIDER = "provider" -CONF_PROVIDERS = "providers" -CONF_REMOTE_ID = "remote_id" CONF_UDP_ID = "udp_id" -CONF_PING_PONG_ENABLE = "ping_pong_enable" -CONF_PING_PONG_RECYCLE_TIME = "ping_pong_recycle_time" -CONF_ROLLING_CODE_ENABLE = "rolling_code_enable" +CONF_ON_RECEIVE = "on_receive" +CONF_LISTEN_PORT = "listen_port" +CONF_BROADCAST_PORT = "broadcast_port" - -def sensor_validation(cls: MockObjClass): - return cv.maybe_simple_value( - cv.Schema( - { - cv.Required(CONF_ID): cv.use_id(cls), - cv.Optional(CONF_BROADCAST_ID): cv.validate_id_name, - } - ), - key=CONF_ID, - ) - - -ENCRYPTION_SCHEMA = { - cv.Optional(CONF_ENCRYPTION): cv.maybe_simple_value( - cv.Schema( - { - cv.Required(CONF_KEY): cv.string, - } - ), - key=CONF_KEY, - ) -} - -PROVIDER_SCHEMA = cv.Schema( +UDP_SCHEMA = cv.Schema( { - cv.Required(CONF_NAME): cv.valid_name, - } -).extend(ENCRYPTION_SCHEMA) - - -def validate_(config): - if CONF_ENCRYPTION in config: - if CONF_SENSORS not in config and CONF_BINARY_SENSORS not in config: - raise cv.Invalid("No sensors or binary sensors to encrypt") - elif config[CONF_ROLLING_CODE_ENABLE]: - raise cv.Invalid("Rolling code requires an encryption key") - if config[CONF_PING_PONG_ENABLE]: - if not any(CONF_ENCRYPTION in p for p in config.get(CONF_PROVIDERS) or ()): - raise cv.Invalid("Ping-pong requires at least one encrypted provider") - return config - - -CONFIG_SCHEMA = cv.All( - cv.polling_component_schema("15s") - .extend( - { - cv.GenerateID(): cv.declare_id(UDPComponent), - cv.Optional(CONF_PORT, default=18511): cv.port, - cv.Optional( - CONF_LISTEN_ADDRESS, default="255.255.255.255" - ): cv.ipv4address_multi_broadcast, - cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list( - cv.ipv4address, - ), - cv.Optional(CONF_ROLLING_CODE_ENABLE, default=False): cv.boolean, - cv.Optional(CONF_PING_PONG_ENABLE, default=False): cv.boolean, - cv.Optional( - CONF_PING_PONG_RECYCLE_TIME, default="600s" - ): cv.positive_time_period_seconds, - cv.Optional(CONF_SENSORS): cv.ensure_list(sensor_validation(Sensor)), - cv.Optional(CONF_BINARY_SENSORS): cv.ensure_list( - sensor_validation(BinarySensor) - ), - cv.Optional(CONF_PROVIDERS): cv.ensure_list(PROVIDER_SCHEMA), - }, - ) - .extend(ENCRYPTION_SCHEMA), - validate_, -) - -SENSOR_SCHEMA = cv.Schema( - { - cv.Optional(CONF_REMOTE_ID): cv.string_strict, - cv.Required(CONF_PROVIDER): cv.valid_name, cv.GenerateID(CONF_UDP_ID): cv.use_id(UDPComponent), } ) -def require_internal_with_name(config): - if CONF_NAME in config and CONF_INTERNAL not in config: - raise cv.Invalid("Must provide internal: config when using name:") - return config +def is_relocated(option): + def validator(value): + raise cv.Invalid( + f"The '{option}' option should now be configured in the 'packet_transport' component" + ) + + return validator -def hash_encryption_key(config: dict): - return list(hashlib.sha256(config[CONF_KEY].encode()).digest()) +RELOCATED = { + cv.Optional(x): is_relocated(x) + for x in ( + CONF_PROVIDERS, + CONF_ENCRYPTION, + CONF_PING_PONG_ENABLE, + CONF_ROLLING_CODE_ENABLE, + CONF_SENSORS, + CONF_BINARY_SENSORS, + ) +} + +CONFIG_SCHEMA = cv.COMPONENT_SCHEMA.extend( + { + cv.GenerateID(): cv.declare_id(UDPComponent), + cv.Optional(CONF_PORT, default=18511): cv.Any( + cv.port, + cv.Schema( + { + cv.Required(CONF_LISTEN_PORT): cv.port, + cv.Required(CONF_BROADCAST_PORT): cv.port, + } + ), + ), + cv.Optional( + CONF_LISTEN_ADDRESS, default="255.255.255.255" + ): cv.ipv4address_multi_broadcast, + cv.Optional(CONF_ADDRESSES, default=["255.255.255.255"]): cv.ensure_list( + cv.ipv4address, + ), + cv.Optional(CONF_ON_RECEIVE): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id( + Trigger.template(trigger_args) + ), + } + ), + } +).extend(RELOCATED) + + +async def register_udp_client(var, config): + udp_var = await cg.get_variable(config[CONF_UDP_ID]) + cg.add(var.set_parent(udp_var)) + return udp_var async def to_code(config): cg.add_define("USE_UDP") cg.add_global(udp_ns.using) var = cg.new_Pvariable(config[CONF_ID]) - await cg.register_component(var, config) - cg.add(var.set_port(config[CONF_PORT])) - cg.add(var.set_rolling_code_enable(config[CONF_ROLLING_CODE_ENABLE])) - cg.add(var.set_ping_pong_enable(config[CONF_PING_PONG_ENABLE])) - cg.add( - var.set_ping_pong_recycle_time( - config[CONF_PING_PONG_RECYCLE_TIME].total_seconds - ) - ) - for sens_conf in config.get(CONF_SENSORS, ()): - sens_id = sens_conf[CONF_ID] - sensor = await cg.get_variable(sens_id) - bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) - cg.add(var.add_sensor(bcst_id, sensor)) - for sens_conf in config.get(CONF_BINARY_SENSORS, ()): - sens_id = sens_conf[CONF_ID] - sensor = await cg.get_variable(sens_id) - bcst_id = sens_conf.get(CONF_BROADCAST_ID, sens_id.id) - cg.add(var.add_binary_sensor(bcst_id, sensor)) + var = await cg.register_component(var, config) + conf_port = config[CONF_PORT] + if isinstance(conf_port, int): + cg.add(var.set_listen_port(conf_port)) + cg.add(var.set_broadcast_port(conf_port)) + else: + cg.add(var.set_listen_port(conf_port[CONF_LISTEN_PORT])) + cg.add(var.set_broadcast_port(conf_port[CONF_BROADCAST_PORT])) + if (listen_address := str(config[CONF_LISTEN_ADDRESS])) != "255.255.255.255": + cg.add(var.set_listen_address(listen_address)) for address in config[CONF_ADDRESSES]: cg.add(var.add_address(str(address))) + if on_receive := config.get(CONF_ON_RECEIVE): + on_receive = on_receive[0] + trigger = cg.new_Pvariable(on_receive[CONF_TRIGGER_ID]) + trigger = await automation.build_automation( + trigger, [(trigger_args, "data")], on_receive + ) + trigger = Lambda(str(ExpressionStatement(trigger.trigger(MockObj("data"))))) + trigger = await cg.process_lambda(trigger, [(trigger_args, "data")]) + cg.add(var.add_listener(trigger)) + cg.add(var.set_should_listen()) - if encryption := config.get(CONF_ENCRYPTION): - cg.add(var.set_encryption_key(hash_encryption_key(encryption))) - for provider in config.get(CONF_PROVIDERS, ()): - name = provider[CONF_NAME] - cg.add(var.add_provider(name)) - if (listen_address := str(config[CONF_LISTEN_ADDRESS])) != "255.255.255.255": - cg.add(var.set_listen_address(listen_address)) - if encryption := provider.get(CONF_ENCRYPTION): - cg.add(var.set_provider_encryption(name, hash_encryption_key(encryption))) +def validate_raw_data(value): + if isinstance(value, str): + return value.encode("utf-8") + if isinstance(value, str): + return value + if isinstance(value, list): + return cv.Schema([cv.hex_uint8_t])(value) + raise cv.Invalid( + "data must either be a string wrapped in quotes or a list of bytes" + ) + + +@automation.register_action( + "udp.write", + UDPWriteAction, + cv.maybe_simple_value( + { + cv.GenerateID(): cv.use_id(UDPComponent), + cv.Required(CONF_DATA): cv.templatable(validate_raw_data), + }, + key=CONF_DATA, + ), +) +async def udp_write_to_code(config, action_id, template_arg, args): + var = cg.new_Pvariable(action_id, template_arg) + udp_var = await cg.get_variable(config[CONF_ID]) + await cg.register_parented(var, udp_var) + cg.add(udp_var.set_should_broadcast()) + data = config[CONF_DATA] + if isinstance(data, bytes): + data = list(data) + + if cg.is_template(data): + templ = await cg.templatable(data, args, cg.std_vector.template(cg.uint8)) + cg.add(var.set_data_template(templ)) + else: + cg.add(var.set_data_static(data)) + return var diff --git a/esphome/components/udp/automation.h b/esphome/components/udp/automation.h new file mode 100644 index 0000000000..f75e6d35bf --- /dev/null +++ b/esphome/components/udp/automation.h @@ -0,0 +1,40 @@ +#pragma once + +#include "udp_component.h" +#ifdef USE_NETWORK +#include "esphome/core/automation.h" + +#include + +namespace esphome { +namespace udp { + +template class UDPWriteAction : public Action, public Parented { + public: + void set_data_template(std::function(Ts...)> func) { + this->data_func_ = func; + this->static_ = false; + } + void set_data_static(const std::vector &data) { + this->data_static_ = data; + this->static_ = true; + } + + void play(Ts... x) override { + if (this->static_) { + this->parent_->send_packet(this->data_static_); + } else { + auto val = this->data_func_(x...); + this->parent_->send_packet(val); + } + } + + protected: + bool static_{false}; + std::function(Ts...)> data_func_{}; + std::vector data_static_{}; +}; + +} // namespace udp +} // namespace esphome +#endif diff --git a/esphome/components/udp/binary_sensor.py b/esphome/components/udp/binary_sensor.py index d90e495527..7d449efbfd 100644 --- a/esphome/components/udp/binary_sensor.py +++ b/esphome/components/udp/binary_sensor.py @@ -1,27 +1,5 @@ -import esphome.codegen as cg -from esphome.components import binary_sensor -from esphome.config_validation import All, has_at_least_one_key -from esphome.const import CONF_ID +import esphome.config_validation as cv -from . import ( - CONF_PROVIDER, - CONF_REMOTE_ID, - CONF_UDP_ID, - SENSOR_SCHEMA, - require_internal_with_name, +CONFIG_SCHEMA = cv.invalid( + "The 'udp.binary_sensor' component has been migrated to the 'packet_transport.binary_sensor' component." ) - -DEPENDENCIES = ["udp"] - -CONFIG_SCHEMA = All( - binary_sensor.binary_sensor_schema().extend(SENSOR_SCHEMA), - has_at_least_one_key(CONF_ID, CONF_REMOTE_ID), - require_internal_with_name, -) - - -async def to_code(config): - var = await binary_sensor.new_binary_sensor(config) - comp = await cg.get_variable(config[CONF_UDP_ID]) - remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) - cg.add(comp.add_remote_binary_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/udp/packet_transport/__init__.py b/esphome/components/udp/packet_transport/__init__.py new file mode 100644 index 0000000000..b6957a372b --- /dev/null +++ b/esphome/components/udp/packet_transport/__init__.py @@ -0,0 +1,29 @@ +import esphome.codegen as cg +from esphome.components.api import CONF_ENCRYPTION +from esphome.components.packet_transport import ( + CONF_PING_PONG_ENABLE, + PacketTransport, + new_packet_transport, + transport_schema, +) +from esphome.const import CONF_BINARY_SENSORS, CONF_SENSORS +from esphome.cpp_types import PollingComponent + +from .. import UDP_SCHEMA, register_udp_client, udp_ns + +UDPTransport = udp_ns.class_("UDPTransport", PacketTransport, PollingComponent) + +CONFIG_SCHEMA = transport_schema(UDPTransport).extend(UDP_SCHEMA) + + +async def to_code(config): + var, providers = await new_packet_transport(config) + udp_var = await register_udp_client(var, config) + if CONF_ENCRYPTION in config or providers: + cg.add(udp_var.set_should_listen()) + if ( + config[CONF_PING_PONG_ENABLE] + or config.get(CONF_SENSORS, ()) + or config.get(CONF_BINARY_SENSORS, ()) + ): + cg.add(udp_var.set_should_broadcast()) diff --git a/esphome/components/udp/packet_transport/udp_transport.cpp b/esphome/components/udp/packet_transport/udp_transport.cpp new file mode 100644 index 0000000000..b92e0d64df --- /dev/null +++ b/esphome/components/udp/packet_transport/udp_transport.cpp @@ -0,0 +1,36 @@ +#include "esphome/core/log.h" +#include "esphome/core/application.h" +#include "esphome/components/network/util.h" +#include "udp_transport.h" + +namespace esphome { +namespace udp { + +static const char *const TAG = "udp_transport"; + +bool UDPTransport::should_send() { return this->should_broadcast_ && network::is_connected(); } +void UDPTransport::setup() { + PacketTransport::setup(); + this->should_broadcast_ = this->ping_pong_enable_; +#ifdef USE_SENSOR + this->should_broadcast_ |= !this->sensors_.empty(); +#endif +#ifdef USE_BINARY_SENSOR + this->should_broadcast_ |= !this->binary_sensors_.empty(); +#endif + if (this->should_broadcast_) + this->parent_->set_should_broadcast(); + if (!this->providers_.empty() || this->is_encrypted_()) { + this->parent_->add_listener([this](std::vector &buf) { this->process_(buf); }); + } +} + +void UDPTransport::update() { + PacketTransport::update(); + this->updated_ = true; + this->resend_data_ = this->should_broadcast_; +} + +void UDPTransport::send_packet(const std::vector &buf) const { this->parent_->send_packet(buf); } +} // namespace udp +} // namespace esphome diff --git a/esphome/components/udp/packet_transport/udp_transport.h b/esphome/components/udp/packet_transport/udp_transport.h new file mode 100644 index 0000000000..c87eb62780 --- /dev/null +++ b/esphome/components/udp/packet_transport/udp_transport.h @@ -0,0 +1,28 @@ +#pragma once + +#include "../udp_component.h" +#ifdef USE_NETWORK +#include "esphome/core/component.h" +#include "esphome/components/packet_transport/packet_transport.h" +#include + +namespace esphome { +namespace udp { + +class UDPTransport : public packet_transport::PacketTransport, public Parented { + public: + void setup() override; + void update() override; + + float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } + + protected: + void send_packet(const std::vector &buf) const override; + bool should_send() override; + bool should_broadcast_{false}; + size_t get_max_packet_size() override { return MAX_PACKET_SIZE; } +}; + +} // namespace udp +} // namespace esphome +#endif diff --git a/esphome/components/udp/sensor.py b/esphome/components/udp/sensor.py index 860c277c44..9ce05e7ffb 100644 --- a/esphome/components/udp/sensor.py +++ b/esphome/components/udp/sensor.py @@ -1,27 +1,5 @@ -import esphome.codegen as cg -from esphome.components.sensor import new_sensor, sensor_schema -from esphome.config_validation import All, has_at_least_one_key -from esphome.const import CONF_ID +import esphome.config_validation as cv -from . import ( - CONF_PROVIDER, - CONF_REMOTE_ID, - CONF_UDP_ID, - SENSOR_SCHEMA, - require_internal_with_name, +CONFIG_SCHEMA = cv.invalid( + "The 'udp.sensor' component has been migrated to the 'packet_transport.sensor' component." ) - -DEPENDENCIES = ["udp"] - -CONFIG_SCHEMA = All( - sensor_schema().extend(SENSOR_SCHEMA), - has_at_least_one_key(CONF_ID, CONF_REMOTE_ID), - require_internal_with_name, -) - - -async def to_code(config): - var = await new_sensor(config) - comp = await cg.get_variable(config[CONF_UDP_ID]) - remote_id = str(config.get(CONF_REMOTE_ID) or config.get(CONF_ID)) - cg.add(comp.add_remote_sensor(config[CONF_PROVIDER], remote_id, var)) diff --git a/esphome/components/udp/udp_component.cpp b/esphome/components/udp/udp_component.cpp index 59cba8c7fe..222c73f82e 100644 --- a/esphome/components/udp/udp_component.cpp +++ b/esphome/components/udp/udp_component.cpp @@ -1,164 +1,24 @@ +#include "esphome/core/defines.h" +#ifdef USE_NETWORK #include "esphome/core/log.h" #include "esphome/core/application.h" #include "esphome/components/network/util.h" #include "udp_component.h" -#include "esphome/components/xxtea/xxtea.h" - namespace esphome { namespace udp { -/** - * Structure of a data packet; everything is little-endian - * - * --- In clear text --- - * MAGIC_NUMBER: 16 bits - * host name length: 1 byte - * host name: (length) bytes - * padding: 0 or more null bytes to a 4 byte boundary - * - * --- Encrypted (if key set) ---- - * DATA_KEY: 1 byte: OR ROLLING_CODE_KEY: - * Rolling code (if enabled): 8 bytes - * Ping keys: if any - * repeat: - * PING_KEY: 1 byte - * ping code: 4 bytes - * Sensors: - * repeat: - * SENSOR_KEY: 1 byte - * float value: 4 bytes - * name length: 1 byte - * name - * Binary Sensors: - * repeat: - * BINARY_SENSOR_KEY: 1 byte - * bool value: 1 bytes - * name length: 1 byte - * name - * - * Padded to a 4 byte boundary with nulls - * - * Structure of a ping request packet: - * --- In clear text --- - * MAGIC_PING: 16 bits - * host name length: 1 byte - * host name: (length) bytes - * Ping key (4 bytes) - * - */ static const char *const TAG = "udp"; -static size_t round4(size_t value) { return (value + 3) & ~3; } - -union FuData { - uint32_t u32; - float f32; -}; - -static const size_t MAX_PACKET_SIZE = 508; -static const uint16_t MAGIC_NUMBER = 0x4553; -static const uint16_t MAGIC_PING = 0x5048; -static const uint32_t PREF_HASH = 0x45535043; -enum DataKey { - ZERO_FILL_KEY, - DATA_KEY, - SENSOR_KEY, - BINARY_SENSOR_KEY, - PING_KEY, - ROLLING_CODE_KEY, -}; - -static const size_t MAX_PING_KEYS = 4; - -static inline void add(std::vector &vec, uint32_t data) { - vec.push_back(data & 0xFF); - vec.push_back((data >> 8) & 0xFF); - vec.push_back((data >> 16) & 0xFF); - vec.push_back((data >> 24) & 0xFF); -} - -static inline uint32_t get_uint32(uint8_t *&buf) { - uint32_t data = *buf++; - data += *buf++ << 8; - data += *buf++ << 16; - data += *buf++ << 24; - return data; -} - -static inline uint16_t get_uint16(uint8_t *&buf) { - uint16_t data = *buf++; - data += *buf++ << 8; - return data; -} - -static inline void add(std::vector &vec, uint8_t data) { vec.push_back(data); } -static inline void add(std::vector &vec, uint16_t data) { - vec.push_back((uint8_t) data); - vec.push_back((uint8_t) (data >> 8)); -} -static inline void add(std::vector &vec, DataKey data) { vec.push_back(data); } -static void add(std::vector &vec, const char *str) { - auto len = strlen(str); - vec.push_back(len); - for (size_t i = 0; i != len; i++) { - vec.push_back(*str++); - } -} - void UDPComponent::setup() { - this->name_ = App.get_name().c_str(); - if (strlen(this->name_) > 255) { - this->mark_failed(); - this->status_set_error("Device name exceeds 255 chars"); - return; - } - this->resend_ping_key_ = this->ping_pong_enable_; - // restore the upper 32 bits of the rolling code, increment and save. - this->pref_ = global_preferences->make_preference(PREF_HASH, true); - this->pref_.load(&this->rolling_code_[1]); - this->rolling_code_[1]++; - this->pref_.save(&this->rolling_code_[1]); - this->ping_key_ = random_uint32(); - ESP_LOGV(TAG, "Rolling code incremented, upper part now %u", (unsigned) this->rolling_code_[1]); -#ifdef USE_SENSOR - for (auto &sensor : this->sensors_) { - sensor.sensor->add_on_state_callback([this, &sensor](float x) { - this->updated_ = true; - sensor.updated = true; - }); - } -#endif -#ifdef USE_BINARY_SENSOR - for (auto &sensor : this->binary_sensors_) { - sensor.sensor->add_on_state_callback([this, &sensor](bool value) { - this->updated_ = true; - sensor.updated = true; - }); - } -#endif - this->should_send_ = this->ping_pong_enable_; -#ifdef USE_SENSOR - this->should_send_ |= !this->sensors_.empty(); -#endif -#ifdef USE_BINARY_SENSOR - this->should_send_ |= !this->binary_sensors_.empty(); -#endif - this->should_listen_ = !this->providers_.empty() || this->is_encrypted_(); - // initialise the header. This is invariant. - add(this->header_, MAGIC_NUMBER); - add(this->header_, this->name_); - // pad to a multiple of 4 bytes - while (this->header_.size() & 0x3) - this->header_.push_back(0); #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) for (const auto &address : this->addresses_) { struct sockaddr saddr {}; - socket::set_sockaddr(&saddr, sizeof(saddr), address, this->port_); + socket::set_sockaddr(&saddr, sizeof(saddr), address, this->broadcast_port_); this->sockaddrs_.push_back(saddr); } // set up broadcast socket - if (this->should_send_) { + if (this->should_broadcast_) { this->broadcast_socket_ = socket::socket(AF_INET, SOCK_DGRAM, IPPROTO_IP); if (this->broadcast_socket_ == nullptr) { this->mark_failed(); @@ -202,14 +62,14 @@ void UDPComponent::setup() { server.sin_family = AF_INET; server.sin_addr.s_addr = ESPHOME_INADDR_ANY; - server.sin_port = htons(this->port_); + server.sin_port = htons(this->listen_port_); if (this->listen_address_.has_value()) { struct ip_mreq imreq = {}; imreq.imr_interface.s_addr = ESPHOME_INADDR_ANY; inet_aton(this->listen_address_.value().str().c_str(), &imreq.imr_multiaddr); server.sin_addr.s_addr = imreq.imr_multiaddr.s_addr; - ESP_LOGV(TAG, "Join multicast %s", this->listen_address_.value().str().c_str()); + ESP_LOGD(TAG, "Join multicast %s", this->listen_address_.value().str().c_str()); err = this->listen_socket_->setsockopt(IPPROTO_IP, IP_ADD_MEMBERSHIP, &imreq, sizeof(imreq)); if (err < 0) { ESP_LOGE(TAG, "Failed to set IP_ADD_MEMBERSHIP. Error %d", errno); @@ -236,341 +96,48 @@ void UDPComponent::setup() { this->ipaddrs_.push_back(ipaddr); } if (this->should_listen_) - this->udp_client_.begin(this->port_); + this->udp_client_.begin(this->listen_port_); #endif } -void UDPComponent::init_data_() { - this->data_.clear(); - if (this->rolling_code_enable_) { - add(this->data_, ROLLING_CODE_KEY); - add(this->data_, this->rolling_code_[0]); - add(this->data_, this->rolling_code_[1]); - this->increment_code_(); - } else { - add(this->data_, DATA_KEY); - } - for (auto pkey : this->ping_keys_) { - add(this->data_, PING_KEY); - add(this->data_, pkey.second); - } -} - -void UDPComponent::flush_() { - if (!network::is_connected() || this->data_.empty()) - return; - uint32_t buffer[MAX_PACKET_SIZE / 4]; - memset(buffer, 0, sizeof buffer); - // len must be a multiple of 4 - auto header_len = round4(this->header_.size()) / 4; - auto len = round4(data_.size()) / 4; - memcpy(buffer, this->header_.data(), this->header_.size()); - memcpy(buffer + header_len, this->data_.data(), this->data_.size()); - if (this->is_encrypted_()) { - xxtea::encrypt(buffer + header_len, len, (uint32_t *) this->encryption_key_.data()); - } - auto total_len = (header_len + len) * 4; - this->send_packet_(buffer, total_len); -} - -void UDPComponent::add_binary_data_(uint8_t key, const char *id, bool data) { - auto len = 1 + 1 + 1 + strlen(id); - if (len + this->header_.size() + this->data_.size() > MAX_PACKET_SIZE) { - this->flush_(); - } - add(this->data_, key); - add(this->data_, (uint8_t) data); - add(this->data_, id); -} -void UDPComponent::add_data_(uint8_t key, const char *id, float data) { - FuData udata{.f32 = data}; - this->add_data_(key, id, udata.u32); -} - -void UDPComponent::add_data_(uint8_t key, const char *id, uint32_t data) { - auto len = 4 + 1 + 1 + strlen(id); - if (len + this->header_.size() + this->data_.size() > MAX_PACKET_SIZE) { - this->flush_(); - } - add(this->data_, key); - add(this->data_, data); - add(this->data_, id); -} -void UDPComponent::send_data_(bool all) { - if (!this->should_send_ || !network::is_connected()) - return; - this->init_data_(); -#ifdef USE_SENSOR - for (auto &sensor : this->sensors_) { - if (all || sensor.updated) { - sensor.updated = false; - this->add_data_(SENSOR_KEY, sensor.id, sensor.sensor->get_state()); - } - } -#endif -#ifdef USE_BINARY_SENSOR - for (auto &sensor : this->binary_sensors_) { - if (all || sensor.updated) { - sensor.updated = false; - this->add_binary_data_(BINARY_SENSOR_KEY, sensor.id, sensor.sensor->state); - } - } -#endif - this->flush_(); - this->updated_ = false; - this->resend_data_ = false; -} - -void UDPComponent::update() { - this->updated_ = true; - this->resend_data_ = this->should_send_; - auto now = millis() / 1000; - if (this->last_key_time_ + this->ping_pong_recyle_time_ < now) { - this->resend_ping_key_ = this->ping_pong_enable_; - this->last_key_time_ = now; - } -} - void UDPComponent::loop() { - uint8_t buf[MAX_PACKET_SIZE]; + auto buf = std::vector(MAX_PACKET_SIZE); if (this->should_listen_) { for (;;) { #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) - auto len = this->listen_socket_->read(buf, sizeof(buf)); + auto len = this->listen_socket_->read(buf.data(), buf.size()); #endif #ifdef USE_SOCKET_IMPL_LWIP_TCP auto len = this->udp_client_.parsePacket(); if (len > 0) - len = this->udp_client_.read(buf, sizeof(buf)); + len = this->udp_client_.read(buf.data(), buf.size()); #endif - if (len > 0) { - this->process_(buf, len); - continue; - } - break; + if (len <= 0) + break; + buf.resize(len); + ESP_LOGV(TAG, "Received packet of length %zu", len); + this->packet_listeners_.call(buf); } } - if (this->resend_ping_key_) - this->send_ping_pong_request_(); - if (this->updated_) { - this->send_data_(this->resend_data_); - } -} - -void UDPComponent::add_key_(const char *name, uint32_t key) { - if (!this->is_encrypted_()) - return; - if (this->ping_keys_.count(name) == 0 && this->ping_keys_.size() == MAX_PING_KEYS) { - ESP_LOGW(TAG, "Ping key from %s discarded", name); - return; - } - this->ping_keys_[name] = key; - this->resend_data_ = true; - ESP_LOGV(TAG, "Ping key from %s now %X", name, (unsigned) key); -} - -void UDPComponent::process_ping_request_(const char *name, uint8_t *ptr, size_t len) { - if (len != 4) { - ESP_LOGW(TAG, "Bad ping request"); - return; - } - auto key = get_uint32(ptr); - this->add_key_(name, key); - ESP_LOGV(TAG, "Updated ping key for %s to %08X", name, (unsigned) key); -} - -static bool process_rolling_code(Provider &provider, uint8_t *&buf, const uint8_t *end) { - if (end - buf < 8) - return false; - auto code0 = get_uint32(buf); - auto code1 = get_uint32(buf); - if (code1 < provider.last_code[1] || (code1 == provider.last_code[1] && code0 <= provider.last_code[0])) { - ESP_LOGW(TAG, "Rolling code for %s %08lX:%08lX is old", provider.name, (unsigned long) code1, - (unsigned long) code0); - return false; - } - provider.last_code[0] = code0; - provider.last_code[1] = code1; - return true; -} - -/** - * Process a received packet - */ -void UDPComponent::process_(uint8_t *buf, const size_t len) { - auto ping_key_seen = !this->ping_pong_enable_; - if (len < 8) { - ESP_LOGV(TAG, "Bad length %zu", len); - return; - } - char namebuf[256]{}; - uint8_t byte; - uint8_t *start_ptr = buf; - const uint8_t *end = buf + len; - FuData rdata{}; - auto magic = get_uint16(buf); - if (magic != MAGIC_NUMBER && magic != MAGIC_PING) { - ESP_LOGV(TAG, "Bad magic %X", magic); - return; - } - - auto hlen = *buf++; - if (hlen > len - 3) { - ESP_LOGV(TAG, "Bad hostname length %u > %zu", hlen, len - 3); - return; - } - memcpy(namebuf, buf, hlen); - if (strcmp(this->name_, namebuf) == 0) { - ESP_LOGV(TAG, "Ignoring our own data"); - return; - } - buf += hlen; - if (magic == MAGIC_PING) { - this->process_ping_request_(namebuf, buf, end - buf); - return; - } - if (round4(len) != len) { - ESP_LOGW(TAG, "Bad length %zu", len); - return; - } - hlen = round4(hlen + 3); - buf = start_ptr + hlen; - if (buf == end) { - ESP_LOGV(TAG, "No data after header"); - return; - } - - if (this->providers_.count(namebuf) == 0) { - ESP_LOGVV(TAG, "Unknown hostname %s", namebuf); - return; - } - auto &provider = this->providers_[namebuf]; - // if encryption not used with this host, ping check is pointless since it would be easily spoofed. - if (provider.encryption_key.empty()) - ping_key_seen = true; - - ESP_LOGV(TAG, "Found hostname %s", namebuf); -#ifdef USE_SENSOR - auto &sensors = this->remote_sensors_[namebuf]; -#endif -#ifdef USE_BINARY_SENSOR - auto &binary_sensors = this->remote_binary_sensors_[namebuf]; -#endif - - if (!provider.encryption_key.empty()) { - xxtea::decrypt((uint32_t *) buf, (end - buf) / 4, (uint32_t *) provider.encryption_key.data()); - } - byte = *buf++; - if (byte == ROLLING_CODE_KEY) { - if (!process_rolling_code(provider, buf, end)) - return; - } else if (byte != DATA_KEY) { - ESP_LOGV(TAG, "Expected rolling_key or data_key, got %X", byte); - return; - } - while (buf < end) { - byte = *buf++; - if (byte == ZERO_FILL_KEY) - continue; - if (byte == PING_KEY) { - if (end - buf < 4) { - ESP_LOGV(TAG, "PING_KEY requires 4 more bytes"); - return; - } - auto key = get_uint32(buf); - if (key == this->ping_key_) { - ping_key_seen = true; - ESP_LOGV(TAG, "Found good ping key %X", (unsigned) key); - } else { - ESP_LOGV(TAG, "Unknown ping key %X", (unsigned) key); - } - continue; - } - if (!ping_key_seen) { - ESP_LOGW(TAG, "Ping key not seen"); - this->resend_ping_key_ = true; - break; - } - if (byte == BINARY_SENSOR_KEY) { - if (end - buf < 3) { - ESP_LOGV(TAG, "Binary sensor key requires at least 3 more bytes"); - return; - } - rdata.u32 = *buf++; - } else if (byte == SENSOR_KEY) { - if (end - buf < 6) { - ESP_LOGV(TAG, "Sensor key requires at least 6 more bytes"); - return; - } - rdata.u32 = get_uint32(buf); - } else { - ESP_LOGW(TAG, "Unknown key byte %X", byte); - return; - } - - hlen = *buf++; - if (end - buf < hlen) { - ESP_LOGV(TAG, "Name length of %u not available", hlen); - return; - } - memset(namebuf, 0, sizeof namebuf); - memcpy(namebuf, buf, hlen); - ESP_LOGV(TAG, "Found sensor key %d, id %s, data %lX", byte, namebuf, (unsigned long) rdata.u32); - buf += hlen; -#ifdef USE_SENSOR - if (byte == SENSOR_KEY && sensors.count(namebuf) != 0) - sensors[namebuf]->publish_state(rdata.f32); -#endif -#ifdef USE_BINARY_SENSOR - if (byte == BINARY_SENSOR_KEY && binary_sensors.count(namebuf) != 0) - binary_sensors[namebuf]->publish_state(rdata.u32 != 0); -#endif - } } void UDPComponent::dump_config() { ESP_LOGCONFIG(TAG, "UDP:"); - ESP_LOGCONFIG(TAG, " Port: %u", this->port_); - ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(this->is_encrypted_())); - ESP_LOGCONFIG(TAG, " Ping-pong: %s", YESNO(this->ping_pong_enable_)); + ESP_LOGCONFIG(TAG, " Listen Port: %u", this->listen_port_); + ESP_LOGCONFIG(TAG, " Broadcast Port: %u", this->broadcast_port_); for (const auto &address : this->addresses_) ESP_LOGCONFIG(TAG, " Address: %s", address.c_str()); if (this->listen_address_.has_value()) { ESP_LOGCONFIG(TAG, " Listen address: %s", this->listen_address_.value().str().c_str()); } -#ifdef USE_SENSOR - for (auto sensor : this->sensors_) - ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.id); -#endif -#ifdef USE_BINARY_SENSOR - for (auto sensor : this->binary_sensors_) - ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.id); -#endif - for (const auto &host : this->providers_) { - ESP_LOGCONFIG(TAG, " Remote host: %s", host.first.c_str()); - ESP_LOGCONFIG(TAG, " Encrypted: %s", YESNO(!host.second.encryption_key.empty())); -#ifdef USE_SENSOR - for (const auto &sensor : this->remote_sensors_[host.first.c_str()]) - ESP_LOGCONFIG(TAG, " Sensor: %s", sensor.first.c_str()); -#endif -#ifdef USE_BINARY_SENSOR - for (const auto &sensor : this->remote_binary_sensors_[host.first.c_str()]) - ESP_LOGCONFIG(TAG, " Binary Sensor: %s", sensor.first.c_str()); -#endif - } + ESP_LOGCONFIG(TAG, " Broadcasting: %s", YESNO(this->should_broadcast_)); + ESP_LOGCONFIG(TAG, " Listening: %s", YESNO(this->should_listen_)); } -void UDPComponent::increment_code_() { - if (this->rolling_code_enable_) { - if (++this->rolling_code_[0] == 0) { - this->rolling_code_[1]++; - this->pref_.save(&this->rolling_code_[1]); - } - } -} -void UDPComponent::send_packet_(void *data, size_t len) { + +void UDPComponent::send_packet(const uint8_t *data, size_t size) { #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) for (const auto &saddr : this->sockaddrs_) { - auto result = this->broadcast_socket_->sendto(data, len, 0, &saddr, sizeof(saddr)); + auto result = this->broadcast_socket_->sendto(data, size, 0, &saddr, sizeof(saddr)); if (result < 0) ESP_LOGW(TAG, "sendto() error %d", errno); } @@ -578,8 +145,8 @@ void UDPComponent::send_packet_(void *data, size_t len) { #ifdef USE_SOCKET_IMPL_LWIP_TCP auto iface = IPAddress(0, 0, 0, 0); for (const auto &saddr : this->ipaddrs_) { - if (this->udp_client_.beginPacketMulticast(saddr, this->port_, iface, 128) != 0) { - this->udp_client_.write((const uint8_t *) data, len); + if (this->udp_client_.beginPacketMulticast(saddr, this->broadcast_port_, iface, 128) != 0) { + this->udp_client_.write(data, size); auto result = this->udp_client_.endPacket(); if (result == 0) ESP_LOGW(TAG, "udp.write() error"); @@ -587,18 +154,7 @@ void UDPComponent::send_packet_(void *data, size_t len) { } #endif } - -void UDPComponent::send_ping_pong_request_() { - if (!this->ping_pong_enable_ || !network::is_connected()) - return; - this->ping_key_ = random_uint32(); - this->ping_header_.clear(); - add(this->ping_header_, MAGIC_PING); - add(this->ping_header_, this->name_); - add(this->ping_header_, this->ping_key_); - this->send_packet_(this->ping_header_.data(), this->ping_header_.size()); - this->resend_ping_key_ = false; - ESP_LOGV(TAG, "Sent new ping request %08X", (unsigned) this->ping_key_); -} } // namespace udp } // namespace esphome + +#endif diff --git a/esphome/components/udp/udp_component.h b/esphome/components/udp/udp_component.h index 02f998ded7..065789ae28 100644 --- a/esphome/components/udp/udp_component.h +++ b/esphome/components/udp/udp_component.h @@ -1,13 +1,8 @@ #pragma once -#include "esphome/core/component.h" +#include "esphome/core/defines.h" +#ifdef USE_NETWORK #include "esphome/components/network/ip_address.h" -#ifdef USE_SENSOR -#include "esphome/components/sensor/sensor.h" -#endif -#ifdef USE_BINARY_SENSOR -#include "esphome/components/binary_sensor/binary_sensor.h" -#endif #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) #include "esphome/components/socket/socket.h" #endif @@ -15,116 +10,35 @@ #include #endif #include -#include namespace esphome { namespace udp { -struct Provider { - std::vector encryption_key; - const char *name; - uint32_t last_code[2]; -}; - -#ifdef USE_SENSOR -struct Sensor { - sensor::Sensor *sensor; - const char *id; - bool updated; -}; -#endif -#ifdef USE_BINARY_SENSOR -struct BinarySensor { - binary_sensor::BinarySensor *sensor; - const char *id; - bool updated; -}; -#endif - -class UDPComponent : public PollingComponent { +static const size_t MAX_PACKET_SIZE = 508; +class UDPComponent : public Component { public: + void add_address(const char *addr) { this->addresses_.emplace_back(addr); } + void set_listen_address(const char *listen_addr) { this->listen_address_ = network::IPAddress(listen_addr); } + void set_listen_port(uint16_t port) { this->listen_port_ = port; } + void set_broadcast_port(uint16_t port) { this->broadcast_port_ = port; } + void set_should_broadcast() { this->should_broadcast_ = true; } + void set_should_listen() { this->should_listen_ = true; } + void add_listener(std::function &)> &&listener) { + this->packet_listeners_.add(std::move(listener)); + } void setup() override; void loop() override; - void update() override; void dump_config() override; - -#ifdef USE_SENSOR - void add_sensor(const char *id, sensor::Sensor *sensor) { - Sensor st{sensor, id, true}; - this->sensors_.push_back(st); - } - void add_remote_sensor(const char *hostname, const char *remote_id, sensor::Sensor *sensor) { - this->add_provider(hostname); - this->remote_sensors_[hostname][remote_id] = sensor; - } -#endif -#ifdef USE_BINARY_SENSOR - void add_binary_sensor(const char *id, binary_sensor::BinarySensor *sensor) { - BinarySensor st{sensor, id, true}; - this->binary_sensors_.push_back(st); - } - - void add_remote_binary_sensor(const char *hostname, const char *remote_id, binary_sensor::BinarySensor *sensor) { - this->add_provider(hostname); - this->remote_binary_sensors_[hostname][remote_id] = sensor; - } -#endif - void add_address(const char *addr) { this->addresses_.emplace_back(addr); } -#ifdef USE_NETWORK - void set_listen_address(const char *listen_addr) { this->listen_address_ = network::IPAddress(listen_addr); } -#endif - void set_port(uint16_t port) { this->port_ = port; } - float get_setup_priority() const override { return setup_priority::AFTER_WIFI; } - - void add_provider(const char *hostname) { - if (this->providers_.count(hostname) == 0) { - Provider provider; - provider.encryption_key = std::vector{}; - provider.last_code[0] = 0; - provider.last_code[1] = 0; - provider.name = hostname; - this->providers_[hostname] = provider; -#ifdef USE_SENSOR - this->remote_sensors_[hostname] = std::map(); -#endif -#ifdef USE_BINARY_SENSOR - this->remote_binary_sensors_[hostname] = std::map(); -#endif - } - } - - void set_encryption_key(std::vector key) { this->encryption_key_ = std::move(key); } - void set_rolling_code_enable(bool enable) { this->rolling_code_enable_ = enable; } - void set_ping_pong_enable(bool enable) { this->ping_pong_enable_ = enable; } - void set_ping_pong_recycle_time(uint32_t recycle_time) { this->ping_pong_recyle_time_ = recycle_time; } - void set_provider_encryption(const char *name, std::vector key) { - this->providers_[name].encryption_key = std::move(key); - } + void send_packet(const uint8_t *data, size_t size); + void send_packet(const std::vector &buf) { this->send_packet(buf.data(), buf.size()); } + float get_setup_priority() const override { return setup_priority::AFTER_WIFI; }; protected: - void send_data_(bool all); - void process_(uint8_t *buf, size_t len); - void flush_(); - void add_data_(uint8_t key, const char *id, float data); - void add_data_(uint8_t key, const char *id, uint32_t data); - void increment_code_(); - void add_binary_data_(uint8_t key, const char *id, bool data); - void init_data_(); - - bool updated_{}; - uint16_t port_{18511}; - uint32_t ping_key_{}; - uint32_t rolling_code_[2]{}; - bool rolling_code_enable_{}; - bool ping_pong_enable_{}; - uint32_t ping_pong_recyle_time_{}; - uint32_t last_key_time_{}; - bool resend_ping_key_{}; - bool resend_data_{}; - bool should_send_{}; - const char *name_{}; + uint16_t listen_port_{}; + uint16_t broadcast_port_{}; + bool should_broadcast_{}; bool should_listen_{}; - ESPPreferenceObject pref_; + CallbackManager &)> packet_listeners_{}; #if defined(USE_SOCKET_IMPL_BSD_SOCKETS) || defined(USE_SOCKET_IMPL_LWIP_SOCKETS) std::unique_ptr broadcast_socket_ = nullptr; @@ -135,32 +49,11 @@ class UDPComponent : public PollingComponent { std::vector ipaddrs_{}; WiFiUDP udp_client_{}; #endif - std::vector encryption_key_{}; std::vector addresses_{}; -#ifdef USE_SENSOR - std::vector sensors_{}; - std::map> remote_sensors_{}; -#endif -#ifdef USE_BINARY_SENSOR - std::vector binary_sensors_{}; - std::map> remote_binary_sensors_{}; -#endif -#ifdef USE_NETWORK optional listen_address_{}; -#endif - std::map providers_{}; - std::vector ping_header_{}; - std::vector header_{}; - std::vector data_{}; - std::map ping_keys_{}; - void add_key_(const char *name, uint32_t key); - void send_ping_pong_request_(); - void send_packet_(void *data, size_t len); - void process_ping_request_(const char *name, uint8_t *ptr, size_t len); - - inline bool is_encrypted_() { return !this->encryption_key_.empty(); } }; } // namespace udp } // namespace esphome +#endif diff --git a/esphome/components/update/__init__.py b/esphome/components/update/__init__.py index 4729d954ee..c2654520fd 100644 --- a/esphome/components/update/__init__.py +++ b/esphome/components/update/__init__.py @@ -6,6 +6,7 @@ from esphome.const import ( CONF_DEVICE_CLASS, CONF_ENTITY_CATEGORY, CONF_FORCE_UPDATE, + CONF_ICON, CONF_ID, CONF_MQTT_ID, CONF_WEB_SERVER, @@ -14,6 +15,7 @@ from esphome.const import ( ENTITY_CATEGORY_CONFIG, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity CODEOWNERS = ["@jesserockz"] @@ -38,7 +40,7 @@ DEVICE_CLASSES = [ CONF_ON_UPDATE_AVAILABLE = "on_update_available" -UPDATE_SCHEMA = ( +_UPDATE_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( @@ -56,6 +58,34 @@ UPDATE_SCHEMA = ( ) +def update_schema( + class_: MockObjClass = cv.UNDEFINED, + *, + icon: str = cv.UNDEFINED, + device_class: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, +) -> cv.Schema: + schema = {} + + if class_ is not cv.UNDEFINED: + schema[cv.GenerateID()] = cv.declare_id(class_) + + for key, default, validator in [ + (CONF_ICON, icon, cv.icon), + (CONF_DEVICE_CLASS, device_class, cv.one_of(*DEVICE_CLASSES, lower=True)), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _UPDATE_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +UPDATE_SCHEMA = update_schema() +UPDATE_SCHEMA.add_extra(cv.deprecated_schema_constant("update")) + + async def setup_update_core_(var, config): await setup_entity(var, config) diff --git a/esphome/components/uptime/text_sensor/__init__.py b/esphome/components/uptime/text_sensor/__init__.py index e4a7ac6517..6b91b526c0 100644 --- a/esphome/components/uptime/text_sensor/__init__.py +++ b/esphome/components/uptime/text_sensor/__init__.py @@ -1,19 +1,59 @@ import esphome.codegen as cg from esphome.components import text_sensor import esphome.config_validation as cv -from esphome.const import ENTITY_CATEGORY_DIAGNOSTIC, ICON_TIMER +from esphome.const import ( + CONF_FORMAT, + CONF_HOURS, + CONF_ID, + CONF_MINUTES, + CONF_SECONDS, + ENTITY_CATEGORY_DIAGNOSTIC, + ICON_TIMER, +) uptime_ns = cg.esphome_ns.namespace("uptime") UptimeTextSensor = uptime_ns.class_( "UptimeTextSensor", text_sensor.TextSensor, cg.PollingComponent ) -CONFIG_SCHEMA = text_sensor.text_sensor_schema( - UptimeTextSensor, - icon=ICON_TIMER, - entity_category=ENTITY_CATEGORY_DIAGNOSTIC, -).extend(cv.polling_component_schema("30s")) + +CONF_SEPARATOR = "separator" +CONF_DAYS = "days" +CONF_EXPAND = "expand" + +CONFIG_SCHEMA = ( + text_sensor.text_sensor_schema( + UptimeTextSensor, + icon=ICON_TIMER, + entity_category=ENTITY_CATEGORY_DIAGNOSTIC, + ) + .extend( + { + cv.Optional(CONF_FORMAT, default={}): cv.Schema( + { + cv.Optional(CONF_DAYS, default="d"): cv.string_strict, + cv.Optional(CONF_HOURS, default="h"): cv.string_strict, + cv.Optional(CONF_MINUTES, default="m"): cv.string_strict, + cv.Optional(CONF_SECONDS, default="s"): cv.string_strict, + cv.Optional(CONF_SEPARATOR, default=""): cv.string_strict, + cv.Optional(CONF_EXPAND, default=False): cv.boolean, + } + ) + } + ) + .extend(cv.polling_component_schema("30s")) +) async def to_code(config): - var = await text_sensor.new_text_sensor(config) + format = config[CONF_FORMAT] + var = cg.new_Pvariable( + config[CONF_ID], + format[CONF_DAYS], + format[CONF_HOURS], + format[CONF_MINUTES], + format[CONF_SECONDS], + format[CONF_SEPARATOR], + format[CONF_EXPAND], + ) + await text_sensor.register_text_sensor(var, config) await cg.register_component(var, config) diff --git a/esphome/components/uptime/text_sensor/uptime_text_sensor.cpp b/esphome/components/uptime/text_sensor/uptime_text_sensor.cpp index 409af6e4ff..94585379fe 100644 --- a/esphome/components/uptime/text_sensor/uptime_text_sensor.cpp +++ b/esphome/components/uptime/text_sensor/uptime_text_sensor.cpp @@ -16,6 +16,11 @@ void UptimeTextSensor::setup() { this->update(); } +void UptimeTextSensor::insert_buffer_(std::string &buffer, const char *key, unsigned value) const { + buffer.insert(0, this->separator_); + buffer.insert(0, str_sprintf("%u%s", value, key)); +} + void UptimeTextSensor::update() { auto now = millis(); // get whole seconds since last update. Note that even if the millis count has overflowed between updates, @@ -32,25 +37,25 @@ void UptimeTextSensor::update() { unsigned remainder = uptime % 60; uptime /= 60; if (interval < 30) { - buffer.insert(0, str_sprintf("%us", remainder)); - if (uptime == 0) + this->insert_buffer_(buffer, this->seconds_text_, remainder); + if (!this->expand_ && uptime == 0) break; } remainder = uptime % 60; uptime /= 60; if (interval < 1800) { - buffer.insert(0, str_sprintf("%um", remainder)); - if (uptime == 0) + this->insert_buffer_(buffer, this->minutes_text_, remainder); + if (!this->expand_ && uptime == 0) break; } remainder = uptime % 24; uptime /= 24; if (interval < 12 * 3600) { - buffer.insert(0, str_sprintf("%uh", remainder)); - if (uptime == 0) + this->insert_buffer_(buffer, this->hours_text_, remainder); + if (!this->expand_ && uptime == 0) break; } - buffer.insert(0, str_sprintf("%ud", (unsigned) uptime)); + this->insert_buffer_(buffer, this->days_text_, (unsigned) uptime); break; } this->publish_state(buffer); diff --git a/esphome/components/uptime/text_sensor/uptime_text_sensor.h b/esphome/components/uptime/text_sensor/uptime_text_sensor.h index 5719ef38a2..8dd058998c 100644 --- a/esphome/components/uptime/text_sensor/uptime_text_sensor.h +++ b/esphome/components/uptime/text_sensor/uptime_text_sensor.h @@ -10,13 +10,32 @@ namespace uptime { class UptimeTextSensor : public text_sensor::TextSensor, public PollingComponent { public: + UptimeTextSensor(const char *days_text, const char *hours_text, const char *minutes_text, const char *seconds_text, + const char *separator, bool expand) + : days_text_(days_text), + hours_text_(hours_text), + minutes_text_(minutes_text), + seconds_text_(seconds_text), + separator_(separator), + expand_(expand) {} void update() override; void dump_config() override; void setup() override; float get_setup_priority() const override; + void set_days(const char *days_text) { this->days_text_ = days_text; } + void set_hours(const char *hours_text) { this->hours_text_ = hours_text; } + void set_minutes(const char *minutes_text) { this->minutes_text_ = minutes_text; } + void set_seconds(const char *seconds_text) { this->seconds_text_ = seconds_text; } protected: + void insert_buffer_(std::string &buffer, const char *key, unsigned value) const; + const char *days_text_; + const char *hours_text_; + const char *minutes_text_; + const char *seconds_text_; + const char *separator_; + bool expand_{}; uint32_t uptime_{0}; // uptime in seconds, will overflow after 136 years uint32_t last_ms_{0}; }; diff --git a/esphome/components/valve/__init__.py b/esphome/components/valve/__init__.py index e55bb522de..f3c0353777 100644 --- a/esphome/components/valve/__init__.py +++ b/esphome/components/valve/__init__.py @@ -5,6 +5,8 @@ from esphome.components import mqtt, web_server import esphome.config_validation as cv from esphome.const import ( CONF_DEVICE_CLASS, + CONF_ENTITY_CATEGORY, + CONF_ICON, CONF_ID, CONF_MQTT_ID, CONF_ON_OPEN, @@ -20,6 +22,7 @@ from esphome.const import ( DEVICE_CLASS_WATER, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity IS_PLATFORM_COMPONENT = True @@ -71,7 +74,7 @@ ValveClosedTrigger = valve_ns.class_( CONF_ON_CLOSED = "on_closed" -VALVE_SCHEMA = ( +_VALVE_SCHEMA = ( cv.ENTITY_BASE_SCHEMA.extend(web_server.WEBSERVER_SORTING_SCHEMA) .extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA) .extend( @@ -100,7 +103,35 @@ VALVE_SCHEMA = ( ) -async def setup_valve_core_(var, config): +def valve_schema( + class_: MockObjClass = cv.UNDEFINED, + *, + device_class: str = cv.UNDEFINED, + entity_category: str = cv.UNDEFINED, + icon: str = cv.UNDEFINED, +) -> cv.Schema: + schema = {} + + if class_ is not cv.UNDEFINED: + schema[cv.GenerateID()] = cv.declare_id(class_) + + for key, default, validator in [ + (CONF_DEVICE_CLASS, device_class, cv.one_of(*DEVICE_CLASSES, lower=True)), + (CONF_ENTITY_CATEGORY, entity_category, cv.entity_category), + (CONF_ICON, icon, cv.icon), + ]: + if default is not cv.UNDEFINED: + schema[cv.Optional(key, default=default)] = validator + + return _VALVE_SCHEMA.extend(schema) + + +# Remove before 2025.11.0 +VALVE_SCHEMA = valve_schema() +VALVE_SCHEMA.add_extra(cv.deprecated_schema_constant("valve")) + + +async def _setup_valve_core(var, config): await setup_entity(var, config) if device_class_config := config.get(CONF_DEVICE_CLASS): @@ -132,7 +163,7 @@ async def register_valve(var, config): if not CORE.has_id(config[CONF_ID]): var = cg.Pvariable(config[CONF_ID], var) cg.add(cg.App.register_valve(var)) - await setup_valve_core_(var, config) + await _setup_valve_core(var, config) async def new_valve(config, *args): diff --git a/esphome/components/valve/automation.h b/esphome/components/valve/automation.h index 24c94a5570..f2c06270c0 100644 --- a/esphome/components/valve/automation.h +++ b/esphome/components/valve/automation.h @@ -1,7 +1,7 @@ #pragma once -#include "esphome/core/component.h" #include "esphome/core/automation.h" +#include "esphome/core/component.h" #include "valve.h" namespace esphome { @@ -67,24 +67,6 @@ template class ControlAction : public Action { Valve *valve_; }; -template class ValvePublishAction : public Action { - public: - ValvePublishAction(Valve *valve) : valve_(valve) {} - TEMPLATABLE_VALUE(float, position) - TEMPLATABLE_VALUE(ValveOperation, current_operation) - - void play(Ts... x) override { - if (this->position_.has_value()) - this->valve_->position = this->position_.value(x...); - if (this->current_operation_.has_value()) - this->valve_->current_operation = this->current_operation_.value(x...); - this->valve_->publish_state(); - } - - protected: - Valve *valve_; -}; - template class ValveIsOpenCondition : public Condition { public: ValveIsOpenCondition(Valve *valve) : valve_(valve) {} diff --git a/esphome/components/vl53l0x/sensor.py b/esphome/components/vl53l0x/sensor.py index 8055d5ff77..583d6ccca9 100644 --- a/esphome/components/vl53l0x/sensor.py +++ b/esphome/components/vl53l0x/sensor.py @@ -20,6 +20,7 @@ VL53L0XSensor = vl53l0x_ns.class_( CONF_SIGNAL_RATE_LIMIT = "signal_rate_limit" CONF_LONG_RANGE = "long_range" +CONF_TIMING_BUDGET = "timing_budget" def check_keys(obj): @@ -54,6 +55,13 @@ CONFIG_SCHEMA = cv.All( cv.Optional(CONF_LONG_RANGE, default=False): cv.boolean, cv.Optional(CONF_TIMEOUT, default="10ms"): check_timeout, cv.Optional(CONF_ENABLE_PIN): pins.gpio_output_pin_schema, + cv.Optional(CONF_TIMING_BUDGET): cv.All( + cv.positive_time_period_microseconds, + cv.Range( + min=cv.TimePeriod(microseconds=20000), + max=cv.TimePeriod(microseconds=4294967295), + ), + ), } ) .extend(cv.polling_component_schema("60s")) @@ -73,4 +81,7 @@ async def to_code(config): enable = await cg.gpio_pin_expression(config[CONF_ENABLE_PIN]) cg.add(var.set_enable_pin(enable)) + if timing_budget := config.get(CONF_TIMING_BUDGET): + cg.add(var.set_timing_budget(timing_budget)) + await i2c.register_i2c_device(var, config) diff --git a/esphome/components/vl53l0x/vl53l0x_sensor.cpp b/esphome/components/vl53l0x/vl53l0x_sensor.cpp index b07779a653..d0b7116eb8 100644 --- a/esphome/components/vl53l0x/vl53l0x_sensor.cpp +++ b/esphome/components/vl53l0x/vl53l0x_sensor.cpp @@ -28,6 +28,7 @@ void VL53L0XSensor::dump_config() { LOG_PIN(" Enable Pin: ", this->enable_pin_); } ESP_LOGCONFIG(TAG, " Timeout: %u%s", this->timeout_us_, this->timeout_us_ > 0 ? "us" : " (no timeout)"); + ESP_LOGCONFIG(TAG, " Timing Budget %uus ", this->measurement_timing_budget_us_); } void VL53L0XSensor::setup() { @@ -230,7 +231,10 @@ void VL53L0XSensor::setup() { reg(0x84) &= ~0x10; reg(0x0B) = 0x01; - measurement_timing_budget_us_ = get_measurement_timing_budget_(); + if (this->measurement_timing_budget_us_ == 0) { + this->measurement_timing_budget_us_ = this->get_measurement_timing_budget_(); + } + reg(0x01) = 0xE8; set_measurement_timing_budget_(measurement_timing_budget_us_); reg(0x01) = 0x01; diff --git a/esphome/components/vl53l0x/vl53l0x_sensor.h b/esphome/components/vl53l0x/vl53l0x_sensor.h index 971fb458bb..dd76e8e0ab 100644 --- a/esphome/components/vl53l0x/vl53l0x_sensor.h +++ b/esphome/components/vl53l0x/vl53l0x_sensor.h @@ -39,6 +39,7 @@ class VL53L0XSensor : public sensor::Sensor, public PollingComponent, public i2c void set_long_range(bool long_range) { long_range_ = long_range; } void set_timeout_us(uint32_t timeout_us) { this->timeout_us_ = timeout_us; } void set_enable_pin(GPIOPin *enable) { this->enable_pin_ = enable; } + void set_timing_budget(uint32_t timing_budget) { this->measurement_timing_budget_us_ = timing_budget; } protected: uint32_t get_measurement_timing_budget_(); @@ -59,7 +60,7 @@ class VL53L0XSensor : public sensor::Sensor, public PollingComponent, public i2c float signal_rate_limit_; bool long_range_; GPIOPin *enable_pin_{nullptr}; - uint32_t measurement_timing_budget_us_; + uint32_t measurement_timing_budget_us_{0}; bool initiated_read_{false}; bool waiting_for_interrupt_{false}; uint8_t stop_variable_; diff --git a/esphome/components/voice_assistant/__init__.py b/esphome/components/voice_assistant/__init__.py index a4fb572208..b9309ab422 100644 --- a/esphome/components/voice_assistant/__init__.py +++ b/esphome/components/voice_assistant/__init__.py @@ -1,7 +1,7 @@ from esphome import automation from esphome.automation import register_action, register_condition import esphome.codegen as cg -from esphome.components import media_player, microphone, speaker +from esphome.components import media_player, micro_wake_word, microphone, speaker import esphome.config_validation as cv from esphome.const import ( CONF_ID, @@ -41,6 +41,7 @@ CONF_AUTO_GAIN = "auto_gain" CONF_NOISE_SUPPRESSION_LEVEL = "noise_suppression_level" CONF_VOLUME_MULTIPLIER = "volume_multiplier" +CONF_MICRO_WAKE_WORD = "micro_wake_word" CONF_WAKE_WORD = "wake_word" CONF_CONVERSATION_TIMEOUT = "conversation_timeout" @@ -88,14 +89,22 @@ CONFIG_SCHEMA = cv.All( cv.Schema( { cv.GenerateID(): cv.declare_id(VoiceAssistant), - cv.GenerateID(CONF_MICROPHONE): cv.use_id(microphone.Microphone), - cv.Exclusive(CONF_SPEAKER, "output"): cv.use_id(speaker.Speaker), + cv.Optional( + CONF_MICROPHONE, default={} + ): microphone.microphone_source_schema( + min_bits_per_sample=16, + max_bits_per_sample=16, + min_channels=1, + max_channels=1, + ), cv.Exclusive(CONF_MEDIA_PLAYER, "output"): cv.use_id( media_player.MediaPlayer ), + cv.Exclusive(CONF_SPEAKER, "output"): cv.use_id(speaker.Speaker), cv.Optional(CONF_USE_WAKE_WORD, default=False): cv.boolean, - cv.Optional(CONF_VAD_THRESHOLD): cv.All( - cv.requires_component("esp_adf"), cv.only_with_esp_idf, cv.uint8_t + cv.Optional(CONF_MICRO_WAKE_WORD): cv.use_id(micro_wake_word.MicroWakeWord), + cv.Optional(CONF_VAD_THRESHOLD): cv.invalid( + "VAD threshold is no longer supported, as it requires the deprecated esp_adf external component. Use an i2s_audio microphone/speaker instead. Additionally, you may need to configure the audio_adc and audio_dac components depending on your hardware." ), cv.Optional(CONF_NOISE_SUPPRESSION_LEVEL, default=0): cv.int_range(0, 4), cv.Optional(CONF_AUTO_GAIN, default="0dBFS"): cv.All( @@ -163,22 +172,39 @@ CONFIG_SCHEMA = cv.All( tts_stream_validate, ) +FINAL_VALIDATE_SCHEMA = cv.All( + cv.Schema( + { + cv.Optional( + CONF_MICROPHONE + ): microphone.final_validate_microphone_source_schema( + "voice_assistant", sample_rate=16000 + ), + }, + extra=cv.ALLOW_EXTRA, + ), +) + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) - mic = await cg.get_variable(config[CONF_MICROPHONE]) - cg.add(var.set_microphone(mic)) + mic_source = await microphone.microphone_source_to_code(config[CONF_MICROPHONE]) + cg.add(var.set_microphone_source(mic_source)) - if CONF_SPEAKER in config: - spkr = await cg.get_variable(config[CONF_SPEAKER]) - cg.add(var.set_speaker(spkr)) + if CONF_MICRO_WAKE_WORD in config: + mww = await cg.get_variable(config[CONF_MICRO_WAKE_WORD]) + cg.add(var.set_micro_wake_word(mww)) if CONF_MEDIA_PLAYER in config: mp = await cg.get_variable(config[CONF_MEDIA_PLAYER]) cg.add(var.set_media_player(mp)) + if CONF_SPEAKER in config: + spkr = await cg.get_variable(config[CONF_SPEAKER]) + cg.add(var.set_speaker(spkr)) + cg.add(var.set_use_wake_word(config[CONF_USE_WAKE_WORD])) if (vad_threshold := config.get(CONF_VAD_THRESHOLD)) is not None: diff --git a/esphome/components/voice_assistant/voice_assistant.cpp b/esphome/components/voice_assistant/voice_assistant.cpp index 4b02867967..1aafea7d85 100644 --- a/esphome/components/voice_assistant/voice_assistant.cpp +++ b/esphome/components/voice_assistant/voice_assistant.cpp @@ -18,14 +18,25 @@ static const char *const TAG = "voice_assistant"; #endif static const size_t SAMPLE_RATE_HZ = 16000; -static const size_t INPUT_BUFFER_SIZE = 32 * SAMPLE_RATE_HZ / 1000; // 32ms * 16kHz / 1000ms -static const size_t BUFFER_SIZE = 512 * SAMPLE_RATE_HZ / 1000; -static const size_t SEND_BUFFER_SIZE = INPUT_BUFFER_SIZE * sizeof(int16_t); + +static const size_t RING_BUFFER_SAMPLES = 512 * SAMPLE_RATE_HZ / 1000; // 512 ms * 16 kHz/ 1000 ms +static const size_t RING_BUFFER_SIZE = RING_BUFFER_SAMPLES * sizeof(int16_t); +static const size_t SEND_BUFFER_SAMPLES = 32 * SAMPLE_RATE_HZ / 1000; // 32ms * 16kHz / 1000ms +static const size_t SEND_BUFFER_SIZE = SEND_BUFFER_SAMPLES * sizeof(int16_t); static const size_t RECEIVE_SIZE = 1024; static const size_t SPEAKER_BUFFER_SIZE = 16 * RECEIVE_SIZE; VoiceAssistant::VoiceAssistant() { global_voice_assistant = this; } +void VoiceAssistant::setup() { + this->mic_source_->add_data_callback([this](const std::vector &data) { + std::shared_ptr temp_ring_buffer = this->ring_buffer_; + if (this->ring_buffer_.use_count() > 1) { + temp_ring_buffer->write((void *) data.data(), data.size()); + } + }); +} + float VoiceAssistant::get_setup_priority() const { return setup_priority::AFTER_CONNECTION; } bool VoiceAssistant::start_udp_socket_() { @@ -72,12 +83,8 @@ bool VoiceAssistant::start_udp_socket_() { } bool VoiceAssistant::allocate_buffers_() { - if (this->send_buffer_ != nullptr) { - return true; // Already allocated - } - #ifdef USE_SPEAKER - if (this->speaker_ != nullptr) { + if ((this->speaker_ != nullptr) && (this->speaker_buffer_ == nullptr)) { ExternalRAMAllocator speaker_allocator(ExternalRAMAllocator::ALLOW_FAILURE); this->speaker_buffer_ = speaker_allocator.allocate(SPEAKER_BUFFER_SIZE); if (this->speaker_buffer_ == nullptr) { @@ -87,28 +94,21 @@ bool VoiceAssistant::allocate_buffers_() { } #endif - ExternalRAMAllocator allocator(ExternalRAMAllocator::ALLOW_FAILURE); - this->input_buffer_ = allocator.allocate(INPUT_BUFFER_SIZE); - if (this->input_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate input buffer"); - return false; + if (this->ring_buffer_.use_count() == 0) { + this->ring_buffer_ = RingBuffer::create(RING_BUFFER_SIZE); + if (this->ring_buffer_.use_count() == 0) { + ESP_LOGE(TAG, "Could not allocate ring buffer"); + return false; + } } -#ifdef USE_ESP_ADF - this->vad_instance_ = vad_create(VAD_MODE_4); -#endif - - this->ring_buffer_ = RingBuffer::create(BUFFER_SIZE * sizeof(int16_t)); - if (this->ring_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate ring buffer"); - return false; - } - - ExternalRAMAllocator send_allocator(ExternalRAMAllocator::ALLOW_FAILURE); - this->send_buffer_ = send_allocator.allocate(SEND_BUFFER_SIZE); - if (send_buffer_ == nullptr) { - ESP_LOGW(TAG, "Could not allocate send buffer"); - return false; + if (this->send_buffer_ == nullptr) { + ExternalRAMAllocator send_allocator(ExternalRAMAllocator::ALLOW_FAILURE); + this->send_buffer_ = send_allocator.allocate(SEND_BUFFER_SIZE); + if (send_buffer_ == nullptr) { + ESP_LOGW(TAG, "Could not allocate send buffer"); + return false; + } } return true; @@ -119,10 +119,6 @@ void VoiceAssistant::clear_buffers_() { memset(this->send_buffer_, 0, SEND_BUFFER_SIZE); } - if (this->input_buffer_ != nullptr) { - memset(this->input_buffer_, 0, INPUT_BUFFER_SIZE * sizeof(int16_t)); - } - if (this->ring_buffer_ != nullptr) { this->ring_buffer_->reset(); } @@ -139,26 +135,16 @@ void VoiceAssistant::clear_buffers_() { } void VoiceAssistant::deallocate_buffers_() { - ExternalRAMAllocator send_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); - send_deallocator.deallocate(this->send_buffer_, SEND_BUFFER_SIZE); - this->send_buffer_ = nullptr; + if (this->send_buffer_ != nullptr) { + ExternalRAMAllocator send_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); + send_deallocator.deallocate(this->send_buffer_, SEND_BUFFER_SIZE); + this->send_buffer_ = nullptr; + } - if (this->ring_buffer_ != nullptr) { + if (this->ring_buffer_.use_count() > 0) { this->ring_buffer_.reset(); - this->ring_buffer_ = nullptr; } -#ifdef USE_ESP_ADF - if (this->vad_instance_ != nullptr) { - vad_destroy(this->vad_instance_); - this->vad_instance_ = nullptr; - } -#endif - - ExternalRAMAllocator input_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); - input_deallocator.deallocate(this->input_buffer_, INPUT_BUFFER_SIZE); - this->input_buffer_ = nullptr; - #ifdef USE_SPEAKER if ((this->speaker_ != nullptr) && (this->speaker_buffer_ != nullptr)) { ExternalRAMAllocator speaker_deallocator(ExternalRAMAllocator::ALLOW_FAILURE); @@ -173,26 +159,10 @@ void VoiceAssistant::reset_conversation_id() { ESP_LOGD(TAG, "reset conversation ID"); } -int VoiceAssistant::read_microphone_() { - size_t bytes_read = 0; - if (this->mic_->is_running()) { // Read audio into input buffer - bytes_read = this->mic_->read(this->input_buffer_, INPUT_BUFFER_SIZE * sizeof(int16_t)); - if (bytes_read == 0) { - memset(this->input_buffer_, 0, INPUT_BUFFER_SIZE * sizeof(int16_t)); - return 0; - } - // Write audio into ring buffer - this->ring_buffer_->write((void *) this->input_buffer_, bytes_read); - } else { - ESP_LOGD(TAG, "microphone not running"); - } - return bytes_read; -} - void VoiceAssistant::loop() { if (this->api_client_ == nullptr && this->state_ != State::IDLE && this->state_ != State::STOP_MICROPHONE && this->state_ != State::STOPPING_MICROPHONE) { - if (this->mic_->is_running() || this->state_ == State::STARTING_MICROPHONE) { + if (this->mic_source_->is_running() || this->state_ == State::STARTING_MICROPHONE) { this->set_state_(State::STOP_MICROPHONE, State::IDLE); } else { this->set_state_(State::IDLE, State::IDLE); @@ -206,16 +176,9 @@ void VoiceAssistant::loop() { case State::IDLE: { if (this->continuous_ && this->desired_state_ == State::IDLE) { this->idle_trigger_->trigger(); -#ifdef USE_ESP_ADF - if (this->use_wake_word_) { - this->set_state_(State::START_MICROPHONE, State::WAIT_FOR_VAD); - } else -#endif - { - this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); - } + this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); } else { - this->high_freq_.stop(); + this->deallocate_buffers_(); } break; } @@ -230,53 +193,20 @@ void VoiceAssistant::loop() { } this->clear_buffers_(); - this->mic_->start(); - this->high_freq_.start(); + this->mic_source_->start(); this->set_state_(State::STARTING_MICROPHONE); break; } case State::STARTING_MICROPHONE: { - if (this->mic_->is_running()) { + if (this->mic_source_->is_running()) { this->set_state_(this->desired_state_); } break; } -#ifdef USE_ESP_ADF - case State::WAIT_FOR_VAD: { - this->read_microphone_(); - ESP_LOGD(TAG, "Waiting for speech..."); - this->set_state_(State::WAITING_FOR_VAD); - break; - } - case State::WAITING_FOR_VAD: { - size_t bytes_read = this->read_microphone_(); - if (bytes_read > 0) { - vad_state_t vad_state = - vad_process(this->vad_instance_, this->input_buffer_, SAMPLE_RATE_HZ, VAD_FRAME_LENGTH_MS); - if (vad_state == VAD_SPEECH) { - if (this->vad_counter_ < this->vad_threshold_) { - this->vad_counter_++; - } else { - ESP_LOGD(TAG, "VAD detected speech"); - this->set_state_(State::START_PIPELINE, State::STREAMING_MICROPHONE); - - // Reset for next time - this->vad_counter_ = 0; - } - } else { - if (this->vad_counter_ > 0) { - this->vad_counter_--; - } - } - } - break; - } -#endif case State::START_PIPELINE: { - this->read_microphone_(); ESP_LOGD(TAG, "Requesting start..."); uint32_t flags = 0; - if (this->use_wake_word_) + if (!this->continue_conversation_ && this->use_wake_word_) flags |= api::enums::VOICE_ASSISTANT_REQUEST_USE_WAKE_WORD; if (this->silence_detection_) flags |= api::enums::VOICE_ASSISTANT_REQUEST_USE_VAD; @@ -306,11 +236,9 @@ void VoiceAssistant::loop() { break; } case State::STARTING_PIPELINE: { - this->read_microphone_(); break; // State changed when udp server port received } case State::STREAMING_MICROPHONE: { - this->read_microphone_(); size_t available = this->ring_buffer_->available(); while (available >= SEND_BUFFER_SIZE) { size_t read_bytes = this->ring_buffer_->read((void *) this->send_buffer_, SEND_BUFFER_SIZE, 0); @@ -334,8 +262,8 @@ void VoiceAssistant::loop() { break; } case State::STOP_MICROPHONE: { - if (this->mic_->is_running()) { - this->mic_->stop(); + if (this->mic_source_->is_running()) { + this->mic_source_->stop(); this->set_state_(State::STOPPING_MICROPHONE); } else { this->set_state_(this->desired_state_); @@ -343,7 +271,7 @@ void VoiceAssistant::loop() { break; } case State::STOPPING_MICROPHONE: { - if (this->mic_->is_stopped()) { + if (this->mic_source_->is_stopped()) { this->set_state_(this->desired_state_); } break; @@ -387,6 +315,25 @@ void VoiceAssistant::loop() { #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { playing = (this->media_player_->state == media_player::MediaPlayerState::MEDIA_PLAYER_STATE_ANNOUNCING); + + if (playing && this->media_player_wait_for_announcement_start_) { + // Announcement has started playing, wait for it to finish + this->media_player_wait_for_announcement_start_ = false; + this->media_player_wait_for_announcement_end_ = true; + } + + if (!playing && this->media_player_wait_for_announcement_end_) { + // Announcement has finished playing + this->media_player_wait_for_announcement_end_ = false; + this->cancel_timeout("playing"); + ESP_LOGD(TAG, "Announcement finished playing"); + this->set_state_(State::RESPONSE_FINISHED, State::RESPONSE_FINISHED); + + api::VoiceAssistantAnnounceFinished msg; + msg.success = true; + this->api_client_->send_voice_assistant_announce_finished(msg); + break; + } } #endif if (playing) { @@ -417,7 +364,11 @@ void VoiceAssistant::loop() { this->tts_stream_end_trigger_->trigger(); } #endif - this->set_state_(State::IDLE, State::IDLE); + if (this->continue_conversation_) { + this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); + } else { + this->set_state_(State::IDLE, State::IDLE); + } break; } default: @@ -527,7 +478,7 @@ void VoiceAssistant::start_streaming() { ESP_LOGD(TAG, "Client started, streaming microphone"); this->audio_mode_ = AUDIO_MODE_API; - if (this->mic_->is_running()) { + if (this->mic_source_->is_running()) { this->set_state_(State::STREAMING_MICROPHONE, State::STREAMING_MICROPHONE); } else { this->set_state_(State::START_MICROPHONE, State::STREAMING_MICROPHONE); @@ -557,7 +508,7 @@ void VoiceAssistant::start_streaming(struct sockaddr_storage *addr, uint16_t por return; } - if (this->mic_->is_running()) { + if (this->mic_source_->is_running()) { this->set_state_(State::STREAMING_MICROPHONE, State::STREAMING_MICROPHONE); } else { this->set_state_(State::START_MICROPHONE, State::STREAMING_MICROPHONE); @@ -574,19 +525,14 @@ void VoiceAssistant::request_start(bool continuous, bool silence_detection) { if (this->state_ == State::IDLE) { this->continuous_ = continuous; this->silence_detection_ = silence_detection; -#ifdef USE_ESP_ADF - if (this->use_wake_word_) { - this->set_state_(State::START_MICROPHONE, State::WAIT_FOR_VAD); - } else -#endif - { - this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); - } + + this->set_state_(State::START_MICROPHONE, State::START_PIPELINE); } } void VoiceAssistant::request_stop() { this->continuous_ = false; + this->continue_conversation_ = false; switch (this->state_) { case State::IDLE: @@ -611,6 +557,16 @@ void VoiceAssistant::request_stop() { this->signal_stop_(); break; case State::STREAMING_RESPONSE: +#ifdef USE_MEDIA_PLAYER + // Stop any ongoing media player announcement + if (this->media_player_ != nullptr) { + this->media_player_->make_call() + .set_command(media_player::MEDIA_PLAYER_COMMAND_STOP) + .set_announcement(true) + .perform(); + } +#endif + break; case State::RESPONSE_FINISHED: break; // Let the incoming audio stream finish then it will go to idle. } @@ -628,9 +584,9 @@ void VoiceAssistant::signal_stop_() { } void VoiceAssistant::start_playback_timeout_() { - this->set_timeout("playing", 100, [this]() { + this->set_timeout("playing", 2000, [this]() { this->cancel_timeout("speaker-timeout"); - this->set_state_(State::IDLE, State::IDLE); + this->set_state_(State::RESPONSE_FINISHED, State::RESPONSE_FINISHED); api::VoiceAssistantAnnounceFinished msg; msg.success = true; @@ -679,6 +635,8 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { for (auto arg : msg.data) { if (arg.name == "conversation_id") { this->conversation_id_ = std::move(arg.value); + } else if (arg.name == "continue_conversation") { + this->continue_conversation_ = (arg.value == "1"); } } this->defer([this]() { this->intent_end_trigger_->trigger(); }); @@ -722,6 +680,9 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { this->media_player_->make_call().set_media_url(url).set_announcement(true).perform(); + + this->media_player_wait_for_announcement_start_ = true; + this->media_player_wait_for_announcement_end_ = false; // Start the playback timeout, as the media player state isn't immediately updated this->start_playback_timeout_(); } @@ -734,21 +695,13 @@ void VoiceAssistant::on_event(const api::VoiceAssistantEventResponse &msg) { } case api::enums::VOICE_ASSISTANT_RUN_END: { ESP_LOGD(TAG, "Assist Pipeline ended"); - if ((this->state_ == State::STARTING_PIPELINE) || (this->state_ == State::AWAITING_RESPONSE)) { - // Pipeline ended before starting microphone - // Or there wasn't a TTS start event ("nevermind") + if ((this->state_ == State::START_PIPELINE) || (this->state_ == State::STARTING_PIPELINE) || + (this->state_ == State::STREAMING_MICROPHONE)) { + // Microphone is running, stop it + this->set_state_(State::STOP_MICROPHONE, State::IDLE); + } else if (this->state_ == State::AWAITING_RESPONSE) { + // No TTS start event ("nevermind") this->set_state_(State::IDLE, State::IDLE); - } else if (this->state_ == State::STREAMING_MICROPHONE) { - this->ring_buffer_->reset(); -#ifdef USE_ESP_ADF - if (this->use_wake_word_) { - // No need to stop the microphone since we didn't use the speaker - this->set_state_(State::WAIT_FOR_VAD, State::WAITING_FOR_VAD); - } else -#endif - { - this->set_state_(State::IDLE, State::IDLE); - } } this->defer([this]() { this->end_trigger_->trigger(); }); break; @@ -888,14 +841,87 @@ void VoiceAssistant::on_announce(const api::VoiceAssistantAnnounceRequest &msg) #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { this->tts_start_trigger_->trigger(msg.text); - this->media_player_->make_call().set_media_url(msg.media_id).set_announcement(true).perform(); - this->set_state_(State::STREAMING_RESPONSE, State::STREAMING_RESPONSE); + if (!msg.preannounce_media_id.empty()) { + this->media_player_->make_call().set_media_url(msg.preannounce_media_id).set_announcement(true).perform(); + } + // Enqueueing a URL with an empty playlist will still play the file immediately + this->media_player_->make_call() + .set_command(media_player::MEDIA_PLAYER_COMMAND_ENQUEUE) + .set_media_url(msg.media_id) + .set_announcement(true) + .perform(); + this->continue_conversation_ = msg.start_conversation; + + this->media_player_wait_for_announcement_start_ = true; + this->media_player_wait_for_announcement_end_ = false; + // Start the playback timeout, as the media player state isn't immediately updated + this->start_playback_timeout_(); + + if (this->continuous_) { + this->set_state_(State::STOP_MICROPHONE, State::STREAMING_RESPONSE); + } else { + this->set_state_(State::STREAMING_RESPONSE, State::STREAMING_RESPONSE); + } + this->tts_end_trigger_->trigger(msg.media_id); this->end_trigger_->trigger(); } #endif } +void VoiceAssistant::on_set_configuration(const std::vector &active_wake_words) { +#ifdef USE_MICRO_WAKE_WORD + if (this->micro_wake_word_) { + // Disable all wake words first + for (auto &model : this->micro_wake_word_->get_wake_words()) { + model->disable(); + } + + // Enable only active wake words + for (auto ww_id : active_wake_words) { + for (auto &model : this->micro_wake_word_->get_wake_words()) { + if (model->get_id() == ww_id) { + model->enable(); + ESP_LOGD(TAG, "Enabled wake word: %s (id=%s)", model->get_wake_word().c_str(), model->get_id().c_str()); + } + } + } + } +#endif +}; + +const Configuration &VoiceAssistant::get_configuration() { + this->config_.available_wake_words.clear(); + this->config_.active_wake_words.clear(); + +#ifdef USE_MICRO_WAKE_WORD + if (this->micro_wake_word_) { + this->config_.max_active_wake_words = 1; + + for (auto &model : this->micro_wake_word_->get_wake_words()) { + if (model->is_enabled()) { + this->config_.active_wake_words.push_back(model->get_id()); + } + + WakeWord wake_word; + wake_word.id = model->get_id(); + wake_word.wake_word = model->get_wake_word(); + for (const auto &lang : model->get_trained_languages()) { + wake_word.trained_languages.push_back(lang); + } + this->config_.available_wake_words.push_back(std::move(wake_word)); + } + } else { +#endif + // No microWakeWord + this->config_.max_active_wake_words = 0; +#ifdef USE_MICRO_WAKE_WORD + } +#endif + + return this->config_; +}; + VoiceAssistant *global_voice_assistant = nullptr; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) } // namespace voice_assistant diff --git a/esphome/components/voice_assistant/voice_assistant.h b/esphome/components/voice_assistant/voice_assistant.h index 12124c1486..865731522f 100644 --- a/esphome/components/voice_assistant/voice_assistant.h +++ b/esphome/components/voice_assistant/voice_assistant.h @@ -11,18 +11,17 @@ #include "esphome/components/api/api_connection.h" #include "esphome/components/api/api_pb2.h" -#include "esphome/components/microphone/microphone.h" -#ifdef USE_SPEAKER -#include "esphome/components/speaker/speaker.h" -#endif +#include "esphome/components/microphone/microphone_source.h" #ifdef USE_MEDIA_PLAYER #include "esphome/components/media_player/media_player.h" #endif -#include "esphome/components/socket/socket.h" - -#ifdef USE_ESP_ADF -#include +#ifdef USE_MICRO_WAKE_WORD +#include "esphome/components/micro_wake_word/micro_wake_word.h" #endif +#ifdef USE_SPEAKER +#include "esphome/components/speaker/speaker.h" +#endif +#include "esphome/components/socket/socket.h" #include #include @@ -41,6 +40,7 @@ enum VoiceAssistantFeature : uint32_t { FEATURE_API_AUDIO = 1 << 2, FEATURE_TIMERS = 1 << 3, FEATURE_ANNOUNCE = 1 << 4, + FEATURE_START_CONVERSATION = 1 << 5, }; enum class State { @@ -95,12 +95,16 @@ class VoiceAssistant : public Component { VoiceAssistant(); void loop() override; + void setup() override; float get_setup_priority() const override; void start_streaming(); void start_streaming(struct sockaddr_storage *addr, uint16_t port); void failed_to_start(); - void set_microphone(microphone::Microphone *mic) { this->mic_ = mic; } + void set_microphone_source(microphone::MicrophoneSource *mic_source) { this->mic_source_ = mic_source; } +#ifdef USE_MICRO_WAKE_WORD + void set_micro_wake_word(micro_wake_word::MicroWakeWord *mww) { this->micro_wake_word_ = mww; } +#endif #ifdef USE_SPEAKER void set_speaker(speaker::Speaker *speaker) { this->speaker_ = speaker; @@ -140,6 +144,7 @@ class VoiceAssistant : public Component { #ifdef USE_MEDIA_PLAYER if (this->media_player_ != nullptr) { flags |= VoiceAssistantFeature::FEATURE_ANNOUNCE; + flags |= VoiceAssistantFeature::FEATURE_START_CONVERSATION; } #endif @@ -153,17 +158,14 @@ class VoiceAssistant : public Component { void on_audio(const api::VoiceAssistantAudio &msg); void on_timer_event(const api::VoiceAssistantTimerEventResponse &msg); void on_announce(const api::VoiceAssistantAnnounceRequest &msg); - void on_set_configuration(const std::vector &active_wake_words){}; - const Configuration &get_configuration() { return this->config_; }; + void on_set_configuration(const std::vector &active_wake_words); + const Configuration &get_configuration(); bool is_running() const { return this->state_ != State::IDLE; } void set_continuous(bool continuous) { this->continuous_ = continuous; } bool is_continuous() const { return this->continuous_; } void set_use_wake_word(bool use_wake_word) { this->use_wake_word_ = use_wake_word; } -#ifdef USE_ESP_ADF - void set_vad_threshold(uint8_t vad_threshold) { this->vad_threshold_ = vad_threshold; } -#endif void set_noise_suppression_level(uint8_t noise_suppression_level) { this->noise_suppression_level_ = noise_suppression_level; @@ -212,7 +214,6 @@ class VoiceAssistant : public Component { void clear_buffers_(); void deallocate_buffers_(); - int read_microphone_(); void set_state_(State state); void set_state_(State state, State desired_state); void signal_stop_(); @@ -254,7 +255,7 @@ class VoiceAssistant : public Component { bool has_timers_{false}; bool timer_tick_running_{false}; - microphone::Microphone *mic_{nullptr}; + microphone::MicrophoneSource *mic_source_{nullptr}; #ifdef USE_SPEAKER void write_speaker_(); speaker::Speaker *speaker_{nullptr}; @@ -267,6 +268,8 @@ class VoiceAssistant : public Component { #endif #ifdef USE_MEDIA_PLAYER media_player::MediaPlayer *media_player_{nullptr}; + bool media_player_wait_for_announcement_start_{false}; + bool media_player_wait_for_announcement_end_{false}; #endif bool local_output_{false}; @@ -275,14 +278,7 @@ class VoiceAssistant : public Component { std::string wake_word_{""}; - HighFrequencyLoopRequester high_freq_; - -#ifdef USE_ESP_ADF - vad_handle_t vad_instance_; - uint8_t vad_threshold_{5}; - uint8_t vad_counter_{0}; -#endif - std::unique_ptr ring_buffer_; + std::shared_ptr ring_buffer_; bool use_wake_word_; uint8_t noise_suppression_level_; @@ -291,11 +287,12 @@ class VoiceAssistant : public Component { uint32_t conversation_timeout_; uint8_t *send_buffer_{nullptr}; - int16_t *input_buffer_{nullptr}; bool continuous_{false}; bool silence_detection_; + bool continue_conversation_{false}; + State state_{State::IDLE}; State desired_state_{State::IDLE}; @@ -304,6 +301,10 @@ class VoiceAssistant : public Component { bool start_udp_socket_(); Configuration config_{}; + +#ifdef USE_MICRO_WAKE_WORD + micro_wake_word::MicroWakeWord *micro_wake_word_{nullptr}; +#endif }; template class StartAction : public Action, public Parented { diff --git a/esphome/components/waveshare_epaper/display.py b/esphome/components/waveshare_epaper/display.py index 8acb6ac68f..cea0b2be5e 100644 --- a/esphome/components/waveshare_epaper/display.py +++ b/esphome/components/waveshare_epaper/display.py @@ -70,12 +70,16 @@ WaveshareEPaper4P2InBV2 = waveshare_epaper_ns.class_( WaveshareEPaper4P2InBV2BWR = waveshare_epaper_ns.class_( "WaveshareEPaper4P2InBV2BWR", WaveshareEPaperBWR ) +WaveshareEPaper5P65InF = waveshare_epaper_ns.class_( + "WaveshareEPaper5P65InF", WaveshareEPaper7C +) WaveshareEPaper5P8In = waveshare_epaper_ns.class_( "WaveshareEPaper5P8In", WaveshareEPaper ) WaveshareEPaper5P8InV2 = waveshare_epaper_ns.class_( "WaveshareEPaper5P8InV2", WaveshareEPaper ) +GDEY0583T81 = waveshare_epaper_ns.class_("GDEY0583T81", WaveshareEPaper) WaveshareEPaper7P3InF = waveshare_epaper_ns.class_( "WaveshareEPaper7P3InF", WaveshareEPaper7C ) @@ -150,8 +154,10 @@ MODELS = { "4.20in": ("b", WaveshareEPaper4P2In), "4.20in-bv2": ("b", WaveshareEPaper4P2InBV2), "4.20in-bv2-bwr": ("b", WaveshareEPaper4P2InBV2BWR), + "5.65in-f": ("b", WaveshareEPaper5P65InF), "5.83in": ("b", WaveshareEPaper5P8In), "5.83inv2": ("b", WaveshareEPaper5P8InV2), + "gdey0583t81": ("c", GDEY0583T81), "7.30in-f": ("b", WaveshareEPaper7P3InF), "7.50in": ("b", WaveshareEPaper7P5In), "7.50in-bv2": ("b", WaveshareEPaper7P5InBV2), diff --git a/esphome/components/waveshare_epaper/waveshare_epaper.cpp b/esphome/components/waveshare_epaper/waveshare_epaper.cpp index 96fc82fcdd..79aae70e41 100644 --- a/esphome/components/waveshare_epaper/waveshare_epaper.cpp +++ b/esphome/components/waveshare_epaper/waveshare_epaper.cpp @@ -258,6 +258,47 @@ void WaveshareEPaper7C::fill(Color color) { } } } +void WaveshareEPaper7C::send_buffers_() { + if (this->buffers_[0] == nullptr) { + ESP_LOGE(TAG, "Buffer unavailable!"); + return; + } + + uint32_t small_buffer_length = this->get_buffer_length_() / NUM_BUFFERS; + uint8_t byte_to_send; + for (auto &buffer : this->buffers_) { + for (uint32_t buffer_pos = 0; buffer_pos < small_buffer_length; buffer_pos += 3) { + std::bitset<24> triplet = + buffer[buffer_pos + 0] << 16 | buffer[buffer_pos + 1] << 8 | buffer[buffer_pos + 2] << 0; + // 8 bitset<3> are stored in 3 bytes + // |aaabbbaa|abbbaaab|bbaaabbb| + // | byte 1 | byte 2 | byte 3 | + byte_to_send = ((triplet >> 17).to_ulong() & 0b01110000) | ((triplet >> 18).to_ulong() & 0b00000111); + this->data(byte_to_send); + + byte_to_send = ((triplet >> 11).to_ulong() & 0b01110000) | ((triplet >> 12).to_ulong() & 0b00000111); + this->data(byte_to_send); + + byte_to_send = ((triplet >> 5).to_ulong() & 0b01110000) | ((triplet >> 6).to_ulong() & 0b00000111); + this->data(byte_to_send); + + byte_to_send = ((triplet << 1).to_ulong() & 0b01110000) | ((triplet << 0).to_ulong() & 0b00000111); + this->data(byte_to_send); + } + App.feed_wdt(); + } +} +void WaveshareEPaper7C::reset_() { + if (this->reset_pin_ != nullptr) { + this->reset_pin_->digital_write(true); + delay(20); + this->reset_pin_->digital_write(false); + delay(1); + this->reset_pin_->digital_write(true); + delay(20); + } +} + void HOT WaveshareEPaper::draw_absolute_pixel_internal(int x, int y, Color color) { if (x >= this->get_width_internal() || y >= this->get_height_internal() || x < 0 || y < 0) return; @@ -963,7 +1004,7 @@ void WaveshareEPaper1P54InBV2::initialize() { this->command(0x4E); // set RAM x address count to 0; this->data(0x00); - this->command(0x4F); // set RAM y address count to 0X199; + this->command(0x4F); // set RAM y address count to 0x199; this->data(0xC7); this->data(0x00); @@ -1837,7 +1878,7 @@ void GDEY029T94::initialize() { this->command(0x4E); // set RAM x address count to 0; this->data(0x00); - this->command(0x4F); // set RAM y address count to 0X199; + this->command(0x4F); // set RAM y address count to 0x199; this->command(0x00); this->command(0x00); this->wait_until_idle_(); @@ -2029,7 +2070,7 @@ void GDEW029T5::init_full_() { this->init_display_(); this->command(0x82); // vcom_DC setting this->data(0x08); - this->command(0X50); // VCOM AND DATA INTERVAL SETTING + this->command(0x50); // VCOM AND DATA INTERVAL SETTING this->data(0x97); // WBmode:VBDF 17|D7 VBDW 97 VBDB 57 WBRmode:VBDF F7 VBDW 77 VBDB 37 VBDR B7 this->command(0x20); this->write_lut_(LUT_20_VCOMDC_29_5, sizeof(LUT_20_VCOMDC_29_5)); @@ -2049,7 +2090,7 @@ void GDEW029T5::init_partial_() { this->init_display_(); this->command(0x82); // vcom_DC setting this->data(0x08); - this->command(0X50); // VCOM AND DATA INTERVAL SETTING + this->command(0x50); // VCOM AND DATA INTERVAL SETTING this->data(0x17); // WBmode:VBDF 17|D7 VBDW 97 VBDB 57 WBRmode:VBDF F7 VBDW 77 VBDB 37 VBDR B7 this->command(0x20); this->write_lut_(LUT_20_VCOMDC_PARTIAL_29_5, sizeof(LUT_20_VCOMDC_PARTIAL_29_5)); @@ -2897,6 +2938,223 @@ void WaveshareEPaper5P8InV2::dump_config() { LOG_UPDATE_INTERVAL(this); } +// ======================================================== +// Good Display 5.83in black/white GDEY0583T81 +// Product page: +// - https://www.good-display.com/product/440.html +// - https://www.seeedstudio.com/5-83-Monochrome-ePaper-Display-with-648x480-Pixels-p-5785.html +// Datasheet: +// - +// https://www.good-display.com/public/html/pdfjs/viewer/viewernew.html?file=https://v4.cecdn.yun300.cn/100001_1909185148/GDEY0583T81-new.pdf +// - https://v4.cecdn.yun300.cn/100001_1909185148/GDEY0583T81-new.pdf +// Reference code from GoodDisplay: +// - https://www.good-display.com/companyfile/903.html +// ======================================================== + +void GDEY0583T81::initialize() { + // Allocate buffer for old data for partial updates + RAMAllocator allocator{}; + this->old_buffer_ = allocator.allocate(this->get_buffer_length_()); + if (this->old_buffer_ == nullptr) { + ESP_LOGE(TAG, "Could not allocate old buffer for display!"); + return; + } + memset(this->old_buffer_, 0xFF, this->get_buffer_length_()); + + this->init_full_(); + + this->wait_until_idle_(); + + this->deep_sleep(); +} + +void GDEY0583T81::power_on_() { + if (!this->power_is_on_) { + this->command(0x04); + this->wait_until_idle_(); + } + this->power_is_on_ = true; + this->is_deep_sleep_ = false; +} + +void GDEY0583T81::power_off_() { + this->command(0x02); + this->wait_until_idle_(); + this->power_is_on_ = false; +} + +void GDEY0583T81::deep_sleep() { + if (this->is_deep_sleep_) { + return; + } + + // VCOM and data interval setting (CDI) + this->command(0x50); + this->data(0xf7); + + this->power_off_(); + delay(10); + + // Deep sleep (DSLP) + this->command(0x07); + this->data(0xA5); + this->is_deep_sleep_ = true; +} + +void GDEY0583T81::reset_() { + if (this->reset_pin_ != nullptr) { + this->reset_pin_->digital_write(false); + delay(10); + this->reset_pin_->digital_write(true); + delay(10); + } +} + +// Initialize for full screen update in fast mode +void GDEY0583T81::init_full_() { + this->init_display_(); + + // Based on the GD sample code + // VCOM and data interval setting (CDI) + this->command(0x50); + this->data(0x29); + this->data(0x07); + + // Cascade Setting (CCSET) + this->command(0xE0); + this->data(0x02); + + // Force Temperature (TSSET) + this->command(0xE5); + this->data(0x5A); +} + +// Initialize for a partial update of the full screen +void GDEY0583T81::init_partial_() { + this->init_display_(); + + // Cascade Setting (CCSET) + this->command(0xE0); + this->data(0x02); + + // Force Temperature (TSSET) + this->command(0xE5); + this->data(0x6E); +} + +void GDEY0583T81::init_display_() { + this->reset_(); + + // Panel Setting (PSR) + this->command(0x00); + // Sets: REG=0, LUT from OTP (set by CDI) + // KW/R=1, Sets KW mode (Black/White) + // as opposed to the default KWR mode (Black/White/Red) + // UD=1, Gate Scan Direction, 1 = up (default) + // SHL=1, Source Shift Direction, 1 = right (default) + // SHD_N=1, Booster Switch, 1 = ON (default) + // RST_N=1, Soft reset, 1 = No effect (default) + this->data(0x1F); + + // Resolution setting (TRES) + this->command(0x61); + + // Horizontal display resolution (HRES) + this->data(get_width_internal() / 256); + this->data(get_width_internal() % 256); + + // Vertical display resolution (VRES) + this->data(get_height_internal() / 256); + this->data(get_height_internal() % 256); + + this->power_on_(); +} + +void HOT GDEY0583T81::display() { + bool full_update = this->at_update_ == 0; + if (full_update) { + this->init_full_(); + } else { + this->init_partial_(); + + // VCOM and data interval setting (CDI) + this->command(0x50); + this->data(0xA9); + this->data(0x07); + + // Partial In (PTIN), makes the display enter partial mode + this->command(0x91); + + // Partial Window (PTL) + // We use the full screen as the window + this->command(0x90); + + // Horizontal start/end channel bank (HRST/HRED) + this->data(0); + this->data(0); + this->data((get_width_internal() - 1) / 256); + this->data((get_width_internal() - 1) % 256); + + // Vertical start/end line (VRST/VRED) + this->data(0); + this->data(0); + this->data((get_height_internal() - 1) / 256); + this->data((get_height_internal() - 1) % 256); + + this->data(0x01); + + // Display Start Transmission 1 (DTM1) + // in KW mode this writes "OLD" data to SRAM + this->command(0x10); + this->start_data_(); + this->write_array(this->old_buffer_, this->get_buffer_length_()); + this->end_data_(); + } + + // Display Start Transmission 2 (DTM2) + // in KW mode this writes "NEW" data to SRAM + this->command(0x13); + this->start_data_(); + this->write_array(this->buffer_, this->get_buffer_length_()); + this->end_data_(); + + for (size_t i = 0; i < this->get_buffer_length_(); i++) { + this->old_buffer_[i] = this->buffer_[i]; + } + + // Display Refresh (DRF) + this->command(0x12); + delay(10); + this->wait_until_idle_(); + + if (full_update) { + ESP_LOGD(TAG, "Full update done"); + } else { + // Partial out (PTOUT), makes the display exit partial mode + this->command(0x92); + ESP_LOGD(TAG, "Partial update done, next full update after %d cycles", + this->full_update_every_ - this->at_update_ - 1); + } + + this->at_update_ = (this->at_update_ + 1) % this->full_update_every_; + + this->deep_sleep(); +} + +void GDEY0583T81::set_full_update_every(uint32_t full_update_every) { this->full_update_every_ = full_update_every; } +int GDEY0583T81::get_width_internal() { return 648; } +int GDEY0583T81::get_height_internal() { return 480; } +uint32_t GDEY0583T81::idle_timeout_() { return 5000; } +void GDEY0583T81::dump_config() { + LOG_DISPLAY("", "GoodDisplay E-Paper", this); + ESP_LOGCONFIG(TAG, " Model: 5.83in B/W GDEY0583T81"); + ESP_LOGCONFIG(TAG, " Full Update Every: %" PRIu32, this->full_update_every_); + LOG_PIN(" Reset Pin: ", this->reset_pin_); + LOG_PIN(" DC Pin: ", this->dc_pin_); + LOG_PIN(" Busy Pin: ", this->busy_pin_); + LOG_UPDATE_INTERVAL(this); +} + void WaveshareEPaper7P5InBV2::initialize() { // COMMAND POWER SETTING this->command(0x01); @@ -3307,6 +3565,175 @@ void WaveshareEPaper7P5In::dump_config() { LOG_PIN(" Busy Pin: ", this->busy_pin_); LOG_UPDATE_INTERVAL(this); } + +// Waveshare 5.65F ======================================================== + +namespace cmddata_5P65InF { +// WaveshareEPaper5P65InF commands +// https://www.waveshare.com/wiki/5.65inch_e-Paper_Module_(F) + +// R00H (PSR): Panel setting Register +// UD(1): scan up +// SHL(1) shift right +// SHD_N(1) DC-DC on +// RST_N(1) no reset +static const uint8_t R00_CMD_PSR[] = {0x00, 0xEF, 0x08}; + +// R01H (PWR): Power setting Register +// internal DC-DC power generation +static const uint8_t R01_CMD_PWR[] = {0x01, 0x07, 0x00, 0x00, 0x00}; + +// R02H (POF): Power OFF Command +static const uint8_t R02_CMD_POF[] = {0x02}; + +// R03H (PFS): Power off sequence setting Register +// T_VDS_OFF (00) = 1 frame +static const uint8_t R03_CMD_PFS[] = {0x03, 0x00}; + +// R04H (PON): Power ON Command +static const uint8_t R04_CMD_PON[] = {0x04}; + +// R06h (BTST): Booster Soft Start +static const uint8_t R06_CMD_BTST[] = {0x06, 0xC7, 0xC7, 0x1D}; + +// R07H (DSLP): Deep sleep# +// Note Documentation @ Waveshare shows cmd code as 0x10 in table, but +// 0x10 is DTM1. +static const uint8_t R07_CMD_DSLP[] = {0x07, 0xA5}; + +// R10H (DTM1): Data Start Transmission 1 + +static const uint8_t R10_CMD_DTM1[] = {0x10}; + +// R11H (DSP): Data Stop +static const uint8_t R11_CMD_DSP[] = {0x11}; + +// R12H (DRF): Display Refresh +static const uint8_t R12_CMD_DRF[] = {0x12}; + +// R13H (IPC): Image Process Command +static const uint8_t R13_CMD_IPC[] = {0x13, 0x00}; + +// R30H (PLL): PLL Control +// 0x3C = 50Hz +static const uint8_t R30_CMD_PLL[] = {0x30, 0x3C}; + +// R41H (TSE): Temperature Sensor Enable +// TSE(0) enable, TO(0000) +0 degree offset +static const uint8_t R41_CMD_TSE[] = {0x41, 0x00}; + +// R50H (CDI) VCOM and Data interval setting +// CDI(0111) 10 +// DDX(1), VBD(001) Border output "White" +static const uint8_t R50_CMD_CDI[] = {0x50, 0x37}; + +// R60H (TCON) Gate and Source non overlap period command +// S2G(10) 12 units +// G2S(10) 12 units +static const uint8_t R60_CMD_TCON[] = {0x60, 0x22}; + +// R61H (TRES) Resolution Setting +// 0x258 = 600 +// 0x1C0 = 448 +static const uint8_t R61_CMD_TRES[] = {0x61, 0x02, 0x58, 0x01, 0xC0}; + +// RE3H (PWS) Power Savings +static const uint8_t RE3_CMD_PWS[] = {0xE3, 0xAA}; +} // namespace cmddata_5P65InF + +void WaveshareEPaper5P65InF::initialize() { + if (this->buffers_[0] == nullptr) { + ESP_LOGE(TAG, "Buffer unavailable!"); + return; + } + + this->reset_(); + delay(20); + this->wait_until_(IDLE); + + using namespace cmddata_5P65InF; + + this->cmd_data(R00_CMD_PSR, sizeof(R00_CMD_PSR)); + this->cmd_data(R01_CMD_PWR, sizeof(R01_CMD_PWR)); + this->cmd_data(R03_CMD_PFS, sizeof(R03_CMD_PFS)); + this->cmd_data(R06_CMD_BTST, sizeof(R06_CMD_BTST)); + this->cmd_data(R30_CMD_PLL, sizeof(R30_CMD_PLL)); + this->cmd_data(R41_CMD_TSE, sizeof(R41_CMD_TSE)); + this->cmd_data(R50_CMD_CDI, sizeof(R50_CMD_CDI)); + this->cmd_data(R60_CMD_TCON, sizeof(R60_CMD_TCON)); + this->cmd_data(R61_CMD_TRES, sizeof(R61_CMD_TRES)); + this->cmd_data(RE3_CMD_PWS, sizeof(RE3_CMD_PWS)); + + delay(100); // NOLINT + this->cmd_data(R50_CMD_CDI, sizeof(R50_CMD_CDI)); + + ESP_LOGI(TAG, "Display initialized successfully"); +} + +void HOT WaveshareEPaper5P65InF::display() { + // INITIALIZATION + ESP_LOGI(TAG, "Initialise the display"); + this->initialize(); + + using namespace cmddata_5P65InF; + + // COMMAND DATA START TRANSMISSION + ESP_LOGI(TAG, "Sending data to the display"); + this->cmd_data(R61_CMD_TRES, sizeof(R61_CMD_TRES)); + this->cmd_data(R10_CMD_DTM1, sizeof(R10_CMD_DTM1)); + this->send_buffers_(); + + // COMMAND POWER ON + ESP_LOGI(TAG, "Power on the display"); + this->cmd_data(R04_CMD_PON, sizeof(R04_CMD_PON)); + this->wait_until_(IDLE); + + // COMMAND REFRESH SCREEN + ESP_LOGI(TAG, "Refresh the display"); + this->cmd_data(R12_CMD_DRF, sizeof(R12_CMD_DRF)); + this->wait_until_(IDLE); + + // COMMAND POWER OFF + ESP_LOGI(TAG, "Power off the display"); + this->cmd_data(R02_CMD_POF, sizeof(R02_CMD_POF)); + this->wait_until_(BUSY); + + if (this->deep_sleep_between_updates_) { + ESP_LOGI(TAG, "Set the display to deep sleep"); + this->cmd_data(R07_CMD_DSLP, sizeof(R07_CMD_DSLP)); + } +} + +int WaveshareEPaper5P65InF::get_width_internal() { return 600; } +int WaveshareEPaper5P65InF::get_height_internal() { return 448; } +uint32_t WaveshareEPaper5P65InF::idle_timeout_() { return 35000; } + +void WaveshareEPaper5P65InF::dump_config() { + LOG_DISPLAY("", "Waveshare E-Paper", this); + ESP_LOGCONFIG(TAG, " Model: 5.65in-F"); + LOG_PIN(" Reset Pin: ", this->reset_pin_); + LOG_PIN(" DC Pin: ", this->dc_pin_); + LOG_PIN(" Busy Pin: ", this->busy_pin_); + LOG_UPDATE_INTERVAL(this); +} + +bool WaveshareEPaper5P65InF::wait_until_(WaitForState busy_state) { + if (this->busy_pin_ == nullptr) { + return true; + } + + const uint32_t start = millis(); + while (busy_state != this->busy_pin_->digital_read()) { + if (millis() - start > this->idle_timeout_()) { + ESP_LOGE(TAG, "Timeout while displaying image!"); + return false; + } + App.feed_wdt(); + delay(10); + } + return true; +} + void WaveshareEPaper7P3InF::initialize() { if (this->buffers_[0] == nullptr) { ESP_LOGE(TAG, "Buffer unavailable!"); @@ -3411,11 +3838,6 @@ void WaveshareEPaper7P3InF::initialize() { ESP_LOGI(TAG, "Display initialized successfully"); } void HOT WaveshareEPaper7P3InF::display() { - if (this->buffers_[0] == nullptr) { - ESP_LOGE(TAG, "Buffer unavailable!"); - return; - } - // INITIALIZATION ESP_LOGI(TAG, "Initialise the display"); this->initialize(); @@ -3423,29 +3845,7 @@ void HOT WaveshareEPaper7P3InF::display() { // COMMAND DATA START TRANSMISSION ESP_LOGI(TAG, "Sending data to the display"); this->command(0x10); - uint32_t small_buffer_length = this->get_buffer_length_() / NUM_BUFFERS; - uint8_t byte_to_send; - for (auto &buffer : this->buffers_) { - for (uint32_t buffer_pos = 0; buffer_pos < small_buffer_length; buffer_pos += 3) { - std::bitset<24> triplet = - buffer[buffer_pos + 0] << 16 | buffer[buffer_pos + 1] << 8 | buffer[buffer_pos + 2] << 0; - // 8 bitset<3> are stored in 3 bytes - // |aaabbbaa|abbbaaab|bbaaabbb| - // | byte 1 | byte 2 | byte 3 | - byte_to_send = ((triplet >> 17).to_ulong() & 0b01110000) | ((triplet >> 18).to_ulong() & 0b00000111); - this->data(byte_to_send); - - byte_to_send = ((triplet >> 11).to_ulong() & 0b01110000) | ((triplet >> 12).to_ulong() & 0b00000111); - this->data(byte_to_send); - - byte_to_send = ((triplet >> 5).to_ulong() & 0b01110000) | ((triplet >> 6).to_ulong() & 0b00000111); - this->data(byte_to_send); - - byte_to_send = ((triplet << 1).to_ulong() & 0b01110000) | ((triplet << 0).to_ulong() & 0b00000111); - this->data(byte_to_send); - } - App.feed_wdt(); - } + this->send_buffers_(); // COMMAND POWER ON ESP_LOGI(TAG, "Power on the display"); @@ -3464,9 +3864,11 @@ void HOT WaveshareEPaper7P3InF::display() { this->data(0x00); this->wait_until_idle_(); - ESP_LOGI(TAG, "Set the display to deep sleep"); - this->command(0x07); - this->data(0xA5); + if (this->deep_sleep_between_updates_) { + ESP_LOGI(TAG, "Set the display to deep sleep"); + this->command(0x07); + this->data(0xA5); + } } int WaveshareEPaper7P3InF::get_width_internal() { return 800; } int WaveshareEPaper7P3InF::get_height_internal() { return 480; } @@ -4079,10 +4481,10 @@ void WaveshareEPaper7P5InHDB::initialize() { this->data(0x01); // LUT1, for white this->command(0x18); - this->data(0X80); + this->data(0x80); this->command(0x22); - this->data(0XB1); // Load Temperature and waveform setting. + this->data(0xB1); // Load Temperature and waveform setting. this->command(0x20); diff --git a/esphome/components/waveshare_epaper/waveshare_epaper.h b/esphome/components/waveshare_epaper/waveshare_epaper.h index d6387cd643..74bb153519 100644 --- a/esphome/components/waveshare_epaper/waveshare_epaper.h +++ b/esphome/components/waveshare_epaper/waveshare_epaper.h @@ -94,7 +94,10 @@ class WaveshareEPaper7C : public WaveshareEPaperBase { void draw_absolute_pixel_internal(int x, int y, Color color) override; uint32_t get_buffer_length_() override; void setup() override; + void init_internal_7c_(uint32_t buffer_length); + void send_buffers_(); + void reset_(); static const int NUM_BUFFERS = 10; uint8_t *buffers_[NUM_BUFFERS]; @@ -683,6 +686,63 @@ class WaveshareEPaper5P8InV2 : public WaveshareEPaper { int get_height_internal() override; }; +class GDEY0583T81 : public WaveshareEPaper { + public: + void initialize() override; + + void display() override; + + void dump_config() override; + + void deep_sleep() override; + + void set_full_update_every(uint32_t full_update_every); + + protected: + int get_width_internal() override; + int get_height_internal() override; + uint32_t idle_timeout_() override; + + private: + void power_on_(); + void power_off_(); + void reset_(); + void update_full_(); + void update_part_(); + void init_full_(); + void init_partial_(); + void init_display_(); + + uint32_t full_update_every_{30}; + uint32_t at_update_{0}; + bool power_is_on_{false}; + bool is_deep_sleep_{false}; + uint8_t *old_buffer_{nullptr}; +}; + +class WaveshareEPaper5P65InF : public WaveshareEPaper7C { + public: + void initialize() override; + + void display() override; + + void dump_config() override; + + protected: + int get_width_internal() override; + + int get_height_internal() override; + + uint32_t idle_timeout_() override; + + void deep_sleep() override { ; } + + enum WaitForState { BUSY = true, IDLE = false }; + bool wait_until_(WaitForState state); + + bool deep_sleep_between_updates_{true}; +}; + class WaveshareEPaper7P3InF : public WaveshareEPaper7C { public: void initialize() override; @@ -703,17 +763,6 @@ class WaveshareEPaper7P3InF : public WaveshareEPaper7C { bool wait_until_idle_(); bool deep_sleep_between_updates_{true}; - - void reset_() { - if (this->reset_pin_ != nullptr) { - this->reset_pin_->digital_write(true); - delay(20); - this->reset_pin_->digital_write(false); - delay(1); - this->reset_pin_->digital_write(true); - delay(20); - } - }; }; class WaveshareEPaper7P5In : public WaveshareEPaper { diff --git a/esphome/components/xpt2046/touchscreen/xpt2046.cpp b/esphome/components/xpt2046/touchscreen/xpt2046.cpp index a4e2b84656..aa11ed4b77 100644 --- a/esphome/components/xpt2046/touchscreen/xpt2046.cpp +++ b/esphome/components/xpt2046/touchscreen/xpt2046.cpp @@ -32,7 +32,7 @@ void XPT2046Component::update_touches() { int16_t touch_pressure_1 = this->read_adc_(0xB1 /* touch_pressure_1 */); int16_t touch_pressure_2 = this->read_adc_(0xC1 /* touch_pressure_2 */); - z_raw = touch_pressure_1 + 0Xfff - touch_pressure_2; + z_raw = touch_pressure_1 + 0xfff - touch_pressure_2; ESP_LOGVV(TAG, "Touchscreen Update z = %d", z_raw); touch = (z_raw >= this->threshold_); if (touch) { diff --git a/esphome/config_validation.py b/esphome/config_validation.py index e6927fd20c..88a805591d 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -56,7 +56,6 @@ from esphome.const import ( KEY_CORE, KEY_FRAMEWORK_VERSION, KEY_TARGET_FRAMEWORK, - KEY_TARGET_PLATFORM, PLATFORM_ESP32, PLATFORM_ESP8266, PLATFORM_RP2040, @@ -117,7 +116,7 @@ RequiredFieldInvalid = vol.RequiredFieldInvalid ROOT_CONFIG_PATH = object() RESERVED_IDS = [ - # C++ keywords http://en.cppreference.com/w/cpp/keyword + # C++ keywords https://en.cppreference.com/w/cpp/keyword "alarm", "alignas", "alignof", @@ -1942,70 +1941,28 @@ def platformio_version_constraint(value): def require_framework_version( *, - esp_idf=None, - esp32_arduino=None, - esp8266_arduino=None, - rp2040_arduino=None, - bk72xx_libretiny=None, - host=None, max_version=False, extra_message=None, + **kwargs, ): def validator(value): core_data = CORE.data[KEY_CORE] framework = core_data[KEY_TARGET_FRAMEWORK] - if framework == "esp-idf": - if esp_idf is None: - msg = "This feature is incompatible with esp-idf" - if extra_message: - msg += f". {extra_message}" - raise Invalid(msg) - required = esp_idf - elif CORE.is_bk72xx and framework == "arduino": - if bk72xx_libretiny is None: - msg = "This feature is incompatible with BK72XX" - if extra_message: - msg += f". {extra_message}" - raise Invalid(msg) - required = bk72xx_libretiny - elif CORE.is_esp32 and framework == "arduino": - if esp32_arduino is None: - msg = "This feature is incompatible with ESP32 using arduino framework" - if extra_message: - msg += f". {extra_message}" - raise Invalid(msg) - required = esp32_arduino - elif CORE.is_esp8266 and framework == "arduino": - if esp8266_arduino is None: - msg = "This feature is incompatible with ESP8266" - if extra_message: - msg += f". {extra_message}" - raise Invalid(msg) - required = esp8266_arduino - elif CORE.is_rp2040 and framework == "arduino": - if rp2040_arduino is None: - msg = "This feature is incompatible with RP2040" - if extra_message: - msg += f". {extra_message}" - raise Invalid(msg) - required = rp2040_arduino - elif CORE.is_host and framework == "host": - if host is None: - msg = "This feature is incompatible with host platform" - if extra_message: - msg += f". {extra_message}" - raise Invalid(msg) - required = host - else: - raise Invalid( - f""" - Internal Error: require_framework_version does not support this platform configuration - platform: {core_data[KEY_TARGET_PLATFORM]} - framework: {framework} - Please report this issue on GitHub -> https://github.com/esphome/issues/issues/new?template=bug_report.yml. - """ - ) + if CORE.is_host and framework == "host": + key = "host" + elif framework == "esp-idf": + key = "esp_idf" + else: + key = CORE.target_platform + "_" + framework + + if key not in kwargs: + msg = f"This feature is incompatible with {CORE.target_platform.upper()} using {framework} framework" + if extra_message: + msg += f". {extra_message}" + raise Invalid(msg) + + required = kwargs[key] if max_version: if core_data[KEY_FRAMEWORK_VERSION] > required: @@ -2115,3 +2072,20 @@ def rename_key(old_key, new_key): return config return validator + + +# Remove before 2025.11.0 +def deprecated_schema_constant(entity_type: str): + def validator(config): + _LOGGER.warning( + "Using `%s.%s_SCHEMA` is deprecated and will be removed in ESPHome 2025.11.0. " + "Please use `%s.%s_schema(...)` instead. " + "If you are seeing this, report an issue to the external_component author and ask them to update it.", + entity_type, + entity_type.upper(), + entity_type, + entity_type, + ) + return config + + return validator diff --git a/esphome/const.py b/esphome/const.py index 20d64513c9..0974a673ec 100644 --- a/esphome/const.py +++ b/esphome/const.py @@ -1,6 +1,6 @@ """Constants used by esphome.""" -__version__ = "2025.4.2" +__version__ = "2025.5.0b1" ALLOWED_NAME_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789-_" VALID_SUBSTITUTIONS_CHARACTERS = ( @@ -45,6 +45,8 @@ CONF_ALLOW_OTHER_USES = "allow_other_uses" CONF_ALPHA = "alpha" CONF_ALTITUDE = "altitude" CONF_AMBIENT_LIGHT = "ambient_light" +CONF_AMBIENT_PRESSURE_COMPENSATION = "ambient_pressure_compensation" +CONF_AMBIENT_PRESSURE_COMPENSATION_SOURCE = "ambient_pressure_compensation_source" CONF_AMMONIA = "ammonia" CONF_ANALOG = "analog" CONF_AND = "and" @@ -63,6 +65,7 @@ CONF_AUTH = "auth" CONF_AUTO_CLEAR_ENABLED = "auto_clear_enabled" CONF_AUTO_MODE = "auto_mode" CONF_AUTOCONF = "autoconf" +CONF_AUTOMATIC_SELF_CALIBRATION = "automatic_self_calibration" CONF_AUTOMATION_ID = "automation_id" CONF_AVAILABILITY = "availability" CONF_AWAY = "away" @@ -157,6 +160,7 @@ CONF_CONDITION = "condition" CONF_CONDITION_ID = "condition_id" CONF_CONDUCTIVITY = "conductivity" CONF_CONSTANT_BRIGHTNESS = "constant_brightness" +CONF_CONTINUOUS = "continuous" CONF_CONTRAST = "contrast" CONF_COOL_ACTION = "cool_action" CONF_COOL_DEADBAND = "cool_deadband" @@ -217,7 +221,9 @@ CONF_DIMENSIONS = "dimensions" CONF_DIO_PIN = "dio_pin" CONF_DIR_PIN = "dir_pin" CONF_DIRECTION = "direction" +CONF_DIRECTION_COMMAND_TOPIC = "direction_command_topic" CONF_DIRECTION_OUTPUT = "direction_output" +CONF_DIRECTION_STATE_TOPIC = "direction_state_topic" CONF_DISABLE_CRC = "disable_crc" CONF_DISABLED = "disabled" CONF_DISABLED_BY_DEFAULT = "disabled_by_default" @@ -330,6 +336,7 @@ CONF_FULL_SPECTRUM = "full_spectrum" CONF_FULL_SPECTRUM_COUNTS = "full_spectrum_counts" CONF_FULL_UPDATE_EVERY = "full_update_every" CONF_GAIN = "gain" +CONF_GAIN_FACTOR = "gain_factor" CONF_GAMMA_CORRECT = "gamma_correct" CONF_GAS_RESISTANCE = "gas_resistance" CONF_GATEWAY = "gateway" @@ -401,6 +408,7 @@ CONF_INITIAL_OPTION = "initial_option" CONF_INITIAL_STATE = "initial_state" CONF_INITIAL_VALUE = "initial_value" CONF_INPUT = "input" +CONF_INT_DATAPOINT = "int_datapoint" CONF_INTEGRATION_TIME = "integration_time" CONF_INTENSITY = "intensity" CONF_INTERLOCK = "interlock" @@ -477,6 +485,7 @@ CONF_MAX_VALUE = "max_value" CONF_MAX_VOLTAGE = "max_voltage" CONF_MDNS = "mdns" CONF_MEASUREMENT_DURATION = "measurement_duration" +CONF_MEASUREMENT_MODE = "measurement_mode" CONF_MEASUREMENT_SEQUENCE_NUMBER = "measurement_sequence_number" CONF_MEDIA_PLAYER = "media_player" CONF_MEDIUM = "medium" @@ -795,6 +804,7 @@ CONF_SHUTDOWN_MESSAGE = "shutdown_message" CONF_SIGNAL_STRENGTH = "signal_strength" CONF_SINGLE_LIGHT_ID = "single_light_id" CONF_SIZE = "size" +CONF_SKIP_CERT_CN_CHECK = "skip_cert_cn_check" CONF_SLEEP_DURATION = "sleep_duration" CONF_SLEEP_PIN = "sleep_pin" CONF_SLEEP_WHEN_DONE = "sleep_when_done" @@ -891,6 +901,8 @@ CONF_TIMES = "times" CONF_TIMEZONE = "timezone" CONF_TIMING = "timing" CONF_TO = "to" +CONF_TO_NTC_RESISTANCE = "to_ntc_resistance" +CONF_TO_NTC_TEMPERATURE = "to_ntc_temperature" CONF_TOLERANCE = "tolerance" CONF_TOPIC = "topic" CONF_TOPIC_PREFIX = "topic_prefix" @@ -1080,6 +1092,7 @@ UNIT_KILOWATT = "kW" UNIT_KILOWATT_HOURS = "kWh" UNIT_LITRE = "L" UNIT_LUX = "lx" +UNIT_MEGAJOULE = "MJ" UNIT_METER = "m" UNIT_METER_PER_SECOND_SQUARED = "m/s²" UNIT_MICROAMP = "µA" @@ -1093,6 +1106,7 @@ UNIT_MILLIGRAMS_PER_CUBIC_METER = "mg/m³" UNIT_MILLIMETER = "mm" UNIT_MILLISECOND = "ms" UNIT_MILLISIEMENS_PER_CENTIMETER = "mS/cm" +UNIT_MILLIVOLT = "mV" UNIT_MINUTE = "min" UNIT_OHM = "Ω" UNIT_PARTS_PER_BILLION = "ppb" diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 1a81a6d6cd..3a02c95c82 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -518,6 +518,8 @@ class EsphomeCore: self.verbose = False # Whether ESPHome was started in quiet mode self.quiet = False + # A list of all known ID classes + self.id_classes = {} def reset(self): from esphome.pins import PIN_SCHEMA_REGISTRY diff --git a/esphome/core/application.cpp b/esphome/core/application.cpp index a4550bcd9e..3f5a283fd8 100644 --- a/esphome/core/application.cpp +++ b/esphome/core/application.cpp @@ -70,6 +70,7 @@ void Application::loop() { this->feed_wdt(); for (Component *component : this->looping_components_) { { + this->set_current_component(component); WarnIfComponentBlockingGuard guard{component}; component->call(); } diff --git a/esphome/core/application.h b/esphome/core/application.h index 462beb1f25..e64e2b7655 100644 --- a/esphome/core/application.h +++ b/esphome/core/application.h @@ -97,6 +97,9 @@ class Application { this->compilation_time_ = compilation_time; } + void set_current_component(Component *component) { this->current_component_ = component; } + Component *get_current_component() { return this->current_component_; } + #ifdef USE_BINARY_SENSOR void register_binary_sensor(binary_sensor::BinarySensor *binary_sensor) { this->binary_sensors_.push_back(binary_sensor); @@ -547,6 +550,7 @@ class Application { uint32_t loop_interval_{16}; size_t dump_config_at_{SIZE_MAX}; uint32_t app_state_{0}; + Component *current_component_{nullptr}; }; /// Global storage of Application pointer - only one Application can exist. diff --git a/esphome/core/automation.h b/esphome/core/automation.h index e77e453431..02c9d44f16 100644 --- a/esphome/core/automation.h +++ b/esphome/core/automation.h @@ -1,10 +1,11 @@ #pragma once -#include #include "esphome/core/component.h" -#include "esphome/core/helpers.h" #include "esphome/core/defines.h" +#include "esphome/core/helpers.h" #include "esphome/core/preferences.h" +#include +#include namespace esphome { @@ -27,7 +28,7 @@ template class TemplatableValue { TemplatableValue() : type_(NONE) {} template::value, int> = 0> - TemplatableValue(F value) : type_(VALUE), value_(value) {} + TemplatableValue(F value) : type_(VALUE), value_(std::move(value)) {} template::value, int> = 0> TemplatableValue(F f) : type_(LAMBDA), f_(f) {} diff --git a/esphome/core/component.cpp b/esphome/core/component.cpp index b20964b872..a7e451b93d 100644 --- a/esphome/core/component.cpp +++ b/esphome/core/component.cpp @@ -39,6 +39,9 @@ const uint32_t STATUS_LED_OK = 0x0000; const uint32_t STATUS_LED_WARNING = 0x0100; const uint32_t STATUS_LED_ERROR = 0x0200; +const uint32_t WARN_IF_BLOCKING_OVER_MS = 50U; ///< Initial blocking time allowed without warning +const uint32_t WARN_IF_BLOCKING_INCREMENT_MS = 10U; ///< How long the blocking time must be larger to warn again + uint32_t global_state = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) float Component::get_loop_priority() const { return 0.0f; } @@ -115,6 +118,13 @@ const char *Component::get_component_source() const { return ""; return this->component_source_; } +bool Component::should_warn_of_blocking(uint32_t blocking_time) { + if (blocking_time > this->warn_if_blocking_over_) { + this->warn_if_blocking_over_ = blocking_time + WARN_IF_BLOCKING_INCREMENT_MS; + return true; + } + return false; +} void Component::mark_failed() { ESP_LOGE(TAG, "Component %s was marked as failed.", this->get_component_source()); this->component_state_ &= ~COMPONENT_STATE_MASK; @@ -233,10 +243,16 @@ void PollingComponent::set_update_interval(uint32_t update_interval) { this->upd WarnIfComponentBlockingGuard::WarnIfComponentBlockingGuard(Component *component) : started_(millis()), component_(component) {} WarnIfComponentBlockingGuard::~WarnIfComponentBlockingGuard() { - uint32_t now = millis(); - if (now - started_ > 50) { + uint32_t blocking_time = millis() - this->started_; + bool should_warn; + if (this->component_ != nullptr) { + should_warn = this->component_->should_warn_of_blocking(blocking_time); + } else { + should_warn = blocking_time > WARN_IF_BLOCKING_OVER_MS; + } + if (should_warn) { const char *src = component_ == nullptr ? "" : component_->get_component_source(); - ESP_LOGW(TAG, "Component %s took a long time for an operation (%" PRIu32 " ms).", src, (now - started_)); + ESP_LOGW(TAG, "Component %s took a long time for an operation (%" PRIu32 " ms).", src, blocking_time); ESP_LOGW(TAG, "Components should block for at most 30 ms."); ; } diff --git a/esphome/core/component.h b/esphome/core/component.h index f5c56459b1..412074282d 100644 --- a/esphome/core/component.h +++ b/esphome/core/component.h @@ -65,6 +65,8 @@ extern const uint32_t STATUS_LED_ERROR; enum class RetryResult { DONE, RETRY }; +extern const uint32_t WARN_IF_BLOCKING_OVER_MS; + class Component { public: /** Where the component's initialization should happen. @@ -158,6 +160,8 @@ class Component { */ const char *get_component_source() const; + bool should_warn_of_blocking(uint32_t blocking_time); + protected: friend class Application; @@ -284,6 +288,7 @@ class Component { uint32_t component_state_{0x0000}; ///< State of this component. float setup_priority_override_{NAN}; const char *component_source_{nullptr}; + uint32_t warn_if_blocking_over_{WARN_IF_BLOCKING_OVER_MS}; std::string error_message_{}; }; diff --git a/esphome/core/defines.h b/esphome/core/defines.h index 64de41f23a..8bc554d5f4 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -20,11 +20,6 @@ // Feature flags #define USE_ALARM_CONTROL_PANEL -#define USE_AUDIO_FLAC_SUPPORT -#define USE_AUDIO_MP3_SUPPORT -#define USE_API -#define USE_API_NOISE -#define USE_API_PLAINTEXT #define USE_BINARY_SENSOR #define USE_BUTTON #define USE_CLIMATE @@ -79,20 +74,10 @@ #define USE_LVGL_TEXTAREA #define USE_LVGL_TILEVIEW #define USE_LVGL_TOUCHSCREEN -#define USE_MD5 #define USE_MDNS #define USE_MEDIA_PLAYER -#define USE_MQTT -#define USE_NETWORK #define USE_NEXTION_TFT_UPLOAD #define USE_NUMBER -#define USE_ONLINE_IMAGE_BMP_SUPPORT -#define USE_ONLINE_IMAGE_PNG_SUPPORT -#define USE_ONLINE_IMAGE_JPEG_SUPPORT -#define USE_OTA -#define USE_OTA_PASSWORD -#define USE_OTA_STATE_CALLBACK -#define USE_OTA_VERSION 2 #define USE_OUTPUT #define USE_POWER_SUPPLY #define USE_QR_CODE @@ -107,14 +92,34 @@ #define USE_UART_DEBUGGER #define USE_UPDATE #define USE_VALVE + +// Feature flags which do not work for zephyr +#ifndef USE_ZEPHYR +#define USE_AUDIO_FLAC_SUPPORT +#define USE_AUDIO_MP3_SUPPORT +#define USE_API +#define USE_API_NOISE +#define USE_API_PLAINTEXT +#define USE_MD5 +#define USE_MQTT +#define USE_NETWORK +#define USE_ONLINE_IMAGE_BMP_SUPPORT +#define USE_ONLINE_IMAGE_PNG_SUPPORT +#define USE_ONLINE_IMAGE_JPEG_SUPPORT +#define USE_OTA +#define USE_OTA_PASSWORD +#define USE_OTA_STATE_CALLBACK +#define USE_OTA_VERSION 2 #define USE_WIFI #define USE_WIFI_AP #define USE_WIREGUARD +#endif // Arduino-specific feature flags #ifdef USE_ARDUINO #define USE_PROMETHEUS #define USE_WIFI_WPA2_EAP +#define USE_I2S_LEGACY #endif // IDF-specific feature flags @@ -131,7 +136,6 @@ #define USE_ESP32_BLE_SERVER #define USE_ESP32_CAMERA #define USE_IMPROV -#define USE_MICRO_WAKE_WORD_VAD #define USE_MICROPHONE #define USE_PSRAM #define USE_SOCKET_IMPL_BSD_SOCKETS @@ -148,7 +152,9 @@ #endif #ifdef USE_ESP_IDF -#define USE_ESP_IDF_VERSION_CODE VERSION_CODE(5, 1, 5) +#define USE_ESP_IDF_VERSION_CODE VERSION_CODE(5, 1, 6) +#define USE_MICRO_WAKE_WORD +#define USE_MICRO_WAKE_WORD_VAD #endif #if defined(USE_ESP32_VARIANT_ESP32S2) diff --git a/esphome/core/macros.h b/esphome/core/macros.h index ee53d20ad1..8b2383321b 100644 --- a/esphome/core/macros.h +++ b/esphome/core/macros.h @@ -2,3 +2,7 @@ // Helper macro to define a version code, whose value can be compared against other version codes. #define VERSION_CODE(major, minor, patch) ((major) << 16 | (minor) << 8 | (patch)) + +#ifdef USE_ARDUINO +#include +#endif diff --git a/esphome/core/scheduler.cpp b/esphome/core/scheduler.cpp index 7e83b3b705..b4f617d405 100644 --- a/esphome/core/scheduler.cpp +++ b/esphome/core/scheduler.cpp @@ -1,4 +1,6 @@ #include "scheduler.h" + +#include "application.h" #include "esphome/core/defines.h" #include "esphome/core/log.h" #include "esphome/core/helpers.h" @@ -215,6 +217,7 @@ void HOT Scheduler::call() { this->pop_raw_(); continue; } + App.set_current_component(item->component); #ifdef ESPHOME_DEBUG_SCHEDULER ESP_LOGV(TAG, "Running %s '%s/%s' with interval=%" PRIu32 " next_execution=%" PRIu64 " (now=%" PRIu64 ")", diff --git a/esphome/cpp_generator.py b/esphome/cpp_generator.py index eb0bd25d1d..93ebb4cb95 100644 --- a/esphome/cpp_generator.py +++ b/esphome/cpp_generator.py @@ -789,13 +789,17 @@ class MockObj(Expression): def class_(self, name: str, *parents: "MockObjClass") -> "MockObjClass": op = "" if self.op == "" else "::" - return MockObjClass(f"{self.base}{op}{name}", ".", parents=parents) + result = MockObjClass(f"{self.base}{op}{name}", ".", parents=parents) + CORE.id_classes[str(result)] = result + return result def struct(self, name: str) -> "MockObjClass": return self.class_(name) def enum(self, name: str, is_class: bool = False) -> "MockObj": - return MockObjEnum(enum=name, is_class=is_class, base=self.base, op=self.op) + result = MockObjEnum(enum=name, is_class=is_class, base=self.base, op=self.op) + CORE.id_classes[str(result)] = result + return result def operator(self, name: str) -> "MockObj": """Various other operations. diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 9c20cf4f58..6196e01760 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -38,7 +38,7 @@ import yaml from yaml.nodes import Node from esphome import const, platformio_api, yaml_util -from esphome.helpers import get_bool_env, mkdir_p +from esphome.helpers import get_bool_env, mkdir_p, sort_ip_addresses from esphome.storage_json import ( StorageJSON, archive_storage_path, @@ -336,7 +336,7 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): # Use the IP address if available but only # if the API is loaded and the device is online # since MQTT logging will not work otherwise - port = address_list[0] + port = sort_ip_addresses(address_list)[0] elif ( entry.address and ( @@ -347,7 +347,7 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): and not isinstance(address_list, Exception) ): # If mdns is not available, try to use the DNS cache - port = address_list[0] + port = sort_ip_addresses(address_list)[0] return [ *DASHBOARD_COMMAND, diff --git a/esphome/helpers.py b/esphome/helpers.py index 8aae43c2bb..b649465d69 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -200,6 +200,45 @@ def resolve_ip_address(host, port): return res +def sort_ip_addresses(address_list: list[str]) -> list[str]: + """Takes a list of IP addresses in string form, e.g. from mDNS or MQTT, + and sorts them into the best order to actually try connecting to them. + + This is roughly based on RFC6724 but a lot simpler: First we choose + IPv6 addresses, then Legacy IP addresses, and lowest priority is + link-local IPv6 addresses that don't have a link specified (which + are useless, but mDNS does provide them in that form). Addresses + which cannot be parsed are silently dropped. + """ + import socket + + # First "resolve" all the IP addresses to getaddrinfo() tuples of the form + # (family, type, proto, canonname, sockaddr) + res: list[ + tuple[ + int, + int, + int, + Union[str, None], + Union[tuple[str, int], tuple[str, int, int, int]], + ] + ] = [] + for addr in address_list: + # This should always work as these are supposed to be IP addresses + try: + res += socket.getaddrinfo( + addr, 0, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST + ) + except OSError: + _LOGGER.info("Failed to parse IP address '%s'", addr) + + # Now use that information to sort them. + res.sort(key=addr_preference_) + + # Finally, turn the getaddrinfo() tuples back into plain hostnames. + return [socket.getnameinfo(r[4], socket.NI_NUMERICHOST)[0] for r in res] + + def get_bool_env(var, default=False): value = os.getenv(var, default) if isinstance(value, str): diff --git a/esphome/loader.py b/esphome/loader.py index 0fb4187b04..dbaa2ac661 100644 --- a/esphome/loader.py +++ b/esphome/loader.py @@ -91,6 +91,10 @@ class ComponentManifest: def codeowners(self) -> list[str]: return getattr(self.module, "CODEOWNERS", []) + @property + def instance_type(self) -> list[str]: + return getattr(self.module, "INSTANCE_TYPE", None) + @property def final_validate_schema(self) -> Optional[Callable[[ConfigType], None]]: """Components can declare a `FINAL_VALIDATE_SCHEMA` cv.Schema that gets called diff --git a/esphome/mqtt.py b/esphome/mqtt.py index 2f90c49025..2403a4a1d9 100644 --- a/esphome/mqtt.py +++ b/esphome/mqtt.py @@ -3,6 +3,7 @@ import hashlib import json import logging import ssl +import tempfile import time import paho.mqtt.client as mqtt @@ -10,6 +11,8 @@ import paho.mqtt.client as mqtt from esphome.const import ( CONF_BROKER, CONF_CERTIFICATE_AUTHORITY, + CONF_CLIENT_CERTIFICATE, + CONF_CLIENT_CERTIFICATE_KEY, CONF_DISCOVERY_PREFIX, CONF_ESPHOME, CONF_LOG_TOPIC, @@ -17,6 +20,7 @@ from esphome.const import ( CONF_NAME, CONF_PASSWORD, CONF_PORT, + CONF_SKIP_CERT_CN_CHECK, CONF_SSL_FINGERPRINTS, CONF_TOPIC, CONF_TOPIC_PREFIX, @@ -102,15 +106,24 @@ def prepare( if config[CONF_MQTT].get(CONF_SSL_FINGERPRINTS) or config[CONF_MQTT].get( CONF_CERTIFICATE_AUTHORITY ): - tls_version = ssl.PROTOCOL_TLS # pylint: disable=no-member - client.tls_set( - ca_certs=None, - certfile=None, - keyfile=None, - cert_reqs=ssl.CERT_REQUIRED, - tls_version=tls_version, - ciphers=None, + context = ssl.create_default_context( + cadata=config[CONF_MQTT].get(CONF_CERTIFICATE_AUTHORITY) ) + if config[CONF_MQTT].get(CONF_SKIP_CERT_CN_CHECK): + context.check_hostname = False + if config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE) and config[CONF_MQTT].get( + CONF_CLIENT_CERTIFICATE_KEY + ): + with ( + tempfile.NamedTemporaryFile(mode="w+") as cert_file, + tempfile.NamedTemporaryFile(mode="w+") as key_file, + ): + cert_file.write(config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE)) + cert_file.flush() + key_file.write(config[CONF_MQTT].get(CONF_CLIENT_CERTIFICATE_KEY)) + key_file.flush() + context.load_cert_chain(cert_file, key_file) + client.tls_set_context(context) try: host = str(config[CONF_MQTT][CONF_BROKER]) diff --git a/esphome/platformio_api.py b/esphome/platformio_api.py index b81ec4ab37..ed95fa125e 100644 --- a/esphome/platformio_api.py +++ b/esphome/platformio_api.py @@ -53,7 +53,7 @@ FILTER_PLATFORMIO_LINES = [ f"You can ignore this message, if `.*{IGNORE_LIB_WARNINGS}.*` is a built-in library.*", r"Scanning dependencies...", r"Found \d+ compatible libraries", - r"Memory Usage -> http://bit.ly/pio-memory-usage", + r"Memory Usage -> https://bit.ly/pio-memory-usage", r"Found: https://platformio.org/lib/show/.*", r"Using cache: .*", r"Installing dependencies", diff --git a/esphome/schema_extractors.py b/esphome/schema_extractors.py index 5491bc88c4..a84e08a8d3 100644 --- a/esphome/schema_extractors.py +++ b/esphome/schema_extractors.py @@ -42,7 +42,6 @@ def schema_extractor_extended(func): def decorate(*args, **kwargs): ret = func(*args, **kwargs) - assert len(args) == 2 extended_schemas[repr(ret)] = args return ret diff --git a/esphome/vscode.py b/esphome/vscode.py index fb62b60eac..d8cfe91938 100644 --- a/esphome/vscode.py +++ b/esphome/vscode.py @@ -7,6 +7,7 @@ from typing import Any from esphome.config import Config, _format_vol_invalid, validate_config import esphome.config_validation as cv +from esphome.const import __version__ as ESPHOME_VERSION from esphome.core import CORE, DocumentRange from esphome.yaml_util import parse_yaml @@ -97,7 +98,21 @@ def _ace_loader(fname: str) -> dict[str, Any]: return parse_yaml(fname, raw_yaml_stream) +def _print_version(): + """Print ESPHome version.""" + print( + json.dumps( + { + "type": "version", + "value": ESPHOME_VERSION, + } + ) + ) + + def read_config(args): + _print_version() + while True: CORE.reset() data = json.loads(input()) diff --git a/esphome/wizard.py b/esphome/wizard.py index 7fdf245c76..8c5bd07e1f 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -361,11 +361,11 @@ def wizard(path): if platform == "ESP32": board_link = ( - "http://docs.platformio.org/en/latest/platforms/espressif32.html#boards" + "https://docs.platformio.org/en/latest/platforms/espressif32.html#boards" ) elif platform == "ESP8266": board_link = ( - "http://docs.platformio.org/en/latest/platforms/espressif8266.html#boards" + "https://docs.platformio.org/en/latest/platforms/espressif8266.html#boards" ) elif platform == "RP2040": board_link = ( diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index b235f06786..c6a143a42f 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -18,6 +18,8 @@ from esphome.storage_json import StorageJSON, ext_storage_path _LOGGER = logging.getLogger(__name__) +DEFAULT_TIMEOUT = 10.0 +DEFAULT_TIMEOUT_MS = DEFAULT_TIMEOUT * 1000 _BACKGROUND_TASKS: set[asyncio.Task] = set() @@ -107,7 +109,7 @@ class DashboardImportDiscovery: self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str ) -> None: """Process a service info.""" - if await info.async_request(zeroconf, timeout=3000): + if await info.async_request(zeroconf, timeout=DEFAULT_TIMEOUT_MS): self._process_service_info(name, info) def _process_service_info(self, name: str, info: ServiceInfo) -> None: @@ -164,7 +166,9 @@ class DashboardImportDiscovery: class EsphomeZeroconf(Zeroconf): - def resolve_host(self, host: str, timeout: float = 3.0) -> list[str] | None: + def resolve_host( + self, host: str, timeout: float = DEFAULT_TIMEOUT + ) -> list[str] | None: """Resolve a host name to an IP address.""" info = AddressResolver(f"{host.partition('.')[0]}.local.") if ( @@ -177,7 +181,7 @@ class EsphomeZeroconf(Zeroconf): class AsyncEsphomeZeroconf(AsyncZeroconf): async def async_resolve_host( - self, host: str, timeout: float = 3.0 + self, host: str, timeout: float = DEFAULT_TIMEOUT ) -> list[str] | None: """Resolve a host name to an IP address.""" info = AddressResolver(f"{host.partition('.')[0]}.local.") diff --git a/platformio.ini b/platformio.ini index 88e7c3b331..ccfd52c3ca 100644 --- a/platformio.ini +++ b/platformio.ini @@ -63,7 +63,7 @@ lib_deps = Wire ; i2c (Arduino built-int) heman/AsyncMqttClient-esphome@1.0.0 ; mqtt esphome/ESPAsyncWebServer-esphome@3.3.0 ; web_server_base - fastled/FastLED@3.3.2 ; fastled_base + fastled/FastLED@3.9.16 ; fastled_base mikalhart/TinyGPSPlus@1.0.2 ; gps freekode/TM1651@1.0.1 ; tm1651 glmnet/Dsmr@0.7 ; dsmr @@ -128,7 +128,7 @@ lib_deps = DNSServer ; captive_portal (Arduino built-in) esphome/ESP32-audioI2S@2.0.7 ; i2s_audio droscy/esp_wireguard@0.4.2 ; wireguard - esphome/esp-audio-libs@1.1.3 ; audio + esphome/esp-audio-libs@1.1.4 ; audio build_flags = ${common:arduino.build_flags} @@ -142,14 +142,14 @@ extra_scripts = post:esphome/components/esp32/post_build.py.script extends = common:idf platform = https://github.com/pioarduino/platform-espressif32/releases/download/51.03.06/platform-espressif32.zip platform_packages = - pioarduino/framework-espidf@https://github.com/pioarduino/esp-idf/releases/download/v5.1.5/esp-idf-v5.1.5.zip + pioarduino/framework-espidf@https://github.com/pioarduino/esp-idf/releases/download/v5.1.6/esp-idf-v5.1.6.zip framework = espidf lib_deps = ${common:idf.lib_deps} droscy/esp_wireguard@0.4.2 ; wireguard kahrendt/ESPMicroSpeechFeatures@1.1.0 ; micro_wake_word - esphome/esp-audio-libs@1.1.3 ; audio + esphome/esp-audio-libs@1.1.4 ; audio build_flags = ${common:idf.build_flags} -Wno-nonnull-compare @@ -194,6 +194,26 @@ build_flags = -DUSE_LIBRETINY build_src_flags = -include Arduino.h +; This is the common settings for the nRF52 using Zephyr. +[common:nrf52-zephyr] +extends = common +platform = https://github.com/tomaszduda23/platform-nordicnrf52/archive/refs/tags/v10.3.0-1.zip +framework = zephyr +platform_packages = + platformio/framework-zephyr @ https://github.com/tomaszduda23/framework-sdk-nrf/archive/refs/tags/v2.6.1-4.zip + platformio/toolchain-gccarmnoneeabi@https://github.com/tomaszduda23/toolchain-sdk-ng/archive/refs/tags/v0.16.1-1.zip +build_flags = + ${common.build_flags} + -DUSE_ZEPHYR + -DUSE_NRF52 +lib_deps = + bblanchon/ArduinoJson@7.0.0 ; json + wjtje/qr-code-generator-library@1.7.0 ; qr_code + pavlodn/HaierProtocol@0.9.31 ; haier + functionpointer/arduino-MLX90393@1.0.2 ; mlx90393 + https://github.com/Sensirion/arduino-gas-index-algorithm.git#3.2.1 ; Sensirion Gas Index Algorithm Arduino Library + lvgl/lvgl@8.4.0 ; lvgl + ; All the actual environments are defined below. ;;;;;;;; ESP8266 ;;;;;;;; @@ -440,3 +460,19 @@ build_flags = ${common.build_flags} -DUSE_HOST -std=c++17 + +;;;;;;;; nRF52 ;;;;;;;; + +[env:nrf52] +extends = common:nrf52-zephyr +board = adafruit_feather_nrf52840 +build_flags = + ${common:nrf52-zephyr.build_flags} + ${flags:runtime.build_flags} + +[env:nrf52-tidy] +extends = common:nrf52-zephyr +board = adafruit_feather_nrf52840 +build_flags = + ${common:nrf52-zephyr.build_flags} + ${flags:clangtidy.build_flags} diff --git a/pyproject.toml b/pyproject.toml index 77dcaf1fab..1971f033c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools==78.1.0", "wheel>=0.43,<0.46"] +requires = ["setuptools==80.4.0", "wheel>=0.43,<0.46"] build-backend = "setuptools.build_meta" [project] @@ -48,7 +48,6 @@ version = {attr = "esphome.const.__version__"} [tool.setuptools.dynamic.optional-dependencies] dev = { file = ["requirements_dev.txt"] } test = { file = ["requirements_test.txt"] } -displays = { file = ["requirements_optional.txt"] } [tool.setuptools.packages.find] include = ["esphome*"] diff --git a/requirements.txt b/requirements.txt index d40ce6c145..9547cd0ef0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,12 +13,13 @@ platformio==6.1.18 # When updating platformio, also update /docker/Dockerfile esptool==4.8.1 click==8.1.7 esphome-dashboard==20250415.0 -aioesphomeapi==29.10.0 -zeroconf==0.146.5 -puremagic==1.28 +aioesphomeapi==30.2.0 +zeroconf==0.147.0 +puremagic==1.29 ruamel.yaml==0.18.10 # dashboard_import esphome-glyphsets==0.2.0 pillow==10.4.0 +cairosvg==2.7.1 freetype-py==2.5.1 # esp-idf requires this, but doesn't bundle it by default diff --git a/requirements_dev.txt b/requirements_dev.txt index d77ccaff69..16e051fcd7 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,4 +1,4 @@ # Useful stuff when working in a development environment clang-format==13.0.1 # also change in .pre-commit-config.yaml and Dockerfile when updating clang-tidy==18.1.8 # When updating clang-tidy, also update Dockerfile -yamllint==1.37.0 # also change in .pre-commit-config.yaml when updating +yamllint==1.37.1 # also change in .pre-commit-config.yaml when updating diff --git a/requirements_optional.txt b/requirements_optional.txt deleted file mode 100644 index 7416753d55..0000000000 --- a/requirements_optional.txt +++ /dev/null @@ -1 +0,0 @@ -cairosvg==2.7.1 diff --git a/requirements_test.txt b/requirements_test.txt index e43df6703f..6dd8d883ba 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,12 +1,12 @@ -pylint==3.3.6 +pylint==3.3.7 flake8==7.2.0 # also change in .pre-commit-config.yaml when updating -ruff==0.11.2 # also change in .pre-commit-config.yaml when updating +ruff==0.11.9 # also change in .pre-commit-config.yaml when updating pyupgrade==3.19.1 # also change in .pre-commit-config.yaml when updating pre-commit # Unit tests pytest==8.3.5 -pytest-cov==6.0.0 +pytest-cov==6.1.1 pytest-mock==3.14.0 pytest-asyncio==0.26.0 asyncmock==0.4.2 diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index 7771922697..63c1efa1ee 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -1,4 +1,39 @@ #!/usr/bin/env python3 +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import IntEnum +import os +from pathlib import Path +import re +from subprocess import call +import sys +from textwrap import dedent +from typing import Any + +import aioesphomeapi.api_options_pb2 as pb +import google.protobuf.descriptor_pb2 as descriptor + + +class WireType(IntEnum): + """Protocol Buffer wire types as defined in the protobuf spec. + + As specified in the Protocol Buffers encoding guide: + https://protobuf.dev/programming-guides/encoding/#structure + """ + + VARINT = 0 # int32, int64, uint32, uint64, sint32, sint64, bool, enum + FIXED64 = 1 # fixed64, sfixed64, double + LENGTH_DELIMITED = 2 # string, bytes, embedded messages, packed repeated fields + START_GROUP = 3 # groups (deprecated) + END_GROUP = 4 # groups (deprecated) + FIXED32 = 5 # fixed32, sfixed32, float + + +# Generate with +# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto + + """Python 3 script to automatically generate C++ classes for ESPHome's native API. It's pretty crappy spaghetti code, but it works. @@ -17,25 +52,14 @@ then run this script with python3 and the files will be generated, they still need to be formatted """ -from abc import ABC, abstractmethod -import os -from pathlib import Path -import re -from subprocess import call -import sys -from textwrap import dedent - -# Generate with -# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto -import aioesphomeapi.api_options_pb2 as pb -import google.protobuf.descriptor_pb2 as descriptor FILE_HEADER = """// This file was automatically generated with a tool. -// See scripts/api_protobuf/api_protobuf.py +// See script/api_protobuf/api_protobuf.py """ -def indent_list(text, padding=" "): +def indent_list(text: str, padding: str = " ") -> list[str]: + """Indent each line of the given text with the specified padding.""" lines = [] for line in text.splitlines(): if line == "": @@ -48,54 +72,72 @@ def indent_list(text, padding=" "): return lines -def indent(text, padding=" "): +def indent(text: str, padding: str = " ") -> str: return "\n".join(indent_list(text, padding)) -def camel_to_snake(name): +def camel_to_snake(name: str) -> str: # https://stackoverflow.com/a/1176023 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() +def force_str(force: bool) -> str: + """Convert a boolean force value to string format for C++ code.""" + return str(force).lower() + + class TypeInfo(ABC): - def __init__(self, field): + """Base class for all type information.""" + + def __init__(self, field: descriptor.FieldDescriptorProto) -> None: self._field = field @property - def default_value(self): + def default_value(self) -> str: + """Get the default value.""" return "" @property - def name(self): + def name(self) -> str: + """Get the name of the field.""" return self._field.name @property - def arg_name(self): + def arg_name(self) -> str: + """Get the argument name.""" return self.name @property - def field_name(self): + def field_name(self) -> str: + """Get the field name.""" return self.name @property - def number(self): + def number(self) -> int: + """Get the field number.""" return self._field.number @property - def repeated(self): + def repeated(self) -> bool: + """Check if the field is repeated.""" return self._field.label == 3 @property - def cpp_type(self): + def wire_type(self) -> WireType: + """Get the wire type for the field.""" raise NotImplementedError @property - def reference_type(self): + def cpp_type(self) -> str: + raise NotImplementedError + + @property + def reference_type(self) -> str: return f"{self.cpp_type} " @property - def const_reference_type(self): + def const_reference_type(self) -> str: return f"{self.cpp_type} " @property @@ -171,28 +213,60 @@ class TypeInfo(ABC): decode_64bit = None @property - def encode_content(self): + def encode_content(self) -> str: return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});" encode_func = None @property - def dump_content(self): + def dump_content(self) -> str: o = f'out.append(" {self.name}: ");\n' o += self.dump(f"this->{self.field_name}") + "\n" o += 'out.append("\\n");\n' return o @abstractmethod - def dump(self, name: str): - pass + def dump(self, name: str) -> str: + """Dump the value to the output.""" + + def calculate_field_id_size(self) -> int: + """Calculates the size of a field ID in bytes. + + Returns: + The number of bytes needed to encode the field ID + """ + # Calculate the tag by combining field_id and wire_type + tag = (self.number << 3) | (self.wire_type & 0b111) + + # Calculate the varint size + if tag < 128: + return 1 # 7 bits + if tag < 16384: + return 2 # 14 bits + if tag < 2097152: + return 3 # 21 bits + if tag < 268435456: + return 4 # 28 bits + return 5 # 32 bits (maximum for uint32_t) + + @abstractmethod + def get_size_calculation(self, name: str, force: bool = False) -> str: + """Calculate the size needed for encoding this field. + + Args: + name: The name of the field + force: Whether to force encoding the field even if it has a default value + """ -TYPE_INFO = {} +TYPE_INFO: dict[int, TypeInfo] = {} -def register_type(name): - def func(value): +def register_type(name: int): + """Decorator to register a type with a name and number.""" + + def func(value: TypeInfo) -> TypeInfo: + """Register the type with the given name and number.""" TYPE_INFO[name] = value return value @@ -205,12 +279,18 @@ class DoubleType(TypeInfo): default_value = "0.0" decode_64bit = "value.as_double()" encode_func = "encode_double" + wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%g", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0.0, {force_str(force)});" + return o + @register_type(2) class FloatType(TypeInfo): @@ -218,12 +298,18 @@ class FloatType(TypeInfo): default_value = "0.0f" decode_32bit = "value.as_float()" encode_func = "encode_float" + wire_type = WireType.FIXED32 # Uses wire type 5 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%g", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0.0f, {force_str(force)});" + return o + @register_type(3) class Int64Type(TypeInfo): @@ -231,12 +317,18 @@ class Int64Type(TypeInfo): default_value = "0" decode_varint = "value.as_int64()" encode_func = "encode_int64" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_int64_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(4) class UInt64Type(TypeInfo): @@ -244,12 +336,18 @@ class UInt64Type(TypeInfo): default_value = "0" decode_varint = "value.as_uint64()" encode_func = "encode_uint64" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%llu", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_uint64_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(5) class Int32Type(TypeInfo): @@ -257,12 +355,18 @@ class Int32Type(TypeInfo): default_value = "0" decode_varint = "value.as_int32()" encode_func = "encode_int32" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_int32_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(6) class Fixed64Type(TypeInfo): @@ -270,12 +374,18 @@ class Fixed64Type(TypeInfo): default_value = "0" decode_64bit = "value.as_fixed64()" encode_func = "encode_fixed64" + wire_type = WireType.FIXED64 # Uses wire type 1 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%llu", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});" + return o + @register_type(7) class Fixed32Type(TypeInfo): @@ -283,12 +393,18 @@ class Fixed32Type(TypeInfo): default_value = "0" decode_32bit = "value.as_fixed32()" encode_func = "encode_fixed32" + wire_type = WireType.FIXED32 # Uses wire type 5 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRIu32, {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});" + return o + @register_type(8) class BoolType(TypeInfo): @@ -296,11 +412,17 @@ class BoolType(TypeInfo): default_value = "false" decode_varint = "value.as_bool()" encode_func = "encode_bool" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f"out.append(YESNO({name}));" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_bool_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(9) class StringType(TypeInfo): @@ -310,40 +432,52 @@ class StringType(TypeInfo): const_reference_type = "const std::string &" decode_length = "value.as_string()" encode_func = "encode_string" + wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2 def dump(self, name): o = f'out.append("\'").append({name}).append("\'");' return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(11) class MessageType(TypeInfo): @property - def cpp_type(self): + def cpp_type(self) -> str: return self._field.type_name[1:] default_value = "" + wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2 @property - def reference_type(self): + def reference_type(self) -> str: return f"{self.cpp_type} &" @property - def const_reference_type(self): + def const_reference_type(self) -> str: return f"const {self.cpp_type} &" @property - def encode_func(self): + def encode_func(self) -> str: return f"encode_message<{self.cpp_type}>" @property - def decode_length(self): + def decode_length(self) -> str: return f"value.as_message<{self.cpp_type}>()" - def dump(self, name): + def dump(self, name: str) -> str: o = f"{name}.dump_to(out);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_message_object(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(12) class BytesType(TypeInfo): @@ -353,11 +487,17 @@ class BytesType(TypeInfo): const_reference_type = "const std::string &" decode_length = "value.as_string()" encode_func = "encode_string" + wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2 - def dump(self, name): + def dump(self, name: str) -> str: o = f'out.append("\'").append({name}).append("\'");' return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(13) class UInt32Type(TypeInfo): @@ -365,33 +505,45 @@ class UInt32Type(TypeInfo): default_value = "0" decode_varint = "value.as_uint32()" encode_func = "encode_uint32" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRIu32, {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_uint32_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(14) class EnumType(TypeInfo): @property - def cpp_type(self): + def cpp_type(self) -> str: return f"enums::{self._field.type_name[1:]}" @property - def decode_varint(self): + def decode_varint(self) -> str: return f"value.as_enum<{self.cpp_type}>()" default_value = "" + wire_type = WireType.VARINT # Uses wire type 0 @property - def encode_func(self): + def encode_func(self) -> str: return f"encode_enum<{self.cpp_type}>" - def dump(self, name): + def dump(self, name: str) -> str: o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_enum_field(total_size, {field_id_size}, static_cast({name}), {force_str(force)});" + return o + @register_type(15) class SFixed32Type(TypeInfo): @@ -399,12 +551,18 @@ class SFixed32Type(TypeInfo): default_value = "0" decode_32bit = "value.as_sfixed32()" encode_func = "encode_sfixed32" + wire_type = WireType.FIXED32 # Uses wire type 5 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});" + return o + @register_type(16) class SFixed64Type(TypeInfo): @@ -412,12 +570,18 @@ class SFixed64Type(TypeInfo): default_value = "0" decode_64bit = "value.as_sfixed64()" encode_func = "encode_sfixed64" + wire_type = WireType.FIXED64 # Uses wire type 1 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});" + return o + @register_type(17) class SInt32Type(TypeInfo): @@ -425,12 +589,18 @@ class SInt32Type(TypeInfo): default_value = "0" decode_varint = "value.as_sint32()" encode_func = "encode_sint32" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%" PRId32, {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_sint32_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + @register_type(18) class SInt64Type(TypeInfo): @@ -438,30 +608,44 @@ class SInt64Type(TypeInfo): default_value = "0" decode_varint = "value.as_sint64()" encode_func = "encode_sint64" + wire_type = WireType.VARINT # Uses wire type 0 - def dump(self, name): + def dump(self, name: str) -> str: o = f'sprintf(buffer, "%lld", {name});\n' o += "out.append(buffer);" return o + def get_size_calculation(self, name: str, force: bool = False) -> str: + field_id_size = self.calculate_field_id_size() + o = f"ProtoSize::add_sint64_field(total_size, {field_id_size}, {name}, {force_str(force)});" + return o + class RepeatedTypeInfo(TypeInfo): - def __init__(self, field): + def __init__(self, field: descriptor.FieldDescriptorProto) -> None: super().__init__(field) - self._ti = TYPE_INFO[field.type](field) + self._ti: TypeInfo = TYPE_INFO[field.type](field) @property - def cpp_type(self): + def cpp_type(self) -> str: return f"std::vector<{self._ti.cpp_type}>" @property - def reference_type(self): + def reference_type(self) -> str: return f"{self.cpp_type} &" @property - def const_reference_type(self): + def const_reference_type(self) -> str: return f"const {self.cpp_type} &" + @property + def wire_type(self) -> WireType: + """Get the wire type for this repeated field. + + For repeated fields, we use the same wire type as the underlying field. + """ + return self._ti.wire_type + @property def decode_varint_content(self) -> str: content = self._ti.decode_varint @@ -515,19 +699,19 @@ class RepeatedTypeInfo(TypeInfo): ) @property - def _ti_is_bool(self): + def _ti_is_bool(self) -> bool: # std::vector is specialized for bool, reference does not work return isinstance(self._ti, BoolType) @property - def encode_content(self): + def encode_content(self) -> str: o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n" o += "}" return o @property - def dump_content(self): + def dump_content(self) -> str: o = f"for (const auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o += f' out.append(" {self.name}: ");\n' o += indent(self._ti.dump("it")) + "\n" @@ -538,8 +722,25 @@ class RepeatedTypeInfo(TypeInfo): def dump(self, _: str): pass + def get_size_calculation(self, name: str, force: bool = False) -> str: + # For repeated fields, we always need to pass force=True to the underlying type's calculation + # This is because the encode method always sets force=true for repeated fields + if isinstance(self._ti, MessageType): + # For repeated messages, use the dedicated helper that handles iteration internally + field_id_size = self._ti.calculate_field_id_size() + o = f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});" + return o + # For other repeated types, use the underlying type's size calculation with force=True + o = f"if (!{name}.empty()) {{\n" + o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n" + o += f" {self._ti.get_size_calculation('it', True)}\n" + o += " }\n" + o += "}" + return o -def build_enum_type(desc): + +def build_enum_type(desc) -> tuple[str, str]: + """Builds the enum type.""" name = desc.name out = f"enum {name} : uint32_t {{\n" for v in desc.value: @@ -561,15 +762,16 @@ def build_enum_type(desc): return out, cpp -def build_message_type(desc): - public_content = [] - protected_content = [] - decode_varint = [] - decode_length = [] - decode_32bit = [] - decode_64bit = [] - encode = [] - dump = [] +def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]: + public_content: list[str] = [] + protected_content: list[str] = [] + decode_varint: list[str] = [] + decode_length: list[str] = [] + decode_32bit: list[str] = [] + decode_64bit: list[str] = [] + encode: list[str] = [] + dump: list[str] = [] + size_calc: list[str] = [] for field in desc.field: if field.label == 3: @@ -579,6 +781,7 @@ def build_message_type(desc): protected_content.extend(ti.protected_content) public_content.extend(ti.public_content) encode.append(ti.encode_content) + size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}")) if ti.decode_varint_content: decode_varint.append(ti.decode_varint_content) @@ -645,6 +848,25 @@ def build_message_type(desc): prot = "void encode(ProtoWriteBuffer buffer) const override;" public_content.append(prot) + # Add calculate_size method + o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{" + + # Add a check for empty/default objects to short-circuit the calculation + # Only add this optimization if we have fields to check + if size_calc: + # For a single field, just inline it for simplicity + if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120: + o += f" {size_calc[0]} " + else: + # For multiple fields, add a short-circuit check + o += "\n" + # Performance optimization: add all the size calculations + o += indent("\n".join(size_calc)) + "\n" + o += "}\n" + cpp += o + prot = "void calculate_size(uint32_t &total_size) const override;" + public_content.append(prot) + o = f"void {desc.name}::dump_to(std::string &out) const {{" if dump: if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120: @@ -687,27 +909,35 @@ SOURCE_BOTH = 0 SOURCE_SERVER = 1 SOURCE_CLIENT = 2 -RECEIVE_CASES = {} +RECEIVE_CASES: dict[int, str] = {} -ifdefs = {} +ifdefs: dict[str, str] = {} -def get_opt(desc, opt, default=None): +def get_opt( + desc: descriptor.DescriptorProto, + opt: descriptor.MessageOptions, + default: Any = None, +) -> Any: + """Get the option from the descriptor.""" if not desc.options.HasExtension(opt): return default return desc.options.Extensions[opt] -def build_service_message_type(mt): +def build_service_message_type( + mt: descriptor.DescriptorProto, +) -> tuple[str, str] | None: + """Builds the service message type.""" snake = camel_to_snake(mt.name) - id_ = get_opt(mt, pb.id) + id_: int | None = get_opt(mt, pb.id) if id_ is None: return None - source = get_opt(mt, pb.source, 0) + source: int = get_opt(mt, pb.source, 0) - ifdef = get_opt(mt, pb.ifdef) - log = get_opt(mt, pb.log, True) + ifdef: str | None = get_opt(mt, pb.ifdef) + log: bool = get_opt(mt, pb.log, True) hout = "" cout = "" @@ -754,7 +984,8 @@ def build_service_message_type(mt): return hout, cout -def main(): +def main() -> None: + """Main function to generate the C++ classes.""" cwd = Path(__file__).resolve().parent root = cwd.parent.parent / "esphome" / "components" / "api" prot_file = root / "api.protoc" @@ -770,6 +1001,7 @@ def main(): #pragma once #include "proto.h" + #include "api_pb2_size.h" namespace esphome { namespace api { @@ -779,6 +1011,7 @@ def main(): cpp = FILE_HEADER cpp += """\ #include "api_pb2.h" + #include "api_pb2_size.h" #include "esphome/core/log.h" #include @@ -959,7 +1192,7 @@ def main(): try: import clang_format - def exec_clang_format(path): + def exec_clang_format(path: Path) -> None: clang_format_path = os.path.join( os.path.dirname(clang_format.__file__), "data", "bin", "clang-format" ) diff --git a/script/build_language_schema.py b/script/build_language_schema.py old mode 100644 new mode 100755 index 7152e23e8f..4473ec1b5a --- a/script/build_language_schema.py +++ b/script/build_language_schema.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 import argparse import glob import inspect @@ -36,6 +37,7 @@ parser = argparse.ArgumentParser() parser.add_argument( "--output-path", default=".", help="Output path", type=os.path.abspath ) +parser.add_argument("--check", action="store_true", help="Check only for CI") args = parser.parse_args() @@ -66,31 +68,42 @@ def get_component_names(): from esphome.loader import CORE_COMPONENTS_PATH component_names = ["esphome", "sensor", "esp32", "esp8266"] + skip_components = [] for d in os.listdir(CORE_COMPONENTS_PATH): if not d.startswith("__") and os.path.isdir( os.path.join(CORE_COMPONENTS_PATH, d) ): - if d not in component_names: + if d not in component_names and d not in skip_components: component_names.append(d) - return component_names + return sorted(component_names) def load_components(): from esphome.config import get_component for domain in get_component_names(): - components[domain] = get_component(domain) + components[domain] = get_component(domain, exception=True) + assert components[domain] is not None -# pylint: disable=wrong-import-position -from esphome.const import CONF_TYPE, KEY_CORE, KEY_TARGET_PLATFORM # noqa: E402 +from esphome.const import ( # noqa: E402 + CONF_TYPE, + KEY_CORE, + KEY_FRAMEWORK_VERSION, + KEY_TARGET_FRAMEWORK, + KEY_TARGET_PLATFORM, +) from esphome.core import CORE # noqa: E402 -# pylint: enable=wrong-import-position +CORE.data[KEY_CORE] = { + KEY_TARGET_PLATFORM: "esp8266", + KEY_TARGET_FRAMEWORK: "arduino", + KEY_FRAMEWORK_VERSION: "0", +} + -CORE.data[KEY_CORE] = {KEY_TARGET_PLATFORM: None} load_components() # Import esphome after loading components (so schema is tracked) @@ -98,7 +111,6 @@ load_components() from esphome import automation, pins # noqa: E402 from esphome.components import remote_base # noqa: E402 import esphome.config_validation as cv # noqa: E402 -import esphome.core as esphome_core # noqa: E402 from esphome.helpers import write_file_if_changed # noqa: E402 from esphome.loader import CORE_COMPONENTS_PATH, get_platform # noqa: E402 from esphome.util import Registry # noqa: E402 @@ -523,11 +535,14 @@ def shrink(): # then are all simple types, integer and strings for x, paths in referenced_schemas.items(): key_s = get_str_path_schema(x) - if key_s and key_s[S_TYPE] in ["enum", "registry", "integer", "string"]: + if key_s and key_s.get(S_TYPE) in ["enum", "registry", "integer", "string"]: if key_s[S_TYPE] == "registry": print("Spreading registry: " + x) for target in paths: target_s = get_arr_path_schema(target) + if S_SCHEMA not in target_s: + print("skipping simple spread for " + ".".join(target)) + continue assert target_s[S_SCHEMA][S_EXTENDS] == [x] target_s.pop(S_SCHEMA) target_s |= key_s @@ -542,14 +557,14 @@ def shrink(): # an empty schema like speaker.SPEAKER_SCHEMA target_s[S_EXTENDS].remove(x) continue - assert target_s[S_SCHEMA][S_EXTENDS] == [x] + assert x in target_s[S_SCHEMA][S_EXTENDS] target_s.pop(S_SCHEMA) target_s.pop(S_TYPE) # undefined target_s["data_type"] = x.split(".")[1] # remove this dangling again pop_str_path_schema(x) - # remove dangling items (unreachable schemas) + # remove unreachable schemas for domain, domain_schemas in output.items(): for schema_name in list(domain_schemas.get(S_SCHEMAS, {}).keys()): s = f"{domain}.{schema_name}" @@ -558,7 +573,6 @@ def shrink(): and s not in referenced_schemas and not is_platform_schema(s) ): - print(f"Removing {s}") domain_schemas[S_SCHEMAS].pop(schema_name) @@ -659,6 +673,9 @@ def build_schema(): # bundle core inside esphome data["esphome"]["core"] = data.pop("core")["core"] + if args.check: # do not gen files + return + for c, s in data.items(): write_file(c, s) delete_extra_files(data.keys()) @@ -721,15 +738,8 @@ def convert(schema, config_var, path): # Extended schemas are tracked when the .extend() is used in a schema if repr_schema in ejs.extended_schemas: extended = ejs.extended_schemas.get(repr_schema) - # The midea actions are extending an empty schema (resulted in the templatize not templatizing anything) - # this causes a recursion in that this extended looks the same in extended schema as the extended[1] - if repr_schema == repr(extended[1]): - assert path.startswith("midea_ac/") - return - - assert len(extended) == 2 - convert(extended[0], config_var, path + "/extL") - convert(extended[1], config_var, path + "/extR") + for idx, ext in enumerate(extended): + convert(ext, config_var, f"{path}/ext{idx}") return if isinstance(schema, cv.All): @@ -881,15 +891,22 @@ def convert(schema, config_var, path): "class": "i2c::I2CBus", "parents": ["Component"], } - elif path == "uart/CONFIG_SCHEMA/val 1/extL/all/id": + elif path == "uart/CONFIG_SCHEMA/val 1/ext0/all/id": config_var["id_type"] = { "class": "uart::UARTComponent", "parents": ["Component"], } + elif path == "http_request/CONFIG_SCHEMA/val 1/ext0/all/id": + config_var["id_type"] = { + "class": "http_request::HttpRequestComponent", + "parents": ["Component"], + } elif path == "pins/esp32/val 1/id": config_var["id_type"] = "pin" else: - raise TypeError("Cannot determine id_type for " + path) + print("Cannot determine id_type for " + path) + + # raise TypeError("Cannot determine id_type for " + path) elif repr_schema in ejs.registry_schemas: solve_registry.append((ejs.registry_schemas[repr_schema], config_var)) @@ -965,9 +982,6 @@ def convert_keys(converted, schema, path): else: converted["key_type"] = str(k) - esphome_core.CORE.data = { - esphome_core.KEY_CORE: {esphome_core.KEY_TARGET_PLATFORM: "esp8266"} - } if hasattr(k, "default") and str(k.default) != "...": default_value = k.default() if default_value is not None: diff --git a/script/ci-custom.py b/script/ci-custom.py index d5d3ab88c8..a3a31b2259 100755 --- a/script/ci-custom.py +++ b/script/ci-custom.py @@ -292,6 +292,7 @@ def highlight(s): "esphome/core/log.h", "esphome/components/socket/headers.h", "esphome/core/defines.h", + "esphome/components/http_request/httplib.h", ], ) def lint_no_defines(fname, match): @@ -317,7 +318,12 @@ def lint_no_long_delays(fname, match): ) -@lint_content_check(include=["esphome/const.py"]) +@lint_content_check( + include=[ + "esphome/const.py", + "esphome/components/const/__init__.py", + ] +) def lint_const_ordered(fname, content): """Lint that value in const.py are ordered. @@ -552,6 +558,8 @@ def lint_relative_py_import(fname): "esphome/components/rp2040/core.cpp", "esphome/components/libretiny/core.cpp", "esphome/components/host/core.cpp", + "esphome/components/zephyr/core.cpp", + "esphome/components/http_request/httplib.h", ], ) def lint_namespace(fname, content): diff --git a/script/clang-tidy b/script/clang-tidy index a857274b01..5baaaf6b3a 100755 --- a/script/clang-tidy +++ b/script/clang-tidy @@ -40,12 +40,37 @@ def clang_options(idedata): else: cmd.append(f"--target={triplet}") + omit_flags = ( + "-free", + "-fipa-pta", + "-fstrict-volatile-bitfields", + "-mlongcalls", + "-mtext-section-literals", + "-mdisable-hardware-atomics", + "-mfix-esp32-psram-cache-issue", + "-mfix-esp32-psram-cache-strategy=memw", + "-fno-tree-switch-conversion", + ) + + if "zephyr" in triplet: + omit_flags += ( + "-fno-reorder-functions", + "-mfp16-format=ieee", + "--param=min-pagesize=0", + ) + else: + cmd.extend( + [ + # disable built-in include directories from the host + "-nostdinc++", + ] + ) + # set flags cmd.extend( [ # disable built-in include directories from the host "-nostdinc", - "-nostdinc++", # replace pgmspace.h, as it uses GNU extensions clang doesn't support # https://github.com/earlephilhower/newlib-xtensa/pull/18 "-D_PGMSPACE_H_", @@ -70,22 +95,7 @@ def clang_options(idedata): ) # copy compiler flags, except those clang doesn't understand. - cmd.extend( - flag - for flag in idedata["cxx_flags"] - if flag - not in ( - "-free", - "-fipa-pta", - "-fstrict-volatile-bitfields", - "-mlongcalls", - "-mtext-section-literals", - "-mdisable-hardware-atomics", - "-mfix-esp32-psram-cache-issue", - "-mfix-esp32-psram-cache-strategy=memw", - "-fno-tree-switch-conversion", - ) - ) + cmd.extend(flag for flag in idedata["cxx_flags"] if flag not in omit_flags) # defines cmd.extend(f"-D{define}" for define in idedata["defines"]) @@ -100,13 +110,16 @@ def clang_options(idedata): # add library include directories using -isystem to suppress their errors for directory in list(idedata["includes"]["build"]): # skip our own directories, we add those later - if not directory.startswith(f"{root_path}") or directory.startswith( - ( - f"{root_path}/.pio", - f"{root_path}/.platformio", - f"{root_path}/.temp", - f"{root_path}/managed_components", + if ( + not directory.startswith(f"{root_path}") + or directory.startswith( + ( + f"{root_path}/.platformio", + f"{root_path}/.temp", + f"{root_path}/managed_components", + ) ) + or (directory.startswith(f"{root_path}") and "/.pio/" in directory) ): cmd.extend(["-isystem", directory]) diff --git a/script/helpers.py b/script/helpers.py index 6148371e32..3c1b0c0ddd 100644 --- a/script/helpers.py +++ b/script/helpers.py @@ -5,6 +5,7 @@ import re import subprocess import colorama +import helpers_zephyr root_path = os.path.abspath(os.path.normpath(os.path.join(__file__, "..", ".."))) basepath = os.path.join(root_path, "esphome") @@ -147,10 +148,14 @@ def load_idedata(environment): # ensure temp directory exists before running pio, as it writes sdkconfig to it Path(temp_folder).mkdir(exist_ok=True) - stdout = subprocess.check_output(["pio", "run", "-t", "idedata", "-e", environment]) - match = re.search(r'{\s*".*}', stdout.decode("utf-8")) - data = json.loads(match.group()) - + if "nrf" in environment: + data = helpers_zephyr.load_idedata(environment, temp_folder, platformio_ini) + else: + stdout = subprocess.check_output( + ["pio", "run", "-t", "idedata", "-e", environment] + ) + match = re.search(r'{\s*".*}', stdout.decode("utf-8")) + data = json.loads(match.group()) temp_idedata.write_text(json.dumps(data, indent=2) + "\n") return data diff --git a/script/helpers_zephyr.py b/script/helpers_zephyr.py new file mode 100644 index 0000000000..c3ba149005 --- /dev/null +++ b/script/helpers_zephyr.py @@ -0,0 +1,124 @@ +import json +from pathlib import Path +import re +import subprocess + + +def load_idedata(environment, temp_folder, platformio_ini): + build_environment = environment.replace("-tidy", "") + build_dir = Path(temp_folder) / f"build-{build_environment}" + Path(build_dir).mkdir(exist_ok=True) + Path(build_dir / "platformio.ini").write_text( + Path(platformio_ini).read_text(encoding="utf-8"), encoding="utf-8" + ) + esphome_dir = Path(build_dir / "esphome") + esphome_dir.mkdir(exist_ok=True) + Path(esphome_dir / "main.cpp").write_text( + """ +#include +int main() { return 0;} +""", + encoding="utf-8", + ) + zephyr_dir = Path(build_dir / "zephyr") + zephyr_dir.mkdir(exist_ok=True) + Path(zephyr_dir / "prj.conf").write_text( + """ +CONFIG_NEWLIB_LIBC=y +""", + encoding="utf-8", + ) + subprocess.run(["pio", "run", "-e", build_environment, "-d", build_dir], check=True) + + def extract_include_paths(command): + include_paths = [] + include_pattern = re.compile(r'("-I\s*[^"]+)|(-isystem\s*[^\s]+)|(-I\s*[^\s]+)') + for match in include_pattern.findall(command): + split_strings = re.split( + r"\s*-\s*(?:I|isystem)", list(filter(lambda x: x, match))[0] + ) + include_paths.append(split_strings[1]) + return include_paths + + def extract_defines(command): + defines = [] + define_pattern = re.compile(r"-D\s*([^\s]+)") + for match in define_pattern.findall(command): + if match not in ("_ASMLANGUAGE"): + defines.append(match) + return defines + + def find_cxx_path(commands): + for entry in commands: + command = entry["command"] + cxx_path = command.split()[0] + if not cxx_path.endswith("++"): + continue + return cxx_path + + def get_builtin_include_paths(compiler): + result = subprocess.run( + [compiler, "-E", "-x", "c++", "-", "-v"], + input="", + text=True, + stderr=subprocess.PIPE, + stdout=subprocess.DEVNULL, + check=True, + ) + include_paths = [] + start_collecting = False + for line in result.stderr.splitlines(): + if start_collecting: + if line.startswith(" "): + include_paths.append(line.strip()) + else: + break + if "#include <...> search starts here:" in line: + start_collecting = True + return include_paths + + def extract_cxx_flags(command): + flags = [] + # Extracts CXXFLAGS from the command string, excluding includes and defines. + flag_pattern = re.compile( + r"(-O[0-3s]|-g|-std=[^\s]+|-Wall|-Wextra|-Werror|--[^\s]+|-f[^\s]+|-m[^\s]+|-imacros\s*[^\s]+)" + ) + for match in flag_pattern.findall(command): + flags.append(match.replace("-imacros ", "-imacros")) + return flags + + def transform_to_idedata_format(compile_commands): + cxx_path = find_cxx_path(compile_commands) + idedata = { + "includes": { + "toolchain": get_builtin_include_paths(cxx_path), + "build": set(), + }, + "defines": set(), + "cxx_path": cxx_path, + "cxx_flags": set(), + } + + for entry in compile_commands: + command = entry["command"] + exec = command.split()[0] + if exec != cxx_path: + continue + + idedata["includes"]["build"].update(extract_include_paths(command)) + idedata["defines"].update(extract_defines(command)) + idedata["cxx_flags"].update(extract_cxx_flags(command)) + + # Convert sets to lists for JSON serialization + idedata["includes"]["build"] = list(idedata["includes"]["build"]) + idedata["defines"] = list(idedata["defines"]) + idedata["cxx_flags"] = list(idedata["cxx_flags"]) + + return idedata + + compile_commands = json.loads( + Path( + build_dir / ".pio" / "build" / build_environment / "compile_commands.json" + ).read_text(encoding="utf-8") + ) + return transform_to_idedata_format(compile_commands) diff --git a/script/setup b/script/setup index 824840c392..acc2ec58b4 100755 --- a/script/setup +++ b/script/setup @@ -4,23 +4,28 @@ set -e cd "$(dirname "$0")/.." -location="venv/bin/activate" if [ ! -n "$DEVCONTAINER" ] && [ ! -n "$VIRTUAL_ENV" ] && [ ! "$ESPHOME_NO_VENV" ]; then - python3 -m venv venv - if [ -f venv/Scripts/activate ]; then - location="venv/Scripts/activate" + if [ -x "$(command -v uv)" ]; then + uv venv venv + else + python3 -m venv venv fi - source $location + source venv/bin/activate fi -pip3 install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt -r requirements_dev.txt -pip3 install setuptools wheel -pip3 install -e ".[dev,test,displays]" --config-settings editable_mode=compat +if ! [ -x "$(command -v uv)" ]; then + python3 -m pip install uv +fi + +uv pip install setuptools wheel +uv pip install -e ".[dev,test]" --config-settings editable_mode=compat pre-commit install script/platformio_install_deps.py platformio.ini --libraries --tools --platforms +mkdir -p .temp + echo echo -echo "Virtual environment created. Run 'source $location' to use it." +echo "Virtual environment created. Run 'source venv/bin/activate' to use it." diff --git a/script/setup.bat b/script/setup.bat index 0b49768139..ea2591bb71 100644 --- a/script/setup.bat +++ b/script/setup.bat @@ -15,9 +15,9 @@ echo Installing required packages... python.exe -m pip install --upgrade pip -pip3 install -r requirements.txt -r requirements_optional.txt -r requirements_test.txt -r requirements_dev.txt +pip3 install -r requirements.txt -r requirements_test.txt -r requirements_dev.txt pip3 install setuptools wheel -pip3 install -e ".[dev,test,displays]" --config-settings editable_mode=compat +pip3 install -e ".[dev,test]" --config-settings editable_mode=compat pre-commit install diff --git a/script/test_build_components b/script/test_build_components index 62fe0f1b55..83ab947fc1 100755 --- a/script/test_build_components +++ b/script/test_build_components @@ -53,7 +53,7 @@ start_esphome() { echo "> [$target_component] [$test_name] [$target_platform_with_version]" set -x # TODO: Validate escape of Command line substitution value - python -m esphome -s component_name $target_component -s component_dir ../../components/$target_component -s test_name $test_name -s target_platform $target_platform $esphome_command $component_test_file + python3 -m esphome -s component_name $target_component -s component_dir ../../components/$target_component -s test_name $test_name -s target_platform $target_platform $esphome_command $component_test_file { set +x; } 2>/dev/null } diff --git a/tests/component_tests/packages/test_packages.py b/tests/component_tests/packages/test_packages.py index 3fbbf49afd..4712daad0d 100644 --- a/tests/component_tests/packages/test_packages.py +++ b/tests/component_tests/packages/test_packages.py @@ -76,10 +76,11 @@ def test_package_unused(basic_esphome, basic_wifi): def test_package_invalid_dict(basic_esphome, basic_wifi): """ - Ensures an error is raised if packages is not valid. + If a url: key is present, it's expected to be well-formed remote package spec. Ensure an error is raised if not. + Any other simple dict passed as a package will be merged as usual but may fail later validation. """ - config = {CONF_ESPHOME: basic_esphome, CONF_PACKAGES: basic_wifi} + config = {CONF_ESPHOME: basic_esphome, CONF_PACKAGES: basic_wifi | {CONF_URL: ""}} with pytest.raises(cv.Invalid): do_packages_pass(config) diff --git a/tests/components/analog_threshold/common.yaml b/tests/components/analog_threshold/common.yaml index b5c14dfe56..44d79756b5 100644 --- a/tests/components/analog_threshold/common.yaml +++ b/tests/components/analog_threshold/common.yaml @@ -26,3 +26,17 @@ binary_sensor: threshold: 100 filters: - invert: + - platform: analog_threshold + name: Analog Threshold 3 + sensor_id: template_sensor + threshold: !lambda return 100; + filters: + - invert: + - platform: analog_threshold + name: Analog Threshold 4 + sensor_id: template_sensor + threshold: + upper: !lambda return 110; + lower: !lambda return 90; + filters: + - invert: diff --git a/tests/components/api/test-dynamic-encryption.esp32-idf.yaml b/tests/components/api/test-dynamic-encryption.esp32-idf.yaml new file mode 100644 index 0000000000..d8f8c247f4 --- /dev/null +++ b/tests/components/api/test-dynamic-encryption.esp32-idf.yaml @@ -0,0 +1,10 @@ +packages: + common: !include common.yaml + +wifi: + ssid: MySSID + password: password1 + +api: + encryption: + key: !remove diff --git a/tests/components/atm90e32/common.yaml b/tests/components/atm90e32/common.yaml index 156d00b4e0..3eeed8395f 100644 --- a/tests/components/atm90e32/common.yaml +++ b/tests/components/atm90e32/common.yaml @@ -17,10 +17,22 @@ sensor: name: EMON Active Power CT1 reactive_power: name: EMON Reactive Power CT1 + apparent_power: + name: EMON Apparent Power CT1 + harmonic_power: + name: EMON Harmonic Power CT1 power_factor: name: EMON Power Factor CT1 + phase_angle: + name: EMON Phase Angle CT1 + peak_current: + name: EMON Peak Current CT1 gain_voltage: 7305 gain_ct: 27961 + offset_voltage: 0 + offset_current: 0 + offset_active_power: 0 + offset_reactive_power: 0 phase_b: current: name: EMON CT2 Current @@ -28,10 +40,22 @@ sensor: name: EMON Active Power CT2 reactive_power: name: EMON Reactive Power CT2 + apparent_power: + name: EMON Apparent Power CT2 + harmonic_power: + name: EMON Harmonic Power CT2 power_factor: name: EMON Power Factor CT2 + phase_angle: + name: EMON Phase Angle CT2 + peak_current: + name: EMON Peak Current CT2 gain_voltage: 7305 gain_ct: 27961 + offset_voltage: 0 + offset_current: 0 + offset_active_power: 0 + offset_reactive_power: 0 phase_c: current: name: EMON CT3 Current @@ -39,23 +63,75 @@ sensor: name: EMON Active Power CT3 reactive_power: name: EMON Reactive Power CT3 + apparent_power: + name: EMON Apparent Power CT3 + harmonic_power: + name: EMON Harmonic Power CT3 power_factor: name: EMON Power Factor CT3 + phase_angle: + name: EMON Phase Angle CT3 + peak_current: + name: EMON Peak Current CT3 gain_voltage: 7305 gain_ct: 27961 + offset_voltage: 0 + offset_current: 0 + offset_active_power: 0 + offset_reactive_power: 0 frequency: name: EMON Line Frequency chip_temperature: - name: EMON Chip Temp A + name: EMON Chip Temp line_frequency: 60Hz current_phases: 3 - gain_pga: 2X + gain_pga: 1X enable_offset_calibration: True + enable_gain_calibration: True + +text_sensor: + - platform: atm90e32 + id: atm90e32_chip1 + phase_status: + phase_a: + name: "Phase A Status" + phase_b: + name: "Phase B Status" + phase_c: + name: "Phase C Status" + frequency_status: + name: "Frequency Status" button: - platform: atm90e32 id: atm90e32_chip1 + run_gain_calibration: + name: "Run Gain Calibration" + clear_gain_calibration: + name: "Clear Gain Calibration" run_offset_calibration: - name: Chip1 - Run Offset Calibration + name: "Run Offset Calibration" clear_offset_calibration: - name: Chip1 - Clear Offset Calibration + name: "Clear Offset Calibration" + run_power_offset_calibration: + name: "Run Power Offset Calibration" + clear_power_offset_calibration: + name: "Clear Power Offset Calibration" + +number: + - platform: atm90e32 + id: atm90e32_chip1 + reference_voltage: + phase_a: + name: "Phase A Ref Voltage" + phase_b: + name: "Phase B Ref Voltage" + phase_c: + name: "Phase C Ref Voltage" + reference_current: + phase_a: + name: "Phase A Ref Current" + phase_b: + name: "Phase B Ref Current" + phase_c: + name: "Phase C Ref Current" diff --git a/tests/components/const/common.yaml b/tests/components/const/common.yaml new file mode 100644 index 0000000000..655af304af --- /dev/null +++ b/tests/components/const/common.yaml @@ -0,0 +1,44 @@ +spi: + id: quad_spi + clk_pin: 15 + type: quad + data_pins: [14, 10, 16, 12] + +display: + - platform: qspi_dbi + model: RM690B0 + data_rate: 80MHz + spi_mode: mode0 + dimensions: + width: 450 + height: 600 + offset_width: 16 + color_order: rgb + invert_colors: false + brightness: 255 + cs_pin: 11 + reset_pin: 13 + enable_pin: 9 + + - platform: qspi_dbi + model: CUSTOM + id: main_lcd + draw_from_origin: true + dimensions: + height: 240 + width: 536 + transform: + mirror_x: true + swap_xy: true + color_order: rgb + brightness: 255 + cs_pin: 6 + reset_pin: 17 + enable_pin: 38 + init_sequence: + - [0x3A, 0x66] + - [0x11] + - delay 120ms + - [0x29] + - delay 20ms + diff --git a/tests/components/const/test.esp32-s3-idf.yaml b/tests/components/const/test.esp32-s3-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/const/test.esp32-s3-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/cst226/common.yaml b/tests/components/cst226/common.yaml index c12d8d872c..d0b8ea3a86 100644 --- a/tests/components/cst226/common.yaml +++ b/tests/components/cst226/common.yaml @@ -23,3 +23,9 @@ touchscreen: interrupt_pin: ${interrupt_pin} reset_pin: ${reset_pin} +binary_sensor: + - id: cst226_touch + platform: cst226 + on_press: + then: + - component.update: ts_cst226 diff --git a/tests/components/debug/common.yaml b/tests/components/debug/common.yaml index 5845beaa80..a9d74e6865 100644 --- a/tests/components/debug/common.yaml +++ b/tests/components/debug/common.yaml @@ -1 +1,18 @@ debug: + +text_sensor: + - platform: debug + device: + name: "Device Info" + reset_reason: + name: "Reset Reason" + +sensor: + - platform: debug + free: + name: "Heap Free" + loop_time: + name: "Loop Time" + cpu_frequency: + name: "CPU Frequency" + diff --git a/tests/components/debug/test.esp32-ard.yaml b/tests/components/debug/test.esp32-ard.yaml index dade44d145..8e19a4d627 100644 --- a/tests/components/debug/test.esp32-ard.yaml +++ b/tests/components/debug/test.esp32-ard.yaml @@ -1 +1,4 @@ <<: !include common.yaml + +esp32: + cpu_frequency: 240MHz diff --git a/tests/components/debug/test.esp32-c3-ard.yaml b/tests/components/debug/test.esp32-c3-ard.yaml index dade44d145..7d43491862 100644 --- a/tests/components/debug/test.esp32-c3-ard.yaml +++ b/tests/components/debug/test.esp32-c3-ard.yaml @@ -1 +1,4 @@ <<: !include common.yaml + +esp32: + cpu_frequency: 80MHz diff --git a/tests/components/debug/test.esp32-idf.yaml b/tests/components/debug/test.esp32-idf.yaml index dade44d145..f7483a54b3 100644 --- a/tests/components/debug/test.esp32-idf.yaml +++ b/tests/components/debug/test.esp32-idf.yaml @@ -1 +1,13 @@ <<: !include common.yaml + +esp32: + cpu_frequency: 240MHz + +sensor: + - platform: debug + free: + name: "Heap Free" + psram: + name: "Free PSRAM" + +psram: diff --git a/tests/components/demo/test.esp32-idf.yaml b/tests/components/demo/test.esp32-idf.yaml new file mode 100644 index 0000000000..80027786df --- /dev/null +++ b/tests/components/demo/test.esp32-idf.yaml @@ -0,0 +1 @@ +demo: diff --git a/tests/components/dfrobot_sen0395/common.yaml b/tests/components/dfrobot_sen0395/common.yaml index 69bcebf182..8c349911d3 100644 --- a/tests/components/dfrobot_sen0395/common.yaml +++ b/tests/components/dfrobot_sen0395/common.yaml @@ -26,3 +26,17 @@ dfrobot_sen0395: binary_sensor: - platform: dfrobot_sen0395 id: mmwave_detected + +switch: + - platform: dfrobot_sen0395 + type: sensor_active + id: mmwave_sensor_active + - platform: dfrobot_sen0395 + type: turn_on_led + id: mmwave_turn_on_led + - platform: dfrobot_sen0395 + type: presence_via_uart + id: mmwave_presence_via_uart + - platform: dfrobot_sen0395 + type: start_after_boot + id: mmwave_start_after_boot diff --git a/tests/components/esp32_ble_server/common.yaml b/tests/components/esp32_ble_server/common.yaml index 696f4ea8fe..e9576a8262 100644 --- a/tests/components/esp32_ble_server/common.yaml +++ b/tests/components/esp32_ble_server/common.yaml @@ -2,6 +2,7 @@ esp32_ble_server: id: ble_server manufacturer_data: [0x72, 0x4, 0x00, 0x23] manufacturer: ESPHome + appearance: 0x1 model: Test on_connect: - lambda: |- diff --git a/tests/components/esp32_ble_tracker/test.esp32-ard.yaml b/tests/components/esp32_ble_tracker/test.esp32-ard.yaml index 070fffd68b..3bfdb8773f 100644 --- a/tests/components/esp32_ble_tracker/test.esp32-ard.yaml +++ b/tests/components/esp32_ble_tracker/test.esp32-ard.yaml @@ -1,4 +1,5 @@ <<: !include common.yaml esp32_ble_tracker: + software_coexistence: true max_connections: 3 diff --git a/tests/components/esp32_ble_tracker/test.esp32-c3-ard.yaml b/tests/components/esp32_ble_tracker/test.esp32-c3-ard.yaml index 070fffd68b..2e3c48117a 100644 --- a/tests/components/esp32_ble_tracker/test.esp32-c3-ard.yaml +++ b/tests/components/esp32_ble_tracker/test.esp32-c3-ard.yaml @@ -2,3 +2,4 @@ esp32_ble_tracker: max_connections: 3 + software_coexistence: false diff --git a/tests/components/esp32_ble_tracker/test.esp32-c3-idf.yaml b/tests/components/esp32_ble_tracker/test.esp32-c3-idf.yaml index 5e09f5020e..b71896bad5 100644 --- a/tests/components/esp32_ble_tracker/test.esp32-c3-idf.yaml +++ b/tests/components/esp32_ble_tracker/test.esp32-c3-idf.yaml @@ -2,3 +2,4 @@ esp32_ble_tracker: max_connections: 9 + software_coexistence: false diff --git a/tests/components/esp32_ble_tracker/test.esp32-idf.yaml b/tests/components/esp32_ble_tracker/test.esp32-idf.yaml index 5e09f5020e..1ffcfb9988 100644 --- a/tests/components/esp32_ble_tracker/test.esp32-idf.yaml +++ b/tests/components/esp32_ble_tracker/test.esp32-idf.yaml @@ -1,4 +1,5 @@ <<: !include common.yaml esp32_ble_tracker: + software_coexistence: true max_connections: 9 diff --git a/tests/components/http_request/common.yaml b/tests/components/http_request/common.yaml index 8408f27a05..af4852901f 100644 --- a/tests/components/http_request/common.yaml +++ b/tests/components/http_request/common.yaml @@ -1,5 +1,4 @@ -substitutions: - verify_ssl: "true" +<<: !include http_request.yaml wifi: ssid: MySSID @@ -10,27 +9,30 @@ esphome: then: - http_request.get: url: https://esphome.io - headers: + request_headers: Content-Type: application/json + collect_headers: + - age on_error: logger.log: "Request failed" on_response: then: - logger.log: - format: "Response status: %d, Duration: %lu ms" + format: "Response status: %d, Duration: %lu ms, age: %s" args: - response->status_code - (long) response->duration_ms + - response->get_response_header("age").c_str() - http_request.post: url: https://esphome.io - headers: + request_headers: Content-Type: application/json json: key: value - http_request.send: method: PUT url: https://esphome.io - headers: + request_headers: Content-Type: application/json body: "Some data" diff --git a/tests/components/http_request/http_request.yaml b/tests/components/http_request/http_request.yaml new file mode 100644 index 0000000000..ea7f6bf5a7 --- /dev/null +++ b/tests/components/http_request/http_request.yaml @@ -0,0 +1,46 @@ +substitutions: + verify_ssl: "true" + +network: + +esphome: + on_boot: + then: + - http_request.get: + url: https://esphome.io + request_headers: + Content-Type: application/json + on_error: + logger.log: "Request failed" + on_response: + then: + - logger.log: + format: "Response status: %d, Duration: %lu ms" + args: + - response->status_code + - (long) response->duration_ms + - http_request.post: + url: https://esphome.io + request_headers: + Content-Type: application/json + json: + key: value + - http_request.send: + method: PUT + url: https://esphome.io + request_headers: + Content-Type: application/json + body: "Some data" + +http_request: + useragent: esphome/tagreader + timeout: 10s + verify_ssl: ${verify_ssl} + +script: + - id: does_not_compile + parameters: + api_url: string + then: + - http_request.get: + url: "http://google.com" diff --git a/tests/components/http_request/test.host.yaml b/tests/components/http_request/test.host.yaml new file mode 100644 index 0000000000..e91445fb2d --- /dev/null +++ b/tests/components/http_request/test.host.yaml @@ -0,0 +1,7 @@ +substitutions: + verify_ssl: "true" +http_request: + # Just a file we can be sure exists + ca_certificate_path: /etc/passwd + +<<: !include http_request.yaml diff --git a/tests/components/image/common.yaml b/tests/components/image/common.yaml index 4c9b9ed670..864ca41c44 100644 --- a/tests/components/image/common.yaml +++ b/tests/components/image/common.yaml @@ -69,3 +69,18 @@ image: - id: another_alert_icon file: mdi:alert-outline type: BINARY + - file: mdil:arrange-bring-to-front + id: mdil_id + resize: 50x50 + type: binary + transparency: chroma_key + - file: mdi:beer + id: mdi_id + resize: 50x50 + type: binary + transparency: chroma_key + - file: memory:alert-octagon + id: memory_id + resize: 50x50 + type: binary + transparency: chroma_key diff --git a/tests/components/key_collector/common.yaml b/tests/components/key_collector/common.yaml index d58922ca91..12e541c865 100644 --- a/tests/components/key_collector/common.yaml +++ b/tests/components/key_collector/common.yaml @@ -26,3 +26,11 @@ key_collector: - logger.log: format: "input timeout: '%s', started by '%c'" args: ['x.c_str()', "(start == 0 ? '~' : start)"] + enable_on_boot: false + +button: + - platform: template + id: button0 + on_press: + - key_collector.enable: + - key_collector.disable: diff --git a/tests/components/lock/common.yaml b/tests/components/lock/common.yaml index 82297a3da4..67da46653c 100644 --- a/tests/components/lock/common.yaml +++ b/tests/components/lock/common.yaml @@ -27,9 +27,7 @@ lock: id: test_lock1 state: !lambda "return LOCK_STATE_UNLOCKED;" on_lock: - - lock.template.publish: - id: test_lock1 - state: !lambda "return LOCK_STATE_LOCKED;" + - lock.template.publish: LOCKED - platform: output name: Generic Output Lock id: test_lock2 diff --git a/tests/components/logger/test-custom_buffer_size.esp32-idf.yaml b/tests/components/logger/test-custom_buffer_size.esp32-idf.yaml new file mode 100644 index 0000000000..9a396ca023 --- /dev/null +++ b/tests/components/logger/test-custom_buffer_size.esp32-idf.yaml @@ -0,0 +1,5 @@ +<<: !include common-default_uart.yaml + +logger: + id: logger_id + task_log_buffer_size: 1024B # Set a custom buffer size diff --git a/tests/components/logger/test-disable_log_buffer.esp32-idf.yaml b/tests/components/logger/test-disable_log_buffer.esp32-idf.yaml new file mode 100644 index 0000000000..4260f178f9 --- /dev/null +++ b/tests/components/logger/test-disable_log_buffer.esp32-idf.yaml @@ -0,0 +1,5 @@ +<<: !include common-default_uart.yaml + +logger: + id: logger_id + task_log_buffer_size: 0 diff --git a/tests/components/logger/test-max_buffer_size.esp32-idf.yaml b/tests/components/logger/test-max_buffer_size.esp32-idf.yaml new file mode 100644 index 0000000000..f6c3eae677 --- /dev/null +++ b/tests/components/logger/test-max_buffer_size.esp32-idf.yaml @@ -0,0 +1,5 @@ +<<: !include common-default_uart.yaml + +logger: + id: logger_id + task_log_buffer_size: 32768B # Maximum buffer size diff --git a/tests/components/logger/test-min_buffer_size.esp32-idf.yaml b/tests/components/logger/test-min_buffer_size.esp32-idf.yaml new file mode 100644 index 0000000000..715b0580ed --- /dev/null +++ b/tests/components/logger/test-min_buffer_size.esp32-idf.yaml @@ -0,0 +1,5 @@ +<<: !include common-default_uart.yaml + +logger: + id: logger_id + task_log_buffer_size: 640B # Minimum buffer size with thread names diff --git a/tests/components/lvgl/lvgl-package.yaml b/tests/components/lvgl/lvgl-package.yaml index 6fd0b5e3c4..db55da9225 100644 --- a/tests/components/lvgl/lvgl-package.yaml +++ b/tests/components/lvgl/lvgl-package.yaml @@ -212,7 +212,7 @@ lvgl: - animimg: height: 60 id: anim_img - src: [cat_image, dog_image] + src: !lambda "return {dog_image, cat_image};" repeat_count: 10 duration: 1s auto_start: true @@ -224,6 +224,7 @@ lvgl: id: anim_img src: !lambda "return {dog_image, cat_image};" duration: 2s + - lvgl.widget.refresh: anim_img - label: on_boot: lvgl.label.update: diff --git a/tests/components/mapping/common.yaml b/tests/components/mapping/common.yaml new file mode 100644 index 0000000000..07ca458146 --- /dev/null +++ b/tests/components/mapping/common.yaml @@ -0,0 +1,71 @@ +image: + grayscale: + alpha_channel: + - file: ../../pnglogo.png + id: image_1 + resize: 50x50 + - file: ../../pnglogo.png + id: image_2 + resize: 50x50 + +mapping: + - id: weather_map + from: string + to: "image::Image" + entries: + clear-night: image_1 + sunny: image_2 + - id: weather_map_1 + from: string + to: esphome::image::Image + entries: + clear-night: image_1 + sunny: image_2 + - id: weather_map_2 + from: string + to: image + entries: + clear-night: image_1 + sunny: image_2 + - id: int_map + from: int + to: string + entries: + 1: "one" + 2: "two" + 3: "three" + 77: "seventy-seven" + - id: string_map + from: string + to: int + entries: + one: 1 + two: 2 + three: 3 + seventy-seven: 77 + - id: color_map + from: string + to: color + entries: + red: red_id + blue: blue_id + green: green_id + +color: + - id: red_id + red: 1.0 + green: 0.0 + blue: 0.0 + - id: green_id + red: 0.0 + green: 1.0 + blue: 0.0 + - id: blue_id + red: 0.0 + green: 0.0 + blue: 1.0 + +display: + lambda: |- + it.image(0, 0, id(weather_map)[0]); + it.image(0, 100, id(weather_map)[1]); diff --git a/tests/components/mapping/test.esp32-ard.yaml b/tests/components/mapping/test.esp32-ard.yaml new file mode 100644 index 0000000000..951a6061f6 --- /dev/null +++ b/tests/components/mapping/test.esp32-ard.yaml @@ -0,0 +1,17 @@ +spi: + - id: spi_main_lcd + clk_pin: 16 + mosi_pin: 17 + miso_pin: 15 + +display: + - platform: ili9xxx + id: main_lcd + model: ili9342 + cs_pin: 12 + dc_pin: 13 + reset_pin: 21 + invert_colors: false + +packages: + map: !include common.yaml diff --git a/tests/components/mapping/test.esp32-c3-ard.yaml b/tests/components/mapping/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..55e5719e50 --- /dev/null +++ b/tests/components/mapping/test.esp32-c3-ard.yaml @@ -0,0 +1,17 @@ +spi: + - id: spi_main_lcd + clk_pin: 6 + mosi_pin: 7 + miso_pin: 5 + +display: + - platform: ili9xxx + id: main_lcd + model: ili9342 + cs_pin: 8 + dc_pin: 9 + reset_pin: 10 + invert_colors: false + +packages: + map: !include common.yaml diff --git a/tests/components/mapping/test.esp32-c3-idf.yaml b/tests/components/mapping/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..55e5719e50 --- /dev/null +++ b/tests/components/mapping/test.esp32-c3-idf.yaml @@ -0,0 +1,17 @@ +spi: + - id: spi_main_lcd + clk_pin: 6 + mosi_pin: 7 + miso_pin: 5 + +display: + - platform: ili9xxx + id: main_lcd + model: ili9342 + cs_pin: 8 + dc_pin: 9 + reset_pin: 10 + invert_colors: false + +packages: + map: !include common.yaml diff --git a/tests/components/mapping/test.esp32-idf.yaml b/tests/components/mapping/test.esp32-idf.yaml new file mode 100644 index 0000000000..951a6061f6 --- /dev/null +++ b/tests/components/mapping/test.esp32-idf.yaml @@ -0,0 +1,17 @@ +spi: + - id: spi_main_lcd + clk_pin: 16 + mosi_pin: 17 + miso_pin: 15 + +display: + - platform: ili9xxx + id: main_lcd + model: ili9342 + cs_pin: 12 + dc_pin: 13 + reset_pin: 21 + invert_colors: false + +packages: + map: !include common.yaml diff --git a/tests/components/mapping/test.esp8266-ard.yaml b/tests/components/mapping/test.esp8266-ard.yaml new file mode 100644 index 0000000000..dd4642b8fe --- /dev/null +++ b/tests/components/mapping/test.esp8266-ard.yaml @@ -0,0 +1,17 @@ +spi: + - id: spi_main_lcd + clk_pin: 14 + mosi_pin: 13 + miso_pin: 12 + +display: + - platform: ili9xxx + id: main_lcd + model: ili9342 + cs_pin: 5 + dc_pin: 15 + reset_pin: 16 + invert_colors: false + +packages: + map: !include common.yaml diff --git a/tests/components/mapping/test.host.yaml b/tests/components/mapping/test.host.yaml new file mode 100644 index 0000000000..98406767a4 --- /dev/null +++ b/tests/components/mapping/test.host.yaml @@ -0,0 +1,12 @@ +display: + - platform: sdl + id: sdl_display + update_interval: 1s + auto_clear_enabled: false + show_test_card: true + dimensions: + width: 450 + height: 600 + +packages: + map: !include common.yaml diff --git a/tests/components/mapping/test.rp2040-ard.yaml b/tests/components/mapping/test.rp2040-ard.yaml new file mode 100644 index 0000000000..1b7e796246 --- /dev/null +++ b/tests/components/mapping/test.rp2040-ard.yaml @@ -0,0 +1,17 @@ +spi: + - id: spi_main_lcd + clk_pin: 2 + mosi_pin: 3 + miso_pin: 4 + +display: + - platform: ili9xxx + id: main_lcd + model: ili9342 + cs_pin: 20 + dc_pin: 21 + reset_pin: 22 + invert_colors: false + +packages: + map: !include common.yaml diff --git a/tests/components/mdns/common-enabled.yaml b/tests/components/mdns/common-enabled.yaml index bc31e32783..8b3d81cf69 100644 --- a/tests/components/mdns/common-enabled.yaml +++ b/tests/components/mdns/common-enabled.yaml @@ -4,3 +4,10 @@ wifi: mdns: disabled: false + services: + - service: _test_service + protocol: _tcp + port: 8888 + txt: + static_string: Anything + templated_string: !lambda return "Something else"; diff --git a/tests/components/micro_wake_word/common.yaml b/tests/components/micro_wake_word/common.yaml index 8bd7345307..c051c8dd57 100644 --- a/tests/components/micro_wake_word/common.yaml +++ b/tests/components/micro_wake_word/common.yaml @@ -8,12 +8,30 @@ microphone: i2s_din_pin: GPIO17 adc_type: external pdm: true + bits_per_sample: 16bit micro_wake_word: + microphone: echo_microphone on_wake_word_detected: - logger.log: "Wake word detected" + - micro_wake_word.stop: + - if: + condition: + - micro_wake_word.model_is_enabled: hey_jarvis_model + then: + - micro_wake_word.disable_model: hey_jarvis_model + else: + - micro_wake_word.enable_model: hey_jarvis_model + - if: + condition: + - not: + - micro_wake_word.is_running: + then: + micro_wake_word.start: + stop_after_detection: false models: - model: hey_jarvis probability_cutoff: 0.7 + id: hey_jarvis_model - model: okay_nabu sliding_window_size: 5 diff --git a/tests/components/microphone/common.yaml b/tests/components/microphone/common.yaml index ea79266281..00d33bcc3d 100644 --- a/tests/components/microphone/common.yaml +++ b/tests/components/microphone/common.yaml @@ -9,3 +9,13 @@ microphone: i2s_din_pin: ${i2s_din_pin} adc_type: external pdm: false + mclk_multiple: 384 + correct_dc_offset: true + on_data: + - if: + condition: + - microphone.is_muted: + then: + - microphone.unmute: + else: + - microphone.mute: diff --git a/tests/components/microphone/test.esp32-idf.yaml b/tests/components/microphone/test.esp32-idf.yaml index 392df582cc..fe9feb9888 100644 --- a/tests/components/microphone/test.esp32-idf.yaml +++ b/tests/components/microphone/test.esp32-idf.yaml @@ -4,9 +4,18 @@ substitutions: i2s_mclk_pin: GPIO17 i2s_din_pin: GPIO33 -<<: !include common.yaml +i2s_audio: + i2s_bclk_pin: ${i2s_bclk_pin} + i2s_lrclk_pin: ${i2s_lrclk_pin} + i2s_mclk_pin: ${i2s_mclk_pin} + use_legacy: true microphone: + - platform: i2s_audio + id: mic_id_external + i2s_din_pin: ${i2s_din_pin} + adc_type: external + pdm: false - platform: i2s_audio id: mic_id_adc adc_pin: 32 diff --git a/tests/components/mipi_spi/common.yaml b/tests/components/mipi_spi/common.yaml new file mode 100644 index 0000000000..e4b1e2b30c --- /dev/null +++ b/tests/components/mipi_spi/common.yaml @@ -0,0 +1,38 @@ +spi: + - id: spi_single + clk_pin: + number: ${clk_pin} + allow_other_uses: true + mosi_pin: + number: ${mosi_pin} + +display: + - platform: mipi_spi + spi_16: true + pixel_mode: 18bit + model: ili9488 + dc_pin: ${dc_pin} + cs_pin: ${cs_pin} + reset_pin: ${reset_pin} + data_rate: 20MHz + invert_colors: true + show_test_card: true + spi_mode: mode0 + draw_rounding: 8 + use_axis_flips: true + init_sequence: + - [0xd0, 1, 2, 3] + - delay 10ms + transform: + swap_xy: true + mirror_x: false + mirror_y: true + dimensions: + width: 100 + height: 200 + enable_pin: + - number: ${clk_pin} + allow_other_uses: true + - number: ${enable_pin} + bus_mode: single + diff --git a/tests/components/mipi_spi/test-esp32-2432s028.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-esp32-2432s028.esp32-s3-idf.yaml new file mode 100644 index 0000000000..a28776798c --- /dev/null +++ b/tests/components/mipi_spi/test-esp32-2432s028.esp32-s3-idf.yaml @@ -0,0 +1,41 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: ESP32-2432S028 diff --git a/tests/components/mipi_spi/test-jc3248w535.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-jc3248w535.esp32-s3-idf.yaml new file mode 100644 index 0000000000..02b8f78d58 --- /dev/null +++ b/tests/components/mipi_spi/test-jc3248w535.esp32-s3-idf.yaml @@ -0,0 +1,41 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: JC3248W535 diff --git a/tests/components/mipi_spi/test-jc3636w518.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-jc3636w518.esp32-s3-idf.yaml new file mode 100644 index 0000000000..147d4833ac --- /dev/null +++ b/tests/components/mipi_spi/test-jc3636w518.esp32-s3-idf.yaml @@ -0,0 +1,19 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 36 + data_pins: + - number: 40 + - number: 41 + - number: 42 + - number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: JC3636W518 diff --git a/tests/components/mipi_spi/test-pico-restouch-lcd-35.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-pico-restouch-lcd-35.esp32-s3-idf.yaml new file mode 100644 index 0000000000..8d96f31fd5 --- /dev/null +++ b/tests/components/mipi_spi/test-pico-restouch-lcd-35.esp32-s3-idf.yaml @@ -0,0 +1,9 @@ +spi: + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: Pico-ResTouch-LCD-3.5 diff --git a/tests/components/mipi_spi/test-s3box.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-s3box.esp32-s3-idf.yaml new file mode 100644 index 0000000000..98f6955bf3 --- /dev/null +++ b/tests/components/mipi_spi/test-s3box.esp32-s3-idf.yaml @@ -0,0 +1,41 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: S3BOX diff --git a/tests/components/mipi_spi/test-s3boxlite.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-s3boxlite.esp32-s3-idf.yaml new file mode 100644 index 0000000000..11ad869d54 --- /dev/null +++ b/tests/components/mipi_spi/test-s3boxlite.esp32-s3-idf.yaml @@ -0,0 +1,41 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: S3BOXLITE diff --git a/tests/components/mipi_spi/test-t-display-s3-amoled-plus.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t-display-s3-amoled-plus.esp32-s3-idf.yaml new file mode 100644 index 0000000000..dc328f950c --- /dev/null +++ b/tests/components/mipi_spi/test-t-display-s3-amoled-plus.esp32-s3-idf.yaml @@ -0,0 +1,9 @@ +spi: + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: T-DISPLAY-S3-AMOLED-PLUS diff --git a/tests/components/mipi_spi/test-t-display-s3-amoled.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t-display-s3-amoled.esp32-s3-idf.yaml new file mode 100644 index 0000000000..f0432270dc --- /dev/null +++ b/tests/components/mipi_spi/test-t-display-s3-amoled.esp32-s3-idf.yaml @@ -0,0 +1,15 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - number: 40 + - number: 41 + - number: 42 + - number: 43 + +display: + - platform: mipi_spi + model: T-DISPLAY-S3-AMOLED diff --git a/tests/components/mipi_spi/test-t-display-s3-pro.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t-display-s3-pro.esp32-s3-idf.yaml new file mode 100644 index 0000000000..5cda38e096 --- /dev/null +++ b/tests/components/mipi_spi/test-t-display-s3-pro.esp32-s3-idf.yaml @@ -0,0 +1,9 @@ +spi: + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 40 + +display: + - platform: mipi_spi + model: T-DISPLAY-S3-PRO diff --git a/tests/components/mipi_spi/test-t-display-s3.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t-display-s3.esp32-s3-idf.yaml new file mode 100644 index 0000000000..144bde8366 --- /dev/null +++ b/tests/components/mipi_spi/test-t-display-s3.esp32-s3-idf.yaml @@ -0,0 +1,37 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + +display: + - platform: mipi_spi + model: T-DISPLAY-S3 diff --git a/tests/components/mipi_spi/test-t-display.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t-display.esp32-s3-idf.yaml new file mode 100644 index 0000000000..39339b5ae2 --- /dev/null +++ b/tests/components/mipi_spi/test-t-display.esp32-s3-idf.yaml @@ -0,0 +1,41 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: T-DISPLAY diff --git a/tests/components/mipi_spi/test-t-embed.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t-embed.esp32-s3-idf.yaml new file mode 100644 index 0000000000..6c9edb25b3 --- /dev/null +++ b/tests/components/mipi_spi/test-t-embed.esp32-s3-idf.yaml @@ -0,0 +1,9 @@ +spi: + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 40 + +display: + - platform: mipi_spi + model: T-EMBED diff --git a/tests/components/mipi_spi/test-t4-s3.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-t4-s3.esp32-s3-idf.yaml new file mode 100644 index 0000000000..46eaedb7cb --- /dev/null +++ b/tests/components/mipi_spi/test-t4-s3.esp32-s3-idf.yaml @@ -0,0 +1,41 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 0 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: spi_id_3 + interface: any + clk_pin: 8 + mosi_pin: 9 + +display: + - platform: mipi_spi + model: T4-S3 diff --git a/tests/components/mipi_spi/test-wt32-sc01-plus.esp32-s3-idf.yaml b/tests/components/mipi_spi/test-wt32-sc01-plus.esp32-s3-idf.yaml new file mode 100644 index 0000000000..3efb05ec89 --- /dev/null +++ b/tests/components/mipi_spi/test-wt32-sc01-plus.esp32-s3-idf.yaml @@ -0,0 +1,37 @@ +spi: + - id: quad_spi + type: quad + interface: spi3 + clk_pin: + number: 47 + data_pins: + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + - id: octal_spi + type: octal + interface: hardware + clk_pin: + number: 9 + data_pins: + - 36 + - 37 + - 38 + - 39 + - allow_other_uses: true + number: 40 + - allow_other_uses: true + number: 41 + - allow_other_uses: true + number: 42 + - allow_other_uses: true + number: 43 + +display: + - platform: mipi_spi + model: WT32-SC01-PLUS diff --git a/tests/components/mipi_spi/test.esp32-ard.yaml b/tests/components/mipi_spi/test.esp32-ard.yaml new file mode 100644 index 0000000000..a5ef77dabc --- /dev/null +++ b/tests/components/mipi_spi/test.esp32-ard.yaml @@ -0,0 +1,15 @@ +substitutions: + clk_pin: GPIO16 + mosi_pin: GPIO17 + miso_pin: GPIO15 + dc_pin: GPIO14 + cs_pin: GPIO13 + enable_pin: GPIO19 + reset_pin: GPIO20 + +display: + - platform: mipi_spi + model: LANBON-L8 + +packages: + display: !include common.yaml diff --git a/tests/components/mipi_spi/test.esp32-c3-ard.yaml b/tests/components/mipi_spi/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..c17748c569 --- /dev/null +++ b/tests/components/mipi_spi/test.esp32-c3-ard.yaml @@ -0,0 +1,10 @@ +substitutions: + clk_pin: GPIO6 + mosi_pin: GPIO7 + miso_pin: GPIO5 + dc_pin: GPIO21 + cs_pin: GPIO18 + enable_pin: GPIO19 + reset_pin: GPIO20 + +<<: !include common.yaml diff --git a/tests/components/mipi_spi/test.esp32-c3-idf.yaml b/tests/components/mipi_spi/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..c17748c569 --- /dev/null +++ b/tests/components/mipi_spi/test.esp32-c3-idf.yaml @@ -0,0 +1,10 @@ +substitutions: + clk_pin: GPIO6 + mosi_pin: GPIO7 + miso_pin: GPIO5 + dc_pin: GPIO21 + cs_pin: GPIO18 + enable_pin: GPIO19 + reset_pin: GPIO20 + +<<: !include common.yaml diff --git a/tests/components/mipi_spi/test.esp32-idf.yaml b/tests/components/mipi_spi/test.esp32-idf.yaml new file mode 100644 index 0000000000..653ccb4910 --- /dev/null +++ b/tests/components/mipi_spi/test.esp32-idf.yaml @@ -0,0 +1,15 @@ +substitutions: + clk_pin: GPIO16 + mosi_pin: GPIO17 + miso_pin: GPIO15 + dc_pin: GPIO21 + cs_pin: GPIO18 + enable_pin: GPIO19 + reset_pin: GPIO20 + +packages: + display: !include common.yaml + +display: + - platform: mipi_spi + model: m5core diff --git a/tests/components/mipi_spi/test.rp2040-ard.yaml b/tests/components/mipi_spi/test.rp2040-ard.yaml new file mode 100644 index 0000000000..5d7333853b --- /dev/null +++ b/tests/components/mipi_spi/test.rp2040-ard.yaml @@ -0,0 +1,10 @@ +substitutions: + clk_pin: GPIO2 + mosi_pin: GPIO3 + miso_pin: GPIO4 + dc_pin: GPIO14 + cs_pin: GPIO13 + enable_pin: GPIO19 + reset_pin: GPIO20 + +<<: !include common.yaml diff --git a/tests/components/mlx90393/common.yaml b/tests/components/mlx90393/common.yaml index 0b074f9be3..58f3b6ecf5 100644 --- a/tests/components/mlx90393/common.yaml +++ b/tests/components/mlx90393/common.yaml @@ -5,8 +5,7 @@ i2c: sensor: - platform: mlx90393 - oversampling: 1 - filter: 0 + oversampling: 3 gain: 1X temperature_compensation: true x_axis: diff --git a/tests/components/mqtt/common.yaml b/tests/components/mqtt/common.yaml index a4bdf58809..1ab8872fdb 100644 --- a/tests/components/mqtt/common.yaml +++ b/tests/components/mqtt/common.yaml @@ -293,6 +293,8 @@ fan: - platform: template name: Template Fan state_topic: some/topic/fan + direction_state_topic: some/topic/direction/state + direction_command_topic: some/topic/direction/command qos: 2 on_state: - logger.log: on_state diff --git a/tests/components/nextion/common.yaml b/tests/components/nextion/common.yaml index 589afcfefb..44d6cdfbc9 100644 --- a/tests/components/nextion/common.yaml +++ b/tests/components/nextion/common.yaml @@ -280,6 +280,7 @@ display: - platform: nextion id: main_lcd update_interval: 5s + command_spacing: 5ms on_sleep: then: lambda: 'ESP_LOGD("display","Display went to sleep");' diff --git a/tests/components/packages/package.yaml b/tests/components/packages/package.yaml new file mode 100644 index 0000000000..672d66151e --- /dev/null +++ b/tests/components/packages/package.yaml @@ -0,0 +1,3 @@ +sensor: + - platform: template + id: package_sensor diff --git a/tests/components/packages/test.esp32-ard.yaml b/tests/components/packages/test.esp32-ard.yaml new file mode 100644 index 0000000000..d35c27d997 --- /dev/null +++ b/tests/components/packages/test.esp32-ard.yaml @@ -0,0 +1,11 @@ +packages: + - sensor: + - platform: template + id: inline_sensor + - !include package.yaml + - github://esphome/esphome/tests/components/template/common.yaml@dev + - url: https://github.com/esphome/esphome + file: tests/components/binary_sensor_map/common.yaml + ref: dev + refresh: 1d + diff --git a/tests/components/packages/test.esp32-idf.yaml b/tests/components/packages/test.esp32-idf.yaml new file mode 100644 index 0000000000..9f1484d1fd --- /dev/null +++ b/tests/components/packages/test.esp32-idf.yaml @@ -0,0 +1,13 @@ +packages: + sensor: + sensor: + - platform: template + id: inline_sensor + local: !include package.yaml + shorthand: github://esphome/esphome/tests/components/template/common.yaml@dev + github: + url: https://github.com/esphome/esphome + file: tests/components/binary_sensor_map/common.yaml + ref: dev + refresh: 1d + diff --git a/tests/components/packet_transport/common.yaml b/tests/components/packet_transport/common.yaml new file mode 100644 index 0000000000..cbb34c4572 --- /dev/null +++ b/tests/components/packet_transport/common.yaml @@ -0,0 +1,40 @@ +wifi: + ssid: MySSID + password: password1 + +udp: + listen_address: 239.0.60.53 + addresses: ["239.0.60.53"] + +packet_transport: + platform: udp + update_interval: 5s + encryption: "our key goes here" + rolling_code_enable: true + ping_pong_enable: true + binary_sensors: + - binary_sensor_id1 + - id: binary_sensor_id1 + broadcast_id: other_id + sensors: + - sensor_id1 + - id: sensor_id1 + broadcast_id: other_id + providers: + - name: some-device-name + encryption: "their key goes here" + +sensor: + - platform: template + id: sensor_id1 + - platform: packet_transport + provider: some-device-name + id: our_id + remote_id: some_sensor_id + +binary_sensor: + - platform: packet_transport + provider: unencrypted-device + id: other_binary_sensor_id + - platform: template + id: binary_sensor_id1 diff --git a/tests/components/packet_transport/test.bk72xx-ard.yaml b/tests/components/packet_transport/test.bk72xx-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.bk72xx-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-ard.yaml b/tests/components/packet_transport/test.esp32-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-c3-ard.yaml b/tests/components/packet_transport/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-c3-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-c3-idf.yaml b/tests/components/packet_transport/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-c3-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp32-idf.yaml b/tests/components/packet_transport/test.esp32-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp32-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.esp8266-ard.yaml b/tests/components/packet_transport/test.esp8266-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.esp8266-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/packet_transport/test.host.yaml b/tests/components/packet_transport/test.host.yaml new file mode 100644 index 0000000000..e735c37e4d --- /dev/null +++ b/tests/components/packet_transport/test.host.yaml @@ -0,0 +1,4 @@ +packages: + common: !include common.yaml + +wifi: !remove diff --git a/tests/components/packet_transport/test.rp2040-ard.yaml b/tests/components/packet_transport/test.rp2040-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/packet_transport/test.rp2040-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/pm2005/common.yaml b/tests/components/pm2005/common.yaml new file mode 100644 index 0000000000..b8f6683b22 --- /dev/null +++ b/tests/components/pm2005/common.yaml @@ -0,0 +1,13 @@ +i2c: + - id: i2c_pm2005 + scl: ${scl_pin} + sda: ${sda_pin} + +sensor: + - platform: pm2005 + pm_1_0: + name: PM1.0 + pm_2_5: + name: PM2.5 + pm_10_0: + name: PM10.0 diff --git a/tests/components/pm2005/test.esp32-ard.yaml b/tests/components/pm2005/test.esp32-ard.yaml new file mode 100644 index 0000000000..63c3bd6afd --- /dev/null +++ b/tests/components/pm2005/test.esp32-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO16 + sda_pin: GPIO17 + +<<: !include common.yaml diff --git a/tests/components/pm2005/test.esp32-c3-ard.yaml b/tests/components/pm2005/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/pm2005/test.esp32-c3-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/pm2005/test.esp32-c3-idf.yaml b/tests/components/pm2005/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/pm2005/test.esp32-c3-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/pm2005/test.esp32-idf.yaml b/tests/components/pm2005/test.esp32-idf.yaml new file mode 100644 index 0000000000..63c3bd6afd --- /dev/null +++ b/tests/components/pm2005/test.esp32-idf.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO16 + sda_pin: GPIO17 + +<<: !include common.yaml diff --git a/tests/components/pm2005/test.esp8266-ard.yaml b/tests/components/pm2005/test.esp8266-ard.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/pm2005/test.esp8266-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/pm2005/test.rp2040-ard.yaml b/tests/components/pm2005/test.rp2040-ard.yaml new file mode 100644 index 0000000000..ee2c29ca4e --- /dev/null +++ b/tests/components/pm2005/test.rp2040-ard.yaml @@ -0,0 +1,5 @@ +substitutions: + scl_pin: GPIO5 + sda_pin: GPIO4 + +<<: !include common.yaml diff --git a/tests/components/prometheus/common.yaml b/tests/components/prometheus/common.yaml index 7c226b6782..131d135f8b 100644 --- a/tests/components/prometheus/common.yaml +++ b/tests/components/prometheus/common.yaml @@ -1,6 +1,3 @@ -substitutions: - verify_ssl: "false" - esphome: name: livingroomdevice friendly_name: Living Room Device @@ -129,6 +126,14 @@ valve: optimistic: true has_position: true +remote_transmitter: + pin: ${pin} + carrier_duty_percent: 50% + +climate: + - platform: climate_ir_lg + name: LG Climate + prometheus: include_internal: true relabel: diff --git a/tests/components/prometheus/test.esp32-ard.yaml b/tests/components/prometheus/test.esp32-ard.yaml index 3045a6db13..9eedaabd82 100644 --- a/tests/components/prometheus/test.esp32-ard.yaml +++ b/tests/components/prometheus/test.esp32-ard.yaml @@ -1,3 +1,7 @@ +substitutions: + verify_ssl: "false" + pin: GPIO5 + <<: !include common.yaml i2s_audio: diff --git a/tests/components/prometheus/test.esp32-c3-ard.yaml b/tests/components/prometheus/test.esp32-c3-ard.yaml index dade44d145..f00bca5947 100644 --- a/tests/components/prometheus/test.esp32-c3-ard.yaml +++ b/tests/components/prometheus/test.esp32-c3-ard.yaml @@ -1 +1,5 @@ +substitutions: + verify_ssl: "false" + pin: GPIO2 + <<: !include common.yaml diff --git a/tests/components/prometheus/test.esp32-c3-idf.yaml b/tests/components/prometheus/test.esp32-c3-idf.yaml index dade44d145..f00bca5947 100644 --- a/tests/components/prometheus/test.esp32-c3-idf.yaml +++ b/tests/components/prometheus/test.esp32-c3-idf.yaml @@ -1 +1,5 @@ +substitutions: + verify_ssl: "false" + pin: GPIO2 + <<: !include common.yaml diff --git a/tests/components/prometheus/test.esp32-idf.yaml b/tests/components/prometheus/test.esp32-idf.yaml index dade44d145..f00bca5947 100644 --- a/tests/components/prometheus/test.esp32-idf.yaml +++ b/tests/components/prometheus/test.esp32-idf.yaml @@ -1 +1,5 @@ +substitutions: + verify_ssl: "false" + pin: GPIO2 + <<: !include common.yaml diff --git a/tests/components/prometheus/test.esp8266-ard.yaml b/tests/components/prometheus/test.esp8266-ard.yaml index dade44d145..6ee1831769 100644 --- a/tests/components/prometheus/test.esp8266-ard.yaml +++ b/tests/components/prometheus/test.esp8266-ard.yaml @@ -1 +1,5 @@ +substitutions: + verify_ssl: "false" + pin: GPIO5 + <<: !include common.yaml diff --git a/tests/components/remote_receiver/common-actions.yaml b/tests/components/remote_receiver/common-actions.yaml index c1f576d20e..ca7713f58a 100644 --- a/tests/components/remote_receiver/common-actions.yaml +++ b/tests/components/remote_receiver/common-actions.yaml @@ -3,6 +3,11 @@ on_abbwelcome: - logger.log: format: "on_abbwelcome: %u" args: ["x.data()[0]"] +on_beo4: + then: + - logger.log: + format: "on_beo4: %u %u" + args: ["x.source", "x.command"] on_aeha: then: - logger.log: @@ -43,6 +48,11 @@ on_drayton: - logger.log: format: "on_drayton: %u %u %u" args: ["x.address", "x.channel", "x.command"] +on_gobox: + then: + - logger.log: + format: "on_gobox: %d" + args: ["x.code"] on_jvc: then: - logger.log: diff --git a/tests/components/remote_transmitter/common-buttons.yaml b/tests/components/remote_transmitter/common-buttons.yaml index b037c50e12..1fb7ef6dbe 100644 --- a/tests/components/remote_transmitter/common-buttons.yaml +++ b/tests/components/remote_transmitter/common-buttons.yaml @@ -1,4 +1,11 @@ button: + - platform: template + name: Beo4 audio mute + id: beo4_audio_mute + on_press: + remote_transmitter.transmit_beo4: + source: 0x01 + command: 0x0C - platform: template name: JVC Off id: living_room_lights_on diff --git a/tests/components/sound_level/common.yaml b/tests/components/sound_level/common.yaml new file mode 100644 index 0000000000..cc04f5bf79 --- /dev/null +++ b/tests/components/sound_level/common.yaml @@ -0,0 +1,26 @@ +i2s_audio: + i2s_lrclk_pin: ${i2s_bclk_pin} + i2s_bclk_pin: ${i2s_lrclk_pin} + +microphone: + - platform: i2s_audio + id: i2s_microphone + i2s_din_pin: ${i2s_dout_pin} + adc_type: external + bits_per_sample: 16bit + +sensor: + - platform: sound_level + microphone: i2s_microphone + measurement_duration: 2000ms + passive: false + peak: + name: "Peak Sound Level" + on_value_range: + - above: -1.0 + then: + - sound_level.stop: + - delay: 5s + - sound_level.start: + rms: + name: "RMS Sound Level" diff --git a/tests/components/sound_level/test.esp32-ard.yaml b/tests/components/sound_level/test.esp32-ard.yaml new file mode 100644 index 0000000000..c6d1bfa330 --- /dev/null +++ b/tests/components/sound_level/test.esp32-ard.yaml @@ -0,0 +1,6 @@ +substitutions: + i2s_bclk_pin: GPIO25 + i2s_lrclk_pin: GPIO26 + i2s_dout_pin: GPIO27 + +<<: !include common.yaml diff --git a/tests/components/sound_level/test.esp32-c3-ard.yaml b/tests/components/sound_level/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..aeb7d9f0af --- /dev/null +++ b/tests/components/sound_level/test.esp32-c3-ard.yaml @@ -0,0 +1,6 @@ +substitutions: + i2s_bclk_pin: GPIO6 + i2s_lrclk_pin: GPIO7 + i2s_dout_pin: GPIO8 + +<<: !include common.yaml diff --git a/tests/components/sound_level/test.esp32-c3-idf.yaml b/tests/components/sound_level/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..aeb7d9f0af --- /dev/null +++ b/tests/components/sound_level/test.esp32-c3-idf.yaml @@ -0,0 +1,6 @@ +substitutions: + i2s_bclk_pin: GPIO6 + i2s_lrclk_pin: GPIO7 + i2s_dout_pin: GPIO8 + +<<: !include common.yaml diff --git a/tests/components/sound_level/test.esp32-idf.yaml b/tests/components/sound_level/test.esp32-idf.yaml new file mode 100644 index 0000000000..c6d1bfa330 --- /dev/null +++ b/tests/components/sound_level/test.esp32-idf.yaml @@ -0,0 +1,6 @@ +substitutions: + i2s_bclk_pin: GPIO25 + i2s_lrclk_pin: GPIO26 + i2s_dout_pin: GPIO27 + +<<: !include common.yaml diff --git a/tests/components/sound_level/test.esp32-s3-ard.yaml b/tests/components/sound_level/test.esp32-s3-ard.yaml new file mode 100644 index 0000000000..9c1f32d5bd --- /dev/null +++ b/tests/components/sound_level/test.esp32-s3-ard.yaml @@ -0,0 +1,6 @@ +substitutions: + i2s_bclk_pin: GPIO4 + i2s_lrclk_pin: GPIO5 + i2s_dout_pin: GPIO6 + +<<: !include common.yaml diff --git a/tests/components/sound_level/test.esp32-s3-idf.yaml b/tests/components/sound_level/test.esp32-s3-idf.yaml new file mode 100644 index 0000000000..9c1f32d5bd --- /dev/null +++ b/tests/components/sound_level/test.esp32-s3-idf.yaml @@ -0,0 +1,6 @@ +substitutions: + i2s_bclk_pin: GPIO4 + i2s_lrclk_pin: GPIO5 + i2s_dout_pin: GPIO6 + +<<: !include common.yaml diff --git a/tests/components/syslog/common.yaml b/tests/components/syslog/common.yaml new file mode 100644 index 0000000000..cd6e63c9ec --- /dev/null +++ b/tests/components/syslog/common.yaml @@ -0,0 +1,15 @@ +wifi: + ssid: MySSID + password: password1 + +udp: + addresses: ["239.0.60.53"] + +time: + platform: host + +syslog: + port: 514 + strip: true + level: info + facility: 16 diff --git a/tests/components/syslog/test.bk72xx-ard.yaml b/tests/components/syslog/test.bk72xx-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.bk72xx-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-ard.yaml b/tests/components/syslog/test.esp32-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-c3-ard.yaml b/tests/components/syslog/test.esp32-c3-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-c3-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-c3-idf.yaml b/tests/components/syslog/test.esp32-c3-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-c3-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp32-idf.yaml b/tests/components/syslog/test.esp32-idf.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp32-idf.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.esp8266-ard.yaml b/tests/components/syslog/test.esp8266-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.esp8266-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/syslog/test.host.yaml b/tests/components/syslog/test.host.yaml new file mode 100644 index 0000000000..e735c37e4d --- /dev/null +++ b/tests/components/syslog/test.host.yaml @@ -0,0 +1,4 @@ +packages: + common: !include common.yaml + +wifi: !remove diff --git a/tests/components/syslog/test.rp2040-ard.yaml b/tests/components/syslog/test.rp2040-ard.yaml new file mode 100644 index 0000000000..dade44d145 --- /dev/null +++ b/tests/components/syslog/test.rp2040-ard.yaml @@ -0,0 +1 @@ +<<: !include common.yaml diff --git a/tests/components/template/common.yaml b/tests/components/template/common.yaml index 79201fbe07..fd9342b3e5 100644 --- a/tests/components/template/common.yaml +++ b/tests/components/template/common.yaml @@ -28,6 +28,16 @@ sensor: value: 20.0 - timeout: timeout: 1d + - to_ntc_resistance: + calibration: + - 10.0kOhm -> 25°C + - 27.219kOhm -> 0°C + - 14.674kOhm -> 15°C + - to_ntc_temperature: + calibration: + - 10.0kOhm -> 25°C + - 27.219kOhm -> 0°C + - 14.674kOhm -> 15°C esphome: on_boot: @@ -164,6 +174,8 @@ valve: - logger.log: open_action close_action: - logger.log: close_action + - valve.template.publish: + state: CLOSED stop_action: - logger.log: stop_action optimistic: true diff --git a/tests/components/tuya/common.yaml b/tests/components/tuya/common.yaml index fcf8a2d96b..2c40628139 100644 --- a/tests/components/tuya/common.yaml +++ b/tests/components/tuya/common.yaml @@ -60,12 +60,19 @@ number: select: - platform: tuya - id: tuya_select + id: tuya_select_enum enum_datapoint: 42 options: 0: Internal 1: Floor 2: Both + - platform: tuya + id: tuya_select_int + int_datapoint: 43 + options: + 0: Internal + 1: Floor + 2: Both sensor: - platform: tuya diff --git a/tests/components/uart/test.esp32-idf.yaml b/tests/components/uart/test.esp32-idf.yaml index bef5b460ab..5a0ed7eba7 100644 --- a/tests/components/uart/test.esp32-idf.yaml +++ b/tests/components/uart/test.esp32-idf.yaml @@ -13,3 +13,6 @@ uart: rx_buffer_size: 512 parity: EVEN stop_bits: 2 + +packet_transport: + - platform: uart diff --git a/tests/components/udp/common.yaml b/tests/components/udp/common.yaml index e533cb965e..79da02a692 100644 --- a/tests/components/udp/common.yaml +++ b/tests/components/udp/common.yaml @@ -3,34 +3,18 @@ wifi: password: password1 udp: - update_interval: 5s - encryption: "our key goes here" - rolling_code_enable: true - ping_pong_enable: true + id: my_udp listen_address: 239.0.60.53 - binary_sensors: - - binary_sensor_id1 - - id: binary_sensor_id1 - broadcast_id: other_id - sensors: - - sensor_id1 - - id: sensor_id1 - broadcast_id: other_id - providers: - - name: some-device-name - encryption: "their key goes here" + addresses: ["239.0.60.53"] + on_receive: + - logger.log: + format: "Received %d bytes" + args: [data.size()] + - udp.write: + id: my_udp + data: "hello world" + - udp.write: + id: my_udp + data: !lambda |- + return std::vector{1,3,4,5,6}; -sensor: - - platform: template - id: sensor_id1 - - platform: udp - provider: some-device-name - id: our_id - remote_id: some_sensor_id - -binary_sensor: - - platform: udp - provider: unencrypted-device - id: other_binary_sensor_id - - platform: template - id: binary_sensor_id1 diff --git a/tests/components/uptime/common.yaml b/tests/components/uptime/common.yaml index d78ef8eca9..86b764e7ff 100644 --- a/tests/components/uptime/common.yaml +++ b/tests/components/uptime/common.yaml @@ -17,3 +17,13 @@ sensor: text_sensor: - platform: uptime name: Uptime Text + - platform: uptime + name: Uptime Text With Separator + format: + separator: "-" + expand: true + days: "Days" + hours: "H" + minutes: "M" + seconds: "S" + update_interval: 10s diff --git a/tests/components/vl53l0x/common.yaml b/tests/components/vl53l0x/common.yaml index 973e481b1a..8346eae854 100644 --- a/tests/components/vl53l0x/common.yaml +++ b/tests/components/vl53l0x/common.yaml @@ -10,3 +10,4 @@ sensor: enable_pin: 3 timeout: 200us update_interval: 60s + timing_budget: 30000us diff --git a/tests/components/voice_assistant/common-idf.yaml b/tests/components/voice_assistant/common-idf.yaml new file mode 100644 index 0000000000..b1d249d5b4 --- /dev/null +++ b/tests/components/voice_assistant/common-idf.yaml @@ -0,0 +1,69 @@ +esphome: + on_boot: + then: + - voice_assistant.start + - voice_assistant.start_continuous + - voice_assistant.stop + +wifi: + ssid: MySSID + password: password1 + +api: + +i2s_audio: + i2s_lrclk_pin: ${i2s_lrclk_pin} + i2s_bclk_pin: ${i2s_bclk_pin} + i2s_mclk_pin: ${i2s_mclk_pin} + +micro_wake_word: + id: mww_id + on_wake_word_detected: + - voice_assistant.start: + wake_word: !lambda return wake_word; + models: + - model: okay_nabu + +microphone: + - platform: i2s_audio + id: mic_id_external + i2s_din_pin: ${i2s_din_pin} + adc_type: external + pdm: false + +speaker: + - platform: i2s_audio + id: speaker_id + dac_type: external + i2s_dout_pin: ${i2s_dout_pin} + +voice_assistant: + microphone: + microphone: mic_id_external + gain_factor: 4 + channels: 0 + speaker: speaker_id + micro_wake_word: mww_id + conversation_timeout: 60s + on_listening: + - logger.log: "Voice assistant microphone listening" + on_start: + - logger.log: "Voice assistant started" + on_stt_end: + - logger.log: + format: "Voice assistant STT ended with result %s" + args: [x.c_str()] + on_tts_start: + - logger.log: + format: "Voice assistant TTS started with text %s" + args: [x.c_str()] + on_tts_end: + - logger.log: + format: "Voice assistant TTS ended with url %s" + args: [x.c_str()] + on_end: + - logger.log: "Voice assistant ended" + on_error: + - logger.log: + format: "Voice assistant error - code %s, message: %s" + args: [code.c_str(), message.c_str()] diff --git a/tests/components/voice_assistant/common.yaml b/tests/components/voice_assistant/common.yaml index e7374941f7..f248154b7e 100644 --- a/tests/components/voice_assistant/common.yaml +++ b/tests/components/voice_assistant/common.yaml @@ -30,7 +30,10 @@ speaker: i2s_dout_pin: ${i2s_dout_pin} voice_assistant: - microphone: mic_id_external + microphone: + microphone: mic_id_external + gain_factor: 4 + channels: 0 speaker: speaker_id conversation_timeout: 60s on_listening: diff --git a/tests/components/voice_assistant/test.esp32-c3-idf.yaml b/tests/components/voice_assistant/test.esp32-c3-idf.yaml index f596d927cb..46745e4308 100644 --- a/tests/components/voice_assistant/test.esp32-c3-idf.yaml +++ b/tests/components/voice_assistant/test.esp32-c3-idf.yaml @@ -5,4 +5,4 @@ substitutions: i2s_din_pin: GPIO3 i2s_dout_pin: GPIO2 -<<: !include common.yaml +<<: !include common-idf.yaml diff --git a/tests/components/voice_assistant/test.esp32-idf.yaml b/tests/components/voice_assistant/test.esp32-idf.yaml index f6e553f9dc..0fe5d347be 100644 --- a/tests/components/voice_assistant/test.esp32-idf.yaml +++ b/tests/components/voice_assistant/test.esp32-idf.yaml @@ -5,4 +5,4 @@ substitutions: i2s_din_pin: GPIO13 i2s_dout_pin: GPIO12 -<<: !include common.yaml +<<: !include common-idf.yaml diff --git a/tests/components/waveshare_epaper/common.yaml b/tests/components/waveshare_epaper/common.yaml index 09ba1af778..a2aa3134b5 100644 --- a/tests/components/waveshare_epaper/common.yaml +++ b/tests/components/waveshare_epaper/common.yaml @@ -541,6 +541,26 @@ display: lambda: |- it.rectangle(0, 0, it.get_width(), it.get_height()); + # 5.65 inch displays + - platform: waveshare_epaper + id: epd_5_65 + model: 5.65in-f + spi_id: spi_waveshare_epaper + cs_pin: + allow_other_uses: true + number: ${cs_pin} + dc_pin: + allow_other_uses: true + number: ${dc_pin} + busy_pin: + allow_other_uses: true + number: ${busy_pin} + reset_pin: + allow_other_uses: true + number: ${reset_pin} + lambda: |- + it.rectangle(0, 0, it.get_width(), it.get_height()); + # 5.83 inch displays - platform: waveshare_epaper id: epd_5_83 diff --git a/tests/unit_tests/test_config_validation.py b/tests/unit_tests/test_config_validation.py index 3b2c72af2c..7a1354589c 100644 --- a/tests/unit_tests/test_config_validation.py +++ b/tests/unit_tests/test_config_validation.py @@ -284,3 +284,93 @@ def test_split_default(framework, platform, variant, full, idf, arduino, simple) assert schema({}).get("idf") == idf assert schema({}).get("arduino") == arduino assert schema({}).get("simple") == simple + + +@pytest.mark.parametrize( + "framework, platform, message", + [ + ("esp-idf", PLATFORM_ESP32, "ESP32 using esp-idf framework"), + ("arduino", PLATFORM_ESP32, "ESP32 using arduino framework"), + ("arduino", PLATFORM_ESP8266, "ESP8266 using arduino framework"), + ("arduino", PLATFORM_RP2040, "RP2040 using arduino framework"), + ("arduino", PLATFORM_BK72XX, "BK72XX using arduino framework"), + ("host", PLATFORM_HOST, "HOST using host framework"), + ], +) +def test_require_framework_version(framework, platform, message): + import voluptuous as vol + + from esphome.const import ( + KEY_CORE, + KEY_FRAMEWORK_VERSION, + KEY_TARGET_FRAMEWORK, + KEY_TARGET_PLATFORM, + ) + + CORE.data[KEY_CORE] = {} + CORE.data[KEY_CORE][KEY_TARGET_PLATFORM] = platform + CORE.data[KEY_CORE][KEY_TARGET_FRAMEWORK] = framework + CORE.data[KEY_CORE][KEY_FRAMEWORK_VERSION] = config_validation.Version(1, 0, 0) + + assert ( + config_validation.require_framework_version( + esp_idf=config_validation.Version(0, 5, 0), + esp32_arduino=config_validation.Version(0, 5, 0), + esp8266_arduino=config_validation.Version(0, 5, 0), + rp2040_arduino=config_validation.Version(0, 5, 0), + bk72xx_arduino=config_validation.Version(0, 5, 0), + host=config_validation.Version(0, 5, 0), + extra_message="test 1", + )("test") + == "test" + ) + + with pytest.raises( + vol.error.Invalid, + match="This feature requires at least framework version 2.0.0. test 2", + ): + config_validation.require_framework_version( + esp_idf=config_validation.Version(2, 0, 0), + esp32_arduino=config_validation.Version(2, 0, 0), + esp8266_arduino=config_validation.Version(2, 0, 0), + rp2040_arduino=config_validation.Version(2, 0, 0), + bk72xx_arduino=config_validation.Version(2, 0, 0), + host=config_validation.Version(2, 0, 0), + extra_message="test 2", + )("test") + + assert ( + config_validation.require_framework_version( + esp_idf=config_validation.Version(1, 5, 0), + esp32_arduino=config_validation.Version(1, 5, 0), + esp8266_arduino=config_validation.Version(1, 5, 0), + rp2040_arduino=config_validation.Version(1, 5, 0), + bk72xx_arduino=config_validation.Version(1, 5, 0), + host=config_validation.Version(1, 5, 0), + max_version=True, + extra_message="test 3", + )("test") + == "test" + ) + + with pytest.raises( + vol.error.Invalid, + match="This feature requires framework version 0.5.0 or lower. test 4", + ): + config_validation.require_framework_version( + esp_idf=config_validation.Version(0, 5, 0), + esp32_arduino=config_validation.Version(0, 5, 0), + esp8266_arduino=config_validation.Version(0, 5, 0), + rp2040_arduino=config_validation.Version(0, 5, 0), + bk72xx_arduino=config_validation.Version(0, 5, 0), + host=config_validation.Version(0, 5, 0), + max_version=True, + extra_message="test 4", + )("test") + + with pytest.raises( + vol.error.Invalid, match=f"This feature is incompatible with {message}. test 5" + ): + config_validation.require_framework_version( + extra_message="test 5", + )("test") diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index 862320b09e..b353d1aa99 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -267,3 +267,13 @@ def test_sanitize(text, expected): actual = helpers.sanitize(text) assert actual == expected + + +@pytest.mark.parametrize( + "text, expected", + ((["127.0.0.1", "fe80::1", "2001::2"], ["2001::2", "127.0.0.1", "fe80::1"]),), +) +def test_sort_ip_addresses(text: list[str], expected: list[str]) -> None: + actual = helpers.sort_ip_addresses(text) + + assert actual == expected diff --git a/tests/unit_tests/test_vscode.py b/tests/unit_tests/test_vscode.py index f5ebd63f60..6e0bde23b2 100644 --- a/tests/unit_tests/test_vscode.py +++ b/tests/unit_tests/test_vscode.py @@ -18,8 +18,12 @@ def _run_repl_test(input_data): vscode.read_config(args) # Capture printed output - full_output = "".join(call[0][0] for call in mock_stdout.write.call_args_list) - return full_output.strip().split("\n") + full_output = "".join( + call[0][0] for call in mock_stdout.write.call_args_list + ).strip() + splitted_output = full_output.split("\n") + remove_version = splitted_output[1:] # remove first entry with version info + return remove_version def _validate(file_path: str):