Compare commits

...

5 Commits

Author SHA1 Message Date
Stefan Agner
f2596feb17 Clamp progress to 100 to prevent floating point precision issues
Floating point arithmetic in weighted progress calculations can produce
values slightly above 100 (e.g., 100.00000000000001). This causes
validation errors when the progress value is checked.

Add min(100, ...) clamping to both size-weighted and count-based
progress calculations to ensure the result never exceeds 100.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 21:18:36 +01:00
Stefan Agner
dc46b5b45b Add registry manifest fetcher for size-based pull progress
Fetch image manifests directly from container registries before pulling
to get accurate layer sizes upfront. This enables size-weighted progress
tracking where each layer contributes proportionally to its byte size,
rather than equal weight per layer.

Key changes:
- Add RegistryManifestFetcher that handles auth discovery via
  WWW-Authenticate headers, token fetching with optional credentials,
  and multi-arch manifest list resolution
- Update ImagePullProgress to accept manifest layer sizes via
  set_manifest() and calculate size-weighted progress
- Fall back to count-based progress when manifest fetch fails
- Pre-populate layer sizes from manifest when creating layer trackers

The manifest fetcher supports ghcr.io, Docker Hub, and private
registries by using credentials from Docker config when available.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 21:18:34 +01:00
Stefan Agner
204206640f Exclude already-existing layers from pull progress calculation
Layers that already exist locally should not count towards download
progress since there's nothing to download for them. Only layers that
need pulling are included in the progress calculation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 21:17:03 +01:00
Stefan Agner
8c3ccfd83d Fix pytest 2025-12-02 21:17:03 +01:00
Stefan Agner
2615199172 Use count-based progress for Docker image pulls
Refactor Docker image pull progress to use a simpler count-based approach
where each layer contributes equally (100% / total_layers) regardless of
size. This replaces the previous size-weighted calculation that was
susceptible to progress regression.

The core issue was that Docker rate-limits concurrent downloads (~3 at a
time) and reports layer sizes only when downloading starts. With size-
weighted progress, large layers appearing late would cause progress to
drop dramatically (e.g., 59% -> 29%) as the total size increased.

The new approach:
- Each layer contributes equally to overall progress
- Per-layer progress: 70% download weight, 30% extraction weight
- Progress only starts after first "Downloading" event (when layer
  count is known)
- Always caps at 99% - job completion handles final 100%

This simplifies the code by moving progress tracking to a dedicated
module (pull_progress.py) and removing complex size-based scaling logic
that tried to account for unknown layer sizes.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 21:16:57 +01:00
13 changed files with 2018 additions and 301 deletions

View File

@@ -2,12 +2,9 @@
from __future__ import annotations
from contextlib import suppress
from enum import Enum, StrEnum
from functools import total_ordering
from enum import StrEnum
from pathlib import PurePath
import re
from typing import cast
from docker.types import Mount
@@ -15,9 +12,8 @@ from ..const import MACHINE_ID
RE_RETRYING_DOWNLOAD_STATUS = re.compile(r"Retrying in \d+ seconds?")
# Docker Hub registry identifier (official default)
# Docker's default registry is docker.io
DOCKER_HUB = "docker.io"
# Docker Hub registry identifier
DOCKER_HUB = "hub.docker.com"
# Legacy Docker Hub identifier for backward compatibility
DOCKER_HUB_LEGACY = "hub.docker.com"
@@ -82,57 +78,6 @@ class PropagationMode(StrEnum):
RSLAVE = "rslave"
@total_ordering
class PullImageLayerStage(Enum):
"""Job stages for pulling an image layer.
These are a subset of the statuses in a docker pull image log. They
are the standardized ones that are the most useful to us.
"""
PULLING_FS_LAYER = 1, "Pulling fs layer"
RETRYING_DOWNLOAD = 2, "Retrying download"
DOWNLOADING = 2, "Downloading"
VERIFYING_CHECKSUM = 3, "Verifying Checksum"
DOWNLOAD_COMPLETE = 4, "Download complete"
EXTRACTING = 5, "Extracting"
PULL_COMPLETE = 6, "Pull complete"
def __init__(self, order: int, status: str) -> None:
"""Set fields from values."""
self.order = order
self.status = status
def __eq__(self, value: object, /) -> bool:
"""Check equality, allow StrEnum style comparisons on status."""
with suppress(AttributeError):
return self.status == cast(PullImageLayerStage, value).status
return self.status == value
def __lt__(self, other: object) -> bool:
"""Order instances."""
with suppress(AttributeError):
return self.order < cast(PullImageLayerStage, other).order
return False
def __hash__(self) -> int:
"""Hash instance."""
return hash(self.status)
@classmethod
def from_status(cls, status: str) -> PullImageLayerStage | None:
"""Return stage instance from pull log status."""
for i in cls:
if i.status == status:
return i
# This one includes number of seconds until download so its not constant
if RE_RETRYING_DOWNLOAD_STATUS.match(status):
return cls.RETRYING_DOWNLOAD
return None
ENV_TIME = "TZ"
ENV_TOKEN = "SUPERVISOR_TOKEN"
ENV_TOKEN_OLD = "HASSIO_TOKEN"

View File

