Compare commits

...

1 Commits

Author SHA1 Message Date
Blake Mizerany
b48b6f85cd server/internal/client/ollama: hold DiskCache on Registry
Previously, clients of a Registry had to carry around a DiskCache to use
it. This change makes the DiskCache an optional field on the Registry
struct.

This also changes DefaultCache to initialize one on first use. This
prevents overhead of building the cache if it is never used, or per
Registry request that involves use of DefaultCache.

Also, slip in some minor docs on Trace.
2025-03-02 15:43:24 -08:00
8 changed files with 136 additions and 105 deletions

View File

@ -27,6 +27,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -73,19 +74,22 @@ const (
DefaultMaxChunkSize = 8 << 20 DefaultMaxChunkSize = 8 << 20
) )
// DefaultCache returns a new disk cache for storing models. If the var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
// OLLAMA_MODELS environment variable is set, it uses that directory;
// otherwise, it uses $HOME/.ollama/models.
func DefaultCache() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS") dir := os.Getenv("OLLAMA_MODELS")
if dir == "" { if dir == "" {
home, err := os.UserHomeDir() home, _ := os.UserHomeDir()
if err != nil { home = cmp.Or(home, ".")
return nil, err
}
dir = filepath.Join(home, ".ollama", "models") dir = filepath.Join(home, ".ollama", "models")
} }
return blob.Open(dir) return blob.Open(dir)
})
// DefaultCache returns the default cache used by the registry. It is
// configured from the OLLAMA_MODELS environment variable, or defaults to
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
// it uses the current working directory.
func DefaultCache() (*blob.DiskCache, error) {
return defaultCache()
} }
// Error is the standard error returned by Ollama APIs. It can represent a // Error is the standard error returned by Ollama APIs. It can represent a
@ -168,6 +172,10 @@ func CompleteName(name string) string {
// Registry is a client for performing push and pull operations against an // Registry is a client for performing push and pull operations against an
// Ollama registry. // Ollama registry.
type Registry struct { type Registry struct {
// Cache is the cache used to store models. If nil, [DefaultCache] is
// used.
Cache *blob.DiskCache
// UserAgent is the User-Agent header to send with requests to the // UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient. // registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string UserAgent string
@ -206,12 +214,18 @@ type Registry struct {
// It is only used when a layer is larger than [MaxChunkingThreshold]. // It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64 MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified // Mask, if set, is the name used to convert non-fully qualified names
// names to fully qualified names. If empty, the default mask // to fully qualified names. If empty, [DefaultMask] is used.
// ("registry.ollama.ai/library/_:latest") is used.
Mask string Mask string
} }
func (r *Registry) cache() (*blob.DiskCache, error) {
if r.Cache != nil {
return r.Cache, nil
}
return defaultCache()
}
func (r *Registry) parseName(name string) (names.Name, error) { func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask mask := defaultMask
if r.Mask != "" { if r.Mask != "" {
@ -241,6 +255,10 @@ func DefaultRegistry() (*Registry, error) {
} }
var rc Registry var rc Registry
rc.Cache, err = defaultCache()
if err != nil {
return nil, err
}
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
if err != nil { if err != nil {
return nil, err return nil, err
@ -282,12 +300,17 @@ type PushParams struct {
} }
// Push pushes the model with the name in the cache to the remote registry. // Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error { func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
if p == nil { if p == nil {
p = &PushParams{} p = &PushParams{}
} }
m, err := r.ResolveLocal(c, cmp.Or(p.From, name)) c, err := r.cache()
if err != nil {
return err
}
m, err := r.ResolveLocal(cmp.Or(p.From, name))
if err != nil { if err != nil {
return err return err
} }
@ -403,7 +426,7 @@ func canRetry(err error) bool {
// chunks of the specified size, and then reassembled and verified. This is // chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly // typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image". // utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error { func (r *Registry) Pull(ctx context.Context, name string) error {
scheme, n, _, err := r.parseNameExtended(name) scheme, n, _, err := r.parseNameExtended(name)
if err != nil { if err != nil {
return err return err
@ -417,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
return fmt.Errorf("%w: no layers", ErrManifestInvalid) return fmt.Errorf("%w: no layers", ErrManifestInvalid)
} }
c, err := r.cache()
if err != nil {
return err
}
exists := func(l *Layer) bool { exists := func(l *Layer) bool {
info, err := c.Get(l.Digest) info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size return err == nil && info.Size == l.Size
@ -554,11 +582,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model. // before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) { func (r *Registry) Unlink(name string) (ok bool, _ error) {
n, err := r.parseName(name) n, err := r.parseName(name)
if err != nil { if err != nil {
return false, err return false, err
} }
c, err := r.cache()
if err != nil {
return false, err
}
return c.Unlink(n.String()) return c.Unlink(n.String())
} }
@ -631,12 +663,17 @@ type Layer struct {
} }
// ResolveLocal resolves a name to a Manifest in the local cache. // ResolveLocal resolves a name to a Manifest in the local cache.
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) { func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
_, n, d, err := r.parseNameExtended(name) _, n, d, err := r.parseNameExtended(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, err := r.cache()
if err != nil {
return nil, err
}
if !d.IsValid() { if !d.IsValid() {
// No digest, so resolve the manifest by name.
d, err = c.Resolve(n.String()) d, err = c.Resolve(n.String())
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// To simulate a network error, pass a handler that returns a 499 status code. // To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper() t.Helper()
c, err := blob.Open(t.TempDir()) c, err := blob.Open(t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
} }
r := &Registry{ r := &Registry{
Cache: c,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Transport: recordRoundTripper(h), Transport: recordRoundTripper(h),
}, },
@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
} }
func TestPushZero(t *testing.T) { func TestPushZero(t *testing.T) {
rc, c := newClient(t, okHandler) rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "empty", nil) err := rc.Push(t.Context(), "empty", nil)
if !errors.Is(err, ErrManifestInvalid) { if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid) t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
} }
} }
func TestPushSingle(t *testing.T) { func TestPushSingle(t *testing.T) {
rc, c := newClient(t, okHandler) rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "single", nil) err := rc.Push(t.Context(), "single", nil)
testutil.Check(t, err) testutil.Check(t, err)
} }
func TestPushMultiple(t *testing.T) { func TestPushMultiple(t *testing.T) {
rc, c := newClient(t, okHandler) rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "multiple", nil) err := rc.Push(t.Context(), "multiple", nil)
testutil.Check(t, err) testutil.Check(t, err)
} }
func TestPushNotFound(t *testing.T) { func TestPushNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r) t.Errorf("unexpected request: %v", r)
}) })
err := rc.Push(t.Context(), c, "notfound", nil) err := rc.Push(t.Context(), "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) { if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist) t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
} }
} }
func TestPushNullLayer(t *testing.T) { func TestPushNullLayer(t *testing.T) {
rc, c := newClient(t, nil) rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), c, "null", nil) err := rc.Push(t.Context(), "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") { if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err) t.Errorf("err = %v; want invalid manifest", err)
} }
} }
func TestPushSizeMismatch(t *testing.T) { func TestPushSizeMismatch(t *testing.T) {
rc, c := newClient(t, nil) rc, _ := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context()) ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, c, "sizemismatch", nil) got := rc.Push(ctx, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") { if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got) t.Errorf("err = %v; want size mismatch", got)
} }
} }
func TestPushInvalid(t *testing.T) { func TestPushInvalid(t *testing.T) {
rc, c := newClient(t, nil) rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), c, "invalid", nil) err := rc.Push(t.Context(), "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") { if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err) t.Errorf("err = %v; want invalid manifest", err)
} }
@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
func TestPushExistsAtRemote(t *testing.T) { func TestPushExistsAtRemote(t *testing.T) {
var pushed bool var pushed bool
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") { if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed { if !pushed {
// First push. Return an uploadURL. // First push. Return an uploadURL.
@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
check := testutil.Checker(t) check := testutil.Checker(t)
err := rc.Push(ctx, c, "single", nil) err := rc.Push(ctx, "single", nil)
check(err) check(err)
if !errors.Is(errors.Join(errs...), nil) { if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
} }
err = rc.Push(ctx, c, "single", nil) err = rc.Push(ctx, "single", nil)
check(err) check(err)
} }
func TestPushRemoteError(t *testing.T) { func TestPushRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") { if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500) w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return return
} }
}) })
got := rc.Push(t.Context(), c, "single", nil) got := rc.Push(t.Context(), "single", nil)
checkErrCode(t, got, 500, "blob_error") checkErrCode(t, got, 500, "blob_error")
} }
func TestPushLocationError(t *testing.T) { func TestPushLocationError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x") w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
}) })
got := rc.Push(t.Context(), c, "single", nil) got := rc.Push(t.Context(), "single", nil)
wantContains := "invalid upload URL" wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) { if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains) t.Errorf("err = %v; want to contain %v", got, wantContains)
@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
} }
func TestPushUploadRoundtripError(t *testing.T) { func TestPushUploadRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" { if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload w.WriteHeader(499) // force RoundTrip error on upload
return return
} }
w.Header().Set("Location", "http://blob.store/blobs/123") w.Header().Set("Location", "http://blob.store/blobs/123")
}) })
got := rc.Push(t.Context(), c, "single", nil) got := rc.Push(t.Context(), "single", nil)
if !errors.Is(got, errRoundTrip) { if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip) t.Errorf("got = %v; want %v", got, errRoundTrip)
} }
@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
os.Remove(c.GetFile(l.Digest)) os.Remove(c.GetFile(l.Digest))
}, },
}) })
got := rc.Push(ctx, c, "single", nil) got := rc.Push(ctx, "single", nil)
if !errors.Is(got, fs.ErrNotExist) { if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got) t.Errorf("got = %v; want fs.ErrNotExist", got)
} }
} }
func TestPushCommitRoundtripError(t *testing.T) { func TestPushCommitRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") { if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected") panic("unexpected")
} }
w.WriteHeader(499) // force RoundTrip error w.WriteHeader(499) // force RoundTrip error
}) })
err := rc.Push(t.Context(), c, "zero", nil) err := rc.Push(t.Context(), "zero", nil)
if !errors.Is(err, errRoundTrip) { if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip) t.Errorf("err = %v; want %v", err, errRoundTrip)
} }
@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
} }
func TestRegistryPullInvalidName(t *testing.T) { func TestRegistryPullInvalidName(t *testing.T) {
rc, c := newClient(t, nil) rc, _ := newClient(t, nil)
err := rc.Pull(t.Context(), c, "://") err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) { if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid) t.Errorf("err = %v; want %v", err, ErrNameInvalid)
} }
@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
} }
for _, resp := range cases { for _, resp := range cases {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp) io.WriteString(w, resp)
}) })
err := rc.Pull(t.Context(), c, "x") err := rc.Pull(t.Context(), "x")
if !errors.Is(err, ErrManifestInvalid) { if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err) t.Errorf("err = %v; want invalid manifest", err)
} }
@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
}) })
// Confirm that the layer does not exist locally // Confirm that the layer does not exist locally
_, err := rc.ResolveLocal(c, "model") _, err := rc.ResolveLocal("model")
checkNotExist(t, err) checkNotExist(t, err)
_, err = c.Get(d) _, err = c.Get(d)
checkNotExist(t, err) checkNotExist(t, err)
err = rc.Pull(t.Context(), c, "model") err = rc.Pull(t.Context(), "model")
check(err) check(err)
mw, err := rc.Resolve(t.Context(), "model") mw, err := rc.Resolve(t.Context(), "model")
check(err) check(err)
mg, err := rc.ResolveLocal(c, "model") mg, err := rc.ResolveLocal("model")
check(err) check(err)
if !reflect.DeepEqual(mw, mg) { if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg) t.Errorf("mw = %v; mg = %v", mw, mg)
@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
func TestRegistryPullCached(t *testing.T) { func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists") cached := blob.DigestFromBytes("exists")
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") { if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called w.WriteHeader(499) // should not be called
return return
@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
err := rc.Pull(ctx, c, "single") err := rc.Pull(ctx, "single")
testutil.Check(t, err) testutil.Check(t, err)
want := []int64{6} want := []int64{6}
@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
} }
func TestRegistryPullManifestNotFound(t *testing.T) { func TestRegistryPullManifestNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
}) })
err := rc.Pull(t.Context(), c, "notfound") err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "") checkErrCode(t, err, 404, "")
} }
func TestRegistryPullResolveRemoteError(t *testing.T) { func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
}) })
err := rc.Pull(t.Context(), c, "single") err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error") checkErrCode(t, err, 500, "an_error")
} }
func TestRegistryPullResolveRoundtripError(t *testing.T) { func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") { if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error w.WriteHeader(499) // force RoundTrip error
return return
} }
}) })
err := rc.Pull(t.Context(), c, "single") err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) { if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip) t.Errorf("err = %v; want %v", err, errRoundTrip)
} }
@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
// Check that we pull all layers that we can. // Check that we pull all layers that we can.
err := rc.Pull(ctx, c, "mixed") err := rc.Pull(ctx, "mixed")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
} }
func TestRegistryPullChunking(t *testing.T) { func TestRegistryPullChunking(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" { if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store. // The production registry redirects to the blob store.
@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
}, },
}) })
err := rc.Pull(ctx, c, "remote") err := rc.Pull(ctx, "remote")
testutil.Check(t, err) testutil.Check(t, err)
want := []int64{0, 3, 6} want := []int64{0, 3, 6}
@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) {
func TestUnlink(t *testing.T) { func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) { t.Run("found by name", func(t *testing.T) {
rc, c := newClient(t, nil) rc, _ := newClient(t, nil)
// confirm linked // confirm linked
_, err := rc.ResolveLocal(c, "single") _, err := rc.ResolveLocal("single")
if err != nil { if err != nil {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
// unlink // unlink
_, err = rc.Unlink(c, "single") _, err = rc.Unlink("single")
testutil.Check(t, err) testutil.Check(t, err)
// confirm unlinked // confirm unlinked
_, err = rc.ResolveLocal(c, "single") _, err = rc.ResolveLocal("single")
if !errors.Is(err, fs.ErrNotExist) { if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err) t.Errorf("err = %v; want fs.ErrNotExist", err)
} }
}) })
t.Run("not found by name", func(t *testing.T) { t.Run("not found by name", func(t *testing.T) {
rc, c := newClient(t, nil) rc, _ := newClient(t, nil)
ok, err := rc.Unlink(c, "manifestNotFound") ok, err := rc.Unlink("manifestNotFound")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -6,6 +6,9 @@ import (
// Trace is a set of functions that are called to report progress during blob // Trace is a set of functions that are called to report progress during blob
// downloads and uploads. // downloads and uploads.
//
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
// and [Registry.Pull].
type Trace struct { type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to // Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads. // report the progress of blob uploads and downloads.

View File

@ -63,25 +63,28 @@ func main() {
} }
flag.Parse() flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
ctx := context.Background() ctx := context.Background()
err = func() error { err := func() error {
switch cmd := flag.Arg(0); cmd { switch cmd := flag.Arg(0); cmd {
case "pull": case "pull":
return cmdPull(ctx, rc, c) rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
return cmdPull(ctx, rc)
case "push": case "push":
return cmdPush(ctx, rc, c) rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
return cmdPush(ctx, rc)
case "import": case "import":
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
return cmdImport(ctx, c) return cmdImport(ctx, c)
default: default:
if cmd == "" { if cmd == "" {
@ -99,7 +102,7 @@ func main() {
} }
} }
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { func cmdPull(ctx context.Context, rc *ollama.Registry) error {
model := flag.Arg(1) model := flag.Arg(1)
if model == "" { if model == "" {
flag.Usage() flag.Usage()
@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
errc := make(chan error) errc := make(chan error)
go func() { go func() {
errc <- rc.Pull(ctx, c, model) errc <- rc.Pull(ctx, model)
}() }()
t := time.NewTicker(time.Second) t := time.NewTicker(time.Second)
@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
} }
} }
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error { func cmdPush(ctx context.Context, rc *ollama.Registry) error {
args := flag.Args()[1:] args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError) flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
} }
from := cmp.Or(*flagFrom, model) from := cmp.Or(*flagFrom, model)
m, err := rc.ResolveLocal(c, from) m, err := rc.ResolveLocal(from)
if err != nil { if err != nil {
return err return err
} }
@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}, },
}) })
return rc.Push(ctx, c, model, &ollama.PushParams{ return rc.Push(ctx, model, &ollama.PushParams{
From: from, From: from,
}) })
} }

