From 92ce438de06c6225ffc004600df7a1ec5a439f8e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 6 May 2025 13:05:01 -0700 Subject: [PATCH] server: remove internal cmd (#10595) --- .../opp/internal/safetensors/safetensors.go | 224 ----------- server/internal/cmd/opp/opp.go | 375 ------------------ 2 files changed, 599 deletions(-) delete mode 100644 server/internal/cmd/opp/internal/safetensors/safetensors.go delete mode 100644 server/internal/cmd/opp/opp.go diff --git a/server/internal/cmd/opp/internal/safetensors/safetensors.go b/server/internal/cmd/opp/internal/safetensors/safetensors.go deleted file mode 100644 index 7a45b91df..000000000 --- a/server/internal/cmd/opp/internal/safetensors/safetensors.go +++ /dev/null @@ -1,224 +0,0 @@ -// safetensors provides a reader for the safetensor directories and files. -package safetensors - -import ( - "encoding/json" - "fmt" - "io" - "io/fs" - "iter" - "slices" - "strconv" - "strings" -) - -// Tensor represents a single tensor in a safetensors file. -// -// It's zero value is not valid. Use [Model.Tensors] to get valid tensors. -// -// It is not safe for use across multiple goroutines. -type Tensor struct { - name string - dataType string - shape []int64 - - fsys fs.FS - fname string // entry name in fsys - offset int64 - size int64 -} - -type Model struct { - fsys fs.FS -} - -func Read(fsys fs.FS) (*Model, error) { - return &Model{fsys: fsys}, nil -} - -func (m *Model) Tensors() iter.Seq2[*Tensor, error] { - return func(yield func(*Tensor, error) bool) { - entries, err := fs.Glob(m.fsys, "*.safetensors") - if err != nil { - yield(nil, err) - return - } - for _, e := range entries { - tt, err := m.readTensors(e) - if err != nil { - yield(nil, err) - return - } - for _, t := range tt { - if !yield(t, nil) { - return - } - } - } - } -} - -func (m *Model) readTensors(fname string) ([]*Tensor, error) { - f, err := m.fsys.Open(fname) - if err != nil { - return nil, err - } - defer f.Close() - - finfo, err := f.Stat() - if err != nil { - return nil, err - } - - headerSize, err := readInt64(f) - if err != nil { - return nil, err - } - - data := make([]byte, headerSize) - _, err = io.ReadFull(f, data) - if err != nil { - return nil, err - } - - var raws map[string]json.RawMessage - if err := json.Unmarshal(data, &raws); err != nil { - return nil, err - } - - endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself - - // TODO(bmizerany): do something with metadata? This could be another - // header read if needed. We also need to figure out if the metadata is - // present in only one .safetensors file or if each file may have their - // own and if it needs to follow each tensor. Currently, I (bmizerany) - // am only seeing them show up with one entry for file type which is - // always "pt". - - tt := make([]*Tensor, 0, len(raws)) - for name, raw := range raws { - if name == "__metadata__" { - // TODO(bmizerany): do something with metadata? - continue - } - var v struct { - DataType string `json:"dtype"` - Shape []int64 `json:"shape"` - Offsets []int64 `json:"data_offsets"` - } - if err := json.Unmarshal(raw, &v); err != nil { - return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err) - } - if len(v.Offsets) != 2 { - return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets) - } - - // TODO(bmizerany): after collecting, validate all offests make - // tensors contiguous? - begin := endOfHeader + v.Offsets[0] - end := endOfHeader + v.Offsets[1] - if err := checkBeginEnd(finfo.Size(), begin, end); err != nil { - return nil, err - } - - // TODO(bmizerany): just yield.. don't be silly and make a slice :) - tt = append(tt, &Tensor{ - name: name, - dataType: v.DataType, - shape: v.Shape, - fsys: m.fsys, - fname: fname, - offset: begin, - size: end - begin, - }) - } - return tt, nil -} - -func checkBeginEnd(size, begin, end int64) error { - if begin < 0 { - return fmt.Errorf("begin must not be negative: %d", begin) - } - if end < 0 { - return fmt.Errorf("end must not be negative: %d", end) - } - if end < begin { - return fmt.Errorf("end must be >= begin: %d < %d", end, begin) - } - if end > size { - return fmt.Errorf("end must be <= size: %d > %d", end, size) - } - return nil -} - -func readInt64(r io.Reader) (int64, error) { - var v uint64 - var buf [8]byte - if _, err := io.ReadFull(r, buf[:]); err != nil { - return 0, err - } - for i := range buf { - v |= uint64(buf[i]) << (8 * i) - } - return int64(v), nil -} - -type Shape []int64 - -func (s Shape) String() string { - var b strings.Builder - b.WriteByte('[') - for i, v := range s { - if i > 0 { - b.WriteByte(',') - } - b.WriteString(strconv.FormatInt(v, 10)) - } - b.WriteByte(']') - return b.String() -} - -func (t *Tensor) Name() string { return t.name } -func (t *Tensor) DataType() string { return t.dataType } -func (t *Tensor) Size() int64 { return t.size } -func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) } - -func (t *Tensor) Reader() (io.ReadCloser, error) { - f, err := t.fsys.Open(t.fname) - if err != nil { - return nil, err - } - r := newSectionReader(f, t.offset, t.size) - rc := struct { - io.Reader - io.Closer - }{r, f} - return rc, nil -} - -// newSectionReader returns a new io.Reader that reads from r starting at -// offset. It is a convenience function for creating a io.SectionReader when r -// may not be an io.ReaderAt. -// -// If r is already a ReaderAt, it is returned directly, otherwise if r is an -// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the -// beginning of the file. -// -// If r is an io.Seeker, -// or slow path. The slow path is used when r does not implement io.ReaderAt, -// in which case it must discard the data it reads. -func newSectionReader(r io.Reader, offset, n int64) io.Reader { - if r, ok := r.(io.ReaderAt); ok { - return io.NewSectionReader(r, offset, n) - } - if r, ok := r.(io.ReadSeeker); ok { - r.Seek(offset, io.SeekStart) - return io.LimitReader(r, n) - } - // Discard to offset and return a limited reader. - _, err := io.CopyN(io.Discard, r, offset) - if err != nil { - return nil - } - return io.LimitReader(r, n) -} diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go deleted file mode 100644 index 6976927c7..000000000 --- a/server/internal/cmd/opp/opp.go +++ /dev/null @@ -1,375 +0,0 @@ -package main - -import ( - "bytes" - "cmp" - "context" - "encoding/json" - "errors" - "flag" - "fmt" - "io" - "log" - "mime" - "net/http" - "os" - "runtime" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/ollama/ollama/server/internal/cache/blob" - "github.com/ollama/ollama/server/internal/client/ollama" - "github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors" - "golang.org/x/sync/errgroup" -) - -var stdout io.Writer = os.Stdout - -const usage = `Opp is a tool for pushing and pulling Ollama models. - -Usage: - - opp [flags] - -Commands: - - push Upload a model to the Ollama server. - pull Download a model from the Ollama server. - import Import a model from a local safetensor directory. - -Examples: - - # Pull a model from the Ollama server. - opp pull library/llama3.2:latest - - # Push a model to the Ollama server. - opp push username/my_model:8b - - # Import a model from a local safetensor directory. - opp import /path/to/safetensor - -Envionment Variables: - - OLLAMA_MODELS - The directory where models are pushed and pulled from - (default ~/.ollama/models). -` - -func main() { - flag.Usage = func() { - fmt.Fprint(os.Stderr, usage) - } - flag.Parse() - - ctx := context.Background() - - err := func() error { - switch cmd := flag.Arg(0); cmd { - case "pull": - rc, err := ollama.DefaultRegistry() - if err != nil { - log.Fatal(err) - } - - return cmdPull(ctx, rc) - case "push": - rc, err := ollama.DefaultRegistry() - if err != nil { - log.Fatal(err) - } - return cmdPush(ctx, rc) - case "import": - c, err := ollama.DefaultCache() - if err != nil { - log.Fatal(err) - } - return cmdImport(ctx, c) - default: - if cmd == "" { - flag.Usage() - } else { - fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd) - } - os.Exit(2) - return errors.New("unreachable") - } - }() - if err != nil { - fmt.Fprintf(os.Stderr, "opp: %v\n", err) - os.Exit(1) - } -} - -func cmdPull(ctx context.Context, rc *ollama.Registry) error { - model := flag.Arg(1) - if model == "" { - flag.Usage() - os.Exit(1) - } - - tr := http.DefaultTransport.(*http.Transport).Clone() - // TODO(bmizerany): configure transport? - rc.HTTPClient = &http.Client{Transport: tr} - - var mu sync.Mutex - p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded] - - var pb bytes.Buffer - printProgress := func() { - pb.Reset() - mu.Lock() - for d, s := range p { - // Write progress to a buffer first to avoid blocking - // on stdout while holding the lock. - stamp := time.Now().Format("2006/01/02 15:04:05") - fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0])) - if s[0] == s[1] { - delete(p, d) - } - } - mu.Unlock() - io.Copy(stdout, &pb) - } - - ctx = ollama.WithTrace(ctx, &ollama.Trace{ - Update: func(l *ollama.Layer, n int64, err error) { - if err != nil && !errors.Is(err, ollama.ErrCached) { - fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err) - return - } - - mu.Lock() - p[l.Digest] = [2]int64{l.Size, n} - mu.Unlock() - }, - }) - - errc := make(chan error) - go func() { - errc <- rc.Pull(ctx, model) - }() - - t := time.NewTicker(time.Second) - defer t.Stop() - for { - select { - case <-t.C: - printProgress() - case err := <-errc: - printProgress() - return err - } - } -} - -func cmdPush(ctx context.Context, rc *ollama.Registry) error { - args := flag.Args()[1:] - flag := flag.NewFlagSet("push", flag.ExitOnError) - flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: opp push \n") - flag.PrintDefaults() - } - flag.Parse(args) - - model := flag.Arg(0) - if model == "" { - return fmt.Errorf("missing model argument") - } - - from := cmp.Or(*flagFrom, model) - m, err := rc.ResolveLocal(from) - if err != nil { - return err - } - - ctx = ollama.WithTrace(ctx, &ollama.Trace{ - Update: func(l *ollama.Layer, n int64, err error) { - switch { - case errors.Is(err, ollama.ErrCached): - fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n) - case err != nil: - fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err) - case n == 0: - l := m.Layer(l.Digest) - mt, p, _ := mime.ParseMediaType(l.MediaType) - mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.") - switch mt { - case "tensor": - fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"]) - default: - fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType) - } - } - }, - }) - - return rc.Push(ctx, model, &ollama.PushParams{ - From: from, - }) -} - -type trackingReader struct { - io.Reader - n *atomic.Int64 -} - -func (r *trackingReader) Read(p []byte) (n int, err error) { - n, err = r.Reader.Read(p) - r.n.Add(int64(n)) - return n, err -} - -func cmdImport(ctx context.Context, c *blob.DiskCache) error { - args := flag.Args()[1:] - flag := flag.NewFlagSet("import", flag.ExitOnError) - flagAs := flag.String("as", "", "Import using the provided name.") - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: opp import \n") - flag.PrintDefaults() - } - flag.Parse(args) - if *flagAs == "" { - return fmt.Errorf("missing -as flag") - } - as := ollama.CompleteName(*flagAs) - - dir := cmp.Or(flag.Arg(0), ".") - fmt.Fprintf(os.Stderr, "Reading %s\n", dir) - - m, err := safetensors.Read(os.DirFS(dir)) - if err != nil { - return err - } - - var total int64 - var tt []*safetensors.Tensor - for t, err := range m.Tensors() { - if err != nil { - return err - } - tt = append(tt, t) - total += t.Size() - } - - var n atomic.Int64 - done := make(chan error) - go func() { - layers := make([]*ollama.Layer, len(tt)) - var g errgroup.Group - g.SetLimit(runtime.GOMAXPROCS(0)) - var ctxErr error - for i, t := range tt { - if ctx.Err() != nil { - // The context may cancel AFTER we exit the - // loop, and so if we use ctx.Err() after the - // loop we may report it as the error that - // broke the loop, when it was not. This can - // manifest as a false-negative, leading the - // user to think their import failed when it - // did not, so capture it if and only if we - // exit the loop because of a ctx.Err() and - // report it. - ctxErr = ctx.Err() - break - } - g.Go(func() (err error) { - rc, err := t.Reader() - if err != nil { - return err - } - defer rc.Close() - tr := &trackingReader{rc, &n} - d, err := c.Import(tr, t.Size()) - if err != nil { - return err - } - if err := rc.Close(); err != nil { - return err - } - - layers[i] = &ollama.Layer{ - Digest: d, - Size: t.Size(), - MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{ - "name": t.Name(), - "dtype": t.DataType(), - "shape": t.Shape().String(), - }), - } - - return nil - }) - } - - done <- func() error { - if err := errors.Join(g.Wait(), ctxErr); err != nil { - return err - } - m := &ollama.Manifest{Layers: layers} - data, err := json.MarshalIndent(m, "", " ") - if err != nil { - return err - } - d := blob.DigestFromBytes(data) - err = blob.PutBytes(c, d, data) - if err != nil { - return err - } - return c.Link(as, d) - }() - }() - - fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir) - - csiHideCursor(stdout) - defer csiShowCursor(stdout) - - csiSavePos(stdout) - writeProgress := func() { - csiRestorePos(stdout) - nn := n.Load() - fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n", - formatNatural(nn), - formatNatural(total), - nn*100/total, - ansiClearToEndOfLine, - ) - } - - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - writeProgress() - case err := <-done: - writeProgress() - fmt.Println() - fmt.Println("Successfully imported", as) - return err - } - } -} - -func formatNatural(n int64) string { - switch { - case n < 1024: - return fmt.Sprintf("%d B", n) - case n < 1024*1024: - return fmt.Sprintf("%.1f KB", float64(n)/1024) - case n < 1024*1024*1024: - return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024)) - default: - return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024)) - } -} - -const ansiClearToEndOfLine = "\033[K" - -func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") } -func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") } -func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") } -func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }