Compare commits

...

16 Commits

Author SHA1 Message Date
jmorganca
cfeca27133 wip 2025-03-23 01:01:23 -07:00
jmorganca
4530661799 wip 2025-03-22 23:20:39 -07:00
jmorganca
8dd2a81f8c wip 2025-03-22 22:33:39 -07:00
jmorganca
caddb1e4cf rebased 2025-03-22 10:15:52 -07:00
Bruce MacDonald
4d8dac8ffc wip 2025-03-22 10:03:23 -07:00
Bruce MacDonald
63e6509ec0 vision conversion 2025-03-22 10:03:22 -07:00
Bruce MacDonald
6f34126dcc image processing 2025-03-22 10:03:22 -07:00
Bruce MacDonald
ecc0ef468f split text model to its own file 2025-03-22 10:03:22 -07:00
Bruce MacDonald
9b57238834 ... 2025-03-22 10:03:22 -07:00
Bruce MacDonald
3b4ad00a4b mistral3 arch 2025-03-22 10:03:22 -07:00
Bruce MacDonald
9a12fd1067 wip: test fixes 2025-03-22 10:03:22 -07:00
Bruce MacDonald
edac05387f convert: mistral-3.1-2503 text component 2025-03-22 10:03:22 -07:00
Bruce MacDonald
e65cf9dc94 minimal convert 2025-03-22 10:03:22 -07:00
jmorganca
7e3c62f388 wip 2025-03-22 10:03:22 -07:00
jmorganca
a75703b2cc wip 2025-03-22 10:03:22 -07:00
Bruce MacDonald
c24e8860c1 model: support for mistral-small in the ollama runner
Mistral is a popular research lab making open source models. This updates
the forward pass of llama architecture models to support both llama models
and mistral models by accounting for additional metadata present in mistral
models, and finding the correct dimensions for the output projection.
2025-03-22 10:03:22 -07:00
20 changed files with 1228089 additions and 34 deletions

View File

@ -182,8 +182,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
var conv ModelConverter
switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM":
case "LlamaForCausalLM":
conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":
@ -246,5 +248,10 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
return err
}
// iterate through all ts and print the name
for _, t := range ts {
fmt.Print(t.Name(), "\n")
}
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
}

194
convert/convert_mistral.go Normal file
View File

@ -0,0 +1,194 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3Model struct {
ModelParameters
ImageTokenIndex uint32 `json:"image_token_index"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
TextModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
PatchSize uint32 `json:"patch_size"`
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
}
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
if p.ProjectorHiddenAct != "" {
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
}
return kv
}
func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
// Skip certain vision model tensors that might need special handling
if strings.HasPrefix(t.Name(), "patch_merger.") || strings.HasPrefix(t.Name(), "pre_mm_projector_output_norm.") {
continue
}
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3Model) Replacements() []string {
return []string{
"language_model.model.norm", "output_norm",
"language_model.model.", "",
"language_model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"patch_merger.merging_layer", "merger",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight") {
heads = p.TextModel.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@ -62,10 +62,7 @@ func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
Pattern string
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
}{
{"model-*-of-*.safetensors", parseSafetensors},
{"model.safetensors", parseSafetensors},
{"adapters.safetensors", parseSafetensors},
{"adapter_model.safetensors", parseSafetensors},
{"*.safetensors", parseSafetensors},
{"pytorch_model-*-of-*.bin", parseTorch},
{"pytorch_model.bin", parseTorch},
{"consolidated.*.pth", parseTorch},

View File

@ -144,6 +144,9 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor

View File

@ -958,6 +958,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
}
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
(*C.int)(unsafe.Pointer(&sections[0])),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
),
}
}
func (t *Tensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_im2col(ctx.(*Context).ctx, t.t, weight.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32),
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,

View File

@ -2186,6 +2186,10 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_MUL_MAT:
{
if (ne00 != ne10) {
printf("mul_mat, ne00: %d, ne01: %d, ne02: %d, ne03: %d, ne10: %d, ne11: %d, ne12: %d, ne13: %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13);
}
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne12 % ne02 == 0);

View File

@ -10,7 +10,7 @@ import (
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
@ -27,7 +27,7 @@ type TextModel struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
*TextConfig
}
const (
@ -55,7 +55,7 @@ func newTextModel(c ml.Config) *TextModel {
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
@ -84,7 +84,7 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
ropeBase = m.TextConfig.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
}
type TextMLP struct {
@ -134,7 +134,7 @@ type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@ -148,7 +148,7 @@ type TextLayer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
var except []int
@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@ -13,9 +13,9 @@ import (
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
// TODO: need to set this in the conversion for mistral:
// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) {
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
@ -75,24 +78,36 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
// Get head dimension - use explicit value if available, otherwise calculate
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
// Query projection and reshape
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Key projection and reshape
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Value projection and reshape
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
// Attention computation
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
// Reshape attention output for final projection
outputDim := headDim * opts.numHeads
kqv = kqv.Reshape(ctx, outputDim, batchSize)
// Apply output projection
return sa.Output.Forward(ctx, kqv)
}

View File

@ -1,4 +1,4 @@
package pixtral
package mistral3
import (
"fmt"
@ -8,6 +8,7 @@ import (
"io"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
@ -20,15 +21,14 @@ func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
b := img.Bounds()
le := float64(longestEdge)
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
ratio := math.Max(float64(b.Max.Y)/float64(longestEdge), float64(b.Max.X)/float64(longestEdge))
newSize := img.Bounds().Max
if ratio > 1.0 {
newSize = image.Point{
int(math.Ceil(float64(b.Max.X) / ratio)),
int(math.Ceil(float64(b.Max.Y) / ratio)),
int(math.Floor(float64(b.Max.X) / ratio)),
int(math.Floor(float64(b.Max.Y) / ratio)),
}
}
@ -66,3 +66,27 @@ func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
opts := map[string]any{}
return data, opts, nil
}
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
longestEdge int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1540)),
}
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, nil
}

View File

@ -1,4 +1,4 @@
package pixtral
package mistral3
import (
"bytes"

View File

@ -0,0 +1,120 @@
package mistral3
import (
"bytes"
"image"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
*TextModel
*VisionModel `gguf:"v,vision"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c ml.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
if err != nil {
return nil, err
}
m := &Model{
TextModel: textModel,
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
// Create tensor from image data
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
1036, // TODO (jmorganca): this should be returned from ProcessImage
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
// fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues))
// Forward pass through vision model
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
// fmt.Println("visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
// Project to text embedding space
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
// fmt.Println("visionOutputs after projector", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
return visionOutputs, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.(ml.Tensor)
// Add special image tokens - using the imageTokenIndex from config
result = append(result, input.Input{Token: 10}) // [IMG]
result = append(result, input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}) // image data
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, inputMultimodal.Dim(1)-1)...) // [IMG] placeholders
result = append(result, input.Input{Token: 13}) // [IMG_END]
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("mistral3", New)
}

