Compare commits

...

21 Commits

Author SHA1 Message Date
ParthSareen
23e8ac9428 wip? 2025-05-07 19:00:44 -07:00
ParthSareen
611d3a17ed server: add python tool parsing logic 2025-05-02 16:23:54 -07:00
Michael Yang
5cfc1c39f3 model: fix build (#10416) 2025-04-25 19:24:48 -07:00
Michael Yang
f0ad49ea17 memory 2025-04-25 16:59:20 -07:00
Michael Yang
7ba9fa9c7d fixes for maverick 2025-04-25 16:59:20 -07:00
Michael Yang
8bf11b84c1 chunked attention 2025-04-25 16:59:20 -07:00
Michael Yang
470af8ab89 connect vision to text 2025-04-25 16:59:20 -07:00
Michael Yang
178761aef3 image processing
Co-authored-by: Patrick Devine <patrick@infrahq.com>
2025-04-25 16:59:20 -07:00
Michael Yang
f0c66e6dea llama4 2025-04-25 16:59:20 -07:00
Michael Yang
54055a6dae fix test 2025-04-25 16:59:01 -07:00
Michael Yang
340448d2d1 explicitly decode maxarraysize 1024 2025-04-25 16:59:01 -07:00
Michael Yang
ced7d0e53d fix parameter count 2025-04-25 16:59:01 -07:00
Michael Yang
a0dba0f8ae default slice values 2025-04-25 16:59:01 -07:00
Michael Yang
5e20b170a7 update comment 2025-04-25 16:59:01 -07:00
Michael Yang
d26c18e25c fix token type 2025-04-25 16:59:01 -07:00
Michael Yang
8d376acc9b zero means zero
use a default of 1024 when asking for zero is confusing since most calls
seem to assume 0 means do not ready any data
2025-04-25 16:59:01 -07:00
Michael Yang
dc1e81f027 convert: use -1 for read all 2025-04-25 16:59:01 -07:00
Michael Yang
5d0279164c generic ggml.array 2025-04-25 16:59:01 -07:00
Michael Yang
214a7678ea fix superfluous call to WriteHeader
the first call to http.ResponseWriter.Write implicitly calls WriteHeader
with http.StatusOK if it hasn't already been called. once WriteHeader
has been called, subsequent calls has no effect. Write is called when
JSON encoding progressUpdateJSON{}. calls to
http.ResponseWriter.WriteHeader after the first encode is useless and
produces a warning:

http: superfluous response.WriteHeader call from github.com/ollama/ollama/server/internal/registry.(*statusCodeRecorder).WriteHeader (server.go:77)
2025-04-25 16:58:49 -07:00
Michael Yang
4892872c18 convert: change to colmajor 2025-04-25 15:27:39 -07:00
Michael Yang
0b9198bf47 ci: silence deprecated gpu targets warning 2025-04-25 13:37:54 -07:00
41 changed files with 2237 additions and 204 deletions

View File

@@ -21,14 +21,16 @@
"name": "CUDA 11", "name": "CUDA 11",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86" "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
} }
}, },
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120" "CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
} }
}, },
{ {

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
@@ -84,14 +85,6 @@ func (ModelParameters) specialTokenTypes() []string {
} }
} }
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface { type ModelConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV KV(*Tokenizer) ggml.KV
@@ -103,8 +96,6 @@ type ModelConverter interface {
// specialTokenTypes returns any special token types the model uses // specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
} }
type moreParser interface { type moreParser interface {
@@ -119,8 +110,6 @@ type AdapterConverter interface {
// Replacements returns a list of string pairs to replace in tensor names. // Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string Replacements() []string
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
} }
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error { func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
@@ -158,7 +147,7 @@ func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
return err return err
} }
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts)) return writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
} }
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
@@ -184,6 +173,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM": case "LlamaForCausalLM":
conv = &llamaModel{} conv = &llamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration": case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{} conv = &mistral3Model{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
@@ -248,5 +239,13 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
return err return err
} }
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts)) return writeFile(ws, conv.KV(t), conv.Tensors(ts))
}
func writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(ws, kv, ts)
} }

View File