@@ -19,7 +19,6 @@ import docker
from docker.models.containers import Container
import requests
from ..bus import EventListener
from ..const import (
ATTR_PASSWORD,
ATTR_REGISTRY,
@@ -35,25 +34,18 @@ from ..exceptions import (
DockerError,
DockerHubRateLimitExceeded,
DockerJobError,
DockerLogOutOfOrder,
DockerNotFound,
DockerRequestError,
)
from ..jobs import SupervisorJob
from ..jobs.const import JOB_GROUP_DOCKER_INTERFACE, JobConcurrency
from ..jobs.decorator import Job
from ..jobs.job_group import JobGroup
from ..resolution.const import ContextType, IssueType, SuggestionType
from ..utils.sentry import async_capture_exception
from .const import (
DOCKER_HUB,
DOCKER_HUB_LEGACY,
ContainerState,
PullImageLayerStage,
RestartPolicy,
)
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, ContainerState, RestartPolicy
from .manager import CommandReturn, PullLogEntry
from .monitor import DockerContainerStateEvent
from .pull_progress import ImagePullProgress
from .stats import DockerStats
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -202,159 +194,6 @@ class DockerInterface(JobGroup, ABC):
return credentials
def _process_pull_image_log( # noqa: C901
self, install_job_id: str, reference: PullLogEntry
) -> None:
"""Process events fired from a docker while pulling an image, filtered to a given job id."""
if (
reference.job_id != install_job_id
or not reference.id
or not reference.status
or not (stage := PullImageLayerStage.from_status(reference.status))
):
return
# Pulling FS Layer is our marker for a layer that needs to be downloaded and extracted. Otherwise it already exists and we can ignore
job: SupervisorJob | None = None
if stage == PullImageLayerStage.PULLING_FS_LAYER:
job = self.sys_jobs.new_job(
name="Pulling container image layer",
initial_stage=stage.status,
reference=reference.id,
parent_id=install_job_id,
internal=True,
)
job.done = False
return
# Find our sub job to update details of
for j in self.sys_jobs.jobs:
if j.parent_id == install_job_id and j.reference == reference.id:
job = j
break
# There should no longer be any real risk of logs out of order anymore.
# However tests with very small images have shown that sometimes Docker
# skips stages in log. So keeping this one as a safety check on null job
if not job:
raise DockerLogOutOfOrder(
f"Received pull image log with status {reference.status} for image id {reference.id} and parent job {install_job_id} but could not find a matching job, skipping",
_LOGGER.debug,
)
# For progress calculation we assume downloading is 70% of time, extracting is 30% and others stages negligible
progress = job.progress
match stage:
case PullImageLayerStage.DOWNLOADING | PullImageLayerStage.EXTRACTING:
if (
reference.progress_detail
and reference.progress_detail.current
and reference.progress_detail.total
):
progress = (
reference.progress_detail.current
/ reference.progress_detail.total
)
if stage == PullImageLayerStage.DOWNLOADING:
progress = 70 * progress
else:
progress = 70 + 30 * progress
case (
PullImageLayerStage.VERIFYING_CHECKSUM
| PullImageLayerStage.DOWNLOAD_COMPLETE
):
progress = 70
case PullImageLayerStage.PULL_COMPLETE:
progress = 100
case PullImageLayerStage.RETRYING_DOWNLOAD:
progress = 0
# No real risk of getting things out of order in current implementation
# but keeping this one in case another change to these trips us up.
if stage != PullImageLayerStage.RETRYING_DOWNLOAD and progress < job.progress:
raise DockerLogOutOfOrder(
f"Received pull image log with status {reference.status} for job {job.uuid} that implied progress was {progress} but current progress is {job.progress}, skipping",
_LOGGER.debug,
)
# Our filters have all passed. Time to update the job
# Only downloading and extracting have progress details. Use that to set extra
# We'll leave it around on later stages as the total bytes may be useful after that stage
# Enforce range to prevent float drift error
progress = max(0, min(progress, 100))
if (
stage in {PullImageLayerStage.DOWNLOADING, PullImageLayerStage.EXTRACTING}
and reference.progress_detail
and reference.progress_detail.current is not None
and reference.progress_detail.total is not None
):
job.update(
progress=progress,
stage=stage.status,
extra={
"current": reference.progress_detail.current,
"total": reference.progress_detail.total,
},
)
else:
# If we reach DOWNLOAD_COMPLETE without ever having set extra (small layers that skip
# the downloading phase), set a minimal extra so aggregate progress calculation can proceed
extra = job.extra
if stage == PullImageLayerStage.DOWNLOAD_COMPLETE and not job.extra:
extra = {"current": 1, "total": 1}
job.update(
progress=progress,
stage=stage.status,
done=stage == PullImageLayerStage.PULL_COMPLETE,
extra=None if stage == PullImageLayerStage.RETRYING_DOWNLOAD else extra,
)
# Once we have received a progress update for every child job, start to set status of the main one
install_job = self.sys_jobs.get_job(install_job_id)
layer_jobs = [
job
for job in self.sys_jobs.jobs
if job.parent_id == install_job.uuid
and job.name == "Pulling container image layer"
]
# First set the total bytes to be downloaded/extracted on the main job
if not install_job.extra:
total = 0
for job in layer_jobs:
if not job.extra:
return
total += job.extra["total"]
install_job.extra = {"total": total}
else:
total = install_job.extra["total"]
# Then determine total progress based on progress of each sub-job, factoring in size of each compared to total
progress = 0.0
stage = PullImageLayerStage.PULL_COMPLETE
for job in layer_jobs:
if not job.extra or not job.extra.get("total"):
return
progress += job.progress * (job.extra["total"] / total)
job_stage = PullImageLayerStage.from_status(cast(str, job.stage))
if job_stage < PullImageLayerStage.EXTRACTING:
stage = PullImageLayerStage.DOWNLOADING
elif (
stage == PullImageLayerStage.PULL_COMPLETE
and job_stage < PullImageLayerStage.PULL_COMPLETE
):
stage = PullImageLayerStage.EXTRACTING
# Ensure progress is 100 at this point to prevent float drift
if stage == PullImageLayerStage.PULL_COMPLETE:
progress = 100
# To reduce noise, limit updates to when result has changed by an entire percent or when stage changed
if stage != install_job.stage or progress >= install_job.progress + 1:
install_job.update(stage=stage.status, progress=max(0, min(progress, 100)))
@Job(
name="docker_interface_install",
on_condition=DockerJobError,
@@ -374,33 +213,55 @@ class DockerInterface(JobGroup, ABC):
raise ValueError("Cannot pull without an image!")
image_arch = arch or self.sys_arch.supervisor
listener: EventListener | None = None
platform = MAP_ARCH[image_arch]
pull_progress = ImagePullProgress()
current_job = self.sys_jobs.current
# Try to fetch manifest for accurate size-based progress
# This is optional - if it fails, we fall back to count-based progress
try:
manifest = await self.sys_docker.manifest_fetcher.get_manifest(
image, str(version), platform=platform
)
if manifest:
pull_progress.set_manifest(manifest)
_LOGGER.debug(
"Using manifest for progress: %d layers, %d bytes",
manifest.layer_count,
manifest.total_size,
)
except Exception as err: # noqa: BLE001
_LOGGER.debug("Could not fetch manifest for progress: %s", err)
async def process_pull_event(event: PullLogEntry) -> None:
"""Process pull event and update job progress."""
if event.job_id != current_job.uuid:
return
# Process event through progress tracker
pull_progress.process_event(event)
# Update job if progress changed significantly (>= 1%)
should_update, progress = pull_progress.should_update_job()
if should_update:
stage = pull_progress.get_stage()
current_job.update(progress=progress, stage=stage)
listener = self.sys_bus.register_event(
BusEvent.DOCKER_IMAGE_PULL_UPDATE, process_pull_event
)
_LOGGER.info("Downloading docker image %s with tag %s.", image, version)
try:
# Get credentials for private registries to pass to aiodocker
credentials = self._get_credentials(image) or None
curr_job_id = self.sys_jobs.current.uuid
async def process_pull_image_log(reference: PullLogEntry) -> None:
try:
self._process_pull_image_log(curr_job_id, reference)
except DockerLogOutOfOrder as err:
# Send all these to sentry. Missing a few progress updates
# shouldn't matter to users but matters to us
await async_capture_exception(err)
listener = self.sys_bus.register_event(
BusEvent.DOCKER_IMAGE_PULL_UPDATE, process_pull_image_log
)
# Pull new image, passing credentials to aiodocker
docker_image = await self.sys_docker.pull_image(
self.sys_jobs.current.uuid,
current_job.uuid,
image,
str(version),
platform=MAP_ARCH[image_arch],
platform=platform,
auth=credentials,
)
@@ -445,8 +306,7 @@ class DockerInterface(JobGroup, ABC):
f"Unknown error with {image}:{version!s} -> {err!s}", _LOGGER.error
) from err
finally:
if listener:
self.sys_bus.remove_listener(listener)
self.sys_bus.remove_listener(listener)
self._meta = docker_image

View File

@@ -50,6 +50,7 @@ from ..exceptions import (
from ..utils.common import FileConfiguration
from ..validate import SCHEMA_DOCKER_CONFIG
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, LABEL_MANAGED
from .manifest import RegistryManifestFetcher
from .monitor import DockerMonitor
from .network import DockerNetwork
from .utils import get_registry_from_image
@@ -258,6 +259,9 @@ class DockerAPI(CoreSysAttributes):
self._info: DockerInfo | None = None
self.config: DockerConfig = DockerConfig()
self._monitor: DockerMonitor = DockerMonitor(coresys)
self._manifest_fetcher: RegistryManifestFetcher = RegistryManifestFetcher(
coresys
)
async def post_init(self) -> Self:
"""Post init actions that must be done in event loop."""
@@ -323,6 +327,11 @@ class DockerAPI(CoreSysAttributes):
"""Return docker events monitor."""
return self._monitor
@property
def manifest_fetcher(self) -> RegistryManifestFetcher:
"""Return manifest fetcher for registry access."""
return self._manifest_fetcher
async def load(self) -> None:
"""Start docker events monitor."""
await self.monitor.load()

View File

@@ -0,0 +1,354 @@
"""Docker registry manifest fetcher.
Fetches image manifests directly from container registries to get layer sizes
before pulling an image. This enables accurate size-based progress tracking.
"""
from __future__ import annotations
from dataclasses import dataclass
import logging
import re
from typing import TYPE_CHECKING
import aiohttp
from .const import DOCKER_HUB, IMAGE_WITH_HOST
if TYPE_CHECKING:
from ..coresys import CoreSys
_LOGGER = logging.getLogger(__name__)
# Default registry for images without explicit host
DEFAULT_REGISTRY = "registry-1.docker.io"
# Media types for manifest requests
MANIFEST_MEDIA_TYPES = (
"application/vnd.docker.distribution.manifest.v2+json",
"application/vnd.oci.image.manifest.v1+json",
"application/vnd.docker.distribution.manifest.list.v2+json",
"application/vnd.oci.image.index.v1+json",
)
@dataclass
class ImageManifest:
"""Container image manifest with layer information."""
digest: str
total_size: int
layers: dict[str, int] # digest -> size in bytes
@property
def layer_count(self) -> int:
"""Return number of layers."""
return len(self.layers)
def parse_image_reference(image: str, tag: str) -> tuple[str, str, str]:
"""Parse image reference into (registry, repository, tag).
Examples:
ghcr.io/home-assistant/home-assistant:2025.1.0
-> (ghcr.io, home-assistant/home-assistant, 2025.1.0)
homeassistant/home-assistant:latest
-> (registry-1.docker.io, homeassistant/home-assistant, latest)
alpine:3.18
-> (registry-1.docker.io, library/alpine, 3.18)
"""
# Check if image has explicit registry host
match = IMAGE_WITH_HOST.match(image)
if match:
registry = match.group(1)
repository = image[len(registry) + 1 :] # Remove "registry/" prefix
else:
registry = DEFAULT_REGISTRY
repository = image
# Docker Hub requires "library/" prefix for official images
if "/" not in repository:
repository = f"library/{repository}"
return registry, repository, tag
class RegistryManifestFetcher:
"""Fetches manifests from container registries."""
def __init__(self, coresys: CoreSys) -> None:
"""Initialize the fetcher."""
self.coresys = coresys
self._session: aiohttp.ClientSession | None = None
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create aiohttp session."""
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
return self._session
async def close(self) -> None:
"""Close the session."""
if self._session and not self._session.closed:
await self._session.close()
self._session = None
def _get_credentials(self, registry: str) -> tuple[str, str] | None:
"""Get credentials for registry from Docker config.
Returns (username, password) tuple or None if no credentials.
"""
registries = self.coresys.docker.config.registries
# Map registry hostname to config key
# Docker Hub can be stored as "hub.docker.com" in config
if registry == DEFAULT_REGISTRY:
if DOCKER_HUB in registries:
creds = registries[DOCKER_HUB]
return creds.get("username"), creds.get("password")
elif registry in registries:
creds = registries[registry]
return creds.get("username"), creds.get("password")
return None
async def _get_auth_token(
self,
session: aiohttp.ClientSession,
registry: str,
repository: str,
) -> str | None:
"""Get authentication token for registry.
Uses the WWW-Authenticate header from a 401 response to discover
the token endpoint, then requests a token with appropriate scope.
"""
# First, make an unauthenticated request to get WWW-Authenticate header
manifest_url = f"https://{registry}/v2/{repository}/manifests/latest"
try:
async with session.get(manifest_url) as resp:
if resp.status == 200:
# No auth required
return None
if resp.status != 401:
_LOGGER.warning(
"Unexpected status %d from registry %s", resp.status, registry
)
return None
www_auth = resp.headers.get("WWW-Authenticate", "")
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to connect to registry %s: %s", registry, err)
return None
# Parse WWW-Authenticate: Bearer realm="...",service="...",scope="..."
if not www_auth.startswith("Bearer "):
_LOGGER.warning("Unsupported auth type from %s: %s", registry, www_auth)
return None
params = {}
for match in re.finditer(r'(\w+)="([^"]*)"', www_auth):
params[match.group(1)] = match.group(2)
realm = params.get("realm")
service = params.get("service")
if not realm:
_LOGGER.warning("No realm in WWW-Authenticate from %s", registry)
return None
# Build token request URL
token_url = f"{realm}?scope=repository:{repository}:pull"
if service:
token_url += f"&service={service}"
# Check for credentials
auth = None
credentials = self._get_credentials(registry)
if credentials:
username, password = credentials
if username and password:
auth = aiohttp.BasicAuth(username, password)
_LOGGER.debug("Using credentials for %s", registry)
try:
async with session.get(token_url, auth=auth) as resp:
if resp.status != 200:
_LOGGER.warning(
"Failed to get token from %s: %d", realm, resp.status
)
return None
data = await resp.json()
return data.get("token") or data.get("access_token")
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to get auth token: %s", err)
return None
async def _fetch_manifest(
self,
session: aiohttp.ClientSession,
registry: str,
repository: str,
reference: str,
token: str | None,
platform: str | None = None,
) -> dict | None:
"""Fetch manifest from registry.
If the manifest is a manifest list (multi-arch), fetches the
platform-specific manifest.
"""
manifest_url = f"https://{registry}/v2/{repository}/manifests/{reference}"
headers = {"Accept": ", ".join(MANIFEST_MEDIA_TYPES)}
if token:
headers["Authorization"] = f"Bearer {token}"
try:
async with session.get(manifest_url, headers=headers) as resp:
if resp.status != 200:
_LOGGER.warning(
"Failed to fetch manifest for %s/%s:%s - %d",
registry,
repository,
reference,
resp.status,
)
return None
manifest = await resp.json()
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to fetch manifest: %s", err)
return None
media_type = manifest.get("mediaType", "")
# Check if this is a manifest list (multi-arch image)
if "list" in media_type or "index" in media_type:
manifests = manifest.get("manifests", [])
if not manifests:
_LOGGER.warning("Empty manifest list for %s/%s", registry, repository)
return None
# Find matching platform
target_os = "linux"
target_arch = "amd64" # Default
if platform:
# Platform format is "linux/amd64", "linux/arm64", etc.
parts = platform.split("/")
if len(parts) >= 2:
target_os, target_arch = parts[0], parts[1]
platform_manifest = None
for m in manifests:
plat = m.get("platform", {})
if (
plat.get("os") == target_os
and plat.get("architecture") == target_arch
):
platform_manifest = m
break
if not platform_manifest:
# Fall back to first manifest
_LOGGER.debug(
"Platform %s/%s not found, using first manifest",
target_os,
target_arch,
)
platform_manifest = manifests[0]
# Fetch the platform-specific manifest
return await self._fetch_manifest(
session,
registry,
repository,
platform_manifest["digest"],
token,
platform,
)
return manifest
async def get_manifest(
self,
image: str,
tag: str,
platform: str | None = None,
) -> ImageManifest | None:
"""Fetch manifest and extract layer sizes.
Args:
image: Image name (e.g., "ghcr.io/home-assistant/home-assistant")
tag: Image tag (e.g., "2025.1.0")
platform: Target platform (e.g., "linux/amd64")
Returns:
ImageManifest with layer sizes, or None if fetch failed.
"""
registry, repository, tag = parse_image_reference(image, tag)
_LOGGER.debug(
"Fetching manifest for %s/%s:%s (platform=%s)",
registry,
repository,
tag,
platform,
)
session = await self._get_session()
# Get auth token
token = await self._get_auth_token(session, registry, repository)
# Fetch manifest
manifest = await self._fetch_manifest(
session, registry, repository, tag, token, platform
)
if not manifest:
return None
# Extract layer information
layers = manifest.get("layers", [])
if not layers:
_LOGGER.warning(
"No layers in manifest for %s/%s:%s", registry, repository, tag
)
return None
layer_sizes: dict[str, int] = {}
total_size = 0
for layer in layers:
digest = layer.get("digest", "")
size = layer.get("size", 0)
if digest and size:
# Store by short digest (first 12 chars after sha256:)
short_digest = (
digest.split(":")[1][:12] if ":" in digest else digest[:12]
)
layer_sizes[short_digest] = size
total_size += size
digest = manifest.get("config", {}).get("digest", "")
_LOGGER.debug(
"Manifest for %s/%s:%s - %d layers, %d bytes total",
registry,
repository,
tag,
len(layer_sizes),
total_size,
)
return ImageManifest(
digest=digest,
total_size=total_size,
layers=layer_sizes,
)

View File

@@ -0,0 +1,371 @@
"""Image pull progress tracking."""
from __future__ import annotations
from contextlib import suppress
from dataclasses import dataclass, field
from enum import Enum
import logging
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from .manager import PullLogEntry
from .manifest import ImageManifest
_LOGGER = logging.getLogger(__name__)
# Progress weight distribution: 70% downloading, 30% extraction
DOWNLOAD_WEIGHT = 70.0
EXTRACT_WEIGHT = 30.0
class LayerPullStatus(Enum):
"""Status values for pulling an image layer.
These are a subset of the statuses in a docker pull image log.
The order field allows comparing which stage is further along.
"""
PULLING_FS_LAYER = 1, "Pulling fs layer"
WAITING = 1, "Waiting"
RETRYING = 2, "Retrying" # Matches "Retrying in N seconds"
DOWNLOADING = 3, "Downloading"
VERIFYING_CHECKSUM = 4, "Verifying Checksum"
DOWNLOAD_COMPLETE = 5, "Download complete"
EXTRACTING = 6, "Extracting"
PULL_COMPLETE = 7, "Pull complete"
ALREADY_EXISTS = 7, "Already exists"
def __init__(self, order: int, status: str) -> None:
"""Set fields from values."""
self.order = order
self.status = status
def __eq__(self, value: object, /) -> bool:
"""Check equality, allow string comparisons on status."""
with suppress(AttributeError):
return self.status == cast(LayerPullStatus, value).status
return self.status == value
def __hash__(self) -> int:
"""Return hash based on status string."""
return hash(self.status)
def __lt__(self, other: object) -> bool:
"""Order instances by stage progression."""
with suppress(AttributeError):
return self.order < cast(LayerPullStatus, other).order
return False
@classmethod
def from_status(cls, status: str) -> LayerPullStatus | None:
"""Get enum from status string, or None if not recognized."""
# Handle "Retrying in N seconds" pattern
if status.startswith("Retrying in "):
return cls.RETRYING
for member in cls:
if member.status == status:
return member
return None
@dataclass
class LayerProgress:
"""Track progress of a single layer."""
layer_id: str
total_size: int = 0 # Size in bytes (from downloading, reused for extraction)
download_current: int = 0
extract_current: int = 0 # Extraction progress in bytes (overlay2 only)
download_complete: bool = False
extract_complete: bool = False
already_exists: bool = False # Layer was already locally available
def calculate_progress(self) -> float:
"""Calculate layer progress 0-100.
Progress is weighted: 70% download, 30% extraction.
For overlay2, we have byte-based extraction progress.
For containerd, extraction jumps from 70% to 100% on completion.
"""
if self.already_exists or self.extract_complete:
return 100.0
if self.download_complete:
# Check if we have extraction progress (overlay2)
if self.extract_current > 0 and self.total_size > 0:
extract_pct = min(1.0, self.extract_current / self.total_size)
return DOWNLOAD_WEIGHT + (extract_pct * EXTRACT_WEIGHT)
# No extraction progress yet - return 70%
return DOWNLOAD_WEIGHT
if self.total_size > 0:
download_pct = min(1.0, self.download_current / self.total_size)
return download_pct * DOWNLOAD_WEIGHT
return 0.0
@dataclass
class ImagePullProgress:
"""Track overall progress of pulling an image.
When manifest layer sizes are provided, uses size-weighted progress where
each layer contributes proportionally to its size. This gives accurate
progress based on actual bytes to download.
When manifest is not available, falls back to count-based progress where
each layer contributes equally.
Layers that already exist locally are excluded from the progress calculation.
"""
layers: dict[str, LayerProgress] = field(default_factory=dict)
_last_reported_progress: float = field(default=0.0, repr=False)
_seen_downloading: bool = field(default=False, repr=False)
_manifest_layer_sizes: dict[str, int] = field(default_factory=dict, repr=False)
_total_manifest_size: int = field(default=0, repr=False)
def set_manifest(self, manifest: ImageManifest) -> None:
"""Set manifest layer sizes for accurate size-based progress.
Should be called before processing pull events.
"""
self._manifest_layer_sizes = dict(manifest.layers)
self._total_manifest_size = manifest.total_size
_LOGGER.debug(
"Manifest set: %d layers, %d bytes total",
len(self._manifest_layer_sizes),
self._total_manifest_size,
)
def get_or_create_layer(self, layer_id: str) -> LayerProgress:
"""Get existing layer or create new one."""
if layer_id not in self.layers:
# If we have manifest sizes, pre-populate the layer's total_size
manifest_size = self._manifest_layer_sizes.get(layer_id, 0)
self.layers[layer_id] = LayerProgress(
layer_id=layer_id, total_size=manifest_size
)
return self.layers[layer_id]
def process_event(self, entry: PullLogEntry) -> None:
"""Process a pull log event and update layer state."""
# Skip events without layer ID or status
if not entry.id or not entry.status:
return
# Skip metadata events that aren't layer-specific
# "Pulling from X" has id=tag but isn't a layer
if entry.status.startswith("Pulling from "):
return
# Parse status to enum (returns None for unrecognized statuses)
status = LayerPullStatus.from_status(entry.status)
if status is None:
return
layer = self.get_or_create_layer(entry.id)
# Handle "Already exists" - layer is locally available
if status is LayerPullStatus.ALREADY_EXISTS:
layer.already_exists = True
layer.download_complete = True
layer.extract_complete = True
return
# Handle "Pulling fs layer" / "Waiting" - layer is being tracked
if status in (LayerPullStatus.PULLING_FS_LAYER, LayerPullStatus.WAITING):
return
# Handle "Downloading" - update download progress
if status is LayerPullStatus.DOWNLOADING:
# Mark that we've seen downloading - now we know layer count is complete
self._seen_downloading = True
if (
entry.progress_detail
and entry.progress_detail.current is not None
and entry.progress_detail.total is not None
):
layer.download_current = entry.progress_detail.current
# Only set total_size if not already set or if this is larger
# (handles case where total changes during download)
layer.total_size = max(layer.total_size, entry.progress_detail.total)
return
# Handle "Verifying Checksum" - download is essentially complete
if status is LayerPullStatus.VERIFYING_CHECKSUM:
if layer.total_size > 0:
layer.download_current = layer.total_size
return
# Handle "Download complete" - download phase done
if status is LayerPullStatus.DOWNLOAD_COMPLETE:
layer.download_complete = True
if layer.total_size > 0:
layer.download_current = layer.total_size
elif layer.total_size == 0:
# Small layer that skipped downloading phase
# Set minimal size so it doesn't distort weighted average
layer.total_size = 1
layer.download_current = 1
return
# Handle "Extracting" - extraction in progress
if status is LayerPullStatus.EXTRACTING:
# For overlay2: progressDetail has {current, total} in bytes
# For containerd: progressDetail has {current, units: "s"} (time elapsed)
# We can only use byte-based progress (overlay2)
layer.download_complete = True
if layer.total_size > 0:
layer.download_current = layer.total_size
# Check if this is byte-based extraction progress (overlay2)
# Overlay2 has {current, total} in bytes, no units field
# Containerd has {current, units: "s"} which is useless for progress
if (
entry.progress_detail
and entry.progress_detail.current is not None
and entry.progress_detail.units is None
):
# Use layer's total_size from downloading phase (doesn't change)
layer.extract_current = entry.progress_detail.current
_LOGGER.debug(
"Layer %s extracting: %d/%d (%.1f%%)",
layer.layer_id,
layer.extract_current,
layer.total_size,
(layer.extract_current / layer.total_size * 100)
if layer.total_size > 0
else 0,
)
return
# Handle "Pull complete" - layer is fully done
if status is LayerPullStatus.PULL_COMPLETE:
layer.download_complete = True
layer.extract_complete = True
if layer.total_size > 0:
layer.download_current = layer.total_size
return
# Handle "Retrying in N seconds" - reset download progress
if status is LayerPullStatus.RETRYING:
layer.download_current = 0
layer.download_complete = False
return
def calculate_progress(self) -> float:
"""Calculate overall progress 0-100.
When manifest layer sizes are available, uses size-weighted progress
where each layer contributes proportionally to its size.
When manifest is not available, falls back to count-based progress
where each layer contributes equally.
Layers that already exist locally are excluded from the calculation.
Returns 0 until we've seen the first "Downloading" event, since Docker
reports "Already exists" and "Pulling fs layer" events before we know
the complete layer count.
"""
# Don't report progress until we've seen downloading start
# This ensures we know the full layer count before calculating progress
if not self._seen_downloading or not self.layers:
return 0.0
# Only count layers that need pulling (exclude already_exists)
layers_to_pull = [
layer for layer in self.layers.values() if not layer.already_exists
]
if not layers_to_pull:
# All layers already exist, nothing to download
return 100.0
# Use size-weighted progress if manifest sizes are available
if self._manifest_layer_sizes:
return min(100, self._calculate_size_weighted_progress(layers_to_pull))
# Fall back to count-based progress
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
return min(100, total_progress / len(layers_to_pull))
def _calculate_size_weighted_progress(
self, layers_to_pull: list[LayerProgress]
) -> float:
"""Calculate size-weighted progress.
Each layer contributes to progress proportionally to its size.
Progress = sum(layer_progress * layer_size) / total_size
"""
# Calculate total size of layers that need pulling
total_size = sum(layer.total_size for layer in layers_to_pull)
if total_size == 0:
# No size info available, fall back to count-based
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
return total_progress / len(layers_to_pull)
# Weight each layer's progress by its size
weighted_progress = 0.0
for layer in layers_to_pull:
if layer.total_size > 0:
layer_weight = layer.total_size / total_size
weighted_progress += layer.calculate_progress() * layer_weight
return weighted_progress
def get_stage(self) -> str | None:
"""Get current stage based on layer states."""
if not self.layers:
return None
# Check if any layer is still downloading
for layer in self.layers.values():
if layer.already_exists:
continue
if not layer.download_complete:
return "Downloading"
# All downloads complete, check if extracting
for layer in self.layers.values():
if layer.already_exists:
continue
if not layer.extract_complete:
return "Extracting"
# All done
return "Pull complete"
def should_update_job(self, threshold: float = 1.0) -> tuple[bool, float]:
"""Check if job should be updated based on progress change.
Returns (should_update, current_progress).
Updates are triggered when progress changes by at least threshold%.
Progress is guaranteed to only increase (monotonic).
"""
current_progress = self.calculate_progress()
# Ensure monotonic progress - never report a decrease
# This can happen when new layers get size info and change the weighted average
if current_progress < self._last_reported_progress:
_LOGGER.debug(
"Progress decreased from %.1f%% to %.1f%%, keeping last reported",
self._last_reported_progress,
current_progress,
)
return False, self._last_reported_progress
if current_progress >= self._last_reported_progress + threshold:
_LOGGER.debug(
"Progress update: %.1f%% -> %.1f%% (delta: %.1f%%)",
self._last_reported_progress,
current_progress,
current_progress - self._last_reported_progress,
)
self._last_reported_progress = current_progress
return True, current_progress
return False, self._last_reported_progress

View File

@@ -632,10 +632,6 @@ class DockerNotFound(DockerError):
"""Docker object don't Exists."""
class DockerLogOutOfOrder(DockerError):
"""Raise when log from docker action was out of order."""
class DockerNoSpaceOnDevice(DockerError):
"""Raise if a docker pull fails due to available space."""

View File

@@ -305,6 +305,8 @@ async def test_api_progress_updates_home_assistant_update(
and evt.args[0]["data"]["event"] == WSEvent.JOB
and evt.args[0]["data"]["data"]["name"] == "home_assistant_core_update"
]
# Count-based progress: 2 layers need pulling (each worth 50%)
# Layers that already exist are excluded from progress calculation
assert events[:5] == [
{
"stage": None,
@@ -318,36 +320,36 @@ async def test_api_progress_updates_home_assistant_update(
},
{
"stage": None,
"progress": 0.1,
"progress": 9.2,
"done": False,
},
{
"stage": None,
"progress": 1.7,
"progress": 25.6,
"done": False,
},
{
"stage": None,
"progress": 4.0,
"progress": 35.4,
"done": False,
},
]
assert events[-5:] == [
{
"stage": None,
"progress": 95.5,
"done": False,
},
{
"stage": None,
"progress": 96.9,
"done": False,
},
{
"stage": None,
"progress": 98.2,
"done": False,
},
{
"stage": None,
"progress": 98.3,
"done": False,
},
{
"stage": None,
"progress": 99.3,
"done": False,
},
{
"stage": None,
"progress": 100,

View File

@@ -764,6 +764,8 @@ async def test_api_progress_updates_addon_install_update(
and evt.args[0]["data"]["data"]["name"] == job_name
and evt.args[0]["data"]["data"]["reference"] == addon_slug
]
# Count-based progress: 2 layers need pulling (each worth 50%)
# Layers that already exist are excluded from progress calculation
assert events[:4] == [
{
"stage": None,
@@ -772,36 +774,36 @@ async def test_api_progress_updates_addon_install_update(
},
{
"stage": None,
"progress": 0.1,
"progress": 9.2,
"done": False,
},
{
"stage": None,
"progress": 1.7,
"progress": 25.6,
"done": False,
},
{
"stage": None,
"progress": 4.0,
"progress": 35.4,
"done": False,
},
]
assert events[-5:] == [
{
"stage": None,
"progress": 95.5,
"done": False,
},
{
"stage": None,
"progress": 96.9,
"done": False,
},
{
"stage": None,
"progress": 98.2,
"done": False,
},
{
"stage": None,
"progress": 98.3,
"done": False,
},
{
"stage": None,
"progress": 99.3,
"done": False,
},
{
"stage": None,
"progress": 100,

View File

@@ -358,6 +358,8 @@ async def test_api_progress_updates_supervisor_update(
and evt.args[0]["data"]["event"] == WSEvent.JOB
and evt.args[0]["data"]["data"]["name"] == "supervisor_update"
]
# Count-based progress: 2 layers need pulling (each worth 50%)
# Layers that already exist are excluded from progress calculation
assert events[:4] == [
{
"stage": None,
@@ -366,36 +368,36 @@ async def test_api_progress_updates_supervisor_update(
},
{
"stage": None,
"progress": 0.1,
"progress": 9.2,
"done": False,
},
{
"stage": None,
"progress": 1.7,
"progress": 25.6,
"done": False,
},
{
"stage": None,
"progress": 4.0,
"progress": 35.4,
"done": False,
},
]
assert events[-5:] == [
{
"stage": None,
"progress": 95.5,
"done": False,
},
{
"stage": None,
"progress": 96.9,
"done": False,
},
{
"stage": None,
"progress": 98.2,
"done": False,
},
{
"stage": None,
"progress": 98.3,
"done": False,
},
{
"stage": None,
"progress": 99.3,
"done": False,
},
{
"stage": None,
"progress": 100,

View File

@@ -154,6 +154,9 @@ async def docker() -> DockerAPI:
docker_obj.info.storage = "overlay2"
docker_obj.info.version = AwesomeVersion("1.0.0")
# Mock manifest fetcher to return None (falls back to count-based progress)
docker_obj._manifest_fetcher.get_manifest = AsyncMock(return_value=None)
yield docker_obj

View File

@@ -709,11 +709,18 @@ async def test_install_progress_handles_layers_skipping_download(
await install_task
await event.wait()
# First update from layer download should have rather low progress ((260937/25445459) / 2 ~ 0.5%)
assert install_job_snapshots[0]["progress"] < 1
# With the new progress calculation approach:
# - Progress is weighted by layer size
# - Small layers that skip downloading get minimal size (1 byte)
# - Progress should increase monotonically
assert len(install_job_snapshots) > 0
# Total 8 events should lead to a progress update on the install job
assert len(install_job_snapshots) == 8
# Verify progress is monotonically increasing (or stable)
for i in range(1, len(install_job_snapshots)):
assert (
install_job_snapshots[i]["progress"]
>= install_job_snapshots[i - 1]["progress"]
)
# Job should complete successfully
assert job.done is True
@@ -844,24 +851,24 @@ async def test_install_progress_containerd_snapshot(
}
assert [c.args[0] for c in ha_ws_client.async_send_command.call_args_list] == [
# During downloading we get continuous progress updates from download status
# Count-based progress: 2 layers, each = 50%. Download = 0-35%, Extract = 35-50%
job_event(0),
job_event(1.7),
job_event(3.4),
job_event(8.5),
job_event(8.4),
job_event(10.2),
job_event(15.3),
job_event(18.8),
job_event(29.0),
job_event(35.8),
job_event(42.6),
job_event(49.5),
job_event(56.0),
job_event(62.8),
# Downloading phase is considered 70% of total. After we only get one update
# per image downloaded when extraction is finished. It uses the total size
# received during downloading to determine percent complete then.
job_event(15.2),
job_event(18.7),
job_event(28.8),
job_event(35.7),
job_event(42.4),
job_event(49.3),
job_event(55.8),
job_event(62.7),
# Downloading phase is considered 70% of layer's progress.
# After download complete, extraction takes remaining 30% per layer.
job_event(70.0),
job_event(84.8),
job_event(85.0),
job_event(100),
job_event(100, True),
]

View File

@@ -0,0 +1,164 @@
"""Tests for registry manifest fetcher."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from supervisor.docker.manifest import (
DEFAULT_REGISTRY,
ImageManifest,
RegistryManifestFetcher,
parse_image_reference,
)
class TestParseImageReference:
"""Tests for parse_image_reference function."""
def test_ghcr_io_image(self):
"""Test parsing ghcr.io image."""
registry, repo, tag = parse_image_reference(
"ghcr.io/home-assistant/home-assistant", "2025.1.0"
)
assert registry == "ghcr.io"
assert repo == "home-assistant/home-assistant"
assert tag == "2025.1.0"
def test_docker_hub_with_org(self):
"""Test parsing Docker Hub image with organization."""
registry, repo, tag = parse_image_reference(
"homeassistant/home-assistant", "latest"
)
assert registry == DEFAULT_REGISTRY
assert repo == "homeassistant/home-assistant"
assert tag == "latest"
def test_docker_hub_official_image(self):
"""Test parsing Docker Hub official image (no org)."""
registry, repo, tag = parse_image_reference("alpine", "3.18")
assert registry == DEFAULT_REGISTRY
assert repo == "library/alpine"
assert tag == "3.18"
def test_gcr_io_image(self):
"""Test parsing gcr.io image."""
registry, repo, tag = parse_image_reference("gcr.io/project/image", "v1")
assert registry == "gcr.io"
assert repo == "project/image"
assert tag == "v1"
class TestImageManifest:
"""Tests for ImageManifest dataclass."""
def test_layer_count(self):
"""Test layer_count property."""
manifest = ImageManifest(
digest="sha256:abc",
total_size=1000,
layers={"layer1": 500, "layer2": 500},
)
assert manifest.layer_count == 2
class TestRegistryManifestFetcher:
"""Tests for RegistryManifestFetcher class."""
@pytest.fixture
def mock_coresys(self):
"""Create mock coresys."""
coresys = MagicMock()
coresys.docker.config.registries = {}
return coresys
@pytest.fixture
def fetcher(self, mock_coresys):
"""Create fetcher instance."""
return RegistryManifestFetcher(mock_coresys)
async def test_get_manifest_success(self, fetcher):
"""Test successful manifest fetch by mocking internal methods."""
manifest_data = {
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": {"digest": "sha256:abc123"},
"layers": [
{"digest": "sha256:layer1abc123def456789012", "size": 1000},
{"digest": "sha256:layer2def456abc789012345", "size": 2000},
],
}
# Mock the internal methods instead of the session
with (
patch.object(
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
),
patch.object(
fetcher, "_fetch_manifest", new=AsyncMock(return_value=manifest_data)
),
patch.object(fetcher, "_get_session", new=AsyncMock()),
):
result = await fetcher.get_manifest(
"test.io/org/image", "v1.0", platform="linux/amd64"
)
assert result is not None
assert result.total_size == 3000
assert result.layer_count == 2
# First 12 chars after sha256:
assert "layer1abc123" in result.layers
assert result.layers["layer1abc123"] == 1000
async def test_get_manifest_returns_none_on_failure(self, fetcher):
"""Test that get_manifest returns None on failure."""
with (
patch.object(
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
),
patch.object(fetcher, "_fetch_manifest", new=AsyncMock(return_value=None)),
patch.object(fetcher, "_get_session", new=AsyncMock()),
):
result = await fetcher.get_manifest(
"test.io/org/image", "v1.0", platform="linux/amd64"
)
assert result is None
async def test_close_session(self, fetcher):
"""Test session cleanup."""
mock_session = AsyncMock()
mock_session.closed = False
mock_session.close = AsyncMock()
fetcher._session = mock_session
await fetcher.close()
mock_session.close.assert_called_once()
assert fetcher._session is None
def test_get_credentials_docker_hub(self, mock_coresys, fetcher):
"""Test getting Docker Hub credentials."""
mock_coresys.docker.config.registries = {
"hub.docker.com": {"username": "user", "password": "pass"}
}
creds = fetcher._get_credentials(DEFAULT_REGISTRY)
assert creds == ("user", "pass")
def test_get_credentials_custom_registry(self, mock_coresys, fetcher):
"""Test getting credentials for custom registry."""
mock_coresys.docker.config.registries = {
"ghcr.io": {"username": "user", "password": "token"}
}
creds = fetcher._get_credentials("ghcr.io")
assert creds == ("user", "token")
def test_get_credentials_not_found(self, mock_coresys, fetcher):
"""Test no credentials found."""
mock_coresys.docker.config.registries = {}
creds = fetcher._get_credentials("unknown.io")
assert creds is None

File diff suppressed because it is too large Load Diff