diff --git a/convert/convert.go b/convert/convert.go index a31b0d6c7..26bc72cc2 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { var conv ModelConverter switch p.Architectures[0] { - case "LlamaForCausalLM", "MistralForCausalLM": + case "LlamaForCausalLM": conv = &llamaModel{} + case "Mistral3ForConditionalGeneration": + conv = &mistral3Model{} case "MixtralForCausalLM": conv = &mixtralModel{} case "GemmaForCausalLM": diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go new file mode 100644 index 000000000..6c224ae4f --- /dev/null +++ b/convert/convert_mistral.go @@ -0,0 +1,190 @@ +package convert + +import ( + "cmp" + "fmt" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" +) + +type mistral3Model struct { + ModelParameters + ImageTokenIndex uint32 `json:"image_token_index"` + SpatialMergeSize uint32 `json:"spatial_merge_size"` + VisionFeatureLayer int32 `json:"vision_feature_layer"` + TextModel struct { + NumHiddenLayers uint32 `json:"num_hidden_layers"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RopeTheta float32 `json:"rope_theta"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + SlidingWindow *uint32 `json:"sliding_window"` + HiddenAct string `json:"hidden_act"` + VocabSize uint32 `json:"vocab_size"` + } `json:"text_config"` + VisionModel struct { + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + ImageSize uint32 `json:"image_size"` + NumChannels uint32 `json:"num_channels"` + PatchSize uint32 `json:"patch_size"` + HeadDim uint32 `json:"head_dim"` + HiddenAct string `json:"hidden_act"` + RopeTheta float32 `json:"rope_theta"` + } `json:"vision_config"` + MultiModalProjectorBias bool `json:"multimodal_projector_bias"` + ProjectorHiddenAct string `json:"projector_hidden_act"` +} + +func (p *mistral3Model) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "mistral3" + kv["mistral3.vocab_size"] = p.TextModel.VocabSize + + // Text configuration + kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers + kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings + kv["mistral3.embedding_length"] = p.TextModel.HiddenSize + kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize + kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads + kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads + kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS + kv["mistral3.attention.key_length"] = p.TextModel.HeadDim + kv["mistral3.attention.value_length"] = p.TextModel.HeadDim + kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers + kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta + + // Vision configuration + kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers + kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize + kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize + kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads + kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim + kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize + kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize + kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels + // kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value + kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta + + // Multimodal configuration + kv["mistral3.image_token_index"] = p.ImageTokenIndex + kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize + + kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias + + if p.ProjectorHiddenAct != "" { + kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct + } + + return kv +} + +func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor { + var out []ggml.Tensor + + for _, t := range ts { + if !strings.HasPrefix(t.Name(), "v.") { + if strings.HasSuffix(t.Name(), ".attn_q.weight") || + strings.HasSuffix(t.Name(), ".attn_k.weight") { + t.SetRepacker(p.repack) + } + } + + out = append(out, ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *mistral3Model) Replacements() []string { + return []string{ + "language_model.model.norm", "output_norm", + "language_model.model.", "", + "language_model.", "", + "layers", "blk", + "transformer.layers", "blk", + "vision_tower", "v", + "ln_pre", "encoder_norm", + "input_layernorm", "attn_norm", + "post_attention_layernorm", "ffn_norm", + "embed_tokens", "token_embd", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "attention.q_proj", "attn_q", + "attention.k_proj", "attn_k", + "attention.v_proj", "attn_v", + "attention.o_proj", "attn_output", + "attention_norm", "attn_norm", + "feed_forward.gate_proj", "ffn_gate", + "feed_forward.down_proj", "ffn_down", + "feed_forward.up_proj", "ffn_up", + "multi_modal_projector", "mm", + "ffn_norm", "ffn_norm", + "lm_head", "output", + } +} + +func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) { + var dims []int + for _, dim := range shape { + dims = append(dims, int(dim)) + } + + var heads uint32 + if strings.HasSuffix(name, ".attn_q.weight") { + heads = p.TextModel.NumAttentionHeads + } else if strings.HasSuffix(name, ".attn_k.weight") { + heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads) + } else { + return nil, fmt.Errorf("unknown tensor for repack: %s", name) + } + + n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil { + return nil, err + } + + if err := n.T(0, 2, 1, 3); err != nil { + return nil, err + } + + if err := n.Reshape(dims...); err != nil { + return nil, err + } + + if err := n.Transpose(); err != nil { + return nil, err + } + + ts, err := native.SelectF32(n, 1) + if err != nil { + return nil, err + } + + var f32s []float32 + for _, t := range ts { + f32s = append(f32s, t...) + } + + return f32s, nil +} diff --git a/convert/reader.go b/convert/reader.go index c1218e66d..904b13a42 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) { Pattern string Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error) }{ - {"model-*-of-*.safetensors", parseSafetensors}, - {"model.safetensors", parseSafetensors}, - {"adapters.safetensors", parseSafetensors}, - {"adapter_model.safetensors", parseSafetensors}, + {"*.safetensors", parseSafetensors}, {"pytorch_model-*-of-*.bin", parseTorch}, {"pytorch_model.bin", parseTorch}, {"consolidated.*.pth", parseTorch}, diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index c88583fb8..9431e9cc1 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -134,7 +134,10 @@ func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { } func (kv KV) OllamaEngineRequired() bool { - return kv.Architecture() == "gemma3" + return slices.Contains([]string{ + "gemma3", + "mistral3", + }, kv.Architecture()) } func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { @@ -638,7 +641,7 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { embeddingLength*numPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles + numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) - case "gemma3": + case "gemma3", "mistral3": graphSize = 4 * (imageSize*imageSize*numChannels + embeddingLength*patchSize + numPatches*numPatches*headCount) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 517e3726d..bd63214cb 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -484,6 +484,14 @@ func (t *testTensor) Floats() []float32 { return out } +func (t *testTensor) Neg(ctx ml.Context) ml.Tensor { + out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) + for i := range out.data { + out.data[i] = -t.data[i] + } + return out +} + func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) @@ -538,17 +546,15 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, di panic("not implemented") } -func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { +func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { panic("not implemented") } -func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { - panic("not implemented") -} - -func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { - panic("not implemented") -} +func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") } +func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") } func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { panic("not implemented") @@ -600,6 +606,8 @@ func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { panic("not implemented") } +func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") } + func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { panic("not implemented") } @@ -612,3 +620,5 @@ func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { copy(t2.(*testTensor).data, t.data) return nil } + +func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") } diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b443fcd3f..13a0a9888 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_SOLAR, "solar" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1371,6 +1372,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_MISTRAL3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + } + }, { LLM_ARCH_UNKNOWN, { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index aad92a5d2..8476ae0a1 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_CHAMELEON, LLM_ARCH_SOLAR, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 701830418..db4f2685d 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -1277,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_MISTRAL3: break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3537,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_MISTRAL3: break; default: throw std::runtime_error("unknown architecture"); } @@ -4015,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: case LLM_ARCH_SOLAR: + case LLM_ARCH_MISTRAL3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index d2f3a5108..ebcbafa1c 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -738,13 +738,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? // don't quantize vision stuff - quantize &= name.find("v.blk.") == std::string::npos; - - quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; - quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; - quantize &= name.find("v.patch_embedding.weight") == std::string::npos; - quantize &= name.find("v.position_embedding.weight") == std::string::npos; - quantize &= name.find("v.post_layernorm.weight") == std::string::npos; + quantize &= name.find("v.") == std::string::npos; + quantize &= name.find("mm.") == std::string::npos; // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0021-gemma3-quantization.patch b/llama/patches/0021-add-model-quantizations.patch similarity index 52% rename from llama/patches/0021-gemma3-quantization.patch rename to llama/patches/0021-add-model-quantizations.patch index 4f6dbc11b..cdc35a412 100644 --- a/llama/patches/0021-gemma3-quantization.patch +++ b/llama/patches/0021-add-model-quantizations.patch @@ -1,17 +1,19 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 14 Mar 2025 16:33:23 -0700 -Subject: [PATCH] gemma3 quantization +Subject: [PATCH] add model quantizations +- gemma3 +- mistral3 --- - src/llama-arch.cpp | 19 +++++++++++++++++++ - src/llama-arch.h | 1 + - src/llama-model.cpp | 7 +++++++ - src/llama-quant.cpp | 9 +++++++++ - 4 files changed, 36 insertions(+) + src/llama-arch.cpp | 36 ++++++++++++++++++++++++++++++++++++ + src/llama-arch.h | 2 ++ + src/llama-model.cpp | 10 ++++++++++ + src/llama-quant.cpp | 4 ++++ + 4 files changed, 52 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index b6f20286..b443fcd3 100644 +index b6f20286..13a0a988 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { @@ -22,7 +24,15 @@ index b6f20286..b443fcd3 100644 { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, -@@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N +@@ -64,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_SOLAR, "solar" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, ++ { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, + }; + +@@ -804,6 +806,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, @@ -47,8 +57,31 @@ index b6f20286..b443fcd3 100644 { LLM_ARCH_STARCODER2, { +@@ -1352,6 +1372,22 @@ static const std::map> LLM_TENSOR_N + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + }, + }, ++ { ++ LLM_ARCH_MISTRAL3, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ } ++ }, + { + LLM_ARCH_UNKNOWN, + { diff --git a/src/llama-arch.h b/src/llama-arch.h -index ec742224..aad92a5d 100644 +index ec742224..8476ae0a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -41,6 +41,7 @@ enum llm_arch { @@ -59,8 +92,16 @@ index ec742224..aad92a5d 100644 LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, +@@ -68,6 +69,7 @@ enum llm_arch { + LLM_ARCH_CHAMELEON, + LLM_ARCH_SOLAR, + LLM_ARCH_WAVTOKENIZER_DEC, ++ LLM_ARCH_MISTRAL3, + LLM_ARCH_UNKNOWN, + }; + diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index ab1a07d1..70183041 100644 +index ab1a07d1..db4f2685 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { @@ -73,7 +114,15 @@ index ab1a07d1..70183041 100644 case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -1274,6 +1277,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + } break; ++ case LLM_ARCH_MISTRAL3: break; + default: throw std::runtime_error("unsupported model architecture"); + } + +@@ -2537,6 +2541,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; @@ -83,7 +132,23 @@ index ab1a07d1..70183041 100644 case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); -@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { +@@ -3531,6 +3538,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + } break; ++ case LLM_ARCH_MISTRAL3: break; + default: + throw std::runtime_error("unknown architecture"); + } +@@ -4009,6 +4017,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: + case LLM_ARCH_SOLAR: ++ case LLM_ARCH_MISTRAL3: + return LLAMA_ROPE_TYPE_NORM; + + // the pairs of head values are offset by n_rot/2 +@@ -4029,6 +4038,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: @@ -92,21 +157,16 @@ index ab1a07d1..70183041 100644 case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp -index 6eb1da08..d2f3a510 100644 +index 6eb1da08..ebcbafa1 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp -@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: +@@ -737,6 +737,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // This used to be a regex, but has an extreme cost to compile times. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + // don't quantize vision stuff -+ quantize &= name.find("v.blk.") == std::string::npos; -+ -+ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; -+ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; -+ quantize &= name.find("v.patch_embedding.weight") == std::string::npos; -+ quantize &= name.find("v.position_embedding.weight") == std::string::npos; -+ quantize &= name.find("v.post_layernorm.weight") == std::string::npos; ++ quantize &= name.find("v.") == std::string::npos; ++ quantize &= name.find("mm.") == std::string::npos; + // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0022-metal-add-op_neg.patch b/llama/patches/0022-metal-add-op_neg.patch new file mode 100644 index 000000000..a903535f2 --- /dev/null +++ b/llama/patches/0022-metal-add-op_neg.patch @@ -0,0 +1,75 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Wed, 2 Apr 2025 15:26:15 -0700 +Subject: [PATCH] metal: add op_neg + +--- + ggml/src/ggml-metal/ggml-metal.m | 15 +++++++++++++++ + ggml/src/ggml-metal/ggml-metal.metal | 7 +++++++ + 2 files changed, 22 insertions(+) + +diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m +index e4c093f9..d8422f1b 100644 +--- a/ggml/src/ggml-metal/ggml-metal.m ++++ b/ggml/src/ggml-metal/ggml-metal.m +@@ -423,6 +423,7 @@ enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_SQRT, + GGML_METAL_KERNEL_TYPE_SIN, + GGML_METAL_KERNEL_TYPE_COS, ++ GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, +@@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); +@@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: ++ case GGML_UNARY_OP_NEG: + return ggml_is_contiguous(op->src[0]); + default: + return false; +@@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node( + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; ++ case GGML_UNARY_OP_NEG: ++ { ++ id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; ++ ++ [encoder setComputePipelineState:pipeline]; ++ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; ++ [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; ++ ++ const int64_t n = ggml_nelements(dst); ++ ++ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; ++ } break; + default: + { + GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index f38909d0..bb0ff668 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -945,6 +945,13 @@ kernel void kernel_cos( + dst[tpig] = cos(src0[tpig]); + } + ++kernel void kernel_neg( ++ device const float * src0, ++ device float * dst, ++ uint tpig[[thread_position_in_grid]]) { ++ dst[tpig] = -src0[tpig]; ++} ++ + kernel void kernel_sum_rows( + device const float * src0, + device float * dst, diff --git a/ml/backend.go b/ml/backend.go index b22ba7952..fffc04a48 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -118,6 +118,7 @@ type Tensor interface { Bytes() []byte Floats() []float32 + Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor @@ -132,7 +133,10 @@ type Tensor interface { Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor + Sin(ctx Context) Tensor + Cos(ctx Context) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor SILU(ctx Context) Tensor @@ -147,9 +151,13 @@ type Tensor interface { Unpad(ctx Context, shape ...int) Tensor Stack(ctx Context, dim int, s ...Tensor) Tensor + + // Repeat repeats the tensor n times along dimension dim + Repeat(ctx Context, dim, n int) Tensor Concat(ctx Context, t2 Tensor, dim int) Tensor Rows(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor + Duplicate(ctx Context) Tensor } // ScaledDotProductAttention implements a fused attention @@ -214,7 +222,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) case DTypeF16, DTypeQ80, DTypeQ40: - f32 := ctx.Empty(DTypeF32, t.Shape()...) + f32 := ctx.Input().Empty(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 17f063840..a106fed5f 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -711,6 +711,13 @@ func (t *Tensor) DType() ml.DType { } } +func (t *Tensor) Neg(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_neg(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -718,6 +725,27 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { + if dim < 0 || dim >= C.GGML_MAX_DIMS { + panic("invalid dimension") + } + + shape := make([]C.int64_t, C.GGML_MAX_DIMS) + for i := range C.GGML_MAX_DIMS { + if i == dim { + shape[i] = C.int64_t(t.Dim(i) * n) + } else { + shape[i] = C.int64_t(t.Dim(i)) + } + } + + tmpl := C.ggml_new_tensor(ctx.(*Context).ctx, t.t._type, C.int(len(shape)), unsafe.SliceData(shape)) + return &Tensor{ + b: t.b, + t: C.ggml_repeat(ctx.(*Context).ctx, t.t, tmpl), + } +} + func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { if len(s) > 0 { return t.Concat(ctx, s[0].Stack(ctx, dim, s[1:]...), dim) @@ -854,6 +882,20 @@ func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { } } +func (t *Tensor) Sin(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_sin(ctx.(*Context).ctx, t.t), + } +} + +func (t *Tensor) Cos(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_cos(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -942,6 +984,13 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } } +func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_im2col(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), + } +} + func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, @@ -1010,3 +1059,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) } } + +func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_dup(ctx.(*Context).ctx, t.t), + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index c3610ac07..a2f599ce5 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -3083,6 +3083,13 @@ kernel void kernel_cos( dst[tpig] = cos(src0[tpig]); } +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index e4c093f9c..d8422f1b7 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -423,6 +423,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SQRT, GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, + GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, @@ -1039,6 +1040,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); @@ -1202,6 +1204,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: return ggml_is_contiguous(op->src[0]); default: return false; @@ -1873,6 +1876,18 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_NEG: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index f38909d0b..bb0ff6688 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -945,6 +945,13 @@ kernel void kernel_cos( dst[tpig] = cos(src0[tpig]); } +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 3b640a968..2d7bb20a7 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -11,7 +11,7 @@ import ( "github.com/ollama/ollama/model/input" ) -type TextOptions struct { +type TextConfig struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int eps, ropeScale float32 @@ -28,7 +28,7 @@ type TextModel struct { OutputNorm *nn.RMSNorm `gguf:"output_norm"` Output *nn.Linear `gguf:"output,alt:token_embd"` - *TextOptions + *TextConfig } const ( @@ -55,7 +55,7 @@ func newTextModel(c fs.Config) *TextModel { }, ), Layers: make([]TextLayer, numBlocks), - TextOptions: &TextOptions{ + TextConfig: &TextConfig{ hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), @@ -84,7 +84,7 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { batchSize := hiddenState.Dim(1) ropeType := uint32(2) @@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextOptions.ropeLocalBase + ropeBase := m.TextConfig.ropeLocalBase if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextOptions.ropeGlobalBase + ropeBase = m.TextConfig.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil + return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil } type TextMLP struct { @@ -134,7 +134,7 @@ type TextMLP struct { Gate *nn.Linear `gguf:"ffn_gate"` } -func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { +func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -148,7 +148,7 @@ type TextLayer struct { PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"` } -func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) + hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) // set image embeddings var except []int @@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor lastLayerOutputs = outputs } - hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/mistral3/imageproc.go b/model/models/mistral3/imageproc.go new file mode 100644 index 000000000..3d464bca4 --- /dev/null +++ b/model/models/mistral3/imageproc.go @@ -0,0 +1,56 @@ +package mistral3 + +import ( + "image" + _ "image/jpeg" + _ "image/png" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize int + patchSize int + numChannels int + longestEdge int +} + +func newImageProcessor(c fs.Config) ImageProcessor { + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size", 1540)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + longestEdge: int(c.Uint("vision.longest_edge", 1540)), + } +} + +// ProcessImage prepares an image for the vision model by: +// 1. Compositing transparent images +// 2. Resizing to fit model constraints while preserving aspect ratio +// 3. Normalizing pixel values +// Returns normalized image data and the final size in pixels +func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) { + img = imageproc.Composite(img) + + size := img.Bounds().Size() + ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge)) + if ratio > 1.0 { + size = image.Point{ + int(math.Floor(float64(size.X) / ratio)), + int(math.Floor(float64(size.Y) / ratio)), + } + } + + patchesX := (size.X-1)/p.patchSize + 1 + patchesY := (size.Y-1)/p.patchSize + 1 + size = image.Point{ + patchesX * p.patchSize, + patchesY * p.patchSize, + } + + img = imageproc.Resize(img, size, imageproc.ResizeBilinear) + data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) + return data, size, nil +} diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go new file mode 100644 index 000000000..fca3896c3 --- /dev/null +++ b/model/models/mistral3/model.go @@ -0,0 +1,189 @@ +package mistral3 + +import ( + "bytes" + "image" + "slices" + "sync" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + *TextModel + *VisionModel `gguf:"v,vision"` + *MultiModalProjector `gguf:"mm"` + + ImageProcessor +} + +// Implement MultimodalProcessor interface +var _ model.MultimodalProcessor = (*Model)(nil) + +func New(c fs.Config) (model.Model, error) { + textModel, err := NewTextModel(c) + if err != nil { + return nil, err + } + + m := &Model{ + TextModel: textModel, + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + MultiModalProjector: newMultiModalProjector(c), + } + + m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) + + return m, nil +} + +type PatchMerger struct { + MergingLayer *nn.Linear `gguf:"merging_layer"` +} + +func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor { + d := visionOutputs.Dim(0) + imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d) + kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d) + patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) + reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) + return pm.MergingLayer.Forward(ctx, reshaped) +} + +type MultiModalProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Linear1 *nn.Linear `gguf:"linear_1"` + Linear2 *nn.Linear `gguf:"linear_2"` + PatchMerger *PatchMerger `gguf:"patch_merger"` + + spatialMergeSize int + eps float32 + patchSize int +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) { + visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps) + patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize} + visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize) + visionOutputs = p.Linear1.Forward(ctx, visionOutputs) + visionOutputs = visionOutputs.GELU(ctx) + return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize} +} + +func newMultiModalProjector(c fs.Config) *MultiModalProjector { + return &MultiModalProjector{ + spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), + eps: c.Float("text_config.rms_norm_eps", 1e-5), + patchSize: int(c.Uint("vision.patch_size", 14)), + } +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + image, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + f32s, size, err := m.ImageProcessor.ProcessImage(image) + if err != nil { + return nil, err + } + + pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) + if err != nil { + return nil, err + } + + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) + + // split into patches to be sent to the text transformer + parent := imageFeatures{tensor: features} + rows := make([]*imageRow, size.Y) + for i := range rows { + rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}} + } + + return rows, nil +} + +type imageFeatures struct { + tensor ml.Tensor + + dataOnce sync.Once + data []float32 +} + +type imageRow struct { + parent *imageFeatures + s int + shape []int +} + +func (r *imageRow) data() []float32 { + n := 1 + for _, s := range r.shape { + n *= s + } + + return r.parent.data[r.s*n : (r.s+1)*n] +} + +// PostTokenize arranges Mistral 3's inputs for the forward pass +// In Mistral 3 and Pixtral, the input patches are arranged as follows: +// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] +// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings +// that can be processed together. +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { + var result []input.Input + for _, inp := range inputs { + if inp.Multimodal == nil { + result = append(result, inp) + } else { + inputMultimodal := inp.Multimodal.([]*imageRow) + for i, row := range inputMultimodal { + // [IMG] + result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]}) + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...) + if i == len(inputMultimodal)-1 { + // [IMG_END] + result = append(result, input.Input{Token: 13}) + } else { + // [IMG_BREAK] + result = append(result, input.Input{Token: 12}) + } + } + } + } + + return result, nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil +} + +func init() { + model.Register("mistral3", New) +} diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go new file mode 100644 index 000000000..c256cbf17 --- /dev/null +++ b/model/models/mistral3/model_text.go @@ -0,0 +1,177 @@ +package mistral3 + +import ( + "fmt" + "math" + "strings" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type TextOptions struct { + hiddenSize, numHeads, numKVHeads, headDim int + eps, ropeBase, ropeScale float32 + ropeDim uint32 +} + +type TextModel struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *TextOptions +} + +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + batchSize := hiddenState.Dim(1) + ropeType := uint32(0) + headDim := opts.headDim + if headDim == 0 { + headDim = opts.hiddenSize / opts.numHeads + } + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache) + kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, kqv) +} + +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + return hiddenState.Add(ctx, residual) +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { + hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx) + + // image embeddings + for _, image := range batch.Multimodal { + row := image.Multimodal.(*imageRow) + row.parent.dataOnce.Do(func() { + // use a new, throwaway context so the image tensor is not added to the graph + temp := m.Backend().NewContext() + temp.Forward(row.parent.tensor).Compute(row.parent.tensor) + row.parent.data = row.parent.tensor.Floats() + temp.Close() + }) + + imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...) + if err != nil { + panic(err) + } + + ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1)))) + } + + for i, layer := range m.Layers { + cache.SetLayer(i) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState) +} + +func NewTextModel(c fs.Config) (*TextModel, error) { + if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") { + return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model")) + } + + textModel := &TextModel{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Uints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + Layers: make([]Layer, c.Uint("block_count")), + TextOptions: &TextOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + ropeDim: c.Uint("rope.dimension_count"), + }, + } + + return textModel, nil +} diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go new file mode 100644 index 000000000..469dc40cb --- /dev/null +++ b/model/models/mistral3/model_vision.go @@ -0,0 +1,186 @@ +package mistral3 + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +var batchSize int = 1 + +func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { + x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) + x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) + return x2.Neg(ctx).Concat(ctx, x1, 0) +} + +func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { + return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +} + +type VisionSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) + + query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type VisionEncoderLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *VisionSelfAttention + FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *VisionMLP +} + +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type VisionModelOptions struct { + hiddenSize int + numHeads int + headDim int + intermediateSize int + imageSize int + patchSize int + numChannels int + eps float32 + ropeBase float32 +} + +type VisionModel struct { + PatchEmbedding *nn.Conv2D `gguf:"patch_conv"` + EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"` + Layers []VisionEncoderLayer `gguf:"blk"` + + *VisionModelOptions +} + +func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor { + maxPatchesPerSide := m.imageSize / m.patchSize + frequencies := m.headDim / 2 + frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide) + frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide) + for i := range frequencies { + for j := range maxPatchesPerSide { + frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim))) + if i%2 == 0 { + frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency + } else { + frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency + } + } + } + + h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) + if err != nil { + panic(err) + } + + w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) + if err != nil { + panic(err) + } + + h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + + h = h.Repeat(ctx, 1, maxPatchesPerSide) + h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + w = w.Repeat(ctx, 2, maxPatchesPerSide) + + inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide) + inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0) + return inverseFrequencies.Rows(ctx, positionIDs) +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { + numPatchesW := pixelValues.Dim(0) / m.patchSize + numPatchesH := pixelValues.Dim(1) / m.patchSize + numPatches := numPatchesW * numPatchesH + + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) + hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize) + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps) + + // Prepare position IDs for 2D rope + positions := make([]int32, numPatches) + for h := range numPatchesH { + for w := range numPatchesW { + idx := h*numPatchesW + w + positions[idx] = int32(h*m.imageSize/m.patchSize + w) + } + } + + positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) + if err != nil { + panic(err) + } + + positionEmbedding := m.positionalEmbedding(ctx, positionIDs) + cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) + cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) + sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions) + } + + return hiddenStates +} + +func newVisionModel(c fs.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)), + VisionModelOptions: &VisionModelOptions{ + hiddenSize: int(c.Uint("vision.embedding_length", 1024)), + numHeads: int(c.Uint("vision.attention.head_count", 16)), + headDim: int(c.Uint("vision.attention.key_length", 64)), + intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)), + imageSize: int(c.Uint("vision.image_size", 1540)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5), + ropeBase: c.Float("vision.rope.freq_base", 10000.0), + }, + } +} diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 2f7d26ca2..8b10bde88 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -186,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions) - hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1) + hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1) hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions) hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps) diff --git a/model/models/models.go b/model/models/models.go index ce1d2ce03..c5da2894b 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -4,5 +4,6 @@ import ( _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/llama" + _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" ) diff --git a/model/models/pixtral/imageproc.go b/model/models/pixtral/imageproc.go deleted file mode 100644 index 16ec0c410..000000000 --- a/model/models/pixtral/imageproc.go +++ /dev/null @@ -1,68 +0,0 @@ -package pixtral - -import ( - "fmt" - "image" - _ "image/jpeg" - _ "image/png" - "io" - "math" - - "github.com/ollama/ollama/model/imageproc" -) - -func getNumImageTokens(imageSize, patchSize image.Point) image.Point { - return image.Point{ - (imageSize.X-1)/patchSize.X + 1, - (imageSize.Y-1)/patchSize.Y + 1, - } -} - -func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point { - b := img.Bounds() - le := float64(longestEdge) - ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le) - - newSize := img.Bounds().Max - - if ratio > 1.0 { - newSize = image.Point{ - int(math.Ceil(float64(b.Max.X) / ratio)), - int(math.Ceil(float64(b.Max.Y) / ratio)), - } - } - - tokens := getNumImageTokens(newSize, patchSize) - return image.Point{ - tokens.X * patchSize.X, - tokens.Y * patchSize.Y, - } -} - -func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image { - if format == "png" { - img = imageproc.Composite(img) - } - - newSize := getResizeOutputImageSize(img, longestEdge, patchSize) - - // todo should be ResizeBicubic, but it doesn't exist - return imageproc.Resize(img, newSize, imageproc.ResizeBilinear) -} - -func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) { - img, format, err := image.Decode(imageData) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode image: %w", err) - } - - longestEdge := 1024 - patchSize := image.Point{16, 16} - - img = resizeImage(img, format, longestEdge, patchSize) - - data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) - - opts := map[string]any{} - return data, opts, nil -} diff --git a/model/models/pixtral/imageproc_test.go b/model/models/pixtral/imageproc_test.go deleted file mode 100644 index 1d9e4ffe5..000000000 --- a/model/models/pixtral/imageproc_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package pixtral - -import ( - "bytes" - "encoding/binary" - "image" - "image/png" - "math" - "os" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestGetNumImageTokens(t *testing.T) { - type numImageTokensCase struct { - ImageSize image.Point - PatchSize image.Point - Expected image.Point - } - - cases := []numImageTokensCase{ - { - ImageSize: image.Point{1024, 764}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{64, 48}, - }, - { - ImageSize: image.Point{800, 600}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{50, 38}, - }, - { - ImageSize: image.Point{640, 480}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{40, 30}, - }, - { - ImageSize: image.Point{320, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{20, 13}, - }, - { - ImageSize: image.Point{1320, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{83, 13}, - }, - { - ImageSize: image.Point{2000, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{125, 13}, - }, - { - ImageSize: image.Point{10000, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{625, 13}, - }, - { - ImageSize: image.Point{1131, 577}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{71, 37}, - }, - { - ImageSize: image.Point{16, 16}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1, 1}, - }, - } - - for _, c := range cases { - actual := getNumImageTokens(c.ImageSize, c.PatchSize) - - if diff := cmp.Diff(actual, c.Expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } -} - -func TestGetResizeOutputImageSize(t *testing.T) { - type resizeCase struct { - Image image.Image - LongestEdge int - PatchSize image.Point - Expected image.Point - } - - cases := []resizeCase{ - { - Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 768}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 624}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 300, 200)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{304, 208}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 288}, - }, - } - - for _, c := range cases { - actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize) - - if diff := cmp.Diff(actual, c.Expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } -} - -func TestResize(t *testing.T) { - type resizeCase struct { - Image image.Image - LongestEdge int - PatchSize image.Point - Expected image.Image - } - - cases := []resizeCase{ - { - Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)), - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 10, 10)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)), - }, - } - - for _, c := range cases { - actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize) - - if actual.Bounds() != c.Expected.Bounds() { - t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds()) - } - } -} - -func TestPreprocess(t *testing.T) { - type preprocessCase struct { - TestImage image.Image - ExpectedLen int - } - - cases := []preprocessCase{ - { - TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)), - ExpectedLen: 16 * 16 * 3 * 1, - }, - { - TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)), - ExpectedLen: 1024 * 1024 * 3 * 1, - }, - } - - for _, c := range cases { - var buf bytes.Buffer - err := png.Encode(&buf, c.TestImage) - if err != nil { - t.Fatal(err) - } - - imgData, _, err := Preprocess(&buf) - if err != nil { - t.Fatalf("error processing: %q", err) - } - - switch len(imgData) { - case 0: - t.Errorf("no image data returned") - case c.ExpectedLen: - // ok - default: - t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen) - } - } -} - -func TestPreprocessImages(t *testing.T) { - for _, testFile := range []string{"flight.png", "sportsball.png"} { - f, err := os.Open(testFile) - if err != nil { - t.Skipf("skipping test, no test image found at %s", testFile) - } - defer f.Close() - - imgData, _, err := Preprocess(f) - if err != nil { - t.Fatalf("error processing: %q", err) - } - - byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes - for i, f := range imgData { - binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f)) - } - - outputPath := "processed_" + testFile + ".bin" - err = os.WriteFile(outputPath, byteData, 0o644) - if err != nil { - t.Fatalf("error writing processed image: %q", err) - } - } -} diff --git a/model/process_text.go b/model/process_text.go index 01af65b62..f0fb77872 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { continue } + if id := bpe.vocab.Encode(pair.value); id < 0 { + continue + } + merges[pair.a].runes = append(left.runes, right.runes...) merges[pair.b].runes = nil diff --git a/parser/parser.go b/parser/parser.go index 6832351fb..9a98c8ea7 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) { } var files []string - if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { + if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 { - // covers adapters.safetensors - files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 { - // covers adapter_model.safetensors - files = append(files, st...) } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { // pytorch files might also be unresolved git lfs references; skip if they are // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin