Compare commits

..

4 Commits

Author SHA1 Message Date
jmorganca
9622b928b4 extras 2025-03-12 18:28:59 +01:00
ParthSareen
7fa6ea0da7 sample: update tests and add test logits 2025-03-12 00:55:18 -04:00
ParthSareen
310b235626 sample: use partial sort for sorting 2025-03-12 00:46:12 -04:00
ParthSareen
448fc4cd2a sample: use container/heap for top_k 2025-03-12 00:45:41 -04:00
21 changed files with 444 additions and 433 deletions

View File

@@ -54,10 +54,6 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download | | Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- | | ------------------ | ---------- | ----- | -------------------------------- |
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
| QwQ | 32B | 20GB | `ollama run qwq` | | QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` | | DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` | | DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
@@ -70,6 +66,9 @@ Here are some example models that can be downloaded:
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | | Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 4 | 14B | 9.1GB | `ollama run phi4` | | Phi 4 | 14B | 9.1GB | `ollama run phi4` |
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` | | Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
| Mistral | 7B | 4.1GB | `ollama run mistral` | | Mistral | 7B | 4.1GB | `ollama run mistral` |
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` | | Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |

View File

@@ -349,7 +349,6 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"` ModifiedAt time.Time `json:"modified_at,omitempty"`
} }
@@ -468,13 +467,6 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"` QuantizationLevel string `json:"quantization_level"`
} }
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() { func (m *Metrics) Summary() {
if m.TotalDuration > 0 { if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)

View File

@@ -18,7 +18,6 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -569,9 +568,8 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters") parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system") system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template") template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} { for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
if boolErr != nil { if boolErr != nil {
return errors.New("error retrieving flags") return errors.New("error retrieving flags")
} }
@@ -609,7 +607,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
} }
req := api.ShowRequest{Name: args[0], Verbose: verbose} req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(cmd.Context(), &req) resp, err := client.Show(cmd.Context(), &req)
if err != nil { if err != nil {
return err return err
@@ -632,10 +630,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
return showInfo(resp, verbose, os.Stdout) return showInfo(resp, os.Stdout)
} }
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { func showInfo(resp *api.ShowResponse, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) { tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header) fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w) table := tablewriter.NewWriter(w)
@@ -692,45 +690,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}) })
} }
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) { head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s)) scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) { for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -1237,7 +1196,6 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model") showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model") showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model") showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{ runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]", Use: "run MODEL [PROMPT]",

View File

@@ -27,7 +27,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -57,7 +57,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -68,56 +68,6 @@ func TestShowInfo(t *testing.T) {
embedding length 0 embedding length 0
quantization FP16 quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -139,7 +89,7 @@ func TestShowInfo(t *testing.T) {
stop you stop you
stop up stop up
temperature 99`, temperature 99`,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -176,7 +126,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0), "clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0), "clip.vision.projection_dim": float64(0),
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -209,7 +159,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey! Ahoy, matey!
Weigh anchor! Weigh anchor!
`, `,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -238,7 +188,7 @@ Weigh anchor!
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
License: license, License: license,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -195,10 +195,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
return err return err
} }
continue continue
@@ -347,7 +343,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
_ = showInfo(resp, false, os.Stderr) _ = showInfo(resp, os.Stderr)
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")

View File

@@ -87,7 +87,7 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.embedding_length"] = p.HiddenSize kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize kv["gemma3.feed_forward_length"] = p.IntermediateSize
default: default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072) kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 8192)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow

View File

