server/internal: replace model delete API with new registry handler. (#9347)

This commit introduces a new API implementation for handling
interactions with the registry and the local model cache. The new API is
located in server/internal/registry. The package name is "registry" and
should be considered temporary; it is hidden and not bleeding outside of
the server package. As the commits roll in, we'll start consuming more
of the API and then let reverse osmosis take effect, at which point it
will surface closer to the root level packages as much as needed.
This commit is contained in:
Blake Mizerany 2025-02-27 12:04:53 -08:00 committed by GitHub
parent be2ac1ed93
commit 2412adf42b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 705 additions and 90 deletions

14
go.mod
View File

@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.10.0
golang.org/x/sync v0.11.0
)
require (
@ -69,12 +69,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.28.0
golang.org/x/term v0.27.0
golang.org/x/text v0.21.0
golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0
golang.org/x/term v0.29.0
golang.org/x/text v0.22.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

28
go.sum
View File

@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -279,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name)
if err != nil {
return err
@ -304,21 +316,19 @@ func (c *DiskCache) Link(name string, d Digest) error {
return c.copyNamedFile(manifest, f, d, info.Size())
}
// Unlink removes the any link for name. If the link does not exist, nothing
// happens, and no error is returned.
//
// It returns an error if the name is invalid or if the link removal encounters
// any issues.
func (c *DiskCache) Unlink(name string) error {
// Unlink unlinks the manifest by name from the cache. If the name is not
// found. If a manifest is removed ok will be true, otherwise false. If an
// error occurs, it returns ok false, and the error.
func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
manifest, err := c.manifestPath(name)
if err != nil {
return err
return false, err
}
err = os.Remove(manifest)
if errors.Is(err, fs.ErrNotExist) {
return nil
return false, nil
}
return err
return true, err
}
// GetFile returns the absolute path to the file, in the cache, for the given

View File

@ -13,7 +13,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/server/internal/internal/testutil"
"github.com/ollama/ollama/server/internal/testutil"
)
func init() {
@ -479,8 +479,11 @@ func testManifestNameReuse(t *testing.T) {
}
// relink with different case
err = c.Unlink("h/n/m:t")
unlinked, err := c.Unlink("h/n/m:t")
check(err)
if !unlinked {
t.Fatal("expected unlinked")
}
err = c.Link("h/n/m:T", d1)
check(err)

View File

@ -86,7 +86,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
// link to docs on that topic.
lines := strings.Split(volumeHint, "\n")
for _, line := range lines {
t.Log(line)
t.Skip(line)
}
}
return false

View File

@ -19,6 +19,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
"path/filepath"
@ -86,9 +87,23 @@ func DefaultCache() (*blob.DiskCache, error) {
return blob.Open(dir)
}
// Error is the standard error returned by Ollama APIs.
// Error is the standard error returned by Ollama APIs. It can represent a
// single or multiple error response.
//
// Single error responses have the following format:
//
// {"code": "optional_code","error":"error message"}
//
// Multiple error responses have the following format:
//
// {"errors": [{"code": "optional_code","message":"error message"}]}
//
// Note, that the error field is used in single error responses, while the
// message field is used in multiple error responses.
//
// In both cases, the code field is optional and may be empty.
type Error struct {
Status int `json:"-"`
Status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
@ -97,13 +112,34 @@ func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
}
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.Status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *Error) UnmarshalJSON(b []byte) error {
type E Error
var v struct{ Errors []E }
var v struct {
// Single error
Code string
Error string
// Multiple errors
Errors []E
}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
if v.Error != "" {
// Single error case
e.Code = v.Code
e.Message = v.Error
return nil
}
if len(v.Errors) == 0 {
return fmt.Errorf("no messages in error response: %s", string(b))
}
@ -111,9 +147,8 @@ func (e *Error) UnmarshalJSON(b []byte) error {
return nil
}
// TODO(bmizerany): make configurable on [Registry]
var defaultName = func() names.Name {
n := names.Parse("ollama.com/library/_:latest")
n := names.Parse("registry.ollama.ai/library/_:latest")
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
}
@ -160,21 +195,26 @@ type Registry struct {
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// NameMask, if set, is the name used to convert non-fully qualified
// names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used.
NameMask string
}
// RegistryFromEnv returns a new Registry configured from the environment. The
// DefaultRegistry returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
// system's temporary directory.
//
// It returns an error if any configuration in the environment is invalid.
func RegistryFromEnv() (*Registry, error) {
func DefaultRegistry() (*Registry, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
if err != nil {
if err != nil && errors.Is(err, fs.ErrNotExist) {
return nil, err
}
@ -208,9 +248,19 @@ type PushParams struct {
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
maskName := defaultName
if mask != "" {
maskName = names.Parse(mask)
if !maskName.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
}
}
scheme, n, ds := names.ParseExtended(s)
n = names.Merge(n, defaultName)
if !n.IsValid() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
n = names.Merge(n, maskName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
@ -223,7 +273,7 @@ func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error)
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
scheme = cmp.Or(scheme, "https")
@ -255,7 +305,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
p = &PushParams{}
}
m, err := ResolveLocal(c, cmp.Or(p.From, name))
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
if err != nil {
return err
}
@ -278,7 +328,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name)
scheme, n, _, err := parseName(name, r.NameMask)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@ -372,7 +422,7 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name)
scheme, n, _, err := parseName(name, r.NameMask)
if err != nil {
return err
}
@ -520,6 +570,16 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
_, n, _, err := parseName(name, r.NameMask)
if err != nil {
return false, err
}
return c.Unlink(n.String())
}
// Manifest represents a [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
@ -590,8 +650,8 @@ type Layer struct {
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored.
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name)
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name, r.NameMask)
if err != nil {
return nil, err
}
@ -617,7 +677,7 @@ func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name)
scheme, n, d, err := parseName(name, r.NameMask)
if err != nil {
return nil, err
}

View File

