mirror of
https://github.com/home-assistant/supervisor.git
synced 2025-12-03 14:38:15 +00:00
Compare commits
5 Commits
2025.12.1
...
refactor-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2596feb17 | ||
|
|
dc46b5b45b | ||
|
|
204206640f | ||
|
|
8c3ccfd83d | ||
|
|
2615199172 |
@@ -2,12 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from enum import Enum, StrEnum
|
||||
from functools import total_ordering
|
||||
from enum import StrEnum
|
||||
from pathlib import PurePath
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from docker.types import Mount
|
||||
|
||||
@@ -15,9 +12,8 @@ from ..const import MACHINE_ID
|
||||
|
||||
RE_RETRYING_DOWNLOAD_STATUS = re.compile(r"Retrying in \d+ seconds?")
|
||||
|
||||
# Docker Hub registry identifier (official default)
|
||||
# Docker's default registry is docker.io
|
||||
DOCKER_HUB = "docker.io"
|
||||
# Docker Hub registry identifier
|
||||
DOCKER_HUB = "hub.docker.com"
|
||||
|
||||
# Legacy Docker Hub identifier for backward compatibility
|
||||
DOCKER_HUB_LEGACY = "hub.docker.com"
|
||||
@@ -82,57 +78,6 @@ class PropagationMode(StrEnum):
|
||||
RSLAVE = "rslave"
|
||||
|
||||
|
||||
@total_ordering
|
||||
class PullImageLayerStage(Enum):
|
||||
"""Job stages for pulling an image layer.
|
||||
|
||||
These are a subset of the statuses in a docker pull image log. They
|
||||
are the standardized ones that are the most useful to us.
|
||||
"""
|
||||
|
||||
PULLING_FS_LAYER = 1, "Pulling fs layer"
|
||||
RETRYING_DOWNLOAD = 2, "Retrying download"
|
||||
DOWNLOADING = 2, "Downloading"
|
||||
VERIFYING_CHECKSUM = 3, "Verifying Checksum"
|
||||
DOWNLOAD_COMPLETE = 4, "Download complete"
|
||||
EXTRACTING = 5, "Extracting"
|
||||
PULL_COMPLETE = 6, "Pull complete"
|
||||
|
||||
def __init__(self, order: int, status: str) -> None:
|
||||
"""Set fields from values."""
|
||||
self.order = order
|
||||
self.status = status
|
||||
|
||||
def __eq__(self, value: object, /) -> bool:
|
||||
"""Check equality, allow StrEnum style comparisons on status."""
|
||||
with suppress(AttributeError):
|
||||
return self.status == cast(PullImageLayerStage, value).status
|
||||
return self.status == value
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
"""Order instances."""
|
||||
with suppress(AttributeError):
|
||||
return self.order < cast(PullImageLayerStage, other).order
|
||||
return False
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash instance."""
|
||||
return hash(self.status)
|
||||
|
||||
@classmethod
|
||||
def from_status(cls, status: str) -> PullImageLayerStage | None:
|
||||
"""Return stage instance from pull log status."""
|
||||
for i in cls:
|
||||
if i.status == status:
|
||||
return i
|
||||
|
||||
# This one includes number of seconds until download so its not constant
|
||||
if RE_RETRYING_DOWNLOAD_STATUS.match(status):
|
||||
return cls.RETRYING_DOWNLOAD
|
||||
|
||||
return None
|
||||
|
||||
|
||||
ENV_TIME = "TZ"
|
||||
ENV_TOKEN = "SUPERVISOR_TOKEN"
|
||||
ENV_TOKEN_OLD = "HASSIO_TOKEN"
|
||||
|
||||
@@ -19,7 +19,6 @@ import docker
|
||||
from docker.models.containers import Container
|
||||
import requests
|
||||
|
||||
from ..bus import EventListener
|
||||
from ..const import (
|
||||
ATTR_PASSWORD,
|
||||
ATTR_REGISTRY,
|
||||
@@ -35,25 +34,18 @@ from ..exceptions import (
|
||||
DockerError,
|
||||
DockerHubRateLimitExceeded,
|
||||
DockerJobError,
|
||||
DockerLogOutOfOrder,
|
||||
DockerNotFound,
|
||||
DockerRequestError,
|
||||
)
|
||||
from ..jobs import SupervisorJob
|
||||
from ..jobs.const import JOB_GROUP_DOCKER_INTERFACE, JobConcurrency
|
||||
from ..jobs.decorator import Job
|
||||
from ..jobs.job_group import JobGroup
|
||||
from ..resolution.const import ContextType, IssueType, SuggestionType
|
||||
from ..utils.sentry import async_capture_exception
|
||||
from .const import (
|
||||
DOCKER_HUB,
|
||||
DOCKER_HUB_LEGACY,
|
||||
ContainerState,
|
||||
PullImageLayerStage,
|
||||
RestartPolicy,
|
||||
)
|
||||
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, ContainerState, RestartPolicy
|
||||
from .manager import CommandReturn, PullLogEntry
|
||||
from .monitor import DockerContainerStateEvent
|
||||
from .pull_progress import ImagePullProgress
|
||||
from .stats import DockerStats
|
||||
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
@@ -202,159 +194,6 @@ class DockerInterface(JobGroup, ABC):
|
||||
|
||||
return credentials
|
||||
|
||||
def _process_pull_image_log( # noqa: C901
|
||||
self, install_job_id: str, reference: PullLogEntry
|
||||
) -> None:
|
||||
"""Process events fired from a docker while pulling an image, filtered to a given job id."""
|
||||
if (
|
||||
reference.job_id != install_job_id
|
||||
or not reference.id
|
||||
or not reference.status
|
||||
or not (stage := PullImageLayerStage.from_status(reference.status))
|
||||
):
|
||||
return
|
||||
|
||||
# Pulling FS Layer is our marker for a layer that needs to be downloaded and extracted. Otherwise it already exists and we can ignore
|
||||
job: SupervisorJob | None = None
|
||||
if stage == PullImageLayerStage.PULLING_FS_LAYER:
|
||||
job = self.sys_jobs.new_job(
|
||||
name="Pulling container image layer",
|
||||
initial_stage=stage.status,
|
||||
reference=reference.id,
|
||||
parent_id=install_job_id,
|
||||
internal=True,
|
||||
)
|
||||
job.done = False
|
||||
return
|
||||
|
||||
# Find our sub job to update details of
|
||||
for j in self.sys_jobs.jobs:
|
||||
if j.parent_id == install_job_id and j.reference == reference.id:
|
||||
job = j
|
||||
break
|
||||
|
||||
# There should no longer be any real risk of logs out of order anymore.
|
||||
# However tests with very small images have shown that sometimes Docker
|
||||
# skips stages in log. So keeping this one as a safety check on null job
|
||||
if not job:
|
||||
raise DockerLogOutOfOrder(
|
||||
f"Received pull image log with status {reference.status} for image id {reference.id} and parent job {install_job_id} but could not find a matching job, skipping",
|
||||
_LOGGER.debug,
|
||||
)
|
||||
|
||||
# For progress calculation we assume downloading is 70% of time, extracting is 30% and others stages negligible
|
||||
progress = job.progress
|
||||
match stage:
|
||||
case PullImageLayerStage.DOWNLOADING | PullImageLayerStage.EXTRACTING:
|
||||
if (
|
||||
reference.progress_detail
|
||||
and reference.progress_detail.current
|
||||
and reference.progress_detail.total
|
||||
):
|
||||
progress = (
|
||||
reference.progress_detail.current
|
||||
/ reference.progress_detail.total
|
||||
)
|
||||
if stage == PullImageLayerStage.DOWNLOADING:
|
||||
progress = 70 * progress
|
||||
else:
|
||||
progress = 70 + 30 * progress
|
||||
case (
|
||||
PullImageLayerStage.VERIFYING_CHECKSUM
|
||||
| PullImageLayerStage.DOWNLOAD_COMPLETE
|
||||
):
|
||||
progress = 70
|
||||
case PullImageLayerStage.PULL_COMPLETE:
|
||||
progress = 100
|
||||
case PullImageLayerStage.RETRYING_DOWNLOAD:
|
||||
progress = 0
|
||||
|
||||
# No real risk of getting things out of order in current implementation
|
||||
# but keeping this one in case another change to these trips us up.
|
||||
if stage != PullImageLayerStage.RETRYING_DOWNLOAD and progress < job.progress:
|
||||
raise DockerLogOutOfOrder(
|
||||
f"Received pull image log with status {reference.status} for job {job.uuid} that implied progress was {progress} but current progress is {job.progress}, skipping",
|
||||
_LOGGER.debug,
|
||||
)
|
||||
|
||||
# Our filters have all passed. Time to update the job
|
||||
# Only downloading and extracting have progress details. Use that to set extra
|
||||
# We'll leave it around on later stages as the total bytes may be useful after that stage
|
||||
# Enforce range to prevent float drift error
|
||||
progress = max(0, min(progress, 100))
|
||||
if (
|
||||
stage in {PullImageLayerStage.DOWNLOADING, PullImageLayerStage.EXTRACTING}
|
||||
and reference.progress_detail
|
||||
and reference.progress_detail.current is not None
|
||||
and reference.progress_detail.total is not None
|
||||
):
|
||||
job.update(
|
||||
progress=progress,
|
||||
stage=stage.status,
|
||||
extra={
|
||||
"current": reference.progress_detail.current,
|
||||
"total": reference.progress_detail.total,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# If we reach DOWNLOAD_COMPLETE without ever having set extra (small layers that skip
|
||||
# the downloading phase), set a minimal extra so aggregate progress calculation can proceed
|
||||
extra = job.extra
|
||||
if stage == PullImageLayerStage.DOWNLOAD_COMPLETE and not job.extra:
|
||||
extra = {"current": 1, "total": 1}
|
||||
|
||||
job.update(
|
||||
progress=progress,
|
||||
stage=stage.status,
|
||||
done=stage == PullImageLayerStage.PULL_COMPLETE,
|
||||
extra=None if stage == PullImageLayerStage.RETRYING_DOWNLOAD else extra,
|
||||
)
|
||||
|
||||
# Once we have received a progress update for every child job, start to set status of the main one
|
||||
install_job = self.sys_jobs.get_job(install_job_id)
|
||||
layer_jobs = [
|
||||
job
|
||||
for job in self.sys_jobs.jobs
|
||||
if job.parent_id == install_job.uuid
|
||||
and job.name == "Pulling container image layer"
|
||||
]
|
||||
|
||||
# First set the total bytes to be downloaded/extracted on the main job
|
||||
if not install_job.extra:
|
||||
total = 0
|
||||
for job in layer_jobs:
|
||||
if not job.extra:
|
||||
return
|
||||
total += job.extra["total"]
|
||||
install_job.extra = {"total": total}
|
||||
else:
|
||||
total = install_job.extra["total"]
|
||||
|
||||
# Then determine total progress based on progress of each sub-job, factoring in size of each compared to total
|
||||
progress = 0.0
|
||||
stage = PullImageLayerStage.PULL_COMPLETE
|
||||
for job in layer_jobs:
|
||||
if not job.extra or not job.extra.get("total"):
|
||||
return
|
||||
progress += job.progress * (job.extra["total"] / total)
|
||||
job_stage = PullImageLayerStage.from_status(cast(str, job.stage))
|
||||
|
||||
if job_stage < PullImageLayerStage.EXTRACTING:
|
||||
stage = PullImageLayerStage.DOWNLOADING
|
||||
elif (
|
||||
stage == PullImageLayerStage.PULL_COMPLETE
|
||||
and job_stage < PullImageLayerStage.PULL_COMPLETE
|
||||
):
|
||||
stage = PullImageLayerStage.EXTRACTING
|
||||
|
||||
# Ensure progress is 100 at this point to prevent float drift
|
||||
if stage == PullImageLayerStage.PULL_COMPLETE:
|
||||
progress = 100
|
||||
|
||||
# To reduce noise, limit updates to when result has changed by an entire percent or when stage changed
|
||||
if stage != install_job.stage or progress >= install_job.progress + 1:
|
||||
install_job.update(stage=stage.status, progress=max(0, min(progress, 100)))
|
||||
|
||||
@Job(
|
||||
name="docker_interface_install",
|
||||
on_condition=DockerJobError,
|
||||
@@ -374,33 +213,55 @@ class DockerInterface(JobGroup, ABC):
|
||||
raise ValueError("Cannot pull without an image!")
|
||||
|
||||
image_arch = arch or self.sys_arch.supervisor
|
||||
listener: EventListener | None = None
|
||||
platform = MAP_ARCH[image_arch]
|
||||
pull_progress = ImagePullProgress()
|
||||
current_job = self.sys_jobs.current
|
||||
|
||||
# Try to fetch manifest for accurate size-based progress
|
||||
# This is optional - if it fails, we fall back to count-based progress
|
||||
try:
|
||||
manifest = await self.sys_docker.manifest_fetcher.get_manifest(
|
||||
image, str(version), platform=platform
|
||||
)
|
||||
if manifest:
|
||||
pull_progress.set_manifest(manifest)
|
||||
_LOGGER.debug(
|
||||
"Using manifest for progress: %d layers, %d bytes",
|
||||
manifest.layer_count,
|
||||
manifest.total_size,
|
||||
)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.debug("Could not fetch manifest for progress: %s", err)
|
||||
|
||||
async def process_pull_event(event: PullLogEntry) -> None:
|
||||
"""Process pull event and update job progress."""
|
||||
if event.job_id != current_job.uuid:
|
||||
return
|
||||
|
||||
# Process event through progress tracker
|
||||
pull_progress.process_event(event)
|
||||
|
||||
# Update job if progress changed significantly (>= 1%)
|
||||
should_update, progress = pull_progress.should_update_job()
|
||||
if should_update:
|
||||
stage = pull_progress.get_stage()
|
||||
current_job.update(progress=progress, stage=stage)
|
||||
|
||||
listener = self.sys_bus.register_event(
|
||||
BusEvent.DOCKER_IMAGE_PULL_UPDATE, process_pull_event
|
||||
)
|
||||
|
||||
_LOGGER.info("Downloading docker image %s with tag %s.", image, version)
|
||||
try:
|
||||
# Get credentials for private registries to pass to aiodocker
|
||||
credentials = self._get_credentials(image) or None
|
||||
|
||||
curr_job_id = self.sys_jobs.current.uuid
|
||||
|
||||
async def process_pull_image_log(reference: PullLogEntry) -> None:
|
||||
try:
|
||||
self._process_pull_image_log(curr_job_id, reference)
|
||||
except DockerLogOutOfOrder as err:
|
||||
# Send all these to sentry. Missing a few progress updates
|
||||
# shouldn't matter to users but matters to us
|
||||
await async_capture_exception(err)
|
||||
|
||||
listener = self.sys_bus.register_event(
|
||||
BusEvent.DOCKER_IMAGE_PULL_UPDATE, process_pull_image_log
|
||||
)
|
||||
|
||||
# Pull new image, passing credentials to aiodocker
|
||||
docker_image = await self.sys_docker.pull_image(
|
||||
self.sys_jobs.current.uuid,
|
||||
current_job.uuid,
|
||||
image,
|
||||
str(version),
|
||||
platform=MAP_ARCH[image_arch],
|
||||
platform=platform,
|
||||
auth=credentials,
|
||||
)
|
||||
|
||||
@@ -445,8 +306,7 @@ class DockerInterface(JobGroup, ABC):
|
||||
f"Unknown error with {image}:{version!s} -> {err!s}", _LOGGER.error
|
||||
) from err
|
||||
finally:
|
||||
if listener:
|
||||
self.sys_bus.remove_listener(listener)
|
||||
self.sys_bus.remove_listener(listener)
|
||||
|
||||
self._meta = docker_image
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ from ..exceptions import (
|
||||
from ..utils.common import FileConfiguration
|
||||
from ..validate import SCHEMA_DOCKER_CONFIG
|
||||
from .const import DOCKER_HUB, DOCKER_HUB_LEGACY, LABEL_MANAGED
|
||||
from .manifest import RegistryManifestFetcher
|
||||
from .monitor import DockerMonitor
|
||||
from .network import DockerNetwork
|
||||
from .utils import get_registry_from_image
|
||||
@@ -258,6 +259,9 @@ class DockerAPI(CoreSysAttributes):
|
||||
self._info: DockerInfo | None = None
|
||||
self.config: DockerConfig = DockerConfig()
|
||||
self._monitor: DockerMonitor = DockerMonitor(coresys)
|
||||
self._manifest_fetcher: RegistryManifestFetcher = RegistryManifestFetcher(
|
||||
coresys
|
||||
)
|
||||
|
||||
async def post_init(self) -> Self:
|
||||
"""Post init actions that must be done in event loop."""
|
||||
@@ -323,6 +327,11 @@ class DockerAPI(CoreSysAttributes):
|
||||
"""Return docker events monitor."""
|
||||
return self._monitor
|
||||
|
||||
@property
|
||||
def manifest_fetcher(self) -> RegistryManifestFetcher:
|
||||
"""Return manifest fetcher for registry access."""
|
||||
return self._manifest_fetcher
|
||||
|
||||
async def load(self) -> None:
|
||||
"""Start docker events monitor."""
|
||||
await self.monitor.load()
|
||||
|
||||
354
supervisor/docker/manifest.py
Normal file
354
supervisor/docker/manifest.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""Docker registry manifest fetcher.
|
||||
|
||||
Fetches image manifests directly from container registries to get layer sizes
|
||||
before pulling an image. This enables accurate size-based progress tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .const import DOCKER_HUB, IMAGE_WITH_HOST
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..coresys import CoreSys
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Default registry for images without explicit host
|
||||
DEFAULT_REGISTRY = "registry-1.docker.io"
|
||||
|
||||
# Media types for manifest requests
|
||||
MANIFEST_MEDIA_TYPES = (
|
||||
"application/vnd.docker.distribution.manifest.v2+json",
|
||||
"application/vnd.oci.image.manifest.v1+json",
|
||||
"application/vnd.docker.distribution.manifest.list.v2+json",
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageManifest:
|
||||
"""Container image manifest with layer information."""
|
||||
|
||||
digest: str
|
||||
total_size: int
|
||||
layers: dict[str, int] # digest -> size in bytes
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
"""Return number of layers."""
|
||||
return len(self.layers)
|
||||
|
||||
|
||||
def parse_image_reference(image: str, tag: str) -> tuple[str, str, str]:
|
||||
"""Parse image reference into (registry, repository, tag).
|
||||
|
||||
Examples:
|
||||
ghcr.io/home-assistant/home-assistant:2025.1.0
|
||||
-> (ghcr.io, home-assistant/home-assistant, 2025.1.0)
|
||||
homeassistant/home-assistant:latest
|
||||
-> (registry-1.docker.io, homeassistant/home-assistant, latest)
|
||||
alpine:3.18
|
||||
-> (registry-1.docker.io, library/alpine, 3.18)
|
||||
|
||||
"""
|
||||
# Check if image has explicit registry host
|
||||
match = IMAGE_WITH_HOST.match(image)
|
||||
if match:
|
||||
registry = match.group(1)
|
||||
repository = image[len(registry) + 1 :] # Remove "registry/" prefix
|
||||
else:
|
||||
registry = DEFAULT_REGISTRY
|
||||
repository = image
|
||||
# Docker Hub requires "library/" prefix for official images
|
||||
if "/" not in repository:
|
||||
repository = f"library/{repository}"
|
||||
|
||||
return registry, repository, tag
|
||||
|
||||
|
||||
class RegistryManifestFetcher:
|
||||
"""Fetches manifests from container registries."""
|
||||
|
||||
def __init__(self, coresys: CoreSys) -> None:
|
||||
"""Initialize the fetcher."""
|
||||
self.coresys = coresys
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession()
|
||||
return self._session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
def _get_credentials(self, registry: str) -> tuple[str, str] | None:
|
||||
"""Get credentials for registry from Docker config.
|
||||
|
||||
Returns (username, password) tuple or None if no credentials.
|
||||
"""
|
||||
registries = self.coresys.docker.config.registries
|
||||
|
||||
# Map registry hostname to config key
|
||||
# Docker Hub can be stored as "hub.docker.com" in config
|
||||
if registry == DEFAULT_REGISTRY:
|
||||
if DOCKER_HUB in registries:
|
||||
creds = registries[DOCKER_HUB]
|
||||
return creds.get("username"), creds.get("password")
|
||||
elif registry in registries:
|
||||
creds = registries[registry]
|
||||
return creds.get("username"), creds.get("password")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_auth_token(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
registry: str,
|
||||
repository: str,
|
||||
) -> str | None:
|
||||
"""Get authentication token for registry.
|
||||
|
||||
Uses the WWW-Authenticate header from a 401 response to discover
|
||||
the token endpoint, then requests a token with appropriate scope.
|
||||
"""
|
||||
# First, make an unauthenticated request to get WWW-Authenticate header
|
||||
manifest_url = f"https://{registry}/v2/{repository}/manifests/latest"
|
||||
|
||||
try:
|
||||
async with session.get(manifest_url) as resp:
|
||||
if resp.status == 200:
|
||||
# No auth required
|
||||
return None
|
||||
|
||||
if resp.status != 401:
|
||||
_LOGGER.warning(
|
||||
"Unexpected status %d from registry %s", resp.status, registry
|
||||
)
|
||||
return None
|
||||
|
||||
www_auth = resp.headers.get("WWW-Authenticate", "")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to connect to registry %s: %s", registry, err)
|
||||
return None
|
||||
|
||||
# Parse WWW-Authenticate: Bearer realm="...",service="...",scope="..."
|
||||
if not www_auth.startswith("Bearer "):
|
||||
_LOGGER.warning("Unsupported auth type from %s: %s", registry, www_auth)
|
||||
return None
|
||||
|
||||
params = {}
|
||||
for match in re.finditer(r'(\w+)="([^"]*)"', www_auth):
|
||||
params[match.group(1)] = match.group(2)
|
||||
|
||||
realm = params.get("realm")
|
||||
service = params.get("service")
|
||||
|
||||
if not realm:
|
||||
_LOGGER.warning("No realm in WWW-Authenticate from %s", registry)
|
||||
return None
|
||||
|
||||
# Build token request URL
|
||||
token_url = f"{realm}?scope=repository:{repository}:pull"
|
||||
if service:
|
||||
token_url += f"&service={service}"
|
||||
|
||||
# Check for credentials
|
||||
auth = None
|
||||
credentials = self._get_credentials(registry)
|
||||
if credentials:
|
||||
username, password = credentials
|
||||
if username and password:
|
||||
auth = aiohttp.BasicAuth(username, password)
|
||||
_LOGGER.debug("Using credentials for %s", registry)
|
||||
|
||||
try:
|
||||
async with session.get(token_url, auth=auth) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Failed to get token from %s: %d", realm, resp.status
|
||||
)
|
||||
return None
|
||||
|
||||
data = await resp.json()
|
||||
return data.get("token") or data.get("access_token")
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to get auth token: %s", err)
|
||||
return None
|
||||
|
||||
async def _fetch_manifest(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
registry: str,
|
||||
repository: str,
|
||||
reference: str,
|
||||
token: str | None,
|
||||
platform: str | None = None,
|
||||
) -> dict | None:
|
||||
"""Fetch manifest from registry.
|
||||
|
||||
If the manifest is a manifest list (multi-arch), fetches the
|
||||
platform-specific manifest.
|
||||
"""
|
||||
manifest_url = f"https://{registry}/v2/{repository}/manifests/{reference}"
|
||||
|
||||
headers = {"Accept": ", ".join(MANIFEST_MEDIA_TYPES)}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
try:
|
||||
async with session.get(manifest_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Failed to fetch manifest for %s/%s:%s - %d",
|
||||
registry,
|
||||
repository,
|
||||
reference,
|
||||
resp.status,
|
||||
)
|
||||
return None
|
||||
|
||||
manifest = await resp.json()
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch manifest: %s", err)
|
||||
return None
|
||||
|
||||
media_type = manifest.get("mediaType", "")
|
||||
|
||||
# Check if this is a manifest list (multi-arch image)
|
||||
if "list" in media_type or "index" in media_type:
|
||||
manifests = manifest.get("manifests", [])
|
||||
if not manifests:
|
||||
_LOGGER.warning("Empty manifest list for %s/%s", registry, repository)
|
||||
return None
|
||||
|
||||
# Find matching platform
|
||||
target_os = "linux"
|
||||
target_arch = "amd64" # Default
|
||||
|
||||
if platform:
|
||||
# Platform format is "linux/amd64", "linux/arm64", etc.
|
||||
parts = platform.split("/")
|
||||
if len(parts) >= 2:
|
||||
target_os, target_arch = parts[0], parts[1]
|
||||
|
||||
platform_manifest = None
|
||||
for m in manifests:
|
||||
plat = m.get("platform", {})
|
||||
if (
|
||||
plat.get("os") == target_os
|
||||
and plat.get("architecture") == target_arch
|
||||
):
|
||||
platform_manifest = m
|
||||
break
|
||||
|
||||
if not platform_manifest:
|
||||
# Fall back to first manifest
|
||||
_LOGGER.debug(
|
||||
"Platform %s/%s not found, using first manifest",
|
||||
target_os,
|
||||
target_arch,
|
||||
)
|
||||
platform_manifest = manifests[0]
|
||||
|
||||
# Fetch the platform-specific manifest
|
||||
return await self._fetch_manifest(
|
||||
session,
|
||||
registry,
|
||||
repository,
|
||||
platform_manifest["digest"],
|
||||
token,
|
||||
platform,
|
||||
)
|
||||
|
||||
return manifest
|
||||
|
||||
async def get_manifest(
|
||||
self,
|
||||
image: str,
|
||||
tag: str,
|
||||
platform: str | None = None,
|
||||
) -> ImageManifest | None:
|
||||
"""Fetch manifest and extract layer sizes.
|
||||
|
||||
Args:
|
||||
image: Image name (e.g., "ghcr.io/home-assistant/home-assistant")
|
||||
tag: Image tag (e.g., "2025.1.0")
|
||||
platform: Target platform (e.g., "linux/amd64")
|
||||
|
||||
Returns:
|
||||
ImageManifest with layer sizes, or None if fetch failed.
|
||||
|
||||
"""
|
||||
registry, repository, tag = parse_image_reference(image, tag)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Fetching manifest for %s/%s:%s (platform=%s)",
|
||||
registry,
|
||||
repository,
|
||||
tag,
|
||||
platform,
|
||||
)
|
||||
|
||||
session = await self._get_session()
|
||||
|
||||
# Get auth token
|
||||
token = await self._get_auth_token(session, registry, repository)
|
||||
|
||||
# Fetch manifest
|
||||
manifest = await self._fetch_manifest(
|
||||
session, registry, repository, tag, token, platform
|
||||
)
|
||||
|
||||
if not manifest:
|
||||
return None
|
||||
|
||||
# Extract layer information
|
||||
layers = manifest.get("layers", [])
|
||||
if not layers:
|
||||
_LOGGER.warning(
|
||||
"No layers in manifest for %s/%s:%s", registry, repository, tag
|
||||
)
|
||||
return None
|
||||
|
||||
layer_sizes: dict[str, int] = {}
|
||||
total_size = 0
|
||||
|
||||
for layer in layers:
|
||||
digest = layer.get("digest", "")
|
||||
size = layer.get("size", 0)
|
||||
if digest and size:
|
||||
# Store by short digest (first 12 chars after sha256:)
|
||||
short_digest = (
|
||||
digest.split(":")[1][:12] if ":" in digest else digest[:12]
|
||||
)
|
||||
layer_sizes[short_digest] = size
|
||||
total_size += size
|
||||
|
||||
digest = manifest.get("config", {}).get("digest", "")
|
||||
|
||||
_LOGGER.debug(
|
||||
"Manifest for %s/%s:%s - %d layers, %d bytes total",
|
||||
registry,
|
||||
repository,
|
||||
tag,
|
||||
len(layer_sizes),
|
||||
total_size,
|
||||
)
|
||||
|
||||
return ImageManifest(
|
||||
digest=digest,
|
||||
total_size=total_size,
|
||||
layers=layer_sizes,
|
||||
)
|
||||
371
supervisor/docker/pull_progress.py
Normal file
371
supervisor/docker/pull_progress.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Image pull progress tracking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .manager import PullLogEntry
|
||||
from .manifest import ImageManifest
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Progress weight distribution: 70% downloading, 30% extraction
|
||||
DOWNLOAD_WEIGHT = 70.0
|
||||
EXTRACT_WEIGHT = 30.0
|
||||
|
||||
|
||||
class LayerPullStatus(Enum):
|
||||
"""Status values for pulling an image layer.
|
||||
|
||||
These are a subset of the statuses in a docker pull image log.
|
||||
The order field allows comparing which stage is further along.
|
||||
"""
|
||||
|
||||
PULLING_FS_LAYER = 1, "Pulling fs layer"
|
||||
WAITING = 1, "Waiting"
|
||||
RETRYING = 2, "Retrying" # Matches "Retrying in N seconds"
|
||||
DOWNLOADING = 3, "Downloading"
|
||||
VERIFYING_CHECKSUM = 4, "Verifying Checksum"
|
||||
DOWNLOAD_COMPLETE = 5, "Download complete"
|
||||
EXTRACTING = 6, "Extracting"
|
||||
PULL_COMPLETE = 7, "Pull complete"
|
||||
ALREADY_EXISTS = 7, "Already exists"
|
||||
|
||||
def __init__(self, order: int, status: str) -> None:
|
||||
"""Set fields from values."""
|
||||
self.order = order
|
||||
self.status = status
|
||||
|
||||
def __eq__(self, value: object, /) -> bool:
|
||||
"""Check equality, allow string comparisons on status."""
|
||||
with suppress(AttributeError):
|
||||
return self.status == cast(LayerPullStatus, value).status
|
||||
return self.status == value
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash based on status string."""
|
||||
return hash(self.status)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
"""Order instances by stage progression."""
|
||||
with suppress(AttributeError):
|
||||
return self.order < cast(LayerPullStatus, other).order
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_status(cls, status: str) -> LayerPullStatus | None:
|
||||
"""Get enum from status string, or None if not recognized."""
|
||||
# Handle "Retrying in N seconds" pattern
|
||||
if status.startswith("Retrying in "):
|
||||
return cls.RETRYING
|
||||
for member in cls:
|
||||
if member.status == status:
|
||||
return member
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerProgress:
|
||||
"""Track progress of a single layer."""
|
||||
|
||||
layer_id: str
|
||||
total_size: int = 0 # Size in bytes (from downloading, reused for extraction)
|
||||
download_current: int = 0
|
||||
extract_current: int = 0 # Extraction progress in bytes (overlay2 only)
|
||||
download_complete: bool = False
|
||||
extract_complete: bool = False
|
||||
already_exists: bool = False # Layer was already locally available
|
||||
|
||||
def calculate_progress(self) -> float:
|
||||
"""Calculate layer progress 0-100.
|
||||
|
||||
Progress is weighted: 70% download, 30% extraction.
|
||||
For overlay2, we have byte-based extraction progress.
|
||||
For containerd, extraction jumps from 70% to 100% on completion.
|
||||
"""
|
||||
if self.already_exists or self.extract_complete:
|
||||
return 100.0
|
||||
|
||||
if self.download_complete:
|
||||
# Check if we have extraction progress (overlay2)
|
||||
if self.extract_current > 0 and self.total_size > 0:
|
||||
extract_pct = min(1.0, self.extract_current / self.total_size)
|
||||
return DOWNLOAD_WEIGHT + (extract_pct * EXTRACT_WEIGHT)
|
||||
# No extraction progress yet - return 70%
|
||||
return DOWNLOAD_WEIGHT
|
||||
|
||||
if self.total_size > 0:
|
||||
download_pct = min(1.0, self.download_current / self.total_size)
|
||||
return download_pct * DOWNLOAD_WEIGHT
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePullProgress:
|
||||
"""Track overall progress of pulling an image.
|
||||
|
||||
When manifest layer sizes are provided, uses size-weighted progress where
|
||||
each layer contributes proportionally to its size. This gives accurate
|
||||
progress based on actual bytes to download.
|
||||
|
||||
When manifest is not available, falls back to count-based progress where
|
||||
each layer contributes equally.
|
||||
|
||||
Layers that already exist locally are excluded from the progress calculation.
|
||||
"""
|
||||
|
||||
layers: dict[str, LayerProgress] = field(default_factory=dict)
|
||||
_last_reported_progress: float = field(default=0.0, repr=False)
|
||||
_seen_downloading: bool = field(default=False, repr=False)
|
||||
_manifest_layer_sizes: dict[str, int] = field(default_factory=dict, repr=False)
|
||||
_total_manifest_size: int = field(default=0, repr=False)
|
||||
|
||||
def set_manifest(self, manifest: ImageManifest) -> None:
|
||||
"""Set manifest layer sizes for accurate size-based progress.
|
||||
|
||||
Should be called before processing pull events.
|
||||
"""
|
||||
self._manifest_layer_sizes = dict(manifest.layers)
|
||||
self._total_manifest_size = manifest.total_size
|
||||
_LOGGER.debug(
|
||||
"Manifest set: %d layers, %d bytes total",
|
||||
len(self._manifest_layer_sizes),
|
||||
self._total_manifest_size,
|
||||
)
|
||||
|
||||
def get_or_create_layer(self, layer_id: str) -> LayerProgress:
|
||||
"""Get existing layer or create new one."""
|
||||
if layer_id not in self.layers:
|
||||
# If we have manifest sizes, pre-populate the layer's total_size
|
||||
manifest_size = self._manifest_layer_sizes.get(layer_id, 0)
|
||||
self.layers[layer_id] = LayerProgress(
|
||||
layer_id=layer_id, total_size=manifest_size
|
||||
)
|
||||
return self.layers[layer_id]
|
||||
|
||||
def process_event(self, entry: PullLogEntry) -> None:
|
||||
"""Process a pull log event and update layer state."""
|
||||
# Skip events without layer ID or status
|
||||
if not entry.id or not entry.status:
|
||||
return
|
||||
|
||||
# Skip metadata events that aren't layer-specific
|
||||
# "Pulling from X" has id=tag but isn't a layer
|
||||
if entry.status.startswith("Pulling from "):
|
||||
return
|
||||
|
||||
# Parse status to enum (returns None for unrecognized statuses)
|
||||
status = LayerPullStatus.from_status(entry.status)
|
||||
if status is None:
|
||||
return
|
||||
|
||||
layer = self.get_or_create_layer(entry.id)
|
||||
|
||||
# Handle "Already exists" - layer is locally available
|
||||
if status is LayerPullStatus.ALREADY_EXISTS:
|
||||
layer.already_exists = True
|
||||
layer.download_complete = True
|
||||
layer.extract_complete = True
|
||||
return
|
||||
|
||||
# Handle "Pulling fs layer" / "Waiting" - layer is being tracked
|
||||
if status in (LayerPullStatus.PULLING_FS_LAYER, LayerPullStatus.WAITING):
|
||||
return
|
||||
|
||||
# Handle "Downloading" - update download progress
|
||||
if status is LayerPullStatus.DOWNLOADING:
|
||||
# Mark that we've seen downloading - now we know layer count is complete
|
||||
self._seen_downloading = True
|
||||
if (
|
||||
entry.progress_detail
|
||||
and entry.progress_detail.current is not None
|
||||
and entry.progress_detail.total is not None
|
||||
):
|
||||
layer.download_current = entry.progress_detail.current
|
||||
# Only set total_size if not already set or if this is larger
|
||||
# (handles case where total changes during download)
|
||||
layer.total_size = max(layer.total_size, entry.progress_detail.total)
|
||||
return
|
||||
|
||||
# Handle "Verifying Checksum" - download is essentially complete
|
||||
if status is LayerPullStatus.VERIFYING_CHECKSUM:
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
return
|
||||
|
||||
# Handle "Download complete" - download phase done
|
||||
if status is LayerPullStatus.DOWNLOAD_COMPLETE:
|
||||
layer.download_complete = True
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
elif layer.total_size == 0:
|
||||
# Small layer that skipped downloading phase
|
||||
# Set minimal size so it doesn't distort weighted average
|
||||
layer.total_size = 1
|
||||
layer.download_current = 1
|
||||
return
|
||||
|
||||
# Handle "Extracting" - extraction in progress
|
||||
if status is LayerPullStatus.EXTRACTING:
|
||||
# For overlay2: progressDetail has {current, total} in bytes
|
||||
# For containerd: progressDetail has {current, units: "s"} (time elapsed)
|
||||
# We can only use byte-based progress (overlay2)
|
||||
layer.download_complete = True
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
|
||||
# Check if this is byte-based extraction progress (overlay2)
|
||||
# Overlay2 has {current, total} in bytes, no units field
|
||||
# Containerd has {current, units: "s"} which is useless for progress
|
||||
if (
|
||||
entry.progress_detail
|
||||
and entry.progress_detail.current is not None
|
||||
and entry.progress_detail.units is None
|
||||
):
|
||||
# Use layer's total_size from downloading phase (doesn't change)
|
||||
layer.extract_current = entry.progress_detail.current
|
||||
_LOGGER.debug(
|
||||
"Layer %s extracting: %d/%d (%.1f%%)",
|
||||
layer.layer_id,
|
||||
layer.extract_current,
|
||||
layer.total_size,
|
||||
(layer.extract_current / layer.total_size * 100)
|
||||
if layer.total_size > 0
|
||||
else 0,
|
||||
)
|
||||
return
|
||||
|
||||
# Handle "Pull complete" - layer is fully done
|
||||
if status is LayerPullStatus.PULL_COMPLETE:
|
||||
layer.download_complete = True
|
||||
layer.extract_complete = True
|
||||
if layer.total_size > 0:
|
||||
layer.download_current = layer.total_size
|
||||
return
|
||||
|
||||
# Handle "Retrying in N seconds" - reset download progress
|
||||
if status is LayerPullStatus.RETRYING:
|
||||
layer.download_current = 0
|
||||
layer.download_complete = False
|
||||
return
|
||||
|
||||
def calculate_progress(self) -> float:
|
||||
"""Calculate overall progress 0-100.
|
||||
|
||||
When manifest layer sizes are available, uses size-weighted progress
|
||||
where each layer contributes proportionally to its size.
|
||||
|
||||
When manifest is not available, falls back to count-based progress
|
||||
where each layer contributes equally.
|
||||
|
||||
Layers that already exist locally are excluded from the calculation.
|
||||
|
||||
Returns 0 until we've seen the first "Downloading" event, since Docker
|
||||
reports "Already exists" and "Pulling fs layer" events before we know
|
||||
the complete layer count.
|
||||
"""
|
||||
# Don't report progress until we've seen downloading start
|
||||
# This ensures we know the full layer count before calculating progress
|
||||
if not self._seen_downloading or not self.layers:
|
||||
return 0.0
|
||||
|
||||
# Only count layers that need pulling (exclude already_exists)
|
||||
layers_to_pull = [
|
||||
layer for layer in self.layers.values() if not layer.already_exists
|
||||
]
|
||||
|
||||
if not layers_to_pull:
|
||||
# All layers already exist, nothing to download
|
||||
return 100.0
|
||||
|
||||
# Use size-weighted progress if manifest sizes are available
|
||||
if self._manifest_layer_sizes:
|
||||
return min(100, self._calculate_size_weighted_progress(layers_to_pull))
|
||||
|
||||
# Fall back to count-based progress
|
||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||
return min(100, total_progress / len(layers_to_pull))
|
||||
|
||||
def _calculate_size_weighted_progress(
|
||||
self, layers_to_pull: list[LayerProgress]
|
||||
) -> float:
|
||||
"""Calculate size-weighted progress.
|
||||
|
||||
Each layer contributes to progress proportionally to its size.
|
||||
Progress = sum(layer_progress * layer_size) / total_size
|
||||
"""
|
||||
# Calculate total size of layers that need pulling
|
||||
total_size = sum(layer.total_size for layer in layers_to_pull)
|
||||
|
||||
if total_size == 0:
|
||||
# No size info available, fall back to count-based
|
||||
total_progress = sum(layer.calculate_progress() for layer in layers_to_pull)
|
||||
return total_progress / len(layers_to_pull)
|
||||
|
||||
# Weight each layer's progress by its size
|
||||
weighted_progress = 0.0
|
||||
for layer in layers_to_pull:
|
||||
if layer.total_size > 0:
|
||||
layer_weight = layer.total_size / total_size
|
||||
weighted_progress += layer.calculate_progress() * layer_weight
|
||||
|
||||
return weighted_progress
|
||||
|
||||
def get_stage(self) -> str | None:
|
||||
"""Get current stage based on layer states."""
|
||||
if not self.layers:
|
||||
return None
|
||||
|
||||
# Check if any layer is still downloading
|
||||
for layer in self.layers.values():
|
||||
if layer.already_exists:
|
||||
continue
|
||||
if not layer.download_complete:
|
||||
return "Downloading"
|
||||
|
||||
# All downloads complete, check if extracting
|
||||
for layer in self.layers.values():
|
||||
if layer.already_exists:
|
||||
continue
|
||||
if not layer.extract_complete:
|
||||
return "Extracting"
|
||||
|
||||
# All done
|
||||
return "Pull complete"
|
||||
|
||||
def should_update_job(self, threshold: float = 1.0) -> tuple[bool, float]:
|
||||
"""Check if job should be updated based on progress change.
|
||||
|
||||
Returns (should_update, current_progress).
|
||||
Updates are triggered when progress changes by at least threshold%.
|
||||
Progress is guaranteed to only increase (monotonic).
|
||||
"""
|
||||
current_progress = self.calculate_progress()
|
||||
|
||||
# Ensure monotonic progress - never report a decrease
|
||||
# This can happen when new layers get size info and change the weighted average
|
||||
if current_progress < self._last_reported_progress:
|
||||
_LOGGER.debug(
|
||||
"Progress decreased from %.1f%% to %.1f%%, keeping last reported",
|
||||
self._last_reported_progress,
|
||||
current_progress,
|
||||
)
|
||||
return False, self._last_reported_progress
|
||||
|
||||
if current_progress >= self._last_reported_progress + threshold:
|
||||
_LOGGER.debug(
|
||||
"Progress update: %.1f%% -> %.1f%% (delta: %.1f%%)",
|
||||
self._last_reported_progress,
|
||||
current_progress,
|
||||
current_progress - self._last_reported_progress,
|
||||
)
|
||||
self._last_reported_progress = current_progress
|
||||
return True, current_progress
|
||||
|
||||
return False, self._last_reported_progress
|
||||
@@ -632,10 +632,6 @@ class DockerNotFound(DockerError):
|
||||
"""Docker object don't Exists."""
|
||||
|
||||
|
||||
class DockerLogOutOfOrder(DockerError):
|
||||
"""Raise when log from docker action was out of order."""
|
||||
|
||||
|
||||
class DockerNoSpaceOnDevice(DockerError):
|
||||
"""Raise if a docker pull fails due to available space."""
|
||||
|
||||
|
||||
@@ -305,6 +305,8 @@ async def test_api_progress_updates_home_assistant_update(
|
||||
and evt.args[0]["data"]["event"] == WSEvent.JOB
|
||||
and evt.args[0]["data"]["data"]["name"] == "home_assistant_core_update"
|
||||
]
|
||||
# Count-based progress: 2 layers need pulling (each worth 50%)
|
||||
# Layers that already exist are excluded from progress calculation
|
||||
assert events[:5] == [
|
||||
{
|
||||
"stage": None,
|
||||
@@ -318,36 +320,36 @@ async def test_api_progress_updates_home_assistant_update(
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 0.1,
|
||||
"progress": 9.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 1.7,
|
||||
"progress": 25.6,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 4.0,
|
||||
"progress": 35.4,
|
||||
"done": False,
|
||||
},
|
||||
]
|
||||
assert events[-5:] == [
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 95.5,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 96.9,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 99.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 100,
|
||||
|
||||
@@ -764,6 +764,8 @@ async def test_api_progress_updates_addon_install_update(
|
||||
and evt.args[0]["data"]["data"]["name"] == job_name
|
||||
and evt.args[0]["data"]["data"]["reference"] == addon_slug
|
||||
]
|
||||
# Count-based progress: 2 layers need pulling (each worth 50%)
|
||||
# Layers that already exist are excluded from progress calculation
|
||||
assert events[:4] == [
|
||||
{
|
||||
"stage": None,
|
||||
@@ -772,36 +774,36 @@ async def test_api_progress_updates_addon_install_update(
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 0.1,
|
||||
"progress": 9.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 1.7,
|
||||
"progress": 25.6,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 4.0,
|
||||
"progress": 35.4,
|
||||
"done": False,
|
||||
},
|
||||
]
|
||||
assert events[-5:] == [
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 95.5,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 96.9,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 99.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 100,
|
||||
|
||||
@@ -358,6 +358,8 @@ async def test_api_progress_updates_supervisor_update(
|
||||
and evt.args[0]["data"]["event"] == WSEvent.JOB
|
||||
and evt.args[0]["data"]["data"]["name"] == "supervisor_update"
|
||||
]
|
||||
# Count-based progress: 2 layers need pulling (each worth 50%)
|
||||
# Layers that already exist are excluded from progress calculation
|
||||
assert events[:4] == [
|
||||
{
|
||||
"stage": None,
|
||||
@@ -366,36 +368,36 @@ async def test_api_progress_updates_supervisor_update(
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 0.1,
|
||||
"progress": 9.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 1.7,
|
||||
"progress": 25.6,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 4.0,
|
||||
"progress": 35.4,
|
||||
"done": False,
|
||||
},
|
||||
]
|
||||
assert events[-5:] == [
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 95.5,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 96.9,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.2,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 98.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 99.3,
|
||||
"done": False,
|
||||
},
|
||||
{
|
||||
"stage": None,
|
||||
"progress": 100,
|
||||
|
||||
@@ -154,6 +154,9 @@ async def docker() -> DockerAPI:
|
||||
docker_obj.info.storage = "overlay2"
|
||||
docker_obj.info.version = AwesomeVersion("1.0.0")
|
||||
|
||||
# Mock manifest fetcher to return None (falls back to count-based progress)
|
||||
docker_obj._manifest_fetcher.get_manifest = AsyncMock(return_value=None)
|
||||
|
||||
yield docker_obj
|
||||
|
||||
|
||||
|
||||
@@ -709,11 +709,18 @@ async def test_install_progress_handles_layers_skipping_download(
|
||||
await install_task
|
||||
await event.wait()
|
||||
|
||||
# First update from layer download should have rather low progress ((260937/25445459) / 2 ~ 0.5%)
|
||||
assert install_job_snapshots[0]["progress"] < 1
|
||||
# With the new progress calculation approach:
|
||||
# - Progress is weighted by layer size
|
||||
# - Small layers that skip downloading get minimal size (1 byte)
|
||||
# - Progress should increase monotonically
|
||||
assert len(install_job_snapshots) > 0
|
||||
|
||||
# Total 8 events should lead to a progress update on the install job
|
||||
assert len(install_job_snapshots) == 8
|
||||
# Verify progress is monotonically increasing (or stable)
|
||||
for i in range(1, len(install_job_snapshots)):
|
||||
assert (
|
||||
install_job_snapshots[i]["progress"]
|
||||
>= install_job_snapshots[i - 1]["progress"]
|
||||
)
|
||||
|
||||
# Job should complete successfully
|
||||
assert job.done is True
|
||||
@@ -844,24 +851,24 @@ async def test_install_progress_containerd_snapshot(
|
||||
}
|
||||
|
||||
assert [c.args[0] for c in ha_ws_client.async_send_command.call_args_list] == [
|
||||
# During downloading we get continuous progress updates from download status
|
||||
# Count-based progress: 2 layers, each = 50%. Download = 0-35%, Extract = 35-50%
|
||||
job_event(0),
|
||||
job_event(1.7),
|
||||
job_event(3.4),
|
||||
job_event(8.5),
|
||||
job_event(8.4),
|
||||
job_event(10.2),
|
||||
job_event(15.3),
|
||||
job_event(18.8),
|
||||
job_event(29.0),
|
||||
job_event(35.8),
|
||||
job_event(42.6),
|
||||
job_event(49.5),
|
||||
job_event(56.0),
|
||||
job_event(62.8),
|
||||
# Downloading phase is considered 70% of total. After we only get one update
|
||||
# per image downloaded when extraction is finished. It uses the total size
|
||||
# received during downloading to determine percent complete then.
|
||||
job_event(15.2),
|
||||
job_event(18.7),
|
||||
job_event(28.8),
|
||||
job_event(35.7),
|
||||
job_event(42.4),
|
||||
job_event(49.3),
|
||||
job_event(55.8),
|
||||
job_event(62.7),
|
||||
# Downloading phase is considered 70% of layer's progress.
|
||||
# After download complete, extraction takes remaining 30% per layer.
|
||||
job_event(70.0),
|
||||
job_event(84.8),
|
||||
job_event(85.0),
|
||||
job_event(100),
|
||||
job_event(100, True),
|
||||
]
|
||||
|
||||
164
tests/docker/test_manifest.py
Normal file
164
tests/docker/test_manifest.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Tests for registry manifest fetcher."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from supervisor.docker.manifest import (
|
||||
DEFAULT_REGISTRY,
|
||||
ImageManifest,
|
||||
RegistryManifestFetcher,
|
||||
parse_image_reference,
|
||||
)
|
||||
|
||||
|
||||
class TestParseImageReference:
|
||||
"""Tests for parse_image_reference function."""
|
||||
|
||||
def test_ghcr_io_image(self):
|
||||
"""Test parsing ghcr.io image."""
|
||||
registry, repo, tag = parse_image_reference(
|
||||
"ghcr.io/home-assistant/home-assistant", "2025.1.0"
|
||||
)
|
||||
assert registry == "ghcr.io"
|
||||
assert repo == "home-assistant/home-assistant"
|
||||
assert tag == "2025.1.0"
|
||||
|
||||
def test_docker_hub_with_org(self):
|
||||
"""Test parsing Docker Hub image with organization."""
|
||||
registry, repo, tag = parse_image_reference(
|
||||
"homeassistant/home-assistant", "latest"
|
||||
)
|
||||
assert registry == DEFAULT_REGISTRY
|
||||
assert repo == "homeassistant/home-assistant"
|
||||
assert tag == "latest"
|
||||
|
||||
def test_docker_hub_official_image(self):
|
||||
"""Test parsing Docker Hub official image (no org)."""
|
||||
registry, repo, tag = parse_image_reference("alpine", "3.18")
|
||||
assert registry == DEFAULT_REGISTRY
|
||||
assert repo == "library/alpine"
|
||||
assert tag == "3.18"
|
||||
|
||||
def test_gcr_io_image(self):
|
||||
"""Test parsing gcr.io image."""
|
||||
registry, repo, tag = parse_image_reference("gcr.io/project/image", "v1")
|
||||
assert registry == "gcr.io"
|
||||
assert repo == "project/image"
|
||||
assert tag == "v1"
|
||||
|
||||
|
||||
class TestImageManifest:
|
||||
"""Tests for ImageManifest dataclass."""
|
||||
|
||||
def test_layer_count(self):
|
||||
"""Test layer_count property."""
|
||||
manifest = ImageManifest(
|
||||
digest="sha256:abc",
|
||||
total_size=1000,
|
||||
layers={"layer1": 500, "layer2": 500},
|
||||
)
|
||||
assert manifest.layer_count == 2
|
||||
|
||||
|
||||
class TestRegistryManifestFetcher:
|
||||
"""Tests for RegistryManifestFetcher class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_coresys(self):
|
||||
"""Create mock coresys."""
|
||||
coresys = MagicMock()
|
||||
coresys.docker.config.registries = {}
|
||||
return coresys
|
||||
|
||||
@pytest.fixture
|
||||
def fetcher(self, mock_coresys):
|
||||
"""Create fetcher instance."""
|
||||
return RegistryManifestFetcher(mock_coresys)
|
||||
|
||||
async def test_get_manifest_success(self, fetcher):
|
||||
"""Test successful manifest fetch by mocking internal methods."""
|
||||
manifest_data = {
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {"digest": "sha256:abc123"},
|
||||
"layers": [
|
||||
{"digest": "sha256:layer1abc123def456789012", "size": 1000},
|
||||
{"digest": "sha256:layer2def456abc789012345", "size": 2000},
|
||||
],
|
||||
}
|
||||
|
||||
# Mock the internal methods instead of the session
|
||||
with (
|
||||
patch.object(
|
||||
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||
),
|
||||
patch.object(
|
||||
fetcher, "_fetch_manifest", new=AsyncMock(return_value=manifest_data)
|
||||
),
|
||||
patch.object(fetcher, "_get_session", new=AsyncMock()),
|
||||
):
|
||||
result = await fetcher.get_manifest(
|
||||
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.total_size == 3000
|
||||
assert result.layer_count == 2
|
||||
# First 12 chars after sha256:
|
||||
assert "layer1abc123" in result.layers
|
||||
assert result.layers["layer1abc123"] == 1000
|
||||
|
||||
async def test_get_manifest_returns_none_on_failure(self, fetcher):
|
||||
"""Test that get_manifest returns None on failure."""
|
||||
with (
|
||||
patch.object(
|
||||
fetcher, "_get_auth_token", new=AsyncMock(return_value="test-token")
|
||||
),
|
||||
patch.object(fetcher, "_fetch_manifest", new=AsyncMock(return_value=None)),
|
||||
patch.object(fetcher, "_get_session", new=AsyncMock()),
|
||||
):
|
||||
result = await fetcher.get_manifest(
|
||||
"test.io/org/image", "v1.0", platform="linux/amd64"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
async def test_close_session(self, fetcher):
|
||||
"""Test session cleanup."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.closed = False
|
||||
mock_session.close = AsyncMock()
|
||||
fetcher._session = mock_session
|
||||
|
||||
await fetcher.close()
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
assert fetcher._session is None
|
||||
|
||||
def test_get_credentials_docker_hub(self, mock_coresys, fetcher):
|
||||
"""Test getting Docker Hub credentials."""
|
||||
mock_coresys.docker.config.registries = {
|
||||
"hub.docker.com": {"username": "user", "password": "pass"}
|
||||
}
|
||||
|
||||
creds = fetcher._get_credentials(DEFAULT_REGISTRY)
|
||||
|
||||
assert creds == ("user", "pass")
|
||||
|
||||
def test_get_credentials_custom_registry(self, mock_coresys, fetcher):
|
||||
"""Test getting credentials for custom registry."""
|
||||
mock_coresys.docker.config.registries = {
|
||||
"ghcr.io": {"username": "user", "password": "token"}
|
||||
}
|
||||
|
||||
creds = fetcher._get_credentials("ghcr.io")
|
||||
|
||||
assert creds == ("user", "token")
|
||||
|
||||
def test_get_credentials_not_found(self, mock_coresys, fetcher):
|
||||
"""Test no credentials found."""
|
||||
mock_coresys.docker.config.registries = {}
|
||||
|
||||
creds = fetcher._get_credentials("unknown.io")
|
||||
|
||||
assert creds is None
|
||||
1002
tests/docker/test_pull_progress.py
Normal file
1002
tests/docker/test_pull_progress.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user