server/internal/client/ollama: hold DiskCache on Registry (#9463)
Previously, using a Registry required a DiskCache to be passed in for use in various methods. This was a bit cumbersome, as the DiskCache is required for most operations, and the DefaultCache is used in most of those cases. This change makes the DiskCache an optional field on the Registry struct. This also changes DefaultCache to initialize on first use. This is to not burden clients with the cost of creating a new cache per use, or having to hold onto a cache for the lifetime of the Registry. Also, slip in some minor docs updates for Trace.
This commit is contained in:
parent
e41c4cbea7
commit
3519dd1c6e
@ -27,6 +27,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -73,19 +74,22 @@ const (
|
|||||||
DefaultMaxChunkSize = 8 << 20
|
DefaultMaxChunkSize = 8 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultCache returns a new disk cache for storing models. If the
|
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||||
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
|
||||||
// otherwise, it uses $HOME/.ollama/models.
|
|
||||||
func DefaultCache() (*blob.DiskCache, error) {
|
|
||||||
dir := os.Getenv("OLLAMA_MODELS")
|
dir := os.Getenv("OLLAMA_MODELS")
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
home, err := os.UserHomeDir()
|
home, _ := os.UserHomeDir()
|
||||||
if err != nil {
|
home = cmp.Or(home, ".")
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dir = filepath.Join(home, ".ollama", "models")
|
dir = filepath.Join(home, ".ollama", "models")
|
||||||
}
|
}
|
||||||
return blob.Open(dir)
|
return blob.Open(dir)
|
||||||
|
})
|
||||||
|
|
||||||
|
// DefaultCache returns the default cache used by the registry. It is
|
||||||
|
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
||||||
|
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
||||||
|
// it uses the current working directory.
|
||||||
|
func DefaultCache() (*blob.DiskCache, error) {
|
||||||
|
return defaultCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||||
@ -168,6 +172,10 @@ func CompleteName(name string) string {
|
|||||||
// Registry is a client for performing push and pull operations against an
|
// Registry is a client for performing push and pull operations against an
|
||||||
// Ollama registry.
|
// Ollama registry.
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
|
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
||||||
|
// used.
|
||||||
|
Cache *blob.DiskCache
|
||||||
|
|
||||||
// UserAgent is the User-Agent header to send with requests to the
|
// UserAgent is the User-Agent header to send with requests to the
|
||||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||||
UserAgent string
|
UserAgent string
|
||||||
@ -206,12 +214,18 @@ type Registry struct {
|
|||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||||
MaxChunkSize int64
|
MaxChunkSize int64
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// names to fully qualified names. If empty, the default mask
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
// ("registry.ollama.ai/library/_:latest") is used.
|
|
||||||
Mask string
|
Mask string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||||
|
if r.Cache != nil {
|
||||||
|
return r.Cache, nil
|
||||||
|
}
|
||||||
|
return defaultCache()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||||
mask := defaultMask
|
mask := defaultMask
|
||||||
if r.Mask != "" {
|
if r.Mask != "" {
|
||||||
@ -282,12 +296,17 @@ type PushParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Push pushes the model with the name in the cache to the remote registry.
|
// Push pushes the model with the name in the cache to the remote registry.
|
||||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
p = &PushParams{}
|
p = &PushParams{}
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -403,7 +422,7 @@ func canRetry(err error) bool {
|
|||||||
// chunks of the specified size, and then reassembled and verified. This is
|
// chunks of the specified size, and then reassembled and verified. This is
|
||||||
// typically slower than splitting the model up across layers, and is mostly
|
// typically slower than splitting the model up across layers, and is mostly
|
||||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||||
scheme, n, _, err := r.parseNameExtended(name)
|
scheme, n, _, err := r.parseNameExtended(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -417,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
|||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
exists := func(l *Layer) bool {
|
||||||
info, err := c.Get(l.Digest)
|
info, err := c.Get(l.Digest)
|
||||||
return err == nil && info.Size == l.Size
|
return err == nil && info.Size == l.Size
|
||||||
@ -554,11 +578,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
|||||||
|
|
||||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||||
// before attempting to unlink the model.
|
// before attempting to unlink the model.
|
||||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||||
n, err := r.parseName(name)
|
n, err := r.parseName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
return c.Unlink(n.String())
|
return c.Unlink(n.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,12 +659,17 @@ type Layer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||||
_, n, d, err := r.parseNameExtended(name)
|
_, n, d, err := r.parseNameExtended(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if !d.IsValid() {
|
if !d.IsValid() {
|
||||||
|
// No digest, so resolve the manifest by name.
|
||||||
d, err = c.Resolve(n.String())
|
d, err = c.Resolve(n.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := &Registry{
|
r := &Registry{
|
||||||
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(h),
|
||||||
},
|
},
|
||||||
@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPushZero(t *testing.T) {
|
func TestPushZero(t *testing.T) {
|
||||||
rc, c := newClient(t, okHandler)
|
rc, _ := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), c, "empty", nil)
|
err := rc.Push(t.Context(), "empty", nil)
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushSingle(t *testing.T) {
|
func TestPushSingle(t *testing.T) {
|
||||||
rc, c := newClient(t, okHandler)
|
rc, _ := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), c, "single", nil)
|
err := rc.Push(t.Context(), "single", nil)
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushMultiple(t *testing.T) {
|
func TestPushMultiple(t *testing.T) {
|
||||||
rc, c := newClient(t, okHandler)
|
rc, _ := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), c, "multiple", nil)
|
err := rc.Push(t.Context(), "multiple", nil)
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushNotFound(t *testing.T) {
|
func TestPushNotFound(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Errorf("unexpected request: %v", r)
|
t.Errorf("unexpected request: %v", r)
|
||||||
})
|
})
|
||||||
err := rc.Push(t.Context(), c, "notfound", nil)
|
err := rc.Push(t.Context(), "notfound", nil)
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushNullLayer(t *testing.T) {
|
func TestPushNullLayer(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
err := rc.Push(t.Context(), c, "null", nil)
|
err := rc.Push(t.Context(), "null", nil)
|
||||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushSizeMismatch(t *testing.T) {
|
func TestPushSizeMismatch(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
ctx, _ := withTraceUnexpected(t.Context())
|
ctx, _ := withTraceUnexpected(t.Context())
|
||||||
got := rc.Push(ctx, c, "sizemismatch", nil)
|
got := rc.Push(ctx, "sizemismatch", nil)
|
||||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||||
t.Errorf("err = %v; want size mismatch", got)
|
t.Errorf("err = %v; want size mismatch", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushInvalid(t *testing.T) {
|
func TestPushInvalid(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
err := rc.Push(t.Context(), c, "invalid", nil)
|
err := rc.Push(t.Context(), "invalid", nil)
|
||||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
|
|||||||
|
|
||||||
func TestPushExistsAtRemote(t *testing.T) {
|
func TestPushExistsAtRemote(t *testing.T) {
|
||||||
var pushed bool
|
var pushed bool
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||||
if !pushed {
|
if !pushed {
|
||||||
// First push. Return an uploadURL.
|
// First push. Return an uploadURL.
|
||||||
@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
|||||||
|
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
err := rc.Push(ctx, c, "single", nil)
|
err := rc.Push(ctx, "single", nil)
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
if !errors.Is(errors.Join(errs...), nil) {
|
if !errors.Is(errors.Join(errs...), nil) {
|
||||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rc.Push(ctx, c, "single", nil)
|
err = rc.Push(ctx, "single", nil)
|
||||||
check(err)
|
check(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushRemoteError(t *testing.T) {
|
func TestPushRemoteError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
w.WriteHeader(500)
|
w.WriteHeader(500)
|
||||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), c, "single", nil)
|
got := rc.Push(t.Context(), "single", nil)
|
||||||
checkErrCode(t, got, 500, "blob_error")
|
checkErrCode(t, got, 500, "blob_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushLocationError(t *testing.T) {
|
func TestPushLocationError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Location", ":///x")
|
w.Header().Set("Location", ":///x")
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), c, "single", nil)
|
got := rc.Push(t.Context(), "single", nil)
|
||||||
wantContains := "invalid upload URL"
|
wantContains := "invalid upload URL"
|
||||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||||
@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPushUploadRoundtripError(t *testing.T) {
|
func TestPushUploadRoundtripError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Host == "blob.store" {
|
if r.Host == "blob.store" {
|
||||||
w.WriteHeader(499) // force RoundTrip error on upload
|
w.WriteHeader(499) // force RoundTrip error on upload
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), c, "single", nil)
|
got := rc.Push(t.Context(), "single", nil)
|
||||||
if !errors.Is(got, errRoundTrip) {
|
if !errors.Is(got, errRoundTrip) {
|
||||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||||
}
|
}
|
||||||
@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
|||||||
os.Remove(c.GetFile(l.Digest))
|
os.Remove(c.GetFile(l.Digest))
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
got := rc.Push(ctx, c, "single", nil)
|
got := rc.Push(ctx, "single", nil)
|
||||||
if !errors.Is(got, fs.ErrNotExist) {
|
if !errors.Is(got, fs.ErrNotExist) {
|
||||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushCommitRoundtripError(t *testing.T) {
|
func TestPushCommitRoundtripError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
w.WriteHeader(499) // force RoundTrip error
|
||||||
})
|
})
|
||||||
err := rc.Push(t.Context(), c, "zero", nil)
|
err := rc.Push(t.Context(), "zero", nil)
|
||||||
if !errors.Is(err, errRoundTrip) {
|
if !errors.Is(err, errRoundTrip) {
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||||
}
|
}
|
||||||
@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullInvalidName(t *testing.T) {
|
func TestRegistryPullInvalidName(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
err := rc.Pull(t.Context(), c, "://")
|
err := rc.Pull(t.Context(), "://")
|
||||||
if !errors.Is(err, ErrNameInvalid) {
|
if !errors.Is(err, ErrNameInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||||
}
|
}
|
||||||
@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, resp := range cases {
|
for _, resp := range cases {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
io.WriteString(w, resp)
|
io.WriteString(w, resp)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "x")
|
err := rc.Pull(t.Context(), "x")
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Confirm that the layer does not exist locally
|
// Confirm that the layer does not exist locally
|
||||||
_, err := rc.ResolveLocal(c, "model")
|
_, err := rc.ResolveLocal("model")
|
||||||
checkNotExist(t, err)
|
checkNotExist(t, err)
|
||||||
|
|
||||||
_, err = c.Get(d)
|
_, err = c.Get(d)
|
||||||
checkNotExist(t, err)
|
checkNotExist(t, err)
|
||||||
|
|
||||||
err = rc.Pull(t.Context(), c, "model")
|
err = rc.Pull(t.Context(), "model")
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "model")
|
mw, err := rc.Resolve(t.Context(), "model")
|
||||||
check(err)
|
check(err)
|
||||||
mg, err := rc.ResolveLocal(c, "model")
|
mg, err := rc.ResolveLocal("model")
|
||||||
check(err)
|
check(err)
|
||||||
if !reflect.DeepEqual(mw, mg) {
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
|||||||
|
|
||||||
func TestRegistryPullCached(t *testing.T) {
|
func TestRegistryPullCached(t *testing.T) {
|
||||||
cached := blob.DigestFromBytes("exists")
|
cached := blob.DigestFromBytes("exists")
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
w.WriteHeader(499) // should not be called
|
w.WriteHeader(499) // should not be called
|
||||||
return
|
return
|
||||||
@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := rc.Pull(ctx, c, "single")
|
err := rc.Pull(ctx, "single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{6}
|
want := []int64{6}
|
||||||
@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "notfound")
|
err := rc.Pull(t.Context(), "notfound")
|
||||||
checkErrCode(t, err, 404, "")
|
checkErrCode(t, err, 404, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "single")
|
err := rc.Pull(t.Context(), "single")
|
||||||
checkErrCode(t, err, 500, "an_error")
|
checkErrCode(t, err, 500, "an_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
w.WriteHeader(499) // force RoundTrip error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "single")
|
err := rc.Pull(t.Context(), "single")
|
||||||
if !errors.Is(err, errRoundTrip) {
|
if !errors.Is(err, errRoundTrip) {
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||||
}
|
}
|
||||||
@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
|
|
||||||
// Check that we pull all layers that we can.
|
// Check that we pull all layers that we can.
|
||||||
|
|
||||||
err := rc.Pull(ctx, c, "mixed")
|
err := rc.Pull(ctx, "mixed")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullChunking(t *testing.T) {
|
func TestRegistryPullChunking(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||||
if r.URL.Host != "blob.store" {
|
if r.URL.Host != "blob.store" {
|
||||||
// The production registry redirects to the blob store.
|
// The production registry redirects to the blob store.
|
||||||
@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
err := rc.Pull(ctx, c, "remote")
|
err := rc.Pull(ctx, "remote")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{0, 3, 6}
|
want := []int64{0, 3, 6}
|
||||||
@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) {
|
|||||||
|
|
||||||
func TestUnlink(t *testing.T) {
|
func TestUnlink(t *testing.T) {
|
||||||
t.Run("found by name", func(t *testing.T) {
|
t.Run("found by name", func(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
|
|
||||||
// confirm linked
|
// confirm linked
|
||||||
_, err := rc.ResolveLocal(c, "single")
|
_, err := rc.ResolveLocal("single")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlink
|
// unlink
|
||||||
_, err = rc.Unlink(c, "single")
|
_, err = rc.Unlink("single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
// confirm unlinked
|
// confirm unlinked
|
||||||
_, err = rc.ResolveLocal(c, "single")
|
_, err = rc.ResolveLocal("single")
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("not found by name", func(t *testing.T) {
|
t.Run("not found by name", func(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
ok, err := rc.Unlink(c, "manifestNotFound")
|
ok, err := rc.Unlink("manifestNotFound")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,9 @@ import (
|
|||||||
|
|
||||||
// Trace is a set of functions that are called to report progress during blob
|
// Trace is a set of functions that are called to report progress during blob
|
||||||
// downloads and uploads.
|
// downloads and uploads.
|
||||||
|
//
|
||||||
|
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||||
|
// and [Registry.Pull].
|
||||||
type Trace struct {
|
type Trace struct {
|
||||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||||
// report the progress of blob uploads and downloads.
|
// report the progress of blob uploads and downloads.
|
||||||
|
@ -63,25 +63,28 @@ func main() {
|
|||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
c, err := ollama.DefaultCache()
|
ctx := context.Background()
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
err := func() error {
|
||||||
|
switch cmd := flag.Arg(0); cmd {
|
||||||
|
case "pull":
|
||||||
rc, err := ollama.DefaultRegistry()
|
rc, err := ollama.DefaultRegistry()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
return cmdPull(ctx, rc)
|
||||||
|
|
||||||
err = func() error {
|
|
||||||
switch cmd := flag.Arg(0); cmd {
|
|
||||||
case "pull":
|
|
||||||
return cmdPull(ctx, rc, c)
|
|
||||||
case "push":
|
case "push":
|
||||||
return cmdPush(ctx, rc, c)
|
rc, err := ollama.DefaultRegistry()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return cmdPush(ctx, rc)
|
||||||
case "import":
|
case "import":
|
||||||
|
c, err := ollama.DefaultCache()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
return cmdImport(ctx, c)
|
return cmdImport(ctx, c)
|
||||||
default:
|
default:
|
||||||
if cmd == "" {
|
if cmd == "" {
|
||||||
@ -99,7 +102,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||||
model := flag.Arg(1)
|
model := flag.Arg(1)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
|
|
||||||
errc := make(chan error)
|
errc := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
errc <- rc.Pull(ctx, c, model)
|
errc <- rc.Pull(ctx, model)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
t := time.NewTicker(time.Second)
|
t := time.NewTicker(time.Second)
|
||||||
@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||||
args := flag.Args()[1:]
|
args := flag.Args()[1:]
|
||||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||||
@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
from := cmp.Or(*flagFrom, model)
|
from := cmp.Or(*flagFrom, model)
|
||||||
m, err := rc.ResolveLocal(c, from)
|
m, err := rc.ResolveLocal(from)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return rc.Push(ctx, c, model, &ollama.PushParams{
|
return rc.Push(ctx, model, &ollama.PushParams{
|
||||||
From: from,
|
From: from,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,7 +26,6 @@ import (
|
|||||||
// directly to the blob disk cache.
|
// directly to the blob disk cache.
|
||||||
type Local struct {
|
type Local struct {
|
||||||
Client *ollama.Registry // required
|
Client *ollama.Registry // required
|
||||||
Cache *blob.DiskCache // required
|
|
||||||
Logger *slog.Logger // required
|
Logger *slog.Logger // required
|
||||||
|
|
||||||
// Fallback, if set, is used to handle requests that are not handled by
|
// Fallback, if set, is used to handle requests that are not handled by
|
||||||
@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ok, err := s.Client.Unlink(s.Cache, p.model())
|
ok, err := s.Client.Unlink(p.model())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
rc := &ollama.Registry{
|
rc := &ollama.Registry{
|
||||||
|
Cache: c,
|
||||||
HTTPClient: panicOnRoundTrip,
|
HTTPClient: panicOnRoundTrip,
|
||||||
}
|
}
|
||||||
l := &Local{
|
l := &Local{
|
||||||
Cache: c,
|
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: testutil.Slogger(t),
|
Logger: testutil.Slogger(t),
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
|
|||||||
|
|
||||||
s := newTestServer(t)
|
s := newTestServer(t)
|
||||||
|
|
||||||
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
_, err := s.Client.ResolveLocal("smol")
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||||
@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
|
|||||||
t.Fatalf("Code = %d; want 200", got.Code)
|
t.Fatalf("Code = %d; want 200", got.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
_, err = s.Client.ResolveLocal("smol")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected smol to have been deleted")
|
t.Fatal("expected smol to have been deleted")
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,6 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/model/models/mllama"
|
"github.com/ollama/ollama/model/models/mllama"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||||
corsConfig := cors.DefaultConfig()
|
corsConfig := cors.DefaultConfig()
|
||||||
corsConfig.AllowWildcard = true
|
corsConfig.AllowWildcard = true
|
||||||
corsConfig.AllowBrowserExtensions = true
|
corsConfig.AllowBrowserExtensions = true
|
||||||
@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
|
|||||||
|
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
rs := ®istry.Local{
|
rs := ®istry.Local{
|
||||||
Cache: c,
|
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||||
Fallback: r,
|
Fallback: r,
|
||||||
@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
|
|||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr()}
|
||||||
|
|
||||||
c, err := ollama.DefaultCache()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rc, err := ollama.DefaultRegistry()
|
rc, err := ollama.DefaultRegistry()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := s.GenerateRoutes(c, rc)
|
h, err := s.GenerateRoutes(rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,6 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
|
|||||||
modelsDir := t.TempDir()
|
modelsDir := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||||
|
|
||||||
c, err := blob.Open(modelsDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to open models dir: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := &ollama.Registry{
|
rc := &ollama.Registry{
|
||||||
// This is a temporary measure to allow us to move forward,
|
// This is a temporary measure to allow us to move forward,
|
||||||
// surfacing any code contacting ollama.com we do not intended
|
// surfacing any code contacting ollama.com we do not intended
|
||||||
@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{}
|
s := &Server{}
|
||||||
router, err := s.GenerateRoutes(c, rc)
|
router, err := s.GenerateRoutes(rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to generate routes: %v", err)
|
t.Fatalf("failed to generate routes: %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user