@@ -187,13 +187,6 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`. Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
```
# Allow all Chrome, Firefox, and Safari extensions
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
```
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform. Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Where are models stored? ## Where are models stored?

View File

@@ -327,10 +327,6 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize() return t.parameters() * t.typeSize() / t.blockSize()
} }
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface { type container interface {
Name() string Name() string
Decode(io.ReadSeeker) (model, error) Decode(io.ReadSeeker) (model, error)
@@ -583,52 +579,39 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
} }
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() { switch llm.KV().Architecture() {
case "mllama": case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8 numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 + graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles + imageSize*imageSize*kv("num_channels")*maxNumTiles +
embeddingLength*numPatches*maxNumTiles + embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
} }
return weights, graphSize return weights, graphSize
} }

View File

@@ -218,8 +218,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size() layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount() layerSize += kv / f.KV().BlockCount()
memoryWeights += blk.Size()
} }
memoryWeights += layerSize
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU { if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
// Stop allocating on GPU(s) once we hit the users target NumGPU // Stop allocating on GPU(s) once we hit the users target NumGPU
@@ -376,7 +376,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
// memory of the weights // memory of the weights
"total", format.HumanBytes2(m.memoryWeights), "total", format.HumanBytes2(m.memoryWeights),
// memory of repeating layers // memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights), "repeating", format.HumanBytes2(m.memoryWeights-m.memoryLayerOutput),
// memory of non-repeating layers // memory of non-repeating layers
"nonrepeating", format.HumanBytes2(m.memoryLayerOutput), "nonrepeating", format.HumanBytes2(m.memoryLayerOutput),
), ),

View File

@@ -1,5 +1,4 @@
#include <string.h> #include <string.h>
#include <inttypes.h>
#include "ollama-debug.h" #include "ollama-debug.h"
@@ -25,7 +24,7 @@ static void print_tensor(const void *tensor, void (*cb)(const void *, int),
fprintf(stderr, "["); fprintf(stderr, "[");
for (int i = 0; i < dims[0]; i++) { for (int i = 0; i < dims[0]; i++) {
if (i >= nitems && i < dims[0] - nitems) { if (i >= nitems && i < dims[0] - nitems) {
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems); fprintf(stderr, "... (%lld more), ", dims[0] - 2 * nitems);
int skip = dims[0] - 2 * nitems; int skip = dims[0] - 2 * nitems;
if (ndims > 1) { if (ndims > 1) {
stride += mul(dims + 1, ndims - 1) * skip; stride += mul(dims + 1, ndims - 1) * skip;
@@ -68,7 +67,7 @@ static void print_tensor_i32(const void *tensor, int i) {
} }
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) { static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name, fprintf(stderr, "%s%s %s (%s): [%lld %lld %lld %lld]\n", prefix, tensor->name,
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0], ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
tensor->ne[1], tensor->ne[2], tensor->ne[3]); tensor->ne[1], tensor->ne[2], tensor->ne[3]);

View File

@@ -22,8 +22,6 @@ import (
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error) Forward(ml.Context, input.Options) (ml.Tensor, error)

View File

@@ -84,10 +84,6 @@ func New(c ml.Config) (model.Model, error) {
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData)) image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -15,6 +15,7 @@ type TextOptions struct {
attnKeyLen, attnValLen int attnKeyLen, attnValLen int
eps, ropeScale float32 eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32 ropeLocalBase, ropeGlobalBase float32
finalLogitSoftcap float32
largeModelScaling bool largeModelScaling bool
} }
@@ -56,15 +57,16 @@ func newTextModel(c ml.Config) *TextModel {
), ),
Layers: make([]TextLayer, numBlocks), Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{ TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)), attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)), attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0), ropeScale: c.Float("rope.freq_scale", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
}, },
} }
@@ -243,5 +245,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState) hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
} }

View File

@@ -63,10 +63,6 @@ func New(c ml.Config) (model.Model, error) {
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData)) image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -116,9 +116,19 @@ func (i *Instance) Readline() (string, error) {
switch r { switch r {
case KeyUp: case KeyUp:
i.historyPrev(buf, &currentLineBuf) if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
case KeyDown: case KeyDown:
i.historyNext(buf, &currentLineBuf) if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(currentLineBuf)
}
}
case KeyLeft: case KeyLeft:
buf.MoveLeft() buf.MoveLeft()
case KeyRight: case KeyRight:
@@ -175,10 +185,6 @@ func (i *Instance) Readline() (string, error) {
esc = true esc = true
case CharInterrupt: case CharInterrupt:
return "", ErrInterrupt return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
case CharNext:
i.historyNext(buf, &currentLineBuf)
case CharLineStart: case CharLineStart:
buf.MoveToStart() buf.MoveToStart()
case CharLineEnd: case CharLineEnd:
@@ -240,24 +246,6 @@ func (i *Instance) HistoryDisable() {
i.History.Enabled = false i.History.Enabled = false
} }
func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
*currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
}
func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(*currentLineBuf)
}
}
}
func NewTerminal() (*Terminal, error) { func NewTerminal() (*Terminal, error) {
fd := os.Stdin.Fd() fd := os.Stdin.Fd()
termios, err := SetRawMode(fd) termios, err := SetRawMode(fd)

View File

@@ -691,6 +691,65 @@ type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"` Embedding []float32 `json:"embedding"`
} }
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct { type HealthResponse struct {
Status string `json:"status"` Status string `json:"status"`
Progress float32 `json:"progress"` Progress float32 `json:"progress"`
@@ -868,13 +927,9 @@ func Execute(args []string) error {
defer listener.Close() defer listener.Close()
mux := http.NewServeMux() mux := http.NewServeMux()
// TODO: support embeddings mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/completion", server.completion)
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) mux.HandleFunc("/health", server.health)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)
httpServer := http.Server{ httpServer := http.Server{
Handler: mux, Handler: mux,

View File

@@ -1,11 +1,10 @@
package sample package sample
import ( import (
"errors"
"math" "math"
"math/rand/v2" "math/rand"
"slices"
"sync" "sync"
"time"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
) )
@@ -84,56 +83,59 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return greedy(tokens), nil return greedy(tokens), nil
} }
// topK also sorts the tokens in descending order of logits if s.topK > 0 {
tokens = topK(tokens, s.topK) tokens = topK(tokens, s.topK)
} else {
tokens = temperature(tokens, s.temperature) sortLogits(tokens)
tokens = softmax(tokens) }
tokens = topP(tokens, s.topP) tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP) tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling // token logit values are updated to probabilities
// or topP, topK values etc should be such that temperature(tokens, s.temperature)
// there are always tokens to sample from softmax(tokens)
if len(tokens) == 0 { return tokens[dist(tokens, s.rng.Int63())], nil
return token{}, errors.New("no tokens to sample from")
}
var r float32 // // TODO: this should fall back to greedy sampling
if s.rng != nil { // // or topP, topK values etc should be such that
r = s.rng.Float32() // // there are always tokens to sample from
} else { // if len(tokens) == 0 {
r = rand.Float32() // return token{}, errors.New("no tokens to sample from")
} // }
// Calculate cumulative sum of probabilities // var r float32
var sum float32 // if s.rng != nil {
for i := range tokens { // r = s.rng.Float32()
sum += tokens[i].value // } else {
tokens[i].value = sum // r = rand.Float32()
} // }
r *= tokens[len(tokens)-1].value
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int { // // Calculate cumulative sum of probabilities
if token.value < target { // var sum float32
return -1 // for i := range tokens {
} // sum += tokens[i].value
return 1 // tokens[i].value = sum
}) // }
// r *= tokens[len(tokens)-1].value
return tokens[idx], nil // idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
// if token.value < target {
// return -1
// }
// return 1
// })
// return tokens[idx], nil
} }
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
var rng *rand.Rand var rng *rand.Rand
if seed != -1 { if seed != -1 {
// PCG requires two parameters: sequence and stream rng = rand.New(rand.NewSource(int64(seed)))
// Use original seed for sequence } else {
sequence := uint64(seed) rng = rand.New(rand.NewSource(time.Now().UnixNano()))
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
} }
if temperature < 0.0 { if temperature < 0.0 {
temperature = 0.0 temperature = 0.0

1
sample/testdata/logits.bin vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -3,6 +3,7 @@ package sample
import ( import (
"container/heap" "container/heap"
"math" "math"
"math/rand"
"slices" "slices"
) )
@@ -10,7 +11,7 @@ import (
type tokenHeap []token type tokenHeap []token
func (h tokenHeap) Len() int { return len(h) } func (h tokenHeap) Len() int { return len(h) }
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } // Use < for min-heap to track largest elements
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *tokenHeap) Push(x any) { func (h *tokenHeap) Push(x any) {
@@ -25,54 +26,10 @@ func (h *tokenHeap) Pop() any {
return x return x
} }
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return ts
}
// softmax applies normalization to the logits
func softmax(ts []token) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Compute exp(x - max)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
sum += ts[i].value
}
// exp(x - max) / sum(exp(x - max))
for i := range ts {
ts[i].value /= sum
}
return ts
}
// topK limits the number of tokens considered to the k highest logits // topK limits the number of tokens considered to the k highest logits
func topK(ts []token, k int) []token { func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 { if k >= len(ts) {
slices.SortFunc(ts, func(a, b token) int { sortLogits(ts)
switch {
case a.value < b.value:
return 1
case a.value > b.value:
return -1
default:
return 0
}
})
return ts return ts
} }
@@ -90,7 +47,7 @@ func topK(ts []token, k int) []token {
} }
// Convert heap to sorted slice in descending order // Convert heap to sorted slice in descending order
result := make([]token, len(h)) result := make([]token, k)
for i := k - 1; i >= 0; i-- { for i := k - 1; i >= 0; i-- {
result[i] = heap.Pop(&h).(token) result[i] = heap.Pop(&h).(token)
} }
@@ -143,3 +100,134 @@ func minP(ts []token, p float32) []token {
ts = validTokens ts = validTokens
return ts return ts
} }
// partialSortLogits uses quickselect to efficiently find and sort the top n tokens
func partialSortLogits(ts []token, n int) []token {
if n >= len(ts) {
n = len(ts)
}
left, right := 0, len(ts)-1
target := n - 1
// Quickselect algorithm to partition array around pivot
for left < right {
// Choose middle element as pivot and move it to the end
pivot := left + (right-left)/2
ts[pivot], ts[right] = ts[right], ts[pivot]
// storeIndex tracks where to put next element greater than pivot
storeIndex := left
pivotValue := ts[right].value
// Partition array into elements >= pivot and < pivot
// Elements >= pivot go to the left side
for i := left; i < right; i++ {
if ts[i].value >= pivotValue {
ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
storeIndex++
}
}
// Move pivot to its final position
ts[right], ts[storeIndex] = ts[storeIndex], ts[right]
// If pivot is at target position, we're done
// Otherwise recursively partition the half containing target
if storeIndex == target {
break
} else if storeIndex < target {
left = storeIndex + 1 // Target is in right half
} else {
right = storeIndex - 1 // Target is in left half
}
}
// Sort just the top n elements in descending order
slices.SortFunc(ts[:n], func(a, b token) int {
if a.value > b.value {
return -1
}
if a.value < b.value {
return 1
}
return 0
})
return ts[:n]
}
// sortLogits uses partialSortLogits to efficiently sort tokens
// It sorts approximately sqrt(len(tokens)) elements which balances
// between having enough tokens for sampling while avoiding full sort
func sortLogits(ts []token) {
// Use sqrt of token length as a heuristic for partial sort size
// This provides a good balance between performance and having enough tokens
n := int(math.Sqrt(float64(len(ts)))) + 1
// Ensure we have at least 100 tokens and at most 1000
switch {
case n < 100:
n = 100
case n > 1000:
n = 1000
}
partialSortLogits(ts, n)
}
func temperature(ts []token, temp float32) {
for i := range ts {
ts[i].value /= temp
}
}
func softmax(ts []token) {
if len(ts) == 0 {
return
}
// Find max logit for numerical stability
maxLogit := ts[0].value
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Compute exp(logit - maxLogit) and sum them
var sumExp float32
for i, t := range ts {
expVal := float32(math.Exp(float64(t.value - maxLogit)))
ts[i].value = expVal
sumExp += expVal
}
// Normalize probabilities
for i := range ts {
ts[i].value /= sumExp
}
}
// applyDist selects a token based on probabilities and seed
func dist(ts []token, seed int64) int {
rng := rand.New(rand.NewSource(seed))
cdf := make([]float32, len(ts))
var cumSum float32
for i, t := range ts {
cumSum += t.value
cdf[i] = cumSum
}
r := rng.Float32() * cumSum
// Select token based on CDF
for i, probSum := range cdf {
if r < probSum {
return i
}
}
return len(ts) - 1
}

View File

@@ -1,8 +1,13 @@
package sample package sample
import ( import (
"encoding/binary"
"errors"
"math" "math"
"math/rand/v2" "math/rand/v2"
"os"
"path/filepath"
"runtime"
"testing" "testing"
) )
@@ -32,90 +37,34 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
} }
} }
func TestTemperature(t *testing.T) { func TestTemperatureAndSoftmax(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0} input := []float32{1, 4, -2, 0}
got := temperature(toTokens(input), 0.5) got := temperature(toTokens(input), 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, got)
got = temperature(toTokens(input), 1.0) // Check probabilities sum to 1
want = []float32{1.0, 4.0, -2.0, 0.0} var sum float32
compareLogits(t, "temperature(1)", want, got) for _, token := range got {
sum += token.value
got = temperature(toTokens(input), 0.0) }
want = []float32{1e7, 4e7, -2e7, 0.0} if math.Abs(float64(sum-1.0)) > 1e-6 {
compareLogits(t, "temperature(0)", want, got) t.Errorf("probabilities don't sum to 1: got %f", sum)
}
func TestSoftmax(t *testing.T) {
tests := []struct {
name string
input []float32
expected []float32
}{
{
name: "correctness softmax",
input: []float32{1, -2, 3, 0},
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
},
{
name: "normal distribution",
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
},
{
name: "single value",
input: []float32{1.0},
},
{
name: "identical values",
input: []float32{0.9, 0.9, 0.9},
},
{
name: "large values",
input: []float32{1000.0, 2000.0, 3000.0},
},
{
name: "small values",
input: []float32{1e-6, 2e-6, 3e-6},
},
{
name: "negative values",
input: []float32{-1.0, -2.0, -3.0},
},
{
name: "mixed values",
input: []float32{-100.0, 0.0, 100.0},
},
} }
for _, tt := range tests { got = temperature(toTokens(input), 1)
t.Run(tt.name, func(t *testing.T) { // Check probabilities sum to 1
got := softmax(toTokens(tt.input)) sum = 0.0
for _, token := range got {
if tt.expected != nil { sum += token.value
compareLogits(t, tt.name, tt.expected, got) }
return if math.Abs(float64(sum-1.0)) > 1e-6 {
} t.Errorf("probabilities don't sum to 1: got %f", sum)
// Check probabilities sum to 1
var sum float32
for _, token := range got {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
}
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
})
} }
} }
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
// Test k=5 // Test k=3
got := topK(toTokens(input), 5) got := topK(toTokens(input), 5)
if len(got) != 5 { if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got)) t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
@@ -128,24 +77,6 @@ func TestTopK(t *testing.T) {
if len(got) != len(input) { if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got)) t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
} }
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
@@ -153,8 +84,8 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax to get probabilities // First apply temperature and softmax to get probabilities
tokens = softmax(tokens) tokens = temperature(tokens, 1)
tokens = topK(tokens, 20) sortLogits(tokens)
// Then apply topP // Then apply topP
got := topP(tokens, 0.95) got := topP(tokens, 0.95)
@@ -171,7 +102,7 @@ func TestMinP(t *testing.T) {
tokens := toTokens(input) tokens := toTokens(input)
// First apply temperature and softmax // First apply temperature and softmax
tokens = softmax(tokens) tokens = temperature(tokens, 1)
// Then apply minP // Then apply minP
got := minP(tokens, 0.2) got := minP(tokens, 0.2)
@@ -186,7 +117,7 @@ func TestSortLogits(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input) tokens := toTokens(input)
tokens = topK(tokens, 20) sortLogits(tokens)
for i := 1; i < len(tokens); i++ { for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value { if tokens[i].value > tokens[i-1].value {
@@ -199,6 +130,98 @@ func TestSortLogits(t *testing.T) {
compareLogits(t, "sortLogits", want, tokens) compareLogits(t, "sortLogits", want, tokens)
} }
// TestSortLogitsWithRealData tests sorting behavior using real model logit distributions
func TestSortLogitsWithRealData(t *testing.T) {
// This will be populated from testdata/logits.bin
// Format: 32-bit float array in binary format
logits, err := loadTestLogits(t)
if err != nil {
t.Skipf("Skipping real logit test: %v", err)
return
}
tokens := toTokens(logits)
sortLogits(tokens)
// Calculate n for verification
n := int(math.Sqrt(float64(len(tokens)))) + 1
if n > 1000 {
n = 1000
} else if n < 100 {
n = 100
}
t.Logf("Testing with %d tokens, partial sorting top %d", len(tokens), n)
// Only verify the top n elements are sorted (which is what we guarantee)
// This is much faster than checking the entire array
topN := tokens[:n]
for i := 1; i < len(topN); i++ {
if topN[i].value > topN[i-1].value {
t.Fatalf("top %d tokens not properly sorted at index %d: %.15f > %.15f",
n, i, topN[i].value, topN[i-1].value)
}
}
// Verify we didn't lose any high value tokens by checking that
// all tokens after position n are <= the nth token
// Do this in chunks to avoid timeouts on large arrays
nthValue := tokens[n-1].value
const chunkSize = 1000
for start := n; start < len(tokens); start += chunkSize {
end := min(start+chunkSize, len(tokens))
for i := start; i < end; i++ {
if tokens[i].value > nthValue {
t.Fatalf("found higher value token after position %d: tokens[%d].value = %.15f > %.15f",
n, i, tokens[i].value, nthValue)
}
}
}
}
// loadTestLogits loads logit test data from testdata/logits.bin
func loadTestLogits(t *testing.T) ([]float32, error) {
t.Helper()
_, currFile, _, ok := runtime.Caller(0)
if !ok {
return nil, errors.New("could not determine test file path")
}
testDataPath := filepath.Join(filepath.Dir(currFile), "testdata", "logits.bin")
file, err := os.Open(testDataPath)
if err != nil {
return nil, err
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return nil, err
}
numFloats := stat.Size() / 4 // each float32 is 4 bytes
if numFloats*4 != stat.Size() {
return nil, errors.New("logits.bin has invalid size: not a multiple of 4 bytes")
}
logits := make([]float32, numFloats)
for i := range logits {
var val uint32
if err := binary.Read(file, binary.LittleEndian, &val); err != nil {
return nil, err
}
logits[i] = math.Float32frombits(val)
}
if len(logits) == 0 {
return nil, errors.New("logits.bin is empty")
}
return logits, nil
}
func BenchmarkTransforms(b *testing.B) { func BenchmarkTransforms(b *testing.B) {
// Generate random logits // Generate random logits
tokens := make([]token, 1<<16) tokens := make([]token, 1<<16)
@@ -219,14 +242,6 @@ func BenchmarkTransforms(b *testing.B) {
} }
}) })
b.Run("Softmax", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
softmax(tokensCopy)
}
})
b.Run("TopK", func(b *testing.B) { b.Run("TopK", func(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
@@ -255,7 +270,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for b.Loop() {
copy(tokensCopy, tokens) copy(tokensCopy, tokens)
topK(tokensCopy, 200000) sortLogits(tokensCopy)
} }
}) })
} }