@@ -42,6 +42,8 @@ type llamaModel struct {
LayerNormEpsilon float32 `json:"layer_norm_epsilon"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
skipRepack bool
} }
var _ ModelConverter = (*llamaModel)(nil) var _ ModelConverter = (*llamaModel)(nil)
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
} }
if p.HeadDim > 0 {
kv["llama.attention.head_dim"] = p.HeadDim
}
if p.RopeTheta > 0 { if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta kv["llama.rope.freq_base"] = p.RopeTheta
} }
@@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
} }
for _, t := range ts { for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
strings.HasSuffix(t.Name(), "attn_k.weight") { if !p.skipRepack {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
}
} }
out = append(out, ggml.Tensor{ out = append(out, ggml.Tensor{

169
convert/convert_llama4.go Normal file
View File

@@ -0,0 +1,169 @@
package convert
import (
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type llama4Model struct {
ModelParameters
TextModel struct {
llamaModel
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumLocalExperts uint32 `json:"num_local_experts"`
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
UseQKNorm bool `json:"use_qk_norm"`
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
AttentionChunkSize uint32 `json:"attention_chunk_size"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
NormEpsilon float32 `json:"norm_eps"`
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
} `json:"vision_config"`
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"
for k, v := range p.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
}
}
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
return kv
}
// Replacements implements ModelConverter.
func (p *llama4Model) Replacements() []string {
return append(
p.TextModel.Replacements(),
"language_model.", "",
"vision_model", "v",
"multi_modal_projector", "mm",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.", "ffn_",
"shared_expert.down_proj", "down_shexp",
"shared_expert.gate_proj", "gate_shexp",
"shared_expert.up_proj", "up_shexp",
"experts.down_proj", "down_exps.weight",
"experts.gate_up_proj", "gate_up_exps.weight",
"router", "gate_inp",
"patch_embedding.linear", "patch_embedding",
)
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim := int(t.Shape()[2]) / 2
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
WriterTo: tt,
})
}
} else if strings.Contains(t.Name(), "ffn_down_exps") {
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
} else {
textTensors = append(textTensors, t)
}
}
p.TextModel.skipRepack = true
out = append(out, p.TextModel.Tensors(textTensors)...)
return out
}
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
return func(name string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1); err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -11,7 +11,6 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
@@ -48,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
} }
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { r.Close() })
m, _, err := ggml.Decode(r, math.MaxInt) m, _, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -332,7 +331,7 @@ func TestConvertAdapter(t *testing.T) {
} }
defer r.Close() defer r.Close()
m, _, err := ggml.Decode(r, math.MaxInt) m, _, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -11,14 +11,15 @@ type Tensor interface {
Name() string Name() string
Shape() []uint64 Shape() []uint64
Kind() uint32 Kind() uint32
SetRepacker(repacker) SetRepacker(Repacker)
WriteTo(io.Writer) (int64, error) WriteTo(io.Writer) (int64, error)
Clone() Tensor
} }
type tensorBase struct { type tensorBase struct {
name string name string
shape []uint64 shape []uint64
repacker repacker Repacker
} }
func (t tensorBase) Name() string { func (t tensorBase) Name() string {
@@ -36,7 +37,8 @@ const (
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" { t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" {
// these tensors are always F32 // these tensors are always F32
return 0 return 0
} }
@@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 {
} }
} }
func (t *tensorBase) SetRepacker(fn repacker) { func (t *tensorBase) SetRepacker(fn Repacker) {
t.repacker = fn t.repacker = fn
} }
type repacker func(string, []float32, []uint64) ([]float32, error) type Repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) { func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct { patterns := []struct {

View File

@@ -94,6 +94,21 @@ type safetensor struct {
*tensorBase *tensorBase
} }
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
shape: slices.Clone(st.shape),
},
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) { func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path) f, err := st.fs.Open(st.path)
if err != nil { if err != nil {

View File

@@ -43,6 +43,17 @@ type torch struct {
*tensorBase *tensorBase
} }
func (t torch) Clone() Tensor {
return torch{
storage: t.storage,
tensorBase: &tensorBase{
name: t.name,
shape: t.shape,
repacker: t.repacker,
},
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) { func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil return 0, nil
} }

View File

@@ -8,6 +8,6 @@ type Config interface {
Bool(string, ...bool) bool Bool(string, ...bool) bool
Strings(string, ...[]string) []string Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32 Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32 Floats(string, ...[]float32) []float32
} }

View File

@@ -33,7 +33,7 @@ func (kv KV) Kind() string {
} }
func (kv KV) ParameterCount() uint64 { func (kv KV) ParameterCount() uint64 {
return keyValue[uint64](kv, "general.parameter_count") return keyValue(kv, "general.parameter_count", uint64(0))
} }
func (kv KV) FileType() fileType { func (kv KV) FileType() fileType {
@@ -105,42 +105,42 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
} }
func (kv KV) Strings(key string, defaultValue ...[]string) []string { func (kv KV) Strings(key string, defaultValue ...[]string) []string {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
s := make([]string, r.size) }
for i := range r.size {
s[i] = r.values[i].(string)
}
return s func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
} }
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{}) return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
} }
func (kv KV) OllamaEngineRequired() bool { func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{ return slices.Contains([]string{
"gemma3", "gemma3",
"mistral3", "mistral3",
"llama4",
}, kv.Architecture()) }, kv.Architecture())
} }
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { type valueTypes interface {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
} }
@@ -375,13 +375,8 @@ func DetectContentType(b []byte) string {
// Decode decodes a GGML model from the given reader. // Decode decodes a GGML model from the given reader.
// //
// It collects array values for arrays with a size less than or equal to // It collects array values for arrays with a size less than or equal to
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If // maxArraySize. If the maxArraySize is negative, all arrays are collected.
// the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10) rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32 var magic uint32
@@ -420,7 +415,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV() headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size) vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCount() embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK() embeddingHeadsK := f.KV().EmbeddingHeadCountK()
@@ -435,7 +430,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
} }
switch f.KV().Architecture() { switch f.KV().Architecture() {
case "llama": case "llama", "llama4":
fullOffload = max( fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)), 4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
@@ -449,7 +444,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok { if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b // mixtral 8x22b
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32)) ff := uint64(f.KV().Uint("feed_forward_length"))
partialOffload = max( partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV), 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch), 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
@@ -466,9 +461,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama": case "mllama":
var visionTokens, tiles uint64 = 1601, 4 var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers") crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
for i := range kv { for i := range kv {
if slices.Contains(crossAttentionLayers, uint32(i)) { if slices.Contains(crossAttentionLayers, int32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32) 4 * // sizeof(float32)
visionTokens * visionTokens *
@@ -645,6 +640,9 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels + graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize + embeddingLength*patchSize +
numPatches*numPatches*headCount) numPatches*numPatches*headCount)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
} }
return weights, graphSize return weights, graphSize

View File

@@ -2,6 +2,7 @@ package ggml
import ( import (
"maps" "maps"
"math"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -210,3 +211,61 @@ func TestTensorTypes(t *testing.T) {
}) })
} }
} }
func TestKeyValue(t *testing.T) {
kv := KV{
"general.architecture": "test",
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
}
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}

View File

@@ -36,10 +36,6 @@ type containerGGUF struct {
maxArraySize int maxArraySize int
} }
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
return "gguf" return "gguf"
} }
@@ -295,6 +291,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil return b.String(), nil
} }
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error { func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8] buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
@@ -352,78 +365,44 @@ func writeGGUFString(w io.Writer, s string) error {
return err return err
} }
type array struct { func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
size int for i := range a.size {
values []any
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
if a.values != nil { if a.values != nil {
e, err := readGGUFString(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e a.values[i] = e
} else {
discardGGUFString(llm, r)
} }
} }
return a, nil return a, nil
} }
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) { type array[T any] struct {
if llm.Version == 1 { // size is the actual size of the array
return readGGUFV1Array(llm, r) size int
}
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
}
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r) t, err := readGGUF[uint32](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -434,45 +413,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
return nil, err return nil, err
} }
a := &array{size: int(n)} switch t {
if llm.canCollectArray(int(n)) { case ggufTypeUint8:
a.values = make([]any, int(n)) a := newArray[uint8](int(n), llm.maxArraySize)
} return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
for i := range n { a := newArray[int8](int(n), llm.maxArraySize)
var e any return readGGUFArrayData(llm, r, a)
switch t { case ggufTypeUint16:
case ggufTypeUint8: a := newArray[uint16](int(n), llm.maxArraySize)
e, err = readGGUF[uint8](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeInt8: case ggufTypeInt16:
e, err = readGGUF[int8](llm, r) a := newArray[int16](int(n), llm.maxArraySize)
case ggufTypeUint16: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[uint16](llm, r) case ggufTypeUint32:
case ggufTypeInt16: a := newArray[uint32](int(n), llm.maxArraySize)
e, err = readGGUF[int16](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeUint32: case ggufTypeInt32:
e, err = readGGUF[uint32](llm, r) a := newArray[int32](int(n), llm.maxArraySize)
case ggufTypeInt32: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[int32](llm, r) case ggufTypeUint64:
case ggufTypeUint64: a := newArray[uint64](int(n), llm.maxArraySize)
e, err = readGGUF[uint64](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeInt64: case ggufTypeInt64:
e, err = readGGUF[int64](llm, r) a := newArray[int64](int(n), llm.maxArraySize)
case ggufTypeFloat32: return readGGUFArrayData(llm, r, a)
e, err = readGGUF[float32](llm, r) case ggufTypeFloat32:
case ggufTypeFloat64: a := newArray[float32](int(n), llm.maxArraySize)
e, err = readGGUF[float64](llm, r) return readGGUFArrayData(llm, r, a)
case ggufTypeBool: case ggufTypeFloat64:
e, err = readGGUF[bool](llm, r) a := newArray[float64](int(n), llm.maxArraySize)
case ggufTypeString: return readGGUFArrayData(llm, r, a)
if a.values != nil { case ggufTypeBool:
e, err = readGGUFString(llm, r) a := newArray[bool](int(n), llm.maxArraySize)
} else { return readGGUFArrayData(llm, r, a)
err = discardGGUFString(llm, r) case ggufTypeString:
} a := newArray[string](int(n), llm.maxArraySize)
default: if llm.Version == 1 {
return nil, fmt.Errorf("invalid array type: %d", t) return readGGUFV1StringsData(llm, r, a)
} }
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
for i := range a.size {
e, err := readGGUF[T](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -627,8 +616,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
return err return err
} }
for i := range len(t.Shape) { for _, n := range t.Shape {
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil { if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
return err return err
} }
} }

View File

@@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
windowSize int32 windowSize int32
chunkSize int32
opts CausalOptions opts CausalOptions
@@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
} }
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
@@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize { c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }

View File

@@ -86,6 +86,64 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
} }
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
}
func TestSequences(t *testing.T) { func TestSequences(t *testing.T) {
backend := &testBackend{} backend := &testBackend{}
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
@@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context.Forward(out, mask).Compute(out, mask) context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { if !slices.Equal(out.Floats(), test.expected) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
} }
}) })
} }

View File

@@ -414,7 +414,7 @@ func projectorMemoryRequirements(filename string) (weights, graphSize uint64) {
} }
defer file.Close() defer file.Close()
ggml, _, err := ggml.Decode(file, 0) ggml, _, err := ggml.Decode(file, 1024)
if err != nil { if err != nil {
return 0, 0 return 0, 0
} }

View File