View File

@ -11,7 +11,6 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
) )
@ -27,7 +26,6 @@ import (
// directly to the blob disk cache. // directly to the blob disk cache.
type Local struct { type Local struct {
Client *ollama.Registry // required Client *ollama.Registry // required
Cache *blob.DiskCache // required
Logger *slog.Logger // required Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by // Fallback, if set, is used to handle requests that are not handled by
@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if err != nil { if err != nil {
return err return err
} }
ok, err := s.Client.Unlink(s.Cache, p.model()) ok, err := s.Client.Unlink(p.model())
if err != nil { if err != nil {
return err return err
} }

View File

@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
t.Fatal(err) t.Fatal(err)
} }
rc := &ollama.Registry{ rc := &ollama.Registry{
Cache: c,
HTTPClient: panicOnRoundTrip, HTTPClient: panicOnRoundTrip,
} }
l := &Local{ l := &Local{
Cache: c,
Client: rc, Client: rc,
Logger: testutil.Slogger(t), Logger: testutil.Slogger(t),
} }
@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
s := newTestServer(t) s := newTestServer(t)
_, err := s.Client.ResolveLocal(s.Cache, "smol") _, err := s.Client.ResolveLocal("smol")
check(err) check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`) got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
t.Fatalf("Code = %d; want 200", got.Code) t.Fatalf("Code = %d; want 200", got.Code)
} }
_, err = s.Client.ResolveLocal(s.Cache, "smol") _, err = s.Client.ResolveLocal("smol")
if err == nil { if err == nil {
t.Fatal("expected smol to have been deleted") t.Fatal("expected smol to have been deleted")
} }

View File

@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
} }
} }
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) { func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig() corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true corsConfig.AllowBrowserExtensions = true
@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
// wrap old with new // wrap old with new
rs := &registry.Local{ rs := &registry.Local{
Cache: c,
Client: rc, Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r, Fallback: r,
@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()} s := &Server{addr: ln.Addr()}
c, err := ollama.DefaultCache()
if err != nil {
return err
}
rc, err := ollama.DefaultRegistry() rc, err := ollama.DefaultRegistry()
if err != nil { if err != nil {
return err return err
} }
h, err := s.GenerateRoutes(c, rc) h, err := s.GenerateRoutes(rc)
if err != nil { if err != nil {
return err return err
} }

View File

@ -23,7 +23,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
modelsDir := t.TempDir() modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir) t.Setenv("OLLAMA_MODELS", modelsDir)
c, err := blob.Open(modelsDir)
if err != nil {
t.Fatalf("failed to open models dir: %v", err)
}
rc := &ollama.Registry{ rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward, // This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended // surfacing any code contacting ollama.com we do not intended
@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
} }
s := &Server{} s := &Server{}
router, err := s.GenerateRoutes(c, rc) router, err := s.GenerateRoutes(rc)
if err != nil { if err != nil {
t.Fatalf("failed to generate routes: %v", err) t.Fatalf("failed to generate routes: %v", err)
} }