Compare commits
1 Commits
v0.6.1
...
pdevine/lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73a1e99f8a |
@@ -187,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
|||||||
|
|
||||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||||
|
|
||||||
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
|
||||||
|
|
||||||
```
|
|
||||||
# Allow all Chrome, Firefox, and Safari extensions
|
|
||||||
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
|
||||||
```
|
|
||||||
|
|
||||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||||
|
|
||||||
## Where are models stored?
|
## Where are models stored?
|
||||||
|
|||||||
@@ -583,52 +583,39 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||||
if llm.KV().Uint("vision.block_count") == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, layer := range llm.Tensors().GroupLayers() {
|
|
||||||
if name == "v" || strings.HasPrefix(name, "v.") {
|
|
||||||
for _, tensor := range layer {
|
|
||||||
weights += tensor.Size()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
|
||||||
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
|
||||||
if patchSize == 0 {
|
|
||||||
slog.Warn("unknown patch size for vision model")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
|
||||||
|
|
||||||
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
|
||||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
|
||||||
numPatches++
|
|
||||||
}
|
|
||||||
|
|
||||||
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
|
||||||
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
|
||||||
|
|
||||||
switch llm.KV().Architecture() {
|
switch llm.KV().Architecture() {
|
||||||
case "mllama":
|
case "mllama":
|
||||||
|
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||||
|
weights += layer.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
kv := func(n string) uint64 {
|
||||||
|
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||||
|
return uint64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
imageSize := kv("image_size")
|
||||||
|
|
||||||
|
maxNumTiles := kv("max_num_tiles")
|
||||||
|
embeddingLength := kv("embedding_length")
|
||||||
|
headCount := kv("attention.head_count")
|
||||||
|
|
||||||
|
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||||
|
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||||
|
numPatches++
|
||||||
|
}
|
||||||
|
|
||||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||||
|
|
||||||
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
|
||||||
|
|
||||||
graphSize = 4 * (8 +
|
graphSize = 4 * (8 +
|
||||||
imageSize*imageSize*numChannels*maxNumTiles +
|
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||||
embeddingLength*numPatches*maxNumTiles +
|
embeddingLength*numPatches*maxNumTiles +
|
||||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||||
case "gemma3":
|
|
||||||
graphSize = 4 * (imageSize*imageSize*numChannels +
|
|
||||||
embeddingLength*patchSize +
|
|
||||||
numPatches*numPatches*headCount)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return weights, graphSize
|
return weights, graphSize
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||||
layerSize = blk.Size()
|
layerSize = blk.Size()
|
||||||
layerSize += kv / f.KV().BlockCount()
|
layerSize += kv / f.KV().BlockCount()
|
||||||
memoryWeights += blk.Size()
|
|
||||||
}
|
}
|
||||||
|
memoryWeights += layerSize
|
||||||
|
|
||||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||||
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
|
|||||||
// memory of the weights
|
// memory of the weights
|
||||||
"total", format.HumanBytes2(m.memoryWeights),
|
"total", format.HumanBytes2(m.memoryWeights),
|
||||||
// memory of repeating layers
|
// memory of repeating layers
|
||||||
"repeating", format.HumanBytes2(m.memoryWeights),
|
"repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
|
||||||
// memory of non-repeating layers
|
// memory of non-repeating layers
|
||||||
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
|
||||||
),
|
),
|
||||||
|
|||||||
40
logging/log.go
Normal file
40
logging/log.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
const LevelTrace slog.Level = slog.LevelDebug - 4
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogger() *Logger {
|
||||||
|
handler := slog.NewTextHandler(os.Stdout, nil)
|
||||||
|
return &Logger{
|
||||||
|
logger: slog.New(handler),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Trace(msg string, args ...any) {
|
||||||
|
l.logger.Log(context.Background(), LevelTrace, msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Debug(msg string, args ...any) {
|
||||||
|
l.logger.Debug(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Info(msg string, args ...any) {
|
||||||
|
l.logger.Info(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Warn(msg string, args ...any) {
|
||||||
|
l.logger.Warn(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Error(msg string, args ...any) {
|
||||||
|
l.logger.Error(msg, args...)
|
||||||
|
}
|
||||||
@@ -2,15 +2,18 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
"github.com/dlclark/regexp2"
|
||||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
|
var log = logging.NewLogger()
|
||||||
|
|
||||||
func replaceWhitespaceBySeperator(s string) string {
|
func replaceWhitespaceBySeperator(s string) string {
|
||||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
||||||
}
|
}
|
||||||
@@ -24,7 +27,7 @@ type SentencePieceModel struct {
|
|||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||||
|
|
||||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
var maxTokenLen int
|
var maxTokenLen int
|
||||||
@@ -38,7 +41,7 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||||
"max token len", maxTokenLen)
|
"max token len", maxTokenLen)
|
||||||
|
|
||||||
@@ -91,7 +94,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Debug("fragments", "frags", fragments)
|
log.Trace("fragments", "frags", fragments)
|
||||||
|
|
||||||
var ids []int32
|
var ids []int32
|
||||||
for _, frag := range fragments {
|
for _, frag := range fragments {
|
||||||
@@ -129,7 +132,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("tokenizer", "merges", merges)
|
log.Trace("tokenizer", "merges", merges)
|
||||||
|
|
||||||
pairwise := func(a, b int) *candidate {
|
pairwise := func(a, b int) *candidate {
|
||||||
if a < 0 || b >= len(runes) {
|
if a < 0 || b >= len(runes) {
|
||||||
@@ -156,7 +159,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
pqv := pq.Values()
|
pqv := pq.Values()
|
||||||
for _, v := range pqv {
|
for _, v := range pqv {
|
||||||
e := v.(*candidate)
|
e := v.(*candidate)
|
||||||
slog.Debug("candidate", "candidate", e)
|
log.Trace("candidate", "candidate", e)
|
||||||
}
|
}
|
||||||
|
|
||||||
for !pq.Empty() {
|
for !pq.Empty() {
|
||||||
@@ -164,7 +167,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
pair := v.(*candidate)
|
pair := v.(*candidate)
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
|
||||||
slog.Debug("pair", "left", left, "right", right)
|
log.Trace("pair", "left", left, "right", right)
|
||||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
if len(left.runes) == 0 || len(right.runes) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -189,14 +192,14 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("merges", "merges", merges)
|
log.Trace("merges", "merges", merges)
|
||||||
|
|
||||||
for _, merge := range merges {
|
for _, merge := range merges {
|
||||||
if len(merge.runes) > 0 {
|
if len(merge.runes) > 0 {
|
||||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("missing token", "token", string(merge.runes))
|
log.Error("missing token", "token", string(merge.runes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -206,19 +209,19 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
if spm.vocab.AddBOS {
|
if spm.vocab.AddBOS {
|
||||||
if ids[0] == spm.vocab.BOS {
|
if ids[0] == spm.vocab.BOS {
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
||||||
ids = append([]int32{spm.vocab.BOS}, ids...)
|
ids = append([]int32{spm.vocab.BOS}, ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if spm.vocab.AddEOS {
|
if spm.vocab.AddEOS {
|
||||||
if ids[len(ids)-1] == spm.vocab.EOS {
|
if ids[len(ids)-1] == spm.vocab.EOS {
|
||||||
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
||||||
ids = append(ids, spm.vocab.EOS)
|
ids = append(ids, spm.vocab.EOS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -241,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
log.Debug("decoded", "ids", ids, "text", sb.String())
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
54
server/internal/cache/blob/cache.go
vendored
54
server/internal/cache/blob/cache.go
vendored
@@ -146,7 +146,7 @@ func debugger(err *error) func(step string) {
|
|||||||
// be in either of the following forms:
|
// be in either of the following forms:
|
||||||
//
|
//
|
||||||
// @<digest>
|
// @<digest>
|
||||||
// <name>@<digest>
|
// <name>
|
||||||
// <name>
|
// <name>
|
||||||
//
|
//
|
||||||
// If a digest is provided, it is returned as is and nothing else happens.
|
// If a digest is provided, it is returned as is and nothing else happens.
|
||||||
@@ -160,6 +160,8 @@ func debugger(err *error) func(step string) {
|
|||||||
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
// hashed is passed to a PutBytes call to ensure that the manifest is in the
|
||||||
// blob store. This is done to ensure that future calls to [Get] succeed in
|
// blob store. This is done to ensure that future calls to [Get] succeed in
|
||||||
// these cases.
|
// these cases.
|
||||||
|
//
|
||||||
|
// TODO(bmizerany): Move Links/Resolve/etc. out of this package.
|
||||||
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
func (c *DiskCache) Resolve(name string) (Digest, error) {
|
||||||
name, digest := splitNameDigest(name)
|
name, digest := splitNameDigest(name)
|
||||||
if digest != "" {
|
if digest != "" {
|
||||||
@@ -277,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
|||||||
// It returns an error if either the name or digest is invalid, or if link
|
// It returns an error if either the name or digest is invalid, or if link
|
||||||
// creation encounters any issues.
|
// creation encounters any issues.
|
||||||
func (c *DiskCache) Link(name string, d Digest) error {
|
func (c *DiskCache) Link(name string, d Digest) error {
|
||||||
|
// TODO(bmizerany): Move link handling from cache to registry.
|
||||||
|
//
|
||||||
|
// We originally placed links in the cache due to its storage
|
||||||
|
// knowledge. However, the registry likely offers better context for
|
||||||
|
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||||
|
// our on-disk format.
|
||||||
|
//
|
||||||
|
// Links work effectively when independent from physical location -
|
||||||
|
// they can reference content with matching SHA regardless of storage
|
||||||
|
// location. In an upcoming change, we plan to shift this
|
||||||
|
// responsibility to the registry where it better aligns with the
|
||||||
|
// system's conceptual model.
|
||||||
manifest, err := c.manifestPath(name)
|
manifest, err := c.manifestPath(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -327,9 +341,7 @@ func (c *DiskCache) GetFile(d Digest) string {
|
|||||||
return absJoin(c.dir, "blobs", filename)
|
return absJoin(c.dir, "blobs", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Links returns a sequence of link names. The sequence is in lexical order.
|
// Links returns a sequence of links in the cache in lexical order.
|
||||||
// Names are converted from their relative path form to their name form but are
|
|
||||||
// not guaranteed to be valid. Callers should validate the names before using.
|
|
||||||
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
func (c *DiskCache) Links() iter.Seq2[string, error] {
|
||||||
return func(yield func(string, error) bool) {
|
return func(yield func(string, error) bool) {
|
||||||
for path, err := range c.links() {
|
for path, err := range c.links() {
|
||||||
@@ -402,14 +414,12 @@ func (c *DiskCache) links() iter.Seq2[string, error] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type checkWriter struct {
|
type checkWriter struct {
|
||||||
size int64
|
|
||||||
d Digest
|
d Digest
|
||||||
f *os.File
|
size int64
|
||||||
|
n int64
|
||||||
h hash.Hash
|
h hash.Hash
|
||||||
|
f *os.File
|
||||||
w io.Writer // underlying writer; set by creator
|
err error
|
||||||
n int64
|
|
||||||
err error
|
|
||||||
|
|
||||||
testHookBeforeFinalWrite func(*os.File)
|
testHookBeforeFinalWrite func(*os.File)
|
||||||
}
|
}
|
||||||
@@ -425,10 +435,6 @@ func (w *checkWriter) seterr(err error) error {
|
|||||||
// underlying writer is guaranteed to be the last byte of p as verified by the
|
// underlying writer is guaranteed to be the last byte of p as verified by the
|
||||||
// hash.
|
// hash.
|
||||||
func (w *checkWriter) Write(p []byte) (int, error) {
|
func (w *checkWriter) Write(p []byte) (int, error) {
|
||||||
if w.err != nil {
|
|
||||||
return 0, w.err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := w.h.Write(p)
|
_, err := w.h.Write(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, w.seterr(err)
|
return 0, w.seterr(err)
|
||||||
@@ -447,7 +453,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
|||||||
if nextSize > w.size {
|
if nextSize > w.size {
|
||||||
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
|
||||||
}
|
}
|
||||||
n, err := w.w.Write(p)
|
n, err := w.f.Write(p)
|
||||||
w.n += int64(n)
|
w.n += int64(n)
|
||||||
return n, w.seterr(err)
|
return n, w.seterr(err)
|
||||||
}
|
}
|
||||||
@@ -487,12 +493,10 @@ func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size
|
|||||||
|
|
||||||
// Copy file to f, but also into h to double-check hash.
|
// Copy file to f, but also into h to double-check hash.
|
||||||
cw := &checkWriter{
|
cw := &checkWriter{
|
||||||
d: out,
|
d: out,
|
||||||
size: size,
|
size: size,
|
||||||
h: sha256.New(),
|
h: sha256.New(),
|
||||||
f: f,
|
f: f,
|
||||||
w: f,
|
|
||||||
|
|
||||||
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
|
||||||
}
|
}
|
||||||
n, err := io.Copy(cw, file)
|
n, err := io.Copy(cw, file)
|
||||||
@@ -528,6 +532,11 @@ func splitNameDigest(s string) (name, digest string) {
|
|||||||
var errInvalidName = errors.New("invalid name")
|
var errInvalidName = errors.New("invalid name")
|
||||||
|
|
||||||
func nameToPath(name string) (_ string, err error) {
|
func nameToPath(name string) (_ string, err error) {
|
||||||
|
if strings.Contains(name, "@") {
|
||||||
|
// TODO(bmizerany): HACK: Fix names.Parse to validate.
|
||||||
|
// TODO(bmizerany): merge with default parts (maybe names.Merge(a, b))
|
||||||
|
return "", errInvalidName
|
||||||
|
}
|
||||||
n := names.Parse(name)
|
n := names.Parse(name)
|
||||||
if !n.IsFullyQualified() {
|
if !n.IsFullyQualified() {
|
||||||
return "", errInvalidName
|
return "", errInvalidName
|
||||||
@@ -538,7 +547,8 @@ func nameToPath(name string) (_ string, err error) {
|
|||||||
func absJoin(pp ...string) string {
|
func absJoin(pp ...string) string {
|
||||||
abs, err := filepath.Abs(filepath.Join(pp...))
|
abs, err := filepath.Abs(filepath.Join(pp...))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err) // this should never happen
|
// Likely a bug bug or a bad OS problem. Just panic.
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
return abs
|
return abs
|
||||||
}
|
}
|
||||||
|
|||||||
73
server/internal/cache/blob/chunked.go
vendored
73
server/internal/cache/blob/chunked.go
vendored
@@ -1,73 +0,0 @@
|
|||||||
package blob
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Chunk represents a range of bytes in a blob.
|
|
||||||
type Chunk struct {
|
|
||||||
Start int64
|
|
||||||
End int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns end minus start plus one.
|
|
||||||
func (c Chunk) Size() int64 {
|
|
||||||
return c.End - c.Start + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chunker writes to a blob in chunks.
|
|
||||||
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
|
|
||||||
type Chunker struct {
|
|
||||||
digest Digest
|
|
||||||
size int64
|
|
||||||
f *os.File // nil means pre-validated
|
|
||||||
}
|
|
||||||
|
|
||||||
// Chunked returns a new Chunker, ready for use storing a blob of the given
|
|
||||||
// size in chunks.
|
|
||||||
//
|
|
||||||
// Use [Chunker.Put] to write data to the blob at specific offsets.
|
|
||||||
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
|
|
||||||
name := c.GetFile(d)
|
|
||||||
info, err := os.Stat(name)
|
|
||||||
if err == nil && info.Size() == size {
|
|
||||||
return &Chunker{}, nil
|
|
||||||
}
|
|
||||||
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Chunker{digest: d, size: size, f: f}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put copies chunk.Size() bytes from r to the blob at the given offset,
|
|
||||||
// merging the data with the existing blob. It returns an error if any. As a
|
|
||||||
// special case, if r has less than chunk.Size() bytes, Put returns
|
|
||||||
// io.ErrUnexpectedEOF.
|
|
||||||
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
|
|
||||||
if c.f == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cw := &checkWriter{
|
|
||||||
d: d,
|
|
||||||
size: chunk.Size(),
|
|
||||||
h: sha256.New(),
|
|
||||||
f: c.f,
|
|
||||||
w: io.NewOffsetWriter(c.f, chunk.Start),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := io.CopyN(cw, r, chunk.Size())
|
|
||||||
if err != nil && errors.Is(err, io.EOF) {
|
|
||||||
return io.ErrUnexpectedEOF
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the underlying file.
|
|
||||||
func (c *Chunker) Close() error {
|
|
||||||
return c.f.Close()
|
|
||||||
}
|
|
||||||
4
server/internal/cache/blob/digest.go
vendored
4
server/internal/cache/blob/digest.go
vendored
@@ -63,10 +63,6 @@ func (d Digest) Short() string {
|
|||||||
return fmt.Sprintf("%x", d.sum[:4])
|
return fmt.Sprintf("%x", d.sum[:4])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Digest) Sum() [32]byte {
|
|
||||||
return d.sum
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Digest) Compare(other Digest) int {
|
func (d Digest) Compare(other Digest) int {
|
||||||
return slices.Compare(d.sum[:], other.sum[:])
|
return slices.Compare(d.sum[:], other.sum[:])
|
||||||
}
|
}
|
||||||
|
|||||||
78
server/internal/chunks/chunks.go
Normal file
78
server/internal/chunks/chunks.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package chunks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Chunk struct {
|
||||||
|
Start, End int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(start, end int64) Chunk {
|
||||||
|
return Chunk{start, end}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRange parses a string in the form "unit=range" where unit is a string
|
||||||
|
// and range is a string in the form "start-end". It returns the unit and the
|
||||||
|
// range as a Chunk.
|
||||||
|
func ParseRange(s string) (unit string, _ Chunk, _ error) {
|
||||||
|
unit, r, _ := strings.Cut(s, "=")
|
||||||
|
if r == "" {
|
||||||
|
return unit, Chunk{}, nil
|
||||||
|
}
|
||||||
|
c, err := Parse(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", Chunk{}, err
|
||||||
|
}
|
||||||
|
return unit, c, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses a string in the form "start-end" and returns the Chunk.
|
||||||
|
func Parse(s string) (Chunk, error) {
|
||||||
|
startStr, endStr, _ := strings.Cut(s, "-")
|
||||||
|
start, err := strconv.ParseInt(startStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return Chunk{}, fmt.Errorf("invalid start: %v", err)
|
||||||
|
}
|
||||||
|
end, err := strconv.ParseInt(endStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return Chunk{}, fmt.Errorf("invalid end: %v", err)
|
||||||
|
}
|
||||||
|
if start > end {
|
||||||
|
return Chunk{}, fmt.Errorf("invalid range %d-%d: start > end", start, end)
|
||||||
|
}
|
||||||
|
return Chunk{start, end}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Of returns a sequence of contiguous Chunks of size chunkSize that cover
|
||||||
|
// the range [0, size), in order.
|
||||||
|
func Of(size, chunkSize int64) iter.Seq[Chunk] {
|
||||||
|
return func(yield func(Chunk) bool) {
|
||||||
|
for start := int64(0); start < size; start += chunkSize {
|
||||||
|
end := min(start+chunkSize-1, size-1)
|
||||||
|
if !yield(Chunk{start, end}) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of Chunks of size chunkSize needed to cover the
|
||||||
|
// range [0, size).
|
||||||
|
func Count(size, chunkSize int64) int64 {
|
||||||
|
return (size + chunkSize - 1) / chunkSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns end minus start plus one.
|
||||||
|
func (c Chunk) Size() int64 {
|
||||||
|
return c.End - c.Start + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the Chunk in the form
|
||||||
|
// "{start}-{end}".
|
||||||
|
func (c Chunk) String() string {
|
||||||
|
return fmt.Sprintf("%d-%d", c.Start, c.End)
|
||||||
|
}
|
||||||
65
server/internal/chunks/chunks_test.go
Normal file
65
server/internal/chunks/chunks_test.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package chunks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOf(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
total int64
|
||||||
|
chunkSize int64
|
||||||
|
want []Chunk
|
||||||
|
}{
|
||||||
|
{0, 1, nil},
|
||||||
|
{1, 1, []Chunk{{0, 0}}},
|
||||||
|
{1, 2, []Chunk{{0, 0}}},
|
||||||
|
{2, 1, []Chunk{{0, 0}, {1, 1}}},
|
||||||
|
{10, 9, []Chunk{{0, 8}, {9, 9}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := slices.Collect(Of(tt.total, tt.chunkSize))
|
||||||
|
if !slices.Equal(got, tt.want) {
|
||||||
|
t.Errorf("[%d/%d]: got %v; want %v", tt.total, tt.chunkSize, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSize(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
c Chunk
|
||||||
|
want int64
|
||||||
|
}{
|
||||||
|
{Chunk{0, 0}, 1},
|
||||||
|
{Chunk{0, 1}, 2},
|
||||||
|
{Chunk{3, 4}, 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := tt.c.Size()
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("%v: got %d; want %d", tt.c, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCount(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
total int64
|
||||||
|
chunkSize int64
|
||||||
|
want int64
|
||||||
|
}{
|
||||||
|
{0, 1, 0},
|
||||||
|
{1, 1, 1},
|
||||||
|
{1, 2, 1},
|
||||||
|
{2, 1, 2},
|
||||||
|
{10, 9, 2},
|
||||||
|
}
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := Count(tt.total, tt.chunkSize)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("[%d/%d]: got %d; want %d", tt.total, tt.chunkSize, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"iter"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -36,8 +35,10 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
|
"github.com/ollama/ollama/server/internal/chunks"
|
||||||
"github.com/ollama/ollama/server/internal/internal/backoff"
|
"github.com/ollama/ollama/server/internal/internal/backoff"
|
||||||
"github.com/ollama/ollama/server/internal/internal/names"
|
"github.com/ollama/ollama/server/internal/internal/names"
|
||||||
|
"github.com/ollama/ollama/server/internal/internal/syncs"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
)
|
)
|
||||||
@@ -65,7 +66,12 @@ var (
|
|||||||
const (
|
const (
|
||||||
// DefaultChunkingThreshold is the threshold at which a layer should be
|
// DefaultChunkingThreshold is the threshold at which a layer should be
|
||||||
// split up into chunks when downloading.
|
// split up into chunks when downloading.
|
||||||
DefaultChunkingThreshold = 64 << 20
|
DefaultChunkingThreshold = 128 << 20
|
||||||
|
|
||||||
|
// DefaultMaxChunkSize is the default maximum size of a chunk to
|
||||||
|
// download. It is configured based on benchmarks and aims to strike a
|
||||||
|
// balance between download speed and memory usage.
|
||||||
|
DefaultMaxChunkSize = 8 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||||
@@ -205,7 +211,8 @@ type Registry struct {
|
|||||||
// pushing or pulling models. If zero, the number of streams is
|
// pushing or pulling models. If zero, the number of streams is
|
||||||
// determined by [runtime.GOMAXPROCS].
|
// determined by [runtime.GOMAXPROCS].
|
||||||
//
|
//
|
||||||
// A negative value means no limit.
|
// Clients that want "unlimited" streams should set this to a large
|
||||||
|
// number.
|
||||||
MaxStreams int
|
MaxStreams int
|
||||||
|
|
||||||
// ChunkingThreshold is the maximum size of a layer to download in a single
|
// ChunkingThreshold is the maximum size of a layer to download in a single
|
||||||
@@ -275,13 +282,24 @@ func DefaultRegistry() (*Registry, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) maxStreams() int {
|
func (r *Registry) maxStreams() int {
|
||||||
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||||
|
|
||||||
|
// Large downloads require a writter stream, so ensure we have at least
|
||||||
|
// two streams to avoid a deadlock.
|
||||||
|
return max(n, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) maxChunkingThreshold() int64 {
|
func (r *Registry) maxChunkingThreshold() int64 {
|
||||||
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chunkSizeFor returns the chunk size for a layer of the given size. If the
|
||||||
|
// size is less than or equal to the max chunking threshold, the size is
|
||||||
|
// returned; otherwise, the max chunk size is returned.
|
||||||
|
func (r *Registry) maxChunkSize() int64 {
|
||||||
|
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
type PushParams struct {
|
type PushParams struct {
|
||||||
// From is an optional destination name for the model. If empty, the
|
// From is an optional destination name for the model. If empty, the
|
||||||
// destination name is the same as the source name.
|
// destination name is the same as the source name.
|
||||||
@@ -408,21 +426,6 @@ func canRetry(err error) bool {
|
|||||||
return re.Status >= 500
|
return re.Status >= 500
|
||||||
}
|
}
|
||||||
|
|
||||||
// trackingReader is an io.Reader that tracks the number of bytes read and
|
|
||||||
// calls the update function with the layer, the number of bytes read.
|
|
||||||
//
|
|
||||||
// It always calls update with a nil error.
|
|
||||||
type trackingReader struct {
|
|
||||||
r io.Reader
|
|
||||||
n *atomic.Int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|
||||||
n, err = r.r.Read(p)
|
|
||||||
r.n.Add(int64(n))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pull pulls the model with the given name from the remote registry into the
|
// Pull pulls the model with the given name from the remote registry into the
|
||||||
// cache.
|
// cache.
|
||||||
//
|
//
|
||||||
@@ -431,6 +434,11 @@ func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|||||||
// typically slower than splitting the model up across layers, and is mostly
|
// typically slower than splitting the model up across layers, and is mostly
|
||||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||||
|
scheme, n, _, err := r.parseNameExtended(name)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
m, err := r.Resolve(ctx, name)
|
m, err := r.Resolve(ctx, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -449,95 +457,126 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return err == nil && info.Size == l.Size
|
return err == nil && info.Size == l.Size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t := traceFromContext(ctx)
|
||||||
|
|
||||||
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
|
g.SetLimit(r.maxStreams())
|
||||||
|
|
||||||
layers := m.Layers
|
layers := m.Layers
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
||||||
layers = append(layers, m.Config)
|
layers = append(layers, m.Config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send initial layer trace events to allow clients to have an
|
for _, l := range layers {
|
||||||
// understanding of work to be done before work starts.
|
|
||||||
t := traceFromContext(ctx)
|
|
||||||
skip := make([]bool, len(layers))
|
|
||||||
for i, l := range layers {
|
|
||||||
t.update(l, 0, nil)
|
|
||||||
if exists(l) {
|
if exists(l) {
|
||||||
skip[i] = true
|
|
||||||
t.update(l, l.Size, ErrCached)
|
t.update(l, l.Size, ErrCached)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
|
||||||
g.SetLimit(r.maxStreams())
|
|
||||||
for i, l := range layers {
|
|
||||||
if skip[i] {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
chunked, err := c.Chunked(l.Digest, l.Size)
|
blobURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), l.Digest)
|
||||||
|
req, err := r.newRequest(ctx, "GET", blobURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.update(l, 0, err)
|
t.update(l, 0, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer chunked.Close()
|
|
||||||
|
|
||||||
var progress atomic.Int64
|
t.update(l, 0, nil)
|
||||||
for cs, err := range r.chunksums(ctx, name, l) {
|
|
||||||
if err != nil {
|
if l.Size <= r.maxChunkingThreshold() {
|
||||||
t.update(l, progress.Load(), err)
|
g.Go(func() error {
|
||||||
break
|
// TODO(bmizerany): retry/backoff like below in
|
||||||
}
|
// the chunking case
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
err = c.Put(l.Digest, res.Body, l.Size)
|
||||||
|
if err == nil {
|
||||||
|
t.update(l, l.Size, nil)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
q := syncs.NewRelayReader()
|
||||||
|
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() { t.update(l, progress.Load(), err) }()
|
defer func() { q.CloseWithError(err) }()
|
||||||
|
return c.Put(l.Digest, q, l.Size)
|
||||||
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err := func() error {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
|
|
||||||
// Count bytes towards
|
|
||||||
// progress, as they arrive, so
|
|
||||||
// that our bytes piggyback
|
|
||||||
// other chunk updates on
|
|
||||||
// completion.
|
|
||||||
//
|
|
||||||
// This tactic is enough to
|
|
||||||
// show "smooth" progress given
|
|
||||||
// the current CLI client. In
|
|
||||||
// the near future, the server
|
|
||||||
// should report download rate
|
|
||||||
// since it knows better than
|
|
||||||
// a client that is measuring
|
|
||||||
// rate based on wall-clock
|
|
||||||
// time-since-last-update.
|
|
||||||
body := &trackingReader{r: res.Body, n: &progress}
|
|
||||||
|
|
||||||
err = chunked.Put(cs.Chunk, cs.Digest, body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
if !canRetry(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var progress atomic.Int64
|
||||||
|
|
||||||
|
// We want to avoid extra round trips per chunk due to
|
||||||
|
// redirects from the registry to the blob store, so
|
||||||
|
// fire an initial request to get the final URL and
|
||||||
|
// then use that URL for the chunk requests.
|
||||||
|
req.Header.Set("Range", "bytes=0-0")
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
req = res.Request.WithContext(req.Context())
|
||||||
|
|
||||||
|
wp := writerPool{size: r.maxChunkSize()}
|
||||||
|
|
||||||
|
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
ticket := q.Take()
|
||||||
|
g.Go(func() (err error) {
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
q.CloseWithError(err)
|
||||||
|
}
|
||||||
|
ticket.Close()
|
||||||
|
t.update(l, progress.Load(), err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, err := range backoff.Loop(ctx, 3*time.Second) {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err := func() error {
|
||||||
|
req := req.Clone(req.Context())
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||||
|
res, err := sendRequest(r.client(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
tw := wp.get()
|
||||||
|
tw.Reset(ticket)
|
||||||
|
defer wp.put(tw)
|
||||||
|
|
||||||
|
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||||
|
if err != nil {
|
||||||
|
return maybeUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
if err := tw.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
total := progress.Add(chunk.Size())
|
||||||
|
if total >= l.Size {
|
||||||
|
q.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if !canRetry(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -576,6 +615,8 @@ type Manifest struct {
|
|||||||
Config *Layer `json:"config"`
|
Config *Layer `json:"config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||||
|
|
||||||
// Layer returns the layer with the given
|
// Layer returns the layer with the given
|
||||||
// digest, or nil if not found.
|
// digest, or nil if not found.
|
||||||
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
||||||
@@ -602,9 +643,10 @@ func (m Manifest) MarshalJSON() ([]byte, error) {
|
|||||||
// last phase of the commit which expects it, but does nothing
|
// last phase of the commit which expects it, but does nothing
|
||||||
// with it. This will be fixed in a future release of
|
// with it. This will be fixed in a future release of
|
||||||
// ollama.com.
|
// ollama.com.
|
||||||
Config Layer `json:"config"`
|
Config *Layer `json:"config"`
|
||||||
}{
|
}{
|
||||||
M: M(m),
|
M: M(m),
|
||||||
|
Config: &Layer{Digest: emptyDigest},
|
||||||
}
|
}
|
||||||
return json.Marshal(v)
|
return json.Marshal(v)
|
||||||
}
|
}
|
||||||
@@ -694,123 +736,6 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type chunksum struct {
|
|
||||||
URL string
|
|
||||||
Chunk blob.Chunk
|
|
||||||
Digest blob.Digest
|
|
||||||
}
|
|
||||||
|
|
||||||
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
|
||||||
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
|
||||||
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
|
||||||
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
|
||||||
return func(yield func(chunksum, error) bool) {
|
|
||||||
scheme, n, _, err := r.parseNameExtended(name)
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if l.Size < r.maxChunkingThreshold() {
|
|
||||||
// any layer under the threshold should be downloaded
|
|
||||||
// in one go.
|
|
||||||
cs := chunksum{
|
|
||||||
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
|
||||||
scheme,
|
|
||||||
n.Host(),
|
|
||||||
n.Namespace(),
|
|
||||||
n.Model(),
|
|
||||||
l.Digest,
|
|
||||||
),
|
|
||||||
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
|
||||||
Digest: l.Digest,
|
|
||||||
}
|
|
||||||
yield(cs, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// A chunksums response is a sequence of chunksums in a
|
|
||||||
// simple, easy to parse line-oriented format.
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
//
|
|
||||||
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
|
|
||||||
//
|
|
||||||
// << HTTP/1.1 200 OK
|
|
||||||
// << Content-Location: <blobURL>
|
|
||||||
// <<
|
|
||||||
// << <digest> <start>-<end>
|
|
||||||
// << ...
|
|
||||||
//
|
|
||||||
// The blobURL is the URL to download the chunks from.
|
|
||||||
|
|
||||||
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
|
||||||
scheme,
|
|
||||||
n.Host(),
|
|
||||||
n.Namespace(),
|
|
||||||
n.Model(),
|
|
||||||
l.Digest,
|
|
||||||
)
|
|
||||||
|
|
||||||
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer res.Body.Close()
|
|
||||||
if res.StatusCode != 200 {
|
|
||||||
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
blobURL := res.Header.Get("Content-Location")
|
|
||||||
|
|
||||||
s := bufio.NewScanner(res.Body)
|
|
||||||
s.Split(bufio.ScanWords)
|
|
||||||
for {
|
|
||||||
if !s.Scan() {
|
|
||||||
if s.Err() != nil {
|
|
||||||
yield(chunksum{}, s.Err())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d, err := blob.ParseDigest(s.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.Scan() {
|
|
||||||
err := s.Err()
|
|
||||||
if err == nil {
|
|
||||||
err = fmt.Errorf("missing chunk range for digest %s", d)
|
|
||||||
}
|
|
||||||
yield(chunksum{}, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
chunk, err := parseChunk(s.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cs := chunksum{
|
|
||||||
URL: blobURL,
|
|
||||||
Chunk: chunk,
|
|
||||||
Digest: d,
|
|
||||||
}
|
|
||||||
if !yield(cs, nil) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Registry) client() *http.Client {
|
func (r *Registry) client() *http.Client {
|
||||||
if r.HTTPClient != nil {
|
if r.HTTPClient != nil {
|
||||||
return r.HTTPClient
|
return r.HTTPClient
|
||||||
@@ -973,6 +898,13 @@ func checkData(url string) string {
|
|||||||
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func maybeUnexpectedEOF(err error) error {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
type publicError struct {
|
type publicError struct {
|
||||||
wrapped error
|
wrapped error
|
||||||
message string
|
message string
|
||||||
@@ -1059,22 +991,27 @@ func splitExtended(s string) (scheme, name, digest string) {
|
|||||||
return scheme, s, digest
|
return scheme, s, digest
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
type writerPool struct {
|
||||||
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
size int64 // set by the caller
|
||||||
startPart, endPart, found := strings.Cut(string(s), "-")
|
|
||||||
if !found {
|
mu sync.Mutex
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
ws []*bufio.Writer
|
||||||
}
|
}
|
||||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
|
||||||
if err != nil {
|
func (p *writerPool) get() *bufio.Writer {
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
p.mu.Lock()
|
||||||
}
|
defer p.mu.Unlock()
|
||||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
if len(p.ws) == 0 {
|
||||||
if err != nil {
|
return bufio.NewWriterSize(nil, int(p.size))
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
}
|
||||||
}
|
w := p.ws[len(p.ws)-1]
|
||||||
if start > end {
|
p.ws = p.ws[:len(p.ws)-1]
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
return w
|
||||||
}
|
}
|
||||||
return blob.Chunk{Start: start, End: end}, nil
|
|
||||||
|
func (p *writerPool) put(w *bufio.Writer) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
w.Reset(nil)
|
||||||
|
p.ws = append(p.ws, w)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
|
"github.com/ollama/ollama/server/internal/chunks"
|
||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -427,7 +428,7 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
err := rc.Pull(ctx, "single")
|
err := rc.Pull(ctx, "single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{0, 6}
|
want := []int64{6}
|
||||||
if !errors.Is(errors.Join(errs...), ErrCached) {
|
if !errors.Is(errors.Join(errs...), ErrCached) {
|
||||||
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
t.Errorf("errs = %v; want %v", errs, ErrCached)
|
||||||
}
|
}
|
||||||
@@ -530,6 +531,54 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRegistryPullChunking(t *testing.T) {
|
||||||
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||||
|
if r.URL.Host != "blob.store" {
|
||||||
|
// The production registry redirects to the blob store.
|
||||||
|
http.Redirect(w, r, "http://blob.store"+r.URL.Path, http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
|
rng := r.Header.Get("Range")
|
||||||
|
if rng == "" {
|
||||||
|
http.Error(w, "missing range", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, c, err := chunks.ParseRange(r.Header.Get("Range"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
io.WriteString(w, "remote"[c.Start:c.End+1])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, blob.DigestFromBytes("remote"))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Force chunking by setting the threshold to less than the size of the
|
||||||
|
// layer.
|
||||||
|
rc.ChunkingThreshold = 3
|
||||||
|
rc.MaxChunkSize = 3
|
||||||
|
|
||||||
|
var reads []int64
|
||||||
|
ctx := WithTrace(t.Context(), &Trace{
|
||||||
|
Update: func(d *Layer, n int64, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("update %v %d %v", d, n, err)
|
||||||
|
}
|
||||||
|
reads = append(reads, n)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
err := rc.Pull(ctx, "remote")
|
||||||
|
testutil.Check(t, err)
|
||||||
|
|
||||||
|
want := []int64{0, 3, 6}
|
||||||
|
if !slices.Equal(reads, want) {
|
||||||
|
t.Errorf("reads = %v; want %v", reads, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRegistryResolveByDigest(t *testing.T) {
|
func TestRegistryResolveByDigest(t *testing.T) {
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
|
|||||||
11
server/internal/cmd/oppbench/oppbench.go
Normal file
11
server/internal/cmd/oppbench/oppbench.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
fmt.Println("Run as 'go test -bench=.' to run the benchmarks")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
107
server/internal/cmd/oppbench/oppbench_test.go
Normal file
107
server/internal/cmd/oppbench/oppbench_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/server/internal/chunks"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkDownload(b *testing.B) {
|
||||||
|
run := func(fileSize, chunkSize int64) {
|
||||||
|
name := fmt.Sprintf("size=%d/chunksize=%d", fileSize, chunkSize)
|
||||||
|
b.Run(name, func(b *testing.B) { benchmarkDownload(b, fileSize, chunkSize) })
|
||||||
|
}
|
||||||
|
|
||||||
|
run(100<<20, 8<<20)
|
||||||
|
run(100<<20, 16<<20)
|
||||||
|
run(100<<20, 32<<20)
|
||||||
|
run(100<<20, 64<<20)
|
||||||
|
run(100<<20, 128<<20) // 1 chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(ctx context.Context, c *http.Client, chunk chunks.Chunk) error {
|
||||||
|
const blobURL = "https://ollama.com/v2/x/x/blobs/sha256-4824460d29f2058aaf6e1118a63a7a197a09bed509f0e7d4e2efb1ee273b447d"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", blobURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||||
|
res, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
_, err = io.CopyN(io.Discard, res.Body, chunk.Size()) // will io.EOF on short read
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var sleepTime atomic.Int64
|
||||||
|
|
||||||
|
func benchmarkDownload(b *testing.B, fileSize, chunkSize int64) {
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: func() http.RoundTripper {
|
||||||
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
tr.DisableKeepAlives = true
|
||||||
|
return tr
|
||||||
|
}(),
|
||||||
|
}
|
||||||
|
defer client.CloseIdleConnections()
|
||||||
|
|
||||||
|
// warm up the client
|
||||||
|
run(context.Background(), client, chunks.New(0, 1<<20))
|
||||||
|
|
||||||
|
b.SetBytes(fileSize)
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
// Give our CDN a min to breathe between benchmarks.
|
||||||
|
time.Sleep(time.Duration(sleepTime.Swap(3)))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
g, ctx := errgroup.WithContext(b.Context())
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
|
for chunk := range chunks.Of(fileSize, chunkSize) {
|
||||||
|
g.Go(func() error { return run(ctx, client, chunk) })
|
||||||
|
}
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWrite(b *testing.B) {
|
||||||
|
b.Run("chunksize=1MiB", func(b *testing.B) { benchmarkWrite(b, 1<<20) })
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkWrite(b *testing.B, chunkSize int) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
dir := b.TempDir()
|
||||||
|
f, err := os.Create(filepath.Join(dir, "write-single"))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
data := make([]byte, chunkSize)
|
||||||
|
b.SetBytes(int64(chunkSize))
|
||||||
|
r := bytes.NewReader(data)
|
||||||
|
for b.Loop() {
|
||||||
|
r.Reset(data)
|
||||||
|
_, err := io.Copy(f, r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// Package registry implements an http.Handler for handling local Ollama API
|
// Package registry provides an http.Handler for handling local Ollama API
|
||||||
// model management requests. See [Local] for details.
|
// requests for performing tasks related to the ollama.com model registry and
|
||||||
|
// the local disk cache.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -9,7 +10,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,11 +18,16 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Local implements an http.Handler for handling local Ollama API model
|
// Local is an http.Handler for handling local Ollama API requests for
|
||||||
// management requests, such as pushing, pulling, and deleting models.
|
// performing tasks related to the ollama.com model registry combined with the
|
||||||
|
// local disk cache.
|
||||||
//
|
//
|
||||||
// It can be arranged for all unknown requests to be passed through to a
|
// It is not concern of Local, or this package, to handle model creation, which
|
||||||
// fallback handler, if one is provided.
|
// proceeds any registry operations for models it produces.
|
||||||
|
//
|
||||||
|
// NOTE: The package built for dealing with model creation should use
|
||||||
|
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||||
|
// directly to the blob disk cache.
|
||||||
type Local struct {
|
type Local struct {
|
||||||
Client *ollama.Registry // required
|
Client *ollama.Registry // required
|
||||||
Logger *slog.Logger // required
|
Logger *slog.Logger // required
|
||||||
@@ -58,7 +63,6 @@ func (e serverError) Error() string {
|
|||||||
var (
|
var (
|
||||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||||
errNotFound = &serverError{404, "not_found", "not found"}
|
errNotFound = &serverError{404, "not_found", "not found"}
|
||||||
errModelNotFound = &serverError{404, "not_found", "model not found"}
|
|
||||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -171,16 +175,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type params struct {
|
type params struct {
|
||||||
// DeprecatedName is the name of the model to push, pull, or delete,
|
DeprecatedName string `json:"name"` // Use [params.model]
|
||||||
// but is deprecated. New clients should use [Model] instead.
|
Model string `json:"model"` // Use [params.model]
|
||||||
//
|
|
||||||
// Use [model()] to get the model name for both old and new API requests.
|
|
||||||
DeprecatedName string `json:"name"`
|
|
||||||
|
|
||||||
// Model is the name of the model to push, pull, or delete.
|
|
||||||
//
|
|
||||||
// Use [model()] to get the model name for both old and new API requests.
|
|
||||||
Model string `json:"model"`
|
|
||||||
|
|
||||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||||
// is doing so, deliberately.
|
// is doing so, deliberately.
|
||||||
@@ -193,18 +189,9 @@ type params struct {
|
|||||||
// confusing flags such as this.
|
// confusing flags such as this.
|
||||||
AllowNonTLS bool `json:"insecure"`
|
AllowNonTLS bool `json:"insecure"`
|
||||||
|
|
||||||
// Stream, if true, will make the server send progress updates in a
|
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||||
// streaming of JSON objects. If false, the server will send a single
|
// progress updates.
|
||||||
// JSON object with the final status as "success", or an error object
|
ProgressStream bool `json:"stream"`
|
||||||
// if an error occurred.
|
|
||||||
//
|
|
||||||
// Unfortunately, this API was designed to be a bit awkward. Stream is
|
|
||||||
// defined to default to true if not present, so we need a way to check
|
|
||||||
// if the client decisively it to false. So, we use a pointer to a
|
|
||||||
// bool. Gross.
|
|
||||||
//
|
|
||||||
// Use [stream()] to get the correct value for this field.
|
|
||||||
Stream *bool `json:"stream"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// model returns the model name for both old and new API requests.
|
// model returns the model name for both old and new API requests.
|
||||||
@@ -212,13 +199,6 @@ func (p params) model() string {
|
|||||||
return cmp.Or(p.Model, p.DeprecatedName)
|
return cmp.Or(p.Model, p.DeprecatedName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p params) stream() bool {
|
|
||||||
if p.Stream == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return *p.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||||
if r.Method != "DELETE" {
|
if r.Method != "DELETE" {
|
||||||
return errMethodNotAllowed
|
return errMethodNotAllowed
|
||||||
@@ -232,16 +212,16 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return errModelNotFound
|
return &serverError{404, "not_found", "model not found"}
|
||||||
}
|
}
|
||||||
if s.Prune != nil {
|
if s.Prune == nil {
|
||||||
return s.Prune()
|
return nil
|
||||||
}
|
}
|
||||||
return nil
|
return s.Prune()
|
||||||
}
|
}
|
||||||
|
|
||||||
type progressUpdateJSON struct {
|
type progressUpdateJSON struct {
|
||||||
Status string `json:"status,omitempty,omitzero"`
|
Status string `json:"status"`
|
||||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||||
Total int64 `json:"total,omitempty,omitzero"`
|
Total int64 `json:"total,omitempty,omitzero"`
|
||||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||||
@@ -257,17 +237,6 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
if !p.stream() {
|
|
||||||
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
|
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
|
||||||
return errModelNotFound
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return enc.Encode(progressUpdateJSON{Status: "success"})
|
|
||||||
}
|
|
||||||
|
|
||||||
maybeFlush := func() {
|
maybeFlush := func() {
|
||||||
fl, _ := w.(http.Flusher)
|
fl, _ := w.(http.Flusher)
|
||||||
if fl != nil {
|
if fl != nil {
|
||||||
@@ -277,67 +246,69 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|||||||
defer maybeFlush()
|
defer maybeFlush()
|
||||||
|
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
progress := make(map[*ollama.Layer]int64)
|
enc := json.NewEncoder(w)
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||||
|
|
||||||
progressCopy := make(map[*ollama.Layer]int64, len(progress))
|
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||||
pushUpdate := func() {
|
Update: func(l *ollama.Layer, n int64, err error) {
|
||||||
defer maybeFlush()
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
// TODO(bmizerany): This scales poorly with more layers due to
|
// TODO(bmizerany): coalesce these updates; writing per
|
||||||
// needing to flush out them all in one big update. We _could_
|
// update is expensive
|
||||||
// just flush on the changed ones, or just track the whole
|
|
||||||
// download. Needs more thought. This is fine for now.
|
|
||||||
mu.Lock()
|
|
||||||
maps.Copy(progressCopy, progress)
|
|
||||||
mu.Unlock()
|
|
||||||
for l, n := range progress {
|
|
||||||
enc.Encode(progressUpdateJSON{
|
enc.Encode(progressUpdateJSON{
|
||||||
Digest: l.Digest,
|
Digest: l.Digest,
|
||||||
|
Status: "pulling",
|
||||||
Total: l.Size,
|
Total: l.Size,
|
||||||
Completed: n,
|
Completed: n,
|
||||||
})
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t := time.NewTicker(time.Hour) // "unstarted" timer
|
|
||||||
start := sync.OnceFunc(func() {
|
|
||||||
pushUpdate()
|
|
||||||
t.Reset(100 * time.Millisecond)
|
|
||||||
})
|
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
|
||||||
if n > 0 {
|
|
||||||
start() // flush initial state
|
|
||||||
}
|
|
||||||
mu.Lock()
|
|
||||||
progress[l] = n
|
|
||||||
mu.Unlock()
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
// TODO(bmizerany): continue to support non-streaming responses
|
||||||
done <- s.Client.Pull(ctx, p.model())
|
done <- s.Client.Pull(ctx, p.model())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
func() {
|
||||||
select {
|
t := time.NewTicker(100 * time.Millisecond)
|
||||||
case <-t.C:
|
defer t.Stop()
|
||||||
pushUpdate()
|
for {
|
||||||
case err := <-done:
|
select {
|
||||||
pushUpdate()
|
case <-t.C:
|
||||||
if err != nil {
|
mu.Lock()
|
||||||
var status string
|
maybeFlush()
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
mu.Unlock()
|
||||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
case err := <-done:
|
||||||
} else {
|
if err != nil {
|
||||||
status = fmt.Sprintf("error: %v", err)
|
var status string
|
||||||
|
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||||
|
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||||
|
enc.Encode(progressUpdateJSON{Status: status})
|
||||||
|
} else {
|
||||||
|
status = fmt.Sprintf("error: %v", err)
|
||||||
|
enc.Encode(progressUpdateJSON{Status: status})
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
enc.Encode(progressUpdateJSON{Status: status})
|
|
||||||
|
// These final updates are not strictly necessary, because they have
|
||||||
|
// already happened at this point. Our pull handler code used to do
|
||||||
|
// these steps after, not during, the pull, and they were slow, so we
|
||||||
|
// wanted to provide feedback to users what was happening. For now, we
|
||||||
|
// keep them to not jar users who are used to seeing them. We can phase
|
||||||
|
// them out with a new and nicer UX later. One without progress bars
|
||||||
|
// and digests that no one cares about.
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||||
|
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
"net"
|
||||||
@@ -159,6 +160,7 @@ var registryFS = sync.OnceValue(func() fs.FS {
|
|||||||
// to \n when parsing the txtar on Windows.
|
// to \n when parsing the txtar on Windows.
|
||||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||||
a := txtar.Parse(data)
|
a := txtar.Parse(data)
|
||||||
|
fmt.Printf("%q\n", a.Comment)
|
||||||
fsys, err := txtar.FS(a)
|
fsys, err := txtar.FS(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -177,7 +179,7 @@ func TestServerPull(t *testing.T) {
|
|||||||
w.WriteHeader(404)
|
w.WriteHeader(404)
|
||||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||||
default:
|
default:
|
||||||
t.Logf("serving blob: %s", r.URL.Path)
|
t.Logf("serving file: %s", r.URL.Path)
|
||||||
modelsHandler.ServeHTTP(w, r)
|
modelsHandler.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -186,7 +188,7 @@ func TestServerPull(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if got.Code != 200 {
|
if got.Code != 200 {
|
||||||
t.Errorf("Code = %d; want 200", got.Code)
|
t.Fatalf("Code = %d; want 200", got.Code)
|
||||||
}
|
}
|
||||||
gotlines := got.Body.String()
|
gotlines := got.Body.String()
|
||||||
t.Logf("got:\n%s", gotlines)
|
t.Logf("got:\n%s", gotlines)
|
||||||
@@ -195,29 +197,35 @@ func TestServerPull(t *testing.T) {
|
|||||||
want, unwanted := strings.CutPrefix(want, "!")
|
want, unwanted := strings.CutPrefix(want, "!")
|
||||||
want = strings.TrimSpace(want)
|
want = strings.TrimSpace(want)
|
||||||
if !unwanted && !strings.Contains(gotlines, want) {
|
if !unwanted && !strings.Contains(gotlines, want) {
|
||||||
t.Errorf("! missing %q in body", want)
|
t.Fatalf("! missing %q in body", want)
|
||||||
}
|
}
|
||||||
if unwanted && strings.Contains(gotlines, want) {
|
if unwanted && strings.Contains(gotlines, want) {
|
||||||
t.Errorf("! unexpected %q in body", want)
|
t.Fatalf("! unexpected %q in body", want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
|
{"status":"pulling manifest"}
|
||||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
{"status":"pulling manifest"}
|
||||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||||
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||||
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||||
|
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||||
|
{"status":"verifying layers"}
|
||||||
|
{"status":"writing manifest"}
|
||||||
|
{"status":"success"}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
|
{"status":"pulling manifest"}
|
||||||
{"status":"error: model \"unknown\" not found"}
|
{"status":"error: model \"unknown\" not found"}
|
||||||
`)
|
`)
|
||||||
|
|
||||||
@@ -232,39 +240,19 @@ func TestServerPull(t *testing.T) {
|
|||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||||
checkResponse(got, `
|
checkResponse(got, `
|
||||||
|
{"status":"pulling manifest"}
|
||||||
{"status":"error: invalid or missing name: \"\""}
|
{"status":"error: invalid or missing name: \"\""}
|
||||||
`)
|
|
||||||
|
|
||||||
// Non-streaming pulls
|
!verifying
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
|
!writing
|
||||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
!success
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
|
|
||||||
checkResponse(got, `
|
|
||||||
{"status":"success"}
|
|
||||||
!digest
|
|
||||||
!total
|
|
||||||
!completed
|
|
||||||
`)
|
`)
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
|
|
||||||
checkErrorResponse(t, got, 404, "not_found", "model not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerUnknownPath(t *testing.T) {
|
func TestServerUnknownPath(t *testing.T) {
|
||||||
s := newTestServer(t, nil)
|
s := newTestServer(t, nil)
|
||||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||||
|
|
||||||
var fellback bool
|
|
||||||
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
fellback = true
|
|
||||||
})
|
|
||||||
got = s.send(t, "DELETE", "/api/unknown", `{}`)
|
|
||||||
if !fellback {
|
|
||||||
t.Fatal("expected Fallback to be called")
|
|
||||||
}
|
|
||||||
if got.Code != 200 {
|
|
||||||
t.Fatalf("Code = %d; want 200", got.Code)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||||
|
|||||||
Reference in New Issue
Block a user