@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/testutil"
"github.com/ollama/ollama/server/internal/testutil"
)
func TestManifestMarshalJSON(t *testing.T) {
@ -37,20 +37,6 @@ func TestManifestMarshalJSON(t *testing.T) {
}
}
func link(c *blob.DiskCache, name string, manifest string) {
_, n, _, err := parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
@ -98,29 +84,44 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
}
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
link := func(name string, manifest string) {
_, n, _, err := parseName(name, rc.NameMask)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(c, name, string(data))
link(name, string(data))
}
link(c, "empty", "")
link("empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link(c, "invalid", "!!!!!")
link("invalid", "!!!!!")
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
return rc, c
}
@ -385,7 +386,7 @@ func TestRegistryPullNotCached(t *testing.T) {
})
// Confirm that the layer does not exist locally
_, err := ResolveLocal(c, "model")
_, err := rc.ResolveLocal(c, "model")
checkNotExist(t, err)
_, err = c.Get(d)
@ -396,7 +397,7 @@ func TestRegistryPullNotCached(t *testing.T) {
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := ResolveLocal(c, "model")
mg, err := rc.ResolveLocal(c, "model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
@ -654,3 +655,72 @@ func TestCanRetry(t *testing.T) {
}
}
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
data string
want *Error
wantErr bool
}{
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors single",
data: `{"errors":[{"code":"blob_unknown"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "errors multiple",
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "error empty",
data: `{"error":""}`,
wantErr: true,
},
{
name: "error very empty",
data: `{}`,
wantErr: true,
},
{
name: "error message",
data: `{"error":"message", "code":"code"}`,
want: &Error{Code: "code", Message: "message"},
},
{
name: "invalid value",
data: `{"error": 1}`,
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var got Error
err := json.Unmarshal([]byte(tt.data), &got)
if err != nil {
if tt.wantErr {
return
}
t.Errorf("Unmarshal() error = %v", err)
// fallthrough and check got
}
if tt.want == nil {
tt.want = &Error{}
}
if !reflect.DeepEqual(got, *tt.want) {
t.Errorf("got = %v; want %v", got, *tt.want)
}
})
}
}

View File

@ -68,7 +68,7 @@ func main() {
log.Fatal(err)
}
rc, err := ollama.RegistryFromEnv()
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
@ -177,7 +177,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}
from := cmp.Or(*flagFrom, model)
m, err := ollama.ResolveLocal(c, from)
m, err := rc.ResolveLocal(c, from)
if err != nil {
return err
}

View File

@ -0,0 +1,215 @@
// Package registry provides an http.Handler for handling local Ollama API
// requests for performing tasks related to the ollama.com model registry and
// the local disk cache.
package registry
import (
"cmp"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
)
// Local is an http.Handler for handling local Ollama API requests for
// performing tasks related to the ollama.com model registry combined with the
// local disk cache.
//
// It is not concern of Local, or this package, to handle model creation, which
// proceeds any registry operations for models it produces.
//
// NOTE: The package built for dealing with model creation should use
// [DefaultCache] to access the blob store and not attempt to read or write
// directly to the blob disk cache.
type Local struct {
Client *ollama.Registry // required
Cache *blob.DiskCache // required
Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by
// this handler.
Fallback http.Handler
}
// serverError is like ollama.Error, but with a Status field for the HTTP
// response code. We want to avoid adding that field to ollama.Error because it
// would always be 0 to clients (we don't want to leak the status code in
// errors), and so it would be confusing to have a field that is always 0.
type serverError struct {
Status int `json:"-"`
// TODO(bmizerany): Decide if we want to keep this and maybe
// bring back later.
Code string `json:"code"`
Message string `json:"error"`
}
func (e serverError) Error() string {
return e.Message
}
// Common API errors
var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"}
)
type statusCodeRecorder struct {
_status int // use status() to get the status code
http.ResponseWriter
}
func (r *statusCodeRecorder) WriteHeader(status int) {
if r._status == 0 {
r._status = status
}
r.ResponseWriter.WriteHeader(status)
}
func (r *statusCodeRecorder) status() int {
return cmp.Or(r._status, 200)
}
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rec := &statusCodeRecorder{ResponseWriter: w}
s.serveHTTP(rec, r)
}
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
var errattr slog.Attr
proxied, err := func() (bool, error) {
switch r.URL.Path {
case "/api/delete":
return false, s.handleDelete(rec, r)
default:
if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r)
return true, nil
}
return false, errNotFound
}
}()
if err != nil {
// We always log the error, so fill in the error log attribute
errattr = slog.String("error", err.Error())
var e *serverError
switch {
case errors.As(err, &e):
case errors.Is(err, ollama.ErrNameInvalid):
e = &serverError{400, "bad_request", err.Error()}
default:
e = errInternalError
}
data, err := json.Marshal(e)
if err != nil {
// unreachable
panic(err)
}
rec.Header().Set("Content-Type", "application/json")
rec.WriteHeader(e.Status)
rec.Write(data)
// fallthrough to log
}
if !proxied {
// we're only responsible for logging if we handled the request
var level slog.Level
if rec.status() >= 500 {
level = slog.LevelError
} else if rec.status() >= 400 {
level = slog.LevelWarn
}
s.Logger.LogAttrs(r.Context(), level, "http",
errattr, // report first in line to make it easy to find
// TODO(bmizerany): Write a test to ensure that we are logging
// all of this correctly. That also goes for the level+error
// logic above.
slog.Int("status", rec.status()),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int64("content-length", r.ContentLength),
slog.String("remote", r.RemoteAddr),
slog.String("proto", r.Proto),
slog.String("query", r.URL.RawQuery),
)
}
}
type params struct {
DeprecatedName string `json:"name"` // Use [params.model]
Model string `json:"model"` // Use [params.model]
// AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately.
//
// Deprecated: This field is ignored and only present for this
// deprecation message. It should be removed in a future release.
//
// Users can just use http or https+insecure to show intent to
// communicate they want to do insecure things, without awkward and
// confusing flags such as this.
AllowNonTLS bool `json:"insecure"`
// ProgressStream is a flag that indicates the client is expecting a stream of
// progress updates.
ProgressStream bool `json:"stream"`
}
// model returns the model name for both old and new API requests.
func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName)
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
if err != nil {
return err
}
ok, err := s.Client.Unlink(s.Cache, p.model())
if err != nil {
return err
}
if !ok {
return &serverError{404, "manifest_not_found", "manifest not found"}
}
return nil
}
func decodeUserJSON[T any](r io.Reader) (T, error) {
var v T
err := json.NewDecoder(r).Decode(&v)
if err == nil {
return v, nil
}
var zero T
// Not sure why, but I can't seem to be able to use:
//
// errors.As(err, &json.UnmarshalTypeError{})
//
// This is working fine in stdlib, so I'm not sure what rules changed
// and why this no longer works here. So, we do it the verbose way.
var a *json.UnmarshalTypeError
var b *json.SyntaxError
if errors.As(err, &a) || errors.As(err, &b) {
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
}
if errors.Is(err, io.EOF) {
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
}
return zero, err
}

View File

@ -0,0 +1,168 @@
package registry
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strings"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil"
)
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
panic("unexpected RoundTrip call")
}
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
// bytesResetter is an interface for types that can be reset and return a byte
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
// etc for the purpose of checking logs.
type bytesResetter interface {
Bytes() []byte
Reset()
}
func newTestServer(t *testing.T) *Local {
t.Helper()
dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models"))
if err != nil {
t.Fatal(err)
}
c, err := blob.Open(dir)
if err != nil {
t.Fatal(err)
}
rc := &ollama.Registry{
HTTPClient: panicOnRoundTrip,
}
l := &Local{
Cache: c,
Client: rc,
Logger: testutil.Slogger(t),
}
return l
}
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
return s.sendRequest(t, req)
}
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
t.Helper()
w := httptest.NewRecorder()
s.ServeHTTP(w, req)
return w
}
type invalidReader struct{}
func (r *invalidReader) Read(p []byte) (int, error) {
return 0, os.ErrInvalid
}
// captureLogs is a helper to capture logs from the server. It returns a
// shallow copy of the server with a new logger and a bytesResetter for the
// logs.
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
t.Helper()
log, logs := testutil.SlogBuffer()
l := *s // shallow copy
l.Logger = log
return &l, logs
}
func TestServerDelete(t *testing.T) {
check := testutil.Checker(t)
s := newTestServer(t)
_, err := s.Client.ResolveLocal(s.Cache, "smol")
check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
_, err = s.Client.ResolveLocal(s.Cache, "smol")
if err == nil {
t.Fatal("expected smol to have been deleted")
}
got = s.send(t, "DELETE", "/api/delete", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "DELETE", "/api/delete", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid name")
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
checkErrorResponse(t, got, 404, "not_found", "not found")
s, logs := captureLogs(t, s)
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
got = s.sendRequest(t, req)
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
check(err)
if !ok {
t.Logf("logs:\n%s", logs)
t.Fatalf("expected log to contain ERROR with invalid argument")
}
}
func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
}
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
t.Helper()
var printedBody bool
errorf := func(format string, args ...any) {
t.Helper()
if !printedBody {
t.Logf("BODY:\n%s", got.Body.String())
printedBody = true
}
t.Errorf(format, args...)
}
if got.Code != status {
errorf("Code = %d; want %d", got.Code, status)
}
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
var e *ollama.Error
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
errorf("unmarshal error: %v", err)
t.FailNow()
}
if e.Code != code {
errorf("Code = %q; want %q", e.Code, code)
}
if !strings.Contains(e.Message, msg) {
errorf("Message = %q; want to contain %q", e.Message, msg)
}
}

View File

@ -0,0 +1 @@
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}

View File

@ -0,0 +1 @@
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}

View File

@ -1,12 +1,40 @@
package testutil
import (
"bytes"
"io"
"log/slog"
"os"
"path/filepath"
"testing"
"time"
)
// LogWriter returns an [io.Writer] that logs each Write using t.Log.
func LogWriter(t *testing.T) io.Writer {
return testWriter{t}
}
type testWriter struct{ t *testing.T }
func (w testWriter) Write(b []byte) (int, error) {
w.t.Logf("%s", b)
return len(b), nil
}
// Slogger returns a [*slog.Logger] that writes each message
// using t.Log.
func Slogger(t *testing.T) *slog.Logger {
return slog.New(slog.NewTextHandler(LogWriter(t), nil))
}
// SlogBuffer returns a [*slog.Logger] that writes each message to out.
func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
var buf bytes.Buffer
lg = slog.New(slog.NewTextHandler(&buf, nil))
return lg, &buf
}
// Check calls t.Fatal(err) if err is not nil.
func Check(t *testing.T, err error) {
if err != nil {

View File

@ -34,6 +34,9 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@ -1126,7 +1129,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
}
func (s *Server) GenerateRoutes() http.Handler {
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
@ -1165,10 +1168,9 @@ func (s *Server) GenerateRoutes() http.Handler {
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management
// Local model cache management (new implementation is at end of function)
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
@ -1193,7 +1195,15 @@ func (s *Server) GenerateRoutes() http.Handler {
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
return r
// wrap old with new
rs := &registry.Local{
Cache: c,
Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r,
}
return rs, nil
}
func Serve(ln net.Listener) error {
@ -1246,12 +1256,27 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
c, err := ollama.DefaultCache()
if err != nil {
return err
}
rc, err := ollama.DefaultRegistry()
if err != nil {
return err
}
h, err := s.GenerateRoutes(c, rc)
if err != nil {
return err
}
http.Handle("/", h)
ctx, done := context.WithCancel(context.Background())
schedCtx, schedDone := context.WithCancel(ctx)
sched := InitScheduler(schedCtx)
s := &Server{addr: ln.Addr(), sched: sched}
http.Handle("/", s.GenerateRoutes())
s.sched = sched
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
srvr := &http.Server{

View File

@ -23,6 +23,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@ -91,7 +93,15 @@ func equalStringSlices(a, b []string) bool {
return true
}
func Test_Routes(t *testing.T) {
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
panic("unexpected RoundTrip call")
}
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
func TestRoutes(t *testing.T) {
type testCase struct {
Name string
Method string
@ -241,10 +251,10 @@ func Test_Routes(t *testing.T) {
Method: http.MethodDelete,
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "model-to-delete")
createTestModel(t, "model_to_delete")
deleteReq := api.DeleteRequest{
Name: "model-to-delete",
Name: "model_to_delete",
}
jsonData, err := json.Marshal(deleteReq)
if err != nil {
@ -271,7 +281,7 @@ func Test_Routes(t *testing.T) {
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
deleteReq := api.DeleteRequest{
Name: "non-existent-model",
Name: "non_existent_model",
}
jsonData, err := json.Marshal(deleteReq)
if err != nil {
@ -477,10 +487,34 @@ func Test_Routes(t *testing.T) {
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())
modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir)
c, err := blob.Open(modelsDir)
if err != nil {
t.Fatalf("failed to open models dir: %v", err)
}
rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended
// to.
//
// Currently, this only handles DELETE /api/delete, which
// should not make any contact with the ollama.com registry, so
// be clear about that.
//
// Tests that do need to contact the registry here, will be
// consumed into our new server/api code packages and removed
// from here.
HTTPClient: panicOnRoundTrip,
}
s := &Server{}
router := s.GenerateRoutes()
router, err := s.GenerateRoutes(c, rc)
if err != nil {
t.Fatalf("failed to generate routes: %v", err)
}
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)