View File

@ -0,0 +1,171 @@
package mistral3
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
// Get head dimension - use explicit value if available, otherwise calculate
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
// Query projection and reshape
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Key projection and reshape
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Value projection and reshape
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
// Attention computation
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output for final projection
outputDim := headDim * opts.numHeads
kqv = kqv.Reshape(ctx, outputDim, batchSize)
// Apply output projection
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
// Process text inputs
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
// Process through text transformer layers
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func NewTextModel(c ml.Config) (*TextModel, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
textModel := &TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
return textModel, nil
}

View File

@ -0,0 +1,201 @@
package mistral3
import (
"fmt"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type PatchMerger struct {
MergingLayer *nn.Linear `gguf:"merging_layer"`
}
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
// TODO: pass these in
w := 110
h := 74
// tokensPerImage := w * h
d := visionOutputs.Dim(0)
// TODO: handle multiple images, this currently assumes one
// fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
// Reshape to [h, w, hidden_size]
imageGrid := visionOutputs.Reshape(ctx, h, w, d)
// fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid))
// TODO: load from config
spatialMergeSize := 2
kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1)
// fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel))
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
// fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches))
// fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2))
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
// fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped))
return pm.MergingLayer.Forward(ctx, reshaped)
}
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"linear_1"`
Linear2 *nn.Linear `gguf:"linear_2"`
PatchMerger *PatchMerger `gguf:"patch_merger"`
spatialMergeSize int
imageTokenIndex int
hasBias bool
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
visionOutputs = visionOutputs.GELU(ctx)
return p.Linear2.Forward(ctx, visionOutputs)
}
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
imageTokenIndex: int(c.Uint("image_token_index", 10)),
hasBias: c.Bool("mm.projector_bias", false),
}
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
ropeType := uint32(24) // 2d vision rope
q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeBase float32
ropeScale float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesH := pixelValues.Dim(1) / m.patchSize
numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatches := numPatchesH * numPatchesW
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
totalPositions := numPatchesH * numPatchesW
positions := make([]int32, totalPositions*4)
for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ {
index := h*numPatchesW + w
positions[totalPositions+index] = int32(h)
positions[totalPositions*2+index] = int32(w)
}
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
}
// fmt.Println("after layers", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
headDim: int(c.Uint("vision.attention.key_length", 64)),
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
ropeScale: c.Float("vision.rope.freq_scale", 1.0),
},
}
}

View File

@ -4,5 +4,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil

View File

@ -209,6 +209,322 @@ func TestLlama(t *testing.T) {
})
}
// tekken loads the Tekken tokenizer for testing
func tekken(t testing.TB) TextProcessor {
t.Helper()
// Load tokenizer config from mistral-small
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
configFile, err := os.Open(tokenizerConfigPath)
if err != nil {
t.Fatal(err)
}
defer configFile.Close()
var config struct {
AddBosToken bool `json:"add_bos_token"`
AddEosToken bool `json:"add_eos_token"`
BosToken string `json:"bos_token"`
EosToken string `json:"eos_token"`
}
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
t.Fatal(err)
}
// Load tokenizer.json which contains the vocabulary and other settings
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
tokenizerFile, err := os.Open(tokenizerJsonPath)
if err != nil {
t.Fatal(err)
}
defer tokenizerFile.Close()
var tokenizerData struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
AddedTokens []struct {
Id int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
Behavior string `json:"behavior"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
}
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
t.Fatal(err)
}
// Extract the pattern from pre_tokenizer if available
var pattern string
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
}
// Combine regular vocab and added tokens
vocab := tokenizerData.Model.Vocab
// Add special tokens from added_tokens
for _, token := range tokenizerData.AddedTokens {
vocab[token.Content] = token.Id
}
// Create vocabulary arrays
maxId := int32(-1)
for _, id := range vocab {
if id > maxId {
maxId = id
}
}
vocabSize := int(maxId + 1)
types := make([]uint32, vocabSize)
tokens := make([]string, vocabSize)
scores := make([]float32, vocabSize)
for token, id := range vocab {
tokens[id] = token
types[id] = TOKEN_TYPE_NORMAL
// Assign appropriate token types for special tokens
if token == "<s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "</s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "[INST]" || token == "[/INST]" {
types[id] = TOKEN_TYPE_CONTROL
}
}
// In Tekken, we don't need to load merges separately as they're part of the model
var merges []string
// Create vocabulary object
vocabObj := &Vocabulary{
Values: tokens,
Types: types,
Scores: scores,
Merges: merges,
BOS: vocab[config.BosToken],
EOS: vocab[config.EosToken],
AddBOS: config.AddBosToken,
AddEOS: config.AddEosToken,
}
// Use pattern from tokenizer.json if available
if pattern != "" {
// Ensure pattern has proper escaping for Go regexp
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
return NewBytePairEncoding(pattern, vocabObj)
}
// Fallback pattern if not found
return NewBytePairEncoding(
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
vocabObj,
)
}
func TestTekken(t *testing.T) {
// Skip if the test data isn't available
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
t.Skip("Mistral-small test data not available")
}
tokenizer := tekken(t)
t.Run("whitespace_handling", func(t *testing.T) {
t.Parallel()
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
cases := []struct {
input string
expected string
}{
{" hello", " hello"},
{"hello ", "hello "},
{"hello world", "hello world"},
{" hello world ", " hello world "},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
if decoded != tc.expected {
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
}
}
})
t.Run("chat_templates", func(t *testing.T) {
t.Parallel()
// Test the Tekken chat template format which doesn't have spaces after special tokens
templates := []struct {
input string
expectSpace bool // whether we expect a space after special tokens
}{
{"<s>[INST]user message[/INST]", false},
{"<s>[INST] user message[/INST]", true},
{"<s>[INST]user message [/INST]", true},
}
for _, tc := range templates {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
// Check if there's a space after special tokens
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
if hasSpaceAfterINST != tc.expectSpace {
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
hasSpaceAfterINST, tc.expectSpace, tc.input)
}
}
})
t.Run("special_tokens", func(t *testing.T) {
t.Parallel()
// Test how Tekken handles special tokens
cases := []struct {
input string
expected []string // We'll check if these tokens are in the decoded output
}{
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
for _, expected := range tc.expected {
if !strings.Contains(decoded, expected) {
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
}
}
}
})
t.Run("vocabulary_coverage", func(t *testing.T) {
t.Parallel()
// Tekken has a larger vocabulary, so test coverage of various token types
samples := []string{
"Hello world!",
"This is a test of the Tekken tokenizer.",
"It has a considerably larger vocabulary size.",
"Special characters: !@#$%^&*()",
"Numbers: 1234567890",
"Multiple languages: こんにちは 你好 안녕하세요",
"Code snippets: def function(): return True",
}
for _, sample := range samples {
ids, err := tokenizer.Encode(sample, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", sample, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
continue
}
if decoded != sample {
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
}
}
})
t.Run("splitting_behavior", func(t *testing.T) {
t.Parallel()
// Test the splitting behavior which might differ from SentencePiece
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"user message": {"user", " message"},
"[INST]hello": {"[INST]", "hello"},
"hello[/INST]": {"hello", "[/INST]"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
}
}
})
t.Run("full_chat_sequence", func(t *testing.T) {
t.Parallel()
// Test a complete chat sequence with Tekken's format
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
ids, err := tokenizer.Encode(chatSequence, false)
if err != nil {
t.Fatalf("Failed to encode chat sequence: %v", err)
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
}
// In Tekken, the whitespace shouldn't be added after special tokens
if strings.Contains(decoded, "[INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
}
if strings.Contains(decoded, "[/INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))

1217945
model/testdata/mistral-small/tokenizer.json vendored Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -211,16 +211,10 @@ func filesForModel(path string) ([]string, error) {
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin

View File

@ -182,6 +182,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
return nil, nil, err
}
for _, t := range tokens {
decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t})
fmt.Println("token", t, "decoded", decoded)
}
for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t})
}