@@ -133,6 +133,7 @@ type Tensor interface {
Mul(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor
MulmatID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor Softmax(ctx Context) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
@@ -150,6 +151,7 @@ type Tensor interface {
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor
SILU(ctx Context) Tensor SILU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
@@ -168,6 +170,8 @@ type Tensor interface {
Rows(ctx Context, t2 Tensor) Tensor Rows(ctx Context, t2 Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor Duplicate(ctx Context) Tensor
TopK(ctx Context, k int) Tensor
} }
// ScaledDotProductAttention implements a fused attention // ScaledDotProductAttention implements a fused attention

View File

@@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
} }
} }
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if b != nil { if w != nil {
tt = tt.Add(ctx, b) tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
if b != nil {
tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
}
} }
return tt return &Tensor{b: t.b, t: tt}
} }
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
}
return &Tensor{b: t.b, t: tt}
} }
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
} }
} }
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 { if len(shape) != 4 {
panic("expected 4 dimensions") panic("expected 4 dimensions")
@@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
t: C.ggml_dup(ctx.(*Context).ctx, t.t), t: C.ggml_dup(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
}

View File

@@ -42,7 +42,7 @@ func New(c fs.Config) (model.Model, error) {
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
}, },

View File

@@ -59,7 +59,7 @@ func New(c fs.Config) (model.Model, error) {
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(1), EOS: int32(1),

View File

@@ -49,7 +49,7 @@ func newTextModel(c fs.Config) *TextModel {
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"), Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
}, },

View File

@@ -41,7 +41,7 @@ func New(c fs.Config) (model.Model, error) {
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -0,0 +1,189 @@
package llama4
import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v,vision"`
*Projector `gguf:"mm"`
*TextModel
}
type Projector struct {
Linear1 *nn.Linear `gguf:"linear_1"`
}
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
return p.Linear1.Forward(ctx, visionOutputs)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewChunkedAttentionCache(int32(c.Uint("attention.chunk_size", 8192)), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
pixelsLocal, pixelsGlobal, size, err := m.ProcessImage(img)
if err != nil {
return nil, err
}
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
if err != nil {
return nil, err
}
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, size.Y/ratioH, m.numChannels, ratioH*ratioW)
pixelValues := tilesLocal
if len(pixelsGlobal) > 0 {
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
if err != nil {
return nil, err
}
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
}
type chunks struct {
*Model
ml.Tensor
aspectRatio image.Point
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
continue
}
t := inp.Multimodal.(*chunks)
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
var offset int
patchesPerChunk := t.Dim(1)
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
for range t.aspectRatio.Y {
for x := range t.aspectRatio.X {
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if x < t.aspectRatio.X-1 {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
offset += patchesPerChunk
}
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
}
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
result = append(result, imageInputs...)
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("llama4", New)
}

View File

@@ -0,0 +1,259 @@
package llama4
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_factors"`
}
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
}
if opts.useQKNorm {
query = query.RMSNorm(ctx, nil, opts.eps)
key = key.RMSNorm(ctx, nil, opts.eps)
}
if attentionScales != nil && !useRope {
query = query.Mul(ctx, attentionScales)
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
experts := routerLogits.TopK(ctx, opts.numExpertsUsed)
scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
}
return nextStates
}
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts
SharedExpert *TextSharedExpert
}
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routerLogits := moe.Router.Forward(ctx, hiddenStates)
sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts)
routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts)
return sharedStates.Add(ctx, routedStates)
}
type TextFeedForward interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor
}
type TextLayer struct {
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
Attention *TextAttention
FFNNorm *nn.LayerNorm `gguf:"ffn_norm"`
FeedForward TextFeedForward
}
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, useRope, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts)
return residual.Add(ctx, hiddenStates)
}
type TextOptions struct {
hiddenSize int
numHeads, numKVHeads, headDim int
numExperts, numExpertsUsed int
ropeDim int
ropeBase, ropeScale float32
eps float32
interleaveLayerStep int
noRopeInterval int
useQKNorm bool
attentionTemperatureTuning bool
attentionScale float64
attentionFloorScale float64
}
type TextModel struct {
Layers []TextLayer `gguf:"blk"`
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.LayerNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func newTextModel(c fs.Config) *TextModel {
layers := make([]TextLayer, c.Uint("block_count"))
interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1)
for i := range layers {
if (i+1)%int(interleaveLayerStep) == 0 {
layers[i] = TextLayer{FeedForward: &TextMOE{}}
} else {
layers[i] = TextLayer{FeedForward: &TextMLP{}}
}
}
return &TextModel{
Layers: layers,
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.head_dim", 128)),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
useQKNorm: c.Bool("use_qk_norm", true),
attentionTemperatureTuning: c.Bool("attention.temperature_tuning", true),
attentionScale: float64(c.Float("attention.scale", 0.1)),
attentionFloorScale: float64(c.Float("attention.floor_scale", 8192)),
},
}
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
var attentionScales ml.Tensor
if m.attentionTemperatureTuning {
scales := make([]float32, len(batch.Positions))
for i, p := range batch.Positions {
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
}
var err error
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
if err != nil {
panic(err)
}
}
for i, layer := range m.Layers {
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(1)
useChunkedAttention := (i+1)%m.noRopeInterval != 0
if useChunkedAttention {
wc.SetLayerType(0)
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
}

View File

@@ -0,0 +1,256 @@
package llama4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type VisionAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
// This is equivalent to the Pytorch implmentation using half rotations:
//
// cos, sin = torch.cos(freqs), torch.sin(freqs)
// cos = cos.unsqueeze(-1)
// sin = sin.unsqueeze(-1)
// t = t.reshape(*t.shape[:-1], -1, 2)
// t_out = (t * cos) + (_rotate_half(t) * sin)
// t_out = t_out.flatten(3)
//
// Which is equivalent to the Pytorch implementation using complex numbers:
//
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
// freqs_ci = freqs_ci.to(t_.device)
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
//
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
// we need to use a different approach to achieve the same result.
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
// t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut)
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2))
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2))
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2))
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates)
return hiddenStates
}
type VisionLayer struct {
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"`
*VisionAttention
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"`
*VisionMLP `gguf:"mlp"`
}
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP
residual = hiddenStates
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionAdapter struct {
FC1 *nn.Linear `gguf:"mlp.fc1"`
FC2 *nn.Linear `gguf:"mlp.fc2"`
}
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
patches := hiddenStates.Dim(1)
patchSize := int(math.Sqrt(float64(patches)))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2))
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles)
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx)
return hiddenStates
}
type VisionOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
ropeTheta float32
eps float32
pixelShuffleRatio float32
}
type PatchEmbedding struct {
*nn.Linear
}
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2))
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3))
return p.Linear.Forward(ctx, hiddenStates)
}
type VisionModel struct {
Layers []VisionLayer `gguf:"blk"`
*PatchEmbedding `gguf:"patch_embedding"`
ClassEmbedding ml.Tensor `gguf:"class_embedding"`
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"`
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"`
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"`
*VisionAdapter `gguf:"vision_adapter"`
*VisionOptions
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionLayer, c.Uint("vision.block_count")),
VisionOptions: &VisionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
ropeTheta: float32(c.Float("vision.rope.freq_base")),
eps: c.Float("vision.layer_norm_epsilon"),
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")),
},
}
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions)
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1)
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding)
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps)
cos, sin := m.rotaryEmbedding(ctx)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
}
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps)
hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0)
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions)
return hiddenStates
}
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
// which in turn mimics Python's // operator.
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T {
if b == 0 {
panic("division by zero")
}
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 {
return a / b
}
return a/b - 1
}
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
patchesPerSide := m.imageSize / m.patchSize
numPatches := patchesPerSide*patchesPerSide + 1
headDim := m.hiddenSize / m.numHeads
freqDim := headDim / 2
freqs := make([]float32, numPatches*freqDim)
for i := range numPatches - 1 {
for j := 0; j < freqDim; j += 2 {
positionX := i*freqDim/2 + j/2
positionY := (i+numPatches)*freqDim/2 + j/2
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim))
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq)
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq)
}
}
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
if err != nil {
panic(err)
}
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx)
}

View File

