Compare commits

...

5 Commits

Author SHA1 Message Date
Bruce MacDonald
ed14ce2db8 convert mistral-3.1-2503 2025-03-18 15:16:23 -07:00
Bruce MacDonald
f94155fba2 do not add both consolidated and parts to model 2025-03-17 16:33:43 -07:00
jmorganca
8025781dce wip 2025-03-17 10:57:10 -07:00
jmorganca
afb34b0e60 wip 2025-03-17 10:56:20 -07:00
Bruce MacDonald
191b1b1eb3 model: support for mistral-small in the ollama runner
Mistral is a popular research lab making open source models. This updates
the forward pass of llama architecture models to support both llama models
and mistral models by accounting for additional metadata present in mistral
models, and finding the correct dimensions for the output projection.
2025-03-17 10:56:20 -07:00
7 changed files with 784 additions and 0 deletions

View File

@ -184,6 +184,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistralModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":

246
convert/convert_mistral.go Normal file
View File

@ -0,0 +1,246 @@
package convert
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistralModel struct {
ModelParameters
// Text model parameters
TextConfig struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
} `json:"text_config"`
// Vision model parameters
VisionConfig struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
// Multimodal specific parameters
ImageTokenIndex uint32 `json:"image_token_index"`
MultimodalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
// For RoPE scaling if needed
RopeScaling struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"`
}
func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral"
kv["mistral.vocab_size"] = p.VocabSize
kv["mistral.image_token_index"] = p.ImageTokenIndex
kv["mistral.multimodal_projector_bias"] = p.MultimodalProjectorBias
kv["mistral.projector_hidden_act"] = p.ProjectorHiddenAct
kv["mistral.spatial_merge_size"] = p.SpatialMergeSize
// kv["mistral.vision_feature_layer"] = p.VisionFeatureLayer
// Text model config
kv["mistral.block_count"] = p.TextConfig.NumHiddenLayers
kv["mistral.context_length"] = p.TextConfig.MaxPositionEmbeddings
kv["mistral.embedding_length"] = p.TextConfig.HiddenSize
kv["mistral.feed_forward_length"] = p.TextConfig.IntermediateSize
kv["mistral.attention.head_count"] = p.TextConfig.NumAttentionHeads
kv["mistral.attention.head_count_kv"] = p.TextConfig.NumKeyValueHeads
kv["mistral.rope.dimension_count"] = p.TextConfig.HiddenSize / p.TextConfig.NumAttentionHeads
kv["mistral.rope.freq_base"] = p.TextConfig.RopeTheta
kv["mistral.attention.layer_norm_rms_epsilon"] = p.TextConfig.RMSNormEPS
kv["mistral.attention.key_length"] = p.TextConfig.HeadDim
kv["mistral.attention.value_length"] = p.TextConfig.HeadDim
// Vision model config
kv["mistral.vision.block_count"] = p.VisionConfig.NumHiddenLayers
kv["mistral.vision.embedding_length"] = p.VisionConfig.HiddenSize
kv["mistral.vision.feed_forward_length"] = p.VisionConfig.IntermediateSize
kv["mistral.vision.attention.head_count"] = p.VisionConfig.NumAttentionHeads
kv["mistral.vision.image_size"] = p.VisionConfig.ImageSize
kv["mistral.vision.patch_size"] = p.VisionConfig.PatchSize
kv["mistral.vision.rope.freq_base"] = p.VisionConfig.RopeTheta
// If RoPE scaling is present
if p.RopeScaling.Type == "linear" {
kv["mistral.rope.scaling.type"] = p.RopeScaling.Type
kv["mistral.rope.scaling.factor"] = p.RopeScaling.Factor
} else if p.RopeScaling.RopeType == "llama3" {
dim := p.TextConfig.HiddenSize / p.TextConfig.NumAttentionHeads
for i := uint32(0); i < dim; i += 2 {
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
lambda := 2 * math.Pi * math.Pow(float64(p.TextConfig.RopeTheta), float64(i)/float64(dim))
if lambda < float64(lambdaHigh) {
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
} else if lambda > float64(lambdaLow) {
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
} else {
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
}
}
}
return kv
}
func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
if p.RopeScaling.factors != nil {
out = append(out, ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
WriterTo: p.RopeScaling.factors,
})
}
for _, t := range ts {
// Process tensors that require repacking
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
// Add all tensors to output
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistralModel) Replacements() []string {
return []string{
// Language model replacements
"language_model.model.embed_tokens", "token_embd",
"language_model.model.norm", "output_norm",
"language_model.model.layers", "blk",
"language_model.model.layers.*.input_layernorm", "input_layernorm",
"language_model.model.layers.*.self_attn.q_proj", "self_attn.q_proj",
"language_model.model.layers.*.self_attn.k_proj", "self_attn.k_proj",
"language_model.model.layers.*.self_attn.v_proj", "self_attn.v_proj",
"language_model.model.layers.*.self_attn.o_proj", "self_attn.o_proj",
"language_model.model.layers.*.mlp.gate_proj", "mlp.gate_proj",
"language_model.model.layers.*.mlp.down_proj", "mlp.down_proj",
"language_model.model.layers.*.mlp.up_proj", "mlp.up_proj",
"language_model.model.layers.*.post_attention_layernorm", "post_attention_layernorm",
"language_model.lm_head", "output",
// Vision model replacements - map to shorter prefixes
"vision_tower", "v",
"multi_modal_projector", "mm",
// Vision transformer blocks - these should be updated accordingly
"vision_tower.transformer.layers", "v.blk",
"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
"vision_tower.ln_pre", "v.encoder_norm",
"vision_tower.patch_conv", "v.patch_conv",
// Multimodal projector components
"multi_modal_projector.patch_merger", "mm.patch_merger",
"multi_modal_projector.norm", "mm.norm",
}
}
func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight") {
if strings.Contains(name, "vision") {
heads = p.VisionConfig.NumAttentionHeads
} else {
heads = p.TextConfig.NumAttentionHeads
}
} else if strings.HasSuffix(name, "attn_k.weight") {
if strings.Contains(name, "vision") {
heads = p.VisionConfig.NumAttentionHeads
} else {
heads = cmp.Or(p.TextConfig.NumKeyValueHeads, p.TextConfig.NumAttentionHeads)
}
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@ -0,0 +1,207 @@
package llama
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
fmt.Println("Model Parameters:")
fmt.Printf(" model_type: %q\n", "gpt2")
fmt.Printf(" vocab_size: %d\n", len(c.Strings("tokenizer.ggml.tokens")))
fmt.Printf(" hidden_size: %d\n", m.Options.hiddenSize)
fmt.Printf(" num_hidden_layers: %d\n", c.Uint("block_count"))
fmt.Printf(" num_attention_heads: %d\n", m.Options.numHeads)
fmt.Printf(" num_key_value_heads: %d\n", m.Options.numKVHeads)
fmt.Printf(" rms_norm_eps: %g\n", m.Options.eps)
fmt.Printf(" rope_theta: %g\n", m.Options.ropeBase)
fmt.Printf(" bos_token_id: %d\n", c.Uint("tokenizer.ggml.bos_token_id"))
fmt.Printf(" eos_token_id: %d\n", c.Uint("tokenizer.ggml.eos_token_id"))
fmt.Printf(" pad_token_id: %d\n", c.Uint("tokenizer.ggml.pad_token_id", 0))
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.o_proj"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
// Get head dimension - use explicit value if available, otherwise calculate
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
// Query projection and reshape
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Key projection and reshape
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Value projection and reshape
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
// Attention computation
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output for final projection
outputDim := headDim * opts.numHeads
kqv = kqv.Reshape(ctx, outputDim, batchSize)
// Apply output projection
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"mlp.up_proj"`
Down *nn.Linear `gguf:"mlp.down_proj"`
Gate *nn.Linear `gguf:"mlp.gate_proj"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// Get token embeddings
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
// Apply output normalization
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
// Apply output projection
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("mistral", New)
}

View File

@ -4,5 +4,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil

View File

@ -209,6 +209,326 @@ func TestLlama(t *testing.T) {
})
}
// tekken loads the Tekken tokenizer for testing
func tekken(t testing.TB) TextProcessor {
t.Helper()
// Load tokenizer config from mistral-small
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
configFile, err := os.Open(tokenizerConfigPath)
if err != nil {
t.Fatal(err)
}
defer configFile.Close()
var config struct {
AddBosToken bool `json:"add_bos_token"`
AddEosToken bool `json:"add_eos_token"`
BosToken struct {
Content string `json:"content"`
} `json:"bos_token"`
EosToken struct {
Content string `json:"content"`
} `json:"eos_token"`
}
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
t.Fatal(err)
}
// Load tokenizer.json which contains the vocabulary and other settings
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
tokenizerFile, err := os.Open(tokenizerJsonPath)
if err != nil {
t.Fatal(err)
}
defer tokenizerFile.Close()
var tokenizerData struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
AddedTokens []struct {
Id int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
Behavior string `json:"behavior"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
}
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
t.Fatal(err)
}
// Extract the pattern from pre_tokenizer if available
var pattern string
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
}
// Combine regular vocab and added tokens
vocab := tokenizerData.Model.Vocab
// Add special tokens from added_tokens
for _, token := range tokenizerData.AddedTokens {
vocab[token.Content] = token.Id
}
// Create vocabulary arrays
maxId := int32(-1)
for _, id := range vocab {
if id > maxId {
maxId = id
}
}
vocabSize := int(maxId + 1)
types := make([]uint32, vocabSize)
tokens := make([]string, vocabSize)
scores := make([]float32, vocabSize)
for token, id := range vocab {
tokens[id] = token
types[id] = TOKEN_TYPE_NORMAL
// Assign appropriate token types for special tokens
if token == "<s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "</s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "[INST]" || token == "[/INST]" {
types[id] = TOKEN_TYPE_CONTROL
}
}
// In Tekken, we don't need to load merges separately as they're part of the model
var merges []string
// Create vocabulary object
vocabObj := &Vocabulary{
Values: tokens,
Types: types,
Scores: scores,
Merges: merges,
BOS: vocab[config.BosToken.Content],
EOS: vocab[config.EosToken.Content],
AddBOS: config.AddBosToken,
AddEOS: config.AddEosToken,
}
// Use pattern from tokenizer.json if available
if pattern != "" {
// Ensure pattern has proper escaping for Go regexp
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
return NewBytePairEncoding(pattern, vocabObj)
}
// Fallback pattern if not found
return NewBytePairEncoding(
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
vocabObj,
)
}
func TestTekken(t *testing.T) {
// Skip if the test data isn't available
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
t.Skip("Mistral-small test data not available")
}
tokenizer := tekken(t)
t.Run("whitespace_handling", func(t *testing.T) {
t.Parallel()
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
cases := []struct {
input string
expected string
}{
{" hello", " hello"},
{"hello ", "hello "},
{"hello world", "hello world"},
{" hello world ", " hello world "},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
if decoded != tc.expected {
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
}
}
})
t.Run("chat_templates", func(t *testing.T) {
t.Parallel()
// Test the Tekken chat template format which doesn't have spaces after special tokens
templates := []struct {
input string
expectSpace bool // whether we expect a space after special tokens
}{
{"<s>[INST]user message[/INST]", false},
{"<s>[INST] user message[/INST]", true},
{"<s>[INST]user message [/INST]", true},
}
for _, tc := range templates {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
// Check if there's a space after special tokens
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
if hasSpaceAfterINST != tc.expectSpace {
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
hasSpaceAfterINST, tc.expectSpace, tc.input)
}
}
})
t.Run("special_tokens", func(t *testing.T) {
t.Parallel()
// Test how Tekken handles special tokens
cases := []struct {
input string
expected []string // We'll check if these tokens are in the decoded output
}{
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
for _, expected := range tc.expected {
if !strings.Contains(decoded, expected) {
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
}
}
}
})
t.Run("vocabulary_coverage", func(t *testing.T) {
t.Parallel()
// Tekken has a larger vocabulary, so test coverage of various token types
samples := []string{
"Hello world!",
"This is a test of the Tekken tokenizer.",
"It has a considerably larger vocabulary size.",
"Special characters: !@#$%^&*()",
"Numbers: 1234567890",
"Multiple languages: こんにちは 你好 안녕하세요",
"Code snippets: def function(): return True",
}
for _, sample := range samples {
ids, err := tokenizer.Encode(sample, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", sample, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
continue
}
if decoded != sample {
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
}
}
})
t.Run("splitting_behavior", func(t *testing.T) {
t.Parallel()
// Test the splitting behavior which might differ from SentencePiece
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"user message": {"user", " message"},
"[INST]hello": {"[INST]", "hello"},
"hello[/INST]": {"hello", "[/INST]"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
}
}
})
t.Run("full_chat_sequence", func(t *testing.T) {
t.Parallel()
// Test a complete chat sequence with Tekken's format
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
ids, err := tokenizer.Encode(chatSequence, false)
if err != nil {
t.Fatalf("Failed to encode chat sequence: %v", err)
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
}
// In Tekken, the whitespace shouldn't be added after special tokens
if strings.Contains(decoded, "[INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
}
if strings.Contains(decoded, "[/INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))

View File

@ -179,6 +179,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
return nil, nil, err
}
for _, t := range tokens {
decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t})
fmt.Println("token", t, "decoded", decoded)
}
for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t})
}