Compare commits
1 Commits
main
...
bmizerany/
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b48b6f85cd |
@ -27,6 +27,7 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -73,19 +74,22 @@ const (
|
|||||||
DefaultMaxChunkSize = 8 << 20
|
DefaultMaxChunkSize = 8 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultCache returns a new disk cache for storing models. If the
|
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||||
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
|
||||||
// otherwise, it uses $HOME/.ollama/models.
|
|
||||||
func DefaultCache() (*blob.DiskCache, error) {
|
|
||||||
dir := os.Getenv("OLLAMA_MODELS")
|
dir := os.Getenv("OLLAMA_MODELS")
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
home, err := os.UserHomeDir()
|
home, _ := os.UserHomeDir()
|
||||||
if err != nil {
|
home = cmp.Or(home, ".")
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dir = filepath.Join(home, ".ollama", "models")
|
dir = filepath.Join(home, ".ollama", "models")
|
||||||
}
|
}
|
||||||
return blob.Open(dir)
|
return blob.Open(dir)
|
||||||
|
})
|
||||||
|
|
||||||
|
// DefaultCache returns the default cache used by the registry. It is
|
||||||
|
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
||||||
|
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
||||||
|
// it uses the current working directory.
|
||||||
|
func DefaultCache() (*blob.DiskCache, error) {
|
||||||
|
return defaultCache()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||||
@ -168,6 +172,10 @@ func CompleteName(name string) string {
|
|||||||
// Registry is a client for performing push and pull operations against an
|
// Registry is a client for performing push and pull operations against an
|
||||||
// Ollama registry.
|
// Ollama registry.
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
|
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
||||||
|
// used.
|
||||||
|
Cache *blob.DiskCache
|
||||||
|
|
||||||
// UserAgent is the User-Agent header to send with requests to the
|
// UserAgent is the User-Agent header to send with requests to the
|
||||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||||
UserAgent string
|
UserAgent string
|
||||||
@ -206,12 +214,18 @@ type Registry struct {
|
|||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||||
MaxChunkSize int64
|
MaxChunkSize int64
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified
|
// Mask, if set, is the name used to convert non-fully qualified names
|
||||||
// names to fully qualified names. If empty, the default mask
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||||
// ("registry.ollama.ai/library/_:latest") is used.
|
|
||||||
Mask string
|
Mask string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||||
|
if r.Cache != nil {
|
||||||
|
return r.Cache, nil
|
||||||
|
}
|
||||||
|
return defaultCache()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||||
mask := defaultMask
|
mask := defaultMask
|
||||||
if r.Mask != "" {
|
if r.Mask != "" {
|
||||||
@ -241,6 +255,10 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rc Registry
|
var rc Registry
|
||||||
|
rc.Cache, err = defaultCache()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -282,12 +300,17 @@ type PushParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Push pushes the model with the name in the cache to the remote registry.
|
// Push pushes the model with the name in the cache to the remote registry.
|
||||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
p = &PushParams{}
|
p = &PushParams{}
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -403,7 +426,7 @@ func canRetry(err error) bool {
|
|||||||
// chunks of the specified size, and then reassembled and verified. This is
|
// chunks of the specified size, and then reassembled and verified. This is
|
||||||
// typically slower than splitting the model up across layers, and is mostly
|
// typically slower than splitting the model up across layers, and is mostly
|
||||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||||
scheme, n, _, err := r.parseNameExtended(name)
|
scheme, n, _, err := r.parseNameExtended(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -417,6 +440,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
|||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
exists := func(l *Layer) bool {
|
||||||
info, err := c.Get(l.Digest)
|
info, err := c.Get(l.Digest)
|
||||||
return err == nil && info.Size == l.Size
|
return err == nil && info.Size == l.Size
|
||||||
@ -554,11 +582,15 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
|||||||
|
|
||||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||||
// before attempting to unlink the model.
|
// before attempting to unlink the model.
|
||||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||||
n, err := r.parseName(name)
|
n, err := r.parseName(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
return c.Unlink(n.String())
|
return c.Unlink(n.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -631,12 +663,17 @@ type Layer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||||
_, n, d, err := r.parseNameExtended(name)
|
_, n, d, err := r.parseNameExtended(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c, err := r.cache()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if !d.IsValid() {
|
if !d.IsValid() {
|
||||||
|
// No digest, so resolve the manifest by name.
|
||||||
d, err = c.Resolve(n.String())
|
d, err = c.Resolve(n.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -73,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -86,6 +87,7 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := &Registry{
|
r := &Registry{
|
||||||
|
Cache: c,
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(h),
|
||||||
},
|
},
|
||||||
@ -152,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPushZero(t *testing.T) {
|
func TestPushZero(t *testing.T) {
|
||||||
rc, c := newClient(t, okHandler)
|
rc, _ := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), c, "empty", nil)
|
err := rc.Push(t.Context(), "empty", nil)
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushSingle(t *testing.T) {
|
func TestPushSingle(t *testing.T) {
|
||||||
rc, c := newClient(t, okHandler)
|
rc, _ := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), c, "single", nil)
|
err := rc.Push(t.Context(), "single", nil)
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushMultiple(t *testing.T) {
|
func TestPushMultiple(t *testing.T) {
|
||||||
rc, c := newClient(t, okHandler)
|
rc, _ := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), c, "multiple", nil)
|
err := rc.Push(t.Context(), "multiple", nil)
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushNotFound(t *testing.T) {
|
func TestPushNotFound(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Errorf("unexpected request: %v", r)
|
t.Errorf("unexpected request: %v", r)
|
||||||
})
|
})
|
||||||
err := rc.Push(t.Context(), c, "notfound", nil)
|
err := rc.Push(t.Context(), "notfound", nil)
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushNullLayer(t *testing.T) {
|
func TestPushNullLayer(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
err := rc.Push(t.Context(), c, "null", nil)
|
err := rc.Push(t.Context(), "null", nil)
|
||||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushSizeMismatch(t *testing.T) {
|
func TestPushSizeMismatch(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
ctx, _ := withTraceUnexpected(t.Context())
|
ctx, _ := withTraceUnexpected(t.Context())
|
||||||
got := rc.Push(ctx, c, "sizemismatch", nil)
|
got := rc.Push(ctx, "sizemismatch", nil)
|
||||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||||
t.Errorf("err = %v; want size mismatch", got)
|
t.Errorf("err = %v; want size mismatch", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushInvalid(t *testing.T) {
|
func TestPushInvalid(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
err := rc.Push(t.Context(), c, "invalid", nil)
|
err := rc.Push(t.Context(), "invalid", nil)
|
||||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
@ -208,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
|
|||||||
|
|
||||||
func TestPushExistsAtRemote(t *testing.T) {
|
func TestPushExistsAtRemote(t *testing.T) {
|
||||||
var pushed bool
|
var pushed bool
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||||
if !pushed {
|
if !pushed {
|
||||||
// First push. Return an uploadURL.
|
// First push. Return an uploadURL.
|
||||||
@ -236,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
|||||||
|
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
err := rc.Push(ctx, c, "single", nil)
|
err := rc.Push(ctx, "single", nil)
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
if !errors.Is(errors.Join(errs...), nil) {
|
if !errors.Is(errors.Join(errs...), nil) {
|
||||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rc.Push(ctx, c, "single", nil)
|
err = rc.Push(ctx, "single", nil)
|
||||||
check(err)
|
check(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushRemoteError(t *testing.T) {
|
func TestPushRemoteError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
w.WriteHeader(500)
|
w.WriteHeader(500)
|
||||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), c, "single", nil)
|
got := rc.Push(t.Context(), "single", nil)
|
||||||
checkErrCode(t, got, 500, "blob_error")
|
checkErrCode(t, got, 500, "blob_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushLocationError(t *testing.T) {
|
func TestPushLocationError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Location", ":///x")
|
w.Header().Set("Location", ":///x")
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), c, "single", nil)
|
got := rc.Push(t.Context(), "single", nil)
|
||||||
wantContains := "invalid upload URL"
|
wantContains := "invalid upload URL"
|
||||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||||
@ -272,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPushUploadRoundtripError(t *testing.T) {
|
func TestPushUploadRoundtripError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Host == "blob.store" {
|
if r.Host == "blob.store" {
|
||||||
w.WriteHeader(499) // force RoundTrip error on upload
|
w.WriteHeader(499) // force RoundTrip error on upload
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), c, "single", nil)
|
got := rc.Push(t.Context(), "single", nil)
|
||||||
if !errors.Is(got, errRoundTrip) {
|
if !errors.Is(got, errRoundTrip) {
|
||||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||||
}
|
}
|
||||||
@ -295,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
|||||||
os.Remove(c.GetFile(l.Digest))
|
os.Remove(c.GetFile(l.Digest))
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
got := rc.Push(ctx, c, "single", nil)
|
got := rc.Push(ctx, "single", nil)
|
||||||
if !errors.Is(got, fs.ErrNotExist) {
|
if !errors.Is(got, fs.ErrNotExist) {
|
||||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushCommitRoundtripError(t *testing.T) {
|
func TestPushCommitRoundtripError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
w.WriteHeader(499) // force RoundTrip error
|
||||||
})
|
})
|
||||||
err := rc.Push(t.Context(), c, "zero", nil)
|
err := rc.Push(t.Context(), "zero", nil)
|
||||||
if !errors.Is(err, errRoundTrip) {
|
if !errors.Is(err, errRoundTrip) {
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||||
}
|
}
|
||||||
@ -322,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullInvalidName(t *testing.T) {
|
func TestRegistryPullInvalidName(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
err := rc.Pull(t.Context(), c, "://")
|
err := rc.Pull(t.Context(), "://")
|
||||||
if !errors.Is(err, ErrNameInvalid) {
|
if !errors.Is(err, ErrNameInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||||
}
|
}
|
||||||
@ -338,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, resp := range cases {
|
for _, resp := range cases {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
io.WriteString(w, resp)
|
io.WriteString(w, resp)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "x")
|
err := rc.Pull(t.Context(), "x")
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
@ -364,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Confirm that the layer does not exist locally
|
// Confirm that the layer does not exist locally
|
||||||
_, err := rc.ResolveLocal(c, "model")
|
_, err := rc.ResolveLocal("model")
|
||||||
checkNotExist(t, err)
|
checkNotExist(t, err)
|
||||||
|
|
||||||
_, err = c.Get(d)
|
_, err = c.Get(d)
|
||||||
checkNotExist(t, err)
|
checkNotExist(t, err)
|
||||||
|
|
||||||
err = rc.Pull(t.Context(), c, "model")
|
err = rc.Pull(t.Context(), "model")
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "model")
|
mw, err := rc.Resolve(t.Context(), "model")
|
||||||
check(err)
|
check(err)
|
||||||
mg, err := rc.ResolveLocal(c, "model")
|
mg, err := rc.ResolveLocal("model")
|
||||||
check(err)
|
check(err)
|
||||||
if !reflect.DeepEqual(mw, mg) {
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
@ -400,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
|||||||
|
|
||||||
func TestRegistryPullCached(t *testing.T) {
|
func TestRegistryPullCached(t *testing.T) {
|
||||||
cached := blob.DigestFromBytes("exists")
|
cached := blob.DigestFromBytes("exists")
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
w.WriteHeader(499) // should not be called
|
w.WriteHeader(499) // should not be called
|
||||||
return
|
return
|
||||||
@ -423,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := rc.Pull(ctx, c, "single")
|
err := rc.Pull(ctx, "single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{6}
|
want := []int64{6}
|
||||||
@ -436,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "notfound")
|
err := rc.Pull(t.Context(), "notfound")
|
||||||
checkErrCode(t, err, 404, "")
|
checkErrCode(t, err, 404, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "single")
|
err := rc.Pull(t.Context(), "single")
|
||||||
checkErrCode(t, err, 500, "an_error")
|
checkErrCode(t, err, 500, "an_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
w.WriteHeader(499) // force RoundTrip error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), c, "single")
|
err := rc.Pull(t.Context(), "single")
|
||||||
if !errors.Is(err, errRoundTrip) {
|
if !errors.Is(err, errRoundTrip) {
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||||
}
|
}
|
||||||
@ -512,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
|
|
||||||
// Check that we pull all layers that we can.
|
// Check that we pull all layers that we can.
|
||||||
|
|
||||||
err := rc.Pull(ctx, c, "mixed")
|
err := rc.Pull(ctx, "mixed")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -530,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullChunking(t *testing.T) {
|
func TestRegistryPullChunking(t *testing.T) {
|
||||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||||
if r.URL.Host != "blob.store" {
|
if r.URL.Host != "blob.store" {
|
||||||
// The production registry redirects to the blob store.
|
// The production registry redirects to the blob store.
|
||||||
@ -568,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
err := rc.Pull(ctx, c, "remote")
|
err := rc.Pull(ctx, "remote")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{0, 3, 6}
|
want := []int64{0, 3, 6}
|
||||||
@ -785,27 +787,27 @@ func TestParseNameExtended(t *testing.T) {
|
|||||||
|
|
||||||
func TestUnlink(t *testing.T) {
|
func TestUnlink(t *testing.T) {
|
||||||
t.Run("found by name", func(t *testing.T) {
|
t.Run("found by name", func(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
|
|
||||||
// confirm linked
|
// confirm linked
|
||||||
_, err := rc.ResolveLocal(c, "single")
|
_, err := rc.ResolveLocal("single")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// unlink
|
// unlink
|
||||||
_, err = rc.Unlink(c, "single")
|
_, err = rc.Unlink("single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
// confirm unlinked
|
// confirm unlinked
|
||||||
_, err = rc.ResolveLocal(c, "single")
|
_, err = rc.ResolveLocal("single")
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("not found by name", func(t *testing.T) {
|
t.Run("not found by name", func(t *testing.T) {
|
||||||
rc, c := newClient(t, nil)
|
rc, _ := newClient(t, nil)
|
||||||
ok, err := rc.Unlink(c, "manifestNotFound")
|
ok, err := rc.Unlink("manifestNotFound")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,9 @@ import (
|
|||||||
|
|
||||||
// Trace is a set of functions that are called to report progress during blob
|
// Trace is a set of functions that are called to report progress during blob
|
||||||
// downloads and uploads.
|
// downloads and uploads.
|
||||||
|
//
|
||||||
|
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||||
|
// and [Registry.Pull].
|
||||||
type Trace struct {
|
type Trace struct {
|
||||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||||
// report the progress of blob uploads and downloads.
|
// report the progress of blob uploads and downloads.
|
||||||
|
@ -63,25 +63,28 @@ func main() {
|
|||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
c, err := ollama.DefaultCache()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := ollama.DefaultRegistry()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
err = func() error {
|
err := func() error {
|
||||||
switch cmd := flag.Arg(0); cmd {
|
switch cmd := flag.Arg(0); cmd {
|
||||||
case "pull":
|
case "pull":
|
||||||
return cmdPull(ctx, rc, c)
|
rc, err := ollama.DefaultRegistry()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmdPull(ctx, rc)
|
||||||
case "push":
|
case "push":
|
||||||
return cmdPush(ctx, rc, c)
|
rc, err := ollama.DefaultRegistry()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return cmdPush(ctx, rc)
|
||||||
case "import":
|
case "import":
|
||||||
|
c, err := ollama.DefaultCache()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
return cmdImport(ctx, c)
|
return cmdImport(ctx, c)
|
||||||
default:
|
default:
|
||||||
if cmd == "" {
|
if cmd == "" {
|
||||||
@ -99,7 +102,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||||
model := flag.Arg(1)
|
model := flag.Arg(1)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
|
|
||||||
errc := make(chan error)
|
errc := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
errc <- rc.Pull(ctx, c, model)
|
errc <- rc.Pull(ctx, model)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
t := time.NewTicker(time.Second)
|
t := time.NewTicker(time.Second)
|
||||||
@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||||
args := flag.Args()[1:]
|
args := flag.Args()[1:]
|
||||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||||
@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
from := cmp.Or(*flagFrom, model)
|
from := cmp.Or(*flagFrom, model)
|
||||||
m, err := rc.ResolveLocal(c, from)
|
m, err := rc.ResolveLocal(from)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return rc.Push(ctx, c, model, &ollama.PushParams{
|
return rc.Push(ctx, model, &ollama.PushParams{
|
||||||
From: from,
|
From: from,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,7 +26,6 @@ import (
|
|||||||
// directly to the blob disk cache.
|
// directly to the blob disk cache.
|
||||||
type Local struct {
|
type Local struct {
|
||||||
Client *ollama.Registry // required
|
Client *ollama.Registry // required
|
||||||
Cache *blob.DiskCache // required
|
|
||||||
Logger *slog.Logger // required
|
Logger *slog.Logger // required
|
||||||
|
|
||||||
// Fallback, if set, is used to handle requests that are not handled by
|
// Fallback, if set, is used to handle requests that are not handled by
|
||||||
@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ok, err := s.Client.Unlink(s.Cache, p.model())
|
ok, err := s.Client.Unlink(p.model())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
rc := &ollama.Registry{
|
rc := &ollama.Registry{
|
||||||
|
Cache: c,
|
||||||
HTTPClient: panicOnRoundTrip,
|
HTTPClient: panicOnRoundTrip,
|
||||||
}
|
}
|
||||||
l := &Local{
|
l := &Local{
|
||||||
Cache: c,
|
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: testutil.Slogger(t),
|
Logger: testutil.Slogger(t),
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
|
|||||||
|
|
||||||
s := newTestServer(t)
|
s := newTestServer(t)
|
||||||
|
|
||||||
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
_, err := s.Client.ResolveLocal("smol")
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||||
@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
|
|||||||
t.Fatalf("Code = %d; want 200", got.Code)
|
t.Fatalf("Code = %d; want 200", got.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
_, err = s.Client.ResolveLocal("smol")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected smol to have been deleted")
|
t.Fatal("expected smol to have been deleted")
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,6 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/model/models/mllama"
|
"github.com/ollama/ollama/model/models/mllama"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||||
corsConfig := cors.DefaultConfig()
|
corsConfig := cors.DefaultConfig()
|
||||||
corsConfig.AllowWildcard = true
|
corsConfig.AllowWildcard = true
|
||||||
corsConfig.AllowBrowserExtensions = true
|
corsConfig.AllowBrowserExtensions = true
|
||||||
@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
|
|||||||
|
|
||||||
// wrap old with new
|
// wrap old with new
|
||||||
rs := ®istry.Local{
|
rs := ®istry.Local{
|
||||||
Cache: c,
|
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||||
Fallback: r,
|
Fallback: r,
|
||||||
@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
|
|||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr()}
|
||||||
|
|
||||||
c, err := ollama.DefaultCache()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rc, err := ollama.DefaultRegistry()
|
rc, err := ollama.DefaultRegistry()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := s.GenerateRoutes(c, rc)
|
h, err := s.GenerateRoutes(rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,6 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
|
|||||||
modelsDir := t.TempDir()
|
modelsDir := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||||
|
|
||||||
c, err := blob.Open(modelsDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to open models dir: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := &ollama.Registry{
|
rc := &ollama.Registry{
|
||||||
// This is a temporary measure to allow us to move forward,
|
// This is a temporary measure to allow us to move forward,
|
||||||
// surfacing any code contacting ollama.com we do not intended
|
// surfacing any code contacting ollama.com we do not intended
|
||||||
@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{}
|
s := &Server{}
|
||||||
router, err := s.GenerateRoutes(c, rc)
|
router, err := s.GenerateRoutes(rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to generate routes: %v", err)
|
t.Fatalf("failed to generate routes: %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user