@@ -0,0 +1,167 @@
package llama4
import (
"cmp"
"image"
"math"
"slices"
"sort"
"golang.org/x/image/draw"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels, maxUpscalingSize int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels", 3)),
maxUpscalingSize: int(c.Uint("vision.max_upscaling_size", 448)),
}
}
func factors(n int) []int {
var result []int
seen := make(map[int]bool)
for i := 1; i <= n/2; i++ {
if n%i == 0 && !seen[i] {
result = append(result, i)
seen[i] = true
}
}
result = append(result, n)
sort.Ints(result)
return result
}
func (p ImageProcessor) supportedResolutions() []image.Point {
var resolutions []image.Point
aspectMap := make(map[float64][]image.Point)
for i := p.patchSize; i >= 1; i-- {
for _, f := range factors(i) {
x := f
y := i / f
k := float64(y) / float64(x)
aspectMap[k] = append(aspectMap[k], image.Point{x, y})
}
}
for _, v := range aspectMap {
for _, i := range v {
resolutions = append(resolutions, image.Point{i.X * p.imageSize, i.Y * p.imageSize})
}
}
return resolutions
}
func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []image.Point, resizeToMaxCanvas bool) image.Point {
w, h := img.X, img.Y
scales := make([]float64, len(possibleResolutions))
for i, res := range possibleResolutions {
scaleW := float64(res.X) / float64(w)
scaleH := float64(res.Y) / float64(h)
scale := math.Min(scaleW, scaleH)
scales[i] = scale
}
minAboveOne := func(scales []float64) (float64, bool) {
min := math.MaxFloat64
found := false
for _, s := range scales {
if s >= 1.0 && s < min {
min = s
found = true
}
}
return min, found
}
bestScale, ok := minAboveOne(scales)
if resizeToMaxCanvas || !ok {
bestScale = slices.Max(scales)
}
var bestOptions []image.Point
for i, scale := range scales {
if math.Abs(scale-bestScale) < 1e-6 {
bestOptions = append(bestOptions, possibleResolutions[i])
}
}
var chosenResolution image.Point
if len(bestOptions) > 1 {
chosenResolution = slices.MinFunc(bestOptions, func(a, b image.Point) int {
return cmp.Compare(a.X*a.Y, b.X*b.Y)
})
} else {
chosenResolution = bestOptions[0]
}
return chosenResolution
}
func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Point {
scaleW := float64(targetRes.X) / float64(imageRes.X)
scaleH := float64(targetRes.Y) / float64(imageRes.Y)
var newRes image.Point
if scaleW < scaleH {
newRes = image.Point{
targetRes.X,
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
}
} else {
newRes = image.Point{
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
targetRes.Y,
}
}
return newRes
}
func (p ImageProcessor) pad(src image.Image, outputSize image.Point) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, outputSize.X, outputSize.Y))
draw.Draw(dst, src.Bounds(), src, image.Point{}, draw.Over)
return dst
}
func (p ImageProcessor) ProcessImage(img image.Image) (pixelsLocal, pixelsGlobal []float32, targetSize image.Point, _ error) {
img = imageproc.Composite(img)
targetSize = p.bestResolution(img.Bounds().Max, p.supportedResolutions(), false)
targetSizeWithoutDistortion := targetSize
if p.maxUpscalingSize > 0 {
targetSizeWithoutDistortion = p.maxResolution(img.Bounds().Max, targetSize)
targetSizeWithoutDistortion.X = min(max(img.Bounds().Max.X, p.maxUpscalingSize), targetSize.X)
targetSizeWithoutDistortion.Y = min(max(img.Bounds().Max.Y, p.maxUpscalingSize), targetSize.Y)
}
newSizeWithoutDistortion := p.maxResolution(img.Bounds().Max, targetSizeWithoutDistortion)
padded := p.pad(imageproc.Resize(img, newSizeWithoutDistortion, imageproc.ResizeBilinear), targetSize)
pixelsLocal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
if targetSize.X/p.imageSize*targetSize.Y/p.imageSize > 1 {
padded := imageproc.Resize(img, image.Point{p.imageSize, p.imageSize}, imageproc.ResizeBilinear)
pixelsGlobal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
}
return pixelsLocal, pixelsGlobal, targetSize, nil
}

View File