View File

@@ -435,7 +435,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
kvData, _, err := getModelData(m.ModelPath, false) kvData, err := getKVData(m.ModelPath, false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -483,7 +483,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
return return
} }
@@ -544,7 +545,8 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
embedding, err := r.Embedding(c.Request.Context(), req.Prompt) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
return return
} }
@@ -848,23 +850,16 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String()) fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String() resp.Modelfile = sb.String()
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) kvData, err := getKVData(m.ModelPath, req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
} }
delete(kvData, "general.name") delete(kvData, "general.name")
delete(kvData, "tokenizer.chat_template") delete(kvData, "tokenizer.chat_template")
resp.ModelInfo = kvData resp.ModelInfo = kvData
tensorData := make([]api.Tensor, len(tensors.Items()))
for cnt, t := range tensors.Items() {
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
}
resp.Tensors = tensorData
if len(m.ProjectorPaths) > 0 { if len(m.ProjectorPaths) > 0 {
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose) projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -874,17 +869,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil return resp, nil
} }
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) { func getKVData(digest string, verbose bool) (ggml.KV, error) {
maxArraySize := 0 maxArraySize := 0
if verbose { if verbose {
maxArraySize = -1 maxArraySize = -1
} }
data, err := llm.LoadModel(digest, maxArraySize) kvData, err := llm.LoadModel(digest, maxArraySize)
if err != nil { if err != nil {
return nil, ggml.Tensors{}, err return nil, err
} }
kv := data.KV() kv := kvData.KV()
if !verbose { if !verbose {
for k := range kv { for k := range kv {
@@ -894,7 +889,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
} }
} }
return kv, data.Tensors(), nil return kv, nil
} }
func (s *Server) ListHandler(c *gin.Context) { func (s *Server) ListHandler(c *gin.Context) {