@@ -0,0 +1,300 @@
package llama4
import (
"cmp"
"image"
"image/color"
"reflect"
"slices"
"testing"
gocmp "github.com/google/go-cmp/cmp"
)
func TestFactors(t *testing.T) {
tests := []struct {
name string
input int
expected []int
}{
{
name: "factors of 1",
input: 1,
expected: []int{1},
},
{
name: "factors of 2",
input: 2,
expected: []int{1, 2},
},
{
name: "factors of 6",
input: 6,
expected: []int{1, 2, 3, 6},
},
{
name: "factors of 28",
input: 28,
expected: []int{1, 2, 4, 7, 14, 28},
},
{
name: "factors of 49",
input: 49,
expected: []int{1, 7, 49},
},
{
name: "factors of 97 (prime)",
input: 97,
expected: []int{1, 97},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := factors(tt.input)
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("factors(%d) = %v; want %v", tt.input, actual, tt.expected)
}
})
}
}
func TestSupportedResolutions(t *testing.T) {
expectedResolutions := []image.Point{
{X: 3360, Y: 336},
{X: 672, Y: 2688},
{X: 336, Y: 1344},
{X: 336, Y: 4032},
{X: 1008, Y: 1344},
{X: 1344, Y: 1008},
{X: 336, Y: 1680},
{X: 1680, Y: 336},
{X: 336, Y: 5040},
{X: 4032, Y: 336},
{X: 2352, Y: 336},
{X: 2688, Y: 672},
{X: 1344, Y: 336},
{X: 5376, Y: 336},
{X: 2352, Y: 672},
{X: 672, Y: 1008},
{X: 1008, Y: 672},
{X: 336, Y: 5376},
{X: 1680, Y: 1008},
{X: 5040, Y: 336},
{X: 336, Y: 3024},
{X: 3024, Y: 336},
{X: 336, Y: 2688},
{X: 672, Y: 1344},
{X: 336, Y: 672},
{X: 336, Y: 2352},
{X: 2016, Y: 672},
{X: 1008, Y: 336},
{X: 336, Y: 3360},
{X: 336, Y: 4368},
{X: 1008, Y: 1680},
{X: 336, Y: 4704},
{X: 4704, Y: 336},
{X: 1344, Y: 672},
{X: 672, Y: 336},
{X: 2688, Y: 336},
{X: 3696, Y: 336},
{X: 2016, Y: 336},
{X: 1344, Y: 1344},
{X: 1008, Y: 1008},
{X: 672, Y: 672},
{X: 336, Y: 336},
{X: 4368, Y: 336},
{X: 672, Y: 2016},
{X: 336, Y: 1008},
{X: 336, Y: 3696},
{X: 672, Y: 1680},
{X: 1680, Y: 672},
{X: 336, Y: 2016},
{X: 672, Y: 2352},
}
sortResolutionFunc := func(a, b image.Point) int {
return cmp.Or(cmp.Compare(a.X, b.X), cmp.Compare(a.Y, b.Y))
}
slices.SortStableFunc(expectedResolutions, sortResolutionFunc)
imgProc := ImageProcessor{
imageSize: 336,
patchSize: 16,
numChannels: 3,
maxUpscalingSize: 448,
}
actualResolutions := imgProc.supportedResolutions()
slices.SortStableFunc(actualResolutions, sortResolutionFunc)
if diff := gocmp.Diff(expectedResolutions, actualResolutions); diff != "" {
t.Errorf("supportedResolutions() mismatch (-want +got):\n%s", diff)
}
}
func TestBestResolution(t *testing.T) {
tests := []struct {
name string
size image.Point
resolutions []image.Point
max bool
expected image.Point
}{
{
"normal",
image.Point{800, 600},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{800, 600},
},
{
"max",
image.Point{800, 600},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
true,
image.Point{1600, 1200},
},
{
"mid",
image.Point{1000, 700},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{1024, 768},
},
{
"smol",
image.Point{100, 100},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{300, 200},
},
{
"huge",
image.Point{10000, 10000},
[]image.Point{
{300, 200},
{640, 480},
{800, 600},
{1024, 768},
{1600, 1200},
},
false,
image.Point{1600, 1200},
},
}
p := ImageProcessor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := p.bestResolution(tt.size, tt.resolutions, tt.max)
if diff := gocmp.Diff(tt.expected, actual); diff != "" {
t.Errorf("best resolution mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestMaxResolution(t *testing.T) {
tests := []struct {
name string
origRes image.Point
targetRes image.Point
expected image.Point
}{
{
"normal",
image.Point{800, 600},
image.Point{800, 600},
image.Point{800, 600},
},
{
"skew",
image.Point{800, 600},
image.Point{1100, 700},
image.Point{933, 700},
},
}
p := ImageProcessor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := p.maxResolution(tt.origRes, tt.targetRes)
if !reflect.DeepEqual(actual, tt.expected) {
t.Errorf("max resolution; got %v want %v", actual, tt.expected)
}
})
}
}
func TestProcessImage(t *testing.T) {
imgProc := ImageProcessor{
imageSize: 336,
patchSize: 16,
numChannels: 3,
maxUpscalingSize: 448,
}
generateImage := func(seed int) image.Image {
width, height := 20, 10
img := image.NewRGBA(image.Rect(0, 0, width, height))
for x := range width {
// Use the seed to vary color generation
r := uint8((seed + x*11) % 256)
g := uint8((seed + x*17) % 256)
b := uint8((seed + x*23) % 256)
c := color.RGBA{R: r, G: g, B: b, A: 255}
for y := range height {
img.Set(x, y, c)
}
}
return img
}
pixelsLocal, pixelsGlobal, targetSize, err := imgProc.ProcessImage(generateImage(12))
if err != nil {
t.Error(err)
}
if n := len(pixelsLocal); n != 336*336*3 {
t.Errorf("unexpected size of f32s: %d", n)
}
if n := len(pixelsGlobal); n > 0 {
t.Errorf("unexpected size of f32s: %d", n)
}
if !targetSize.Eq(image.Point{336, 336}) {
t.Errorf("unexpected target size: %v", targetSize)
}
}

View File

@@ -152,7 +152,7 @@ func NewTextModel(c fs.Config) (*TextModel, error) {
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -43,7 +43,7 @@ func New(c fs.Config) (model.Model, error) {
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"), Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),

View File

@@ -177,7 +177,7 @@ type TextDecoder struct {
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers { for i, layer := range d.Layers {
layerType := selfAttentionLayer layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) { if slices.Contains(opts.crossAttentionLayers, int32(i)) {
layerType = crossAttentionLayer layerType = crossAttentionLayer
} }
@@ -202,7 +202,7 @@ type TextModelOptions struct {
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32
crossAttentionLayers []uint32 crossAttentionLayers []int32
} }
type TextModel struct { type TextModel struct {
@@ -225,7 +225,7 @@ func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") { for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) { if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{} textDecoderLayer = &TextCrossAttentionDecoderLayer{}
} else { } else {
textDecoderLayer = &TextSelfAttentionDecoderLayer{} textDecoderLayer = &TextSelfAttentionDecoderLayer{}
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"), ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
}, },
} }
} }

View File

@@ -96,10 +96,10 @@ type VisionEncoder struct {
Layers []VisionEncoderLayer Layers []VisionEncoderLayer
} }
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) { func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []int32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
var intermediateHiddenStates []ml.Tensor var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers { for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) { if slices.Contains(intermediateLayersIndices, int32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...)) intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
} }
@@ -154,7 +154,7 @@ type VisionModelOptions struct {
imageSize, patchSize int imageSize, patchSize int
eps float32 eps float32
intermediateLayersIndices []uint32 intermediateLayersIndices []int32
} }
type VisionModel struct { type VisionModel struct {
@@ -229,7 +229,7 @@ func newVisionModel(c fs.Config) *VisionModel {
eps: c.Float("vision.attention.layer_norm_epsilon"), eps: c.Float("vision.attention.layer_norm_epsilon"),
intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"), intermediateLayersIndices: c.Ints("vision.intermediate_layers_indices"),
}, },
} }
} }

View File

@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/mllama"
) )

View File

@@ -37,7 +37,7 @@ type TextProcessor interface {
type Vocabulary struct { type Vocabulary struct {
Values []string Values []string
Types []uint32 Types []int32
Scores []float32 Scores []float32
Merges []string Merges []string

View File

@@ -35,9 +35,9 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
sentencepiece.ModelProto_SentencePiece_CONTROL, sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED, sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE: sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t)) v.Types = append(v.Types, int32(t))
default: default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL) tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file // todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and // - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed // <end_of_turn> tokens aren't processed
@@ -124,7 +124,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
"<0xC3>", "<0xC3>",
"<0xA3>", "<0xA3>",
}, },
Types: []uint32{ Types: []int32{
TOKEN_TYPE_NORMAL, TOKEN_TYPE_NORMAL,
TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE,
TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE,

View File

@@ -28,7 +28,7 @@ func llama(t testing.TB) BytePairEncoding {
t.Fatal(err) t.Fatal(err)
} }
types := make([]uint32, len(vocab)) types := make([]int32, len(vocab))
tokens := make([]string, len(vocab)) tokens := make([]string, len(vocab))
for token, id := range vocab { for token, id := range vocab {
tokens[id] = token tokens[id] = token

View File

@@ -74,7 +74,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
t.Fatal(err) t.Fatal(err)
} }
types := make([]uint32, len(vocab))
tokens := make([]string, len(vocab)) tokens := make([]string, len(vocab))
for token, id := range vocab { for token, id := range vocab {
tokens[id] = token tokens[id] = token
@@ -86,7 +85,7 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
``, ``,
&model.Vocabulary{ &model.Vocabulary{
Values: tokens, Values: tokens,
Types: types, Types: make([]int32, len(vocab)),
Merges: merges, Merges: merges,
}, },
) )

View File

@@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
} }
defer bin.Close() defer bin.Close()
f, _, err := ggml.Decode(bin, 0) f, _, err := ggml.Decode(bin, 1024)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -457,7 +457,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err return nil, err
} }
f, _, err := ggml.Decode(temp, 0) f, _, err := ggml.Decode(temp, 1024)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err)) slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err return nil, err
@@ -499,7 +499,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var offset int64 var offset int64
for offset < stat.Size() { for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 0) f, n, err := ggml.Decode(blob, 1024)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {

View File

@@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
if err == nil { if err == nil {
defer r.Close() defer r.Close()
f, _, err := ggml.Decode(r, 0) f, _, err := ggml.Decode(r, 1024)
if err == nil { if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding) capabilities = append(capabilities, model.CapabilityEmbedding)

View File

@@ -73,8 +73,13 @@ type statusCodeRecorder struct {
func (r *statusCodeRecorder) WriteHeader(status int) { func (r *statusCodeRecorder) WriteHeader(status int) {
if r._status == 0 { if r._status == 0 {
r._status = status r._status = status
r.ResponseWriter.WriteHeader(status)
} }
r.ResponseWriter.WriteHeader(status) }
func (r *statusCodeRecorder) Write(b []byte) (int, error) {
r._status = r.status()
return r.ResponseWriter.Write(b)
} }
var ( var (

View File

@@ -64,7 +64,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
} }
defer blob.Close() defer blob.Close()
f, _, err := ggml.Decode(blob, 0) f, _, err := ggml.Decode(blob, 1024)
if err != nil { if err != nil {
return nil, err return nil, err
} }

226
server/python_tools.go Normal file
View File

@@ -0,0 +1,226 @@
package server
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/api"
)
var (
pythonFuncRegex = regexp.MustCompile(`(\w+)\((.*?)\)`)
braces = map[rune]rune{
'[': ']',
'{': '}',
'(': ')',
'"': '"',
'\'': '\'',
}
)
// parsePythonValue converts a Python value string to its appropriate Go type
func parsePythonValue(value string) (any, error) {
value = strings.TrimSpace(value)
// string
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
// Remove quotes
result := value[1 : len(value)-1]
return result, nil
}
// bool
switch strings.ToLower(value) {
case "true":
return true, nil
case "false":
return false, nil
case "none":
return nil, nil
}
// int
if i, err := strconv.Atoi(value); err == nil {
return i, nil
}
// float
if f, err := strconv.ParseFloat(value, 64); err == nil {
return f, nil
}
// list
if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
listStr := value[1 : len(value)-1]
var list []any
stack := []rune{}
start := 0
for i, char := range listStr {
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
stack = stack[:len(stack)-1]
} else if _, ok := braces[char]; ok {
stack = append(stack, char)
}
if len(stack) == 0 && (char == ',' || i == len(listStr)-1) {
end := i
if i == len(listStr)-1 {
end = i + 1
}
item := strings.TrimSpace(listStr[start:end])
if val, err := parsePythonValue(item); err == nil {
list = append(list, val)
} else {
return nil, fmt.Errorf("invalid list item: %s", item)
}
start = i + 1
}
}
return list, nil
}
// dictionary
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") && strings.Contains(value, ":") {
dictStr := value[1 : len(value)-1]
dict := make(map[any]any)
stack := []rune{}
start := 0
for i, char := range dictStr {
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
stack = stack[:len(stack)-1]
} else if _, ok := braces[char]; ok {
stack = append(stack, char)
}
if len(stack) == 0 && (char == ',' || i == len(dictStr)-1) {
end := i
if i == len(dictStr)-1 {
end = i + 1
}
item := strings.TrimSpace(dictStr[start:end])
kv := strings.SplitN(item, ":", 2)
if len(kv) != 2 {
return nil, fmt.Errorf("invalid dictionary key-value pair: %s", item)
}
key, err := parsePythonValue(strings.TrimSpace(kv[0]))
if err != nil {
return nil, fmt.Errorf("invalid dictionary key: %s", kv[0])
}
val, err := parsePythonValue(strings.TrimSpace(kv[1]))
if err != nil {
return nil, fmt.Errorf("invalid dictionary value: %s", kv[1])
}
dict[key] = val
start = i + 1
}
}
return dict, nil
}
// sets (stored as lists)
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") {
setStr := value[1 : len(value)-1]
var list []any
stack := []rune{}
start := 0
for i, char := range setStr {
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
stack = stack[:len(stack)-1]
} else if _, ok := braces[char]; ok {
stack = append(stack, char)
}
if len(stack) == 0 && (char == ',' || i == len(setStr)-1) {
end := i
if i == len(setStr)-1 {
end = i + 1
}
item := strings.TrimSpace(setStr[start:end])
if val, err := parsePythonValue(item); err == nil {
list = append(list, val)
} else {
return nil, fmt.Errorf("invalid set item: %s", item)
}
start = i + 1
}
}
return list, nil
}
return nil, fmt.Errorf("invalid Python value: %s", value)
}
// parsePythonToolCall parses Python function calls from a string
// it supports keyword arguments, as well as multiple functions in a single string
func parsePythonToolCall(s string) ([]api.ToolCall, error) {
matches := pythonFuncRegex.FindAllStringSubmatchIndex(s, -1)
if len(matches) == 0 {
return nil, fmt.Errorf("no Python function calls found")
}
var toolCalls []api.ToolCall
for _, match := range matches {
name := s[match[2]:match[3]]
args := s[match[4]:match[5]]
var arguments api.ToolCallFunctionArguments
if len(args) == 0 {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: name,
},
})
continue
}
start := 0
stack := []rune{}
for i, char := range args {
if len(stack) != 0 && char == braces[stack[len(stack)-1]] {
stack = stack[:len(stack)-1]
} else if _, ok := braces[char]; ok {
stack = append(stack, char)
}
if len(stack) == 0 && (char == ',' || i == len(args)-1) {
end := i
if i == len(args)-1 {
end = i + 1
}
kv := strings.SplitN(args[start:end], "=", 2)
if len(kv) == 2 {
key := strings.TrimSpace(kv[0])
valueStr := strings.TrimSpace(kv[1])
// Parse the value into appropriate type
value, err := parsePythonValue(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for key %q: %v", key, err)
}
arguments[key] = value
} else {
return nil, fmt.Errorf("invalid argument format: %q", args[start:end])
}
start = i + 1
}
}
if len(arguments) > 0 {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: name,
Arguments: arguments,
},
})
}
}
if len(toolCalls) > 0 {
return toolCalls, nil
}
return nil, fmt.Errorf("failed to parse any valid tool calls")
}

269
server/python_tools_test.go Normal file
View File

@@ -0,0 +1,269 @@
package server
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestParsePythonFunctionCall(t *testing.T) {
t1 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco, CA",
"format": "fahrenheit",
},
},
}
t2 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_forecast",
Arguments: api.ToolCallFunctionArguments{
"days": 5,
"location": "Seattle",
},
},
}
t3 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"list": []any{1, 2, 3},
"int": -1,
"float": 1.23,
"string": "hello",
},
},
}
t4 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
},
}
cases := []struct {
name string
input string
want []api.ToolCall
err bool
}{
{
name: "malformed function call - missing closing paren",
input: "get_current_weather(location=\"San Francisco\"",
err: true,
},
{
name: "empty function call",
input: "get_current_weather()",
want: []api.ToolCall{t4},
err: false,
},
{
name: "single valid function call",
input: "get_current_weather(location=\"San Francisco, CA\", format=\"fahrenheit\")",
want: []api.ToolCall{t1},
},
{
name: "multiple valid function calls",
input: "get_current_weather(location=\"San Francisco, CA\", format=\"fahrenheit\") get_forecast(days=5, location=\"Seattle\")",
want: []api.ToolCall{t1, t2},
},
{
name: "multiple valid function calls with list",
input: "get_current_weather(list=[1,2,3], int=-1, float=1.23, string=\"hello\")",
want: []api.ToolCall{t3},
},
{
name: "positional arguments not supported",
input: "get_current_weather(1, 2, 3)",
err: true,
},
{
name: "invalid argument format without equals",
input: "get_current_weather(\"San Francisco\")",
err: true,
},
{
name: "nested lists",
input: "get_current_weather(data=[[1,2],[3,4]])",
want: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"data": []any{[]any{1, 2}, []any{3, 4}},
},
},
}},
},
{
name: "boolean and none values",
input: "get_current_weather(active=true, enabled=false, value=None)",
want: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"active": true,
"enabled": false,
"value": nil,
},
},
}},
},
{
name: "single vs double quotes",
input: "get_current_weather(str1='single', str2=\"double\")",
want: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"str1": "single",
"str2": "double",
},
},
}},
},
{
name: "whitespace handling",
input: "get_current_weather( location = \"San Francisco\" , temp = 72 )",
want: []api.ToolCall{{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco",
"temp": 72,
},
},
}},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := parsePythonToolCall(tt.input)
if (err != nil) != tt.err {
t.Fatalf("expected error: %v, got error: %v", tt.err, err)
}
if tt.err {
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestParsePythonValue(t *testing.T) {
cases := []struct {
name string
input string
want any
err bool
}{
{
name: "string with double quotes",
input: "\"hello\"",
want: "hello",
},
{
name: "string with single quotes",
input: "'world'",
want: "world",
},
{
name: "integer",
input: "42",
want: 42,
},
{
name: "float",
input: "3.14",
want: 3.14,
},
{
name: "boolean true",
input: "True",
want: true,
},
{
name: "boolean false",
input: "False",
want: false,
},
{
name: "none/null",
input: "None",
want: nil,
},
{
name: "simple list",
input: "[1, 2, 3]",
want: []any{1, 2, 3},
},
{
name: "nested list",
input: "[1, [2, 3], 4]",
want: []any{1, []any{2, 3}, 4},
},
{
name: "mixed type list",
input: "[1, \"two\", 3.0, true]",
want: []any{1, "two", 3.0, true},
},
{
name: "invalid list",
input: "[1, 2,",
want: nil,
err: true,
},
{
name: "dictionaries",
input: "{'a': 1, 'b': 2}",
want: map[any]any{"a": 1, "b": 2},
err: false,
},
{
name: "int dictionary",
input: "{1: 2}",
want: map[any]any{1: 2},
err: false,
},
{
name: "mixed type dictionary",
input: "{'a': 1, 'b': 2.0, 'c': True}",
want: map[any]any{"a": 1, "b": 2.0, "c": true},
err: false,
},
{
name: "invalid dictionary - missing closing brace",
input: "{'a': 1, 'b': 2",
want: nil,
err: true,
},
{
name: "sets",
input: "{1, 2, 3}",
want: []any{1, 2, 3},
err: false,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := parsePythonValue(tt.input)
if (err != nil) != tt.err {
t.Fatalf("expected error: %v, got error: %v", tt.err, err)
}
if tt.err {
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}