diff --git a/model/models/llama/model.go b/model/models/llama/model.go
index 19a2ab8c4..47a88043e 100644
--- a/model/models/llama/model.go
+++ b/model/models/llama/model.go
@@ -13,9 +13,9 @@ import (
)
type Options struct {
- hiddenSize, numHeads, numKVHeads int
- eps, ropeBase, ropeScale float32
- ropeDim uint32
+ hiddenSize, numHeads, numKVHeads, headDim int
+ eps, ropeBase, ropeScale float32
+ ropeDim uint32
}
type Model struct {
@@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
+ // TODO: need to set this in the conversion for mistral:
+ // tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
@@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) {
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
+ headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
@@ -75,24 +78,36 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
- headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
+ // Get head dimension - use explicit value if available, otherwise calculate
+ headDim := opts.headDim
+ if headDim == 0 {
+ headDim = opts.hiddenSize / opts.numHeads
+ }
+ // Query projection and reshape
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
+ // Key projection and reshape
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
+ // Value projection and reshape
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
+ // Attention computation
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
- kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
+ // Reshape attention output for final projection
+ outputDim := headDim * opts.numHeads
+ kqv = kqv.Reshape(ctx, outputDim, batchSize)
+
+ // Apply output projection
return sa.Output.Forward(ctx, kqv)
}
diff --git a/model/process_text_test.go b/model/process_text_test.go
index f48303212..8654f6d27 100644
--- a/model/process_text_test.go
+++ b/model/process_text_test.go
@@ -209,6 +209,326 @@ func TestLlama(t *testing.T) {
})
}
+// tekken loads the Tekken tokenizer for testing
+func tekken(t testing.TB) TextProcessor {
+ t.Helper()
+
+ // Load tokenizer config from mistral-small
+ tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
+ configFile, err := os.Open(tokenizerConfigPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer configFile.Close()
+
+ var config struct {
+ AddBosToken bool `json:"add_bos_token"`
+ AddEosToken bool `json:"add_eos_token"`
+ BosToken struct {
+ Content string `json:"content"`
+ } `json:"bos_token"`
+ EosToken struct {
+ Content string `json:"content"`
+ } `json:"eos_token"`
+ }
+ if err := json.NewDecoder(configFile).Decode(&config); err != nil {
+ t.Fatal(err)
+ }
+
+ // Load tokenizer.json which contains the vocabulary and other settings
+ tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
+ tokenizerFile, err := os.Open(tokenizerJsonPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tokenizerFile.Close()
+
+ var tokenizerData struct {
+ Model struct {
+ Type string `json:"type"`
+ Vocab map[string]int32 `json:"vocab"`
+ Merges []string `json:"merges"`
+ } `json:"model"`
+ AddedTokens []struct {
+ Id int32 `json:"id"`
+ Content string `json:"content"`
+ Special bool `json:"special"`
+ } `json:"added_tokens"`
+ PreTokenizer struct {
+ Type string `json:"type"`
+ Pretokenizers []struct {
+ Type string `json:"type"`
+ Pattern struct {
+ String string `json:"String"`
+ } `json:"pattern"`
+ Behavior string `json:"behavior"`
+ } `json:"pretokenizers"`
+ } `json:"pre_tokenizer"`
+ }
+ if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
+ t.Fatal(err)
+ }
+
+ // Extract the pattern from pre_tokenizer if available
+ var pattern string
+ if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
+ pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
+ }
+
+ // Combine regular vocab and added tokens
+ vocab := tokenizerData.Model.Vocab
+
+ // Add special tokens from added_tokens
+ for _, token := range tokenizerData.AddedTokens {
+ vocab[token.Content] = token.Id
+ }
+
+ // Create vocabulary arrays
+ maxId := int32(-1)
+ for _, id := range vocab {
+ if id > maxId {
+ maxId = id
+ }
+ }
+
+ vocabSize := int(maxId + 1)
+ types := make([]uint32, vocabSize)
+ tokens := make([]string, vocabSize)
+ scores := make([]float32, vocabSize)
+
+ for token, id := range vocab {
+ tokens[id] = token
+ types[id] = TOKEN_TYPE_NORMAL
+
+ // Assign appropriate token types for special tokens
+ if token == "" {
+ types[id] = TOKEN_TYPE_CONTROL
+ } else if token == "" {
+ types[id] = TOKEN_TYPE_CONTROL
+ } else if token == "[INST]" || token == "[/INST]" {
+ types[id] = TOKEN_TYPE_CONTROL
+ }
+ }
+
+ // In Tekken, we don't need to load merges separately as they're part of the model
+ var merges []string
+
+ // Create vocabulary object
+ vocabObj := &Vocabulary{
+ Values: tokens,
+ Types: types,
+ Scores: scores,
+ Merges: merges,
+ BOS: vocab[config.BosToken.Content],
+ EOS: vocab[config.EosToken.Content],
+ AddBOS: config.AddBosToken,
+ AddEOS: config.AddEosToken,
+ }
+
+ // Use pattern from tokenizer.json if available
+ if pattern != "" {
+ // Ensure pattern has proper escaping for Go regexp
+ pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
+ return NewBytePairEncoding(pattern, vocabObj)
+ }
+
+ // Fallback pattern if not found
+ return NewBytePairEncoding(
+ `\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
+ vocabObj,
+ )
+}
+
+func TestTekken(t *testing.T) {
+ // Skip if the test data isn't available
+ if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
+ t.Skip("Mistral-small test data not available")
+ }
+
+ tokenizer := tekken(t)
+
+ t.Run("whitespace_handling", func(t *testing.T) {
+ t.Parallel()
+
+ // The key difference from SentencePiece is that Tekken doesn't prepend whitespace
+ cases := []struct {
+ input string
+ expected string
+ }{
+ {" hello", " hello"},
+ {"hello ", "hello "},
+ {"hello world", "hello world"},
+ {" hello world ", " hello world "},
+ }
+
+ for _, tc := range cases {
+ ids, err := tokenizer.Encode(tc.input, false)
+ if err != nil {
+ t.Errorf("Failed to encode %q: %v", tc.input, err)
+ continue
+ }
+
+ decoded, err := tokenizer.Decode(ids)
+ if err != nil {
+ t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
+ continue
+ }
+
+ if decoded != tc.expected {
+ t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
+ }
+ }
+ })
+
+ t.Run("chat_templates", func(t *testing.T) {
+ t.Parallel()
+
+ // Test the Tekken chat template format which doesn't have spaces after special tokens
+ templates := []struct {
+ input string
+ expectSpace bool // whether we expect a space after special tokens
+ }{
+ {"[INST]user message[/INST]", false},
+ {"[INST] user message[/INST]", true},
+ {"[INST]user message [/INST]", true},
+ }
+
+ for _, tc := range templates {
+ ids, err := tokenizer.Encode(tc.input, false)
+ if err != nil {
+ t.Errorf("Failed to encode %q: %v", tc.input, err)
+ continue
+ }
+
+ decoded, err := tokenizer.Decode(ids)
+ if err != nil {
+ t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
+ continue
+ }
+
+ // Check if there's a space after special tokens
+ hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
+
+ if hasSpaceAfterINST != tc.expectSpace {
+ t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
+ hasSpaceAfterINST, tc.expectSpace, tc.input)
+ }
+ }
+ })
+
+ t.Run("special_tokens", func(t *testing.T) {
+ t.Parallel()
+
+ // Test how Tekken handles special tokens
+ cases := []struct {
+ input string
+ expected []string // We'll check if these tokens are in the decoded output
+ }{
+ {"[INST]hello[/INST]", []string{"", "[INST]", "hello", "[/INST]"}},
+ {"[INST]hello[/INST]", []string{"[INST]", "hello", "[/INST]", ""}},
+ {"[INST]hello[/INST][INST]again[/INST]", []string{"", "[INST]", "hello", "[/INST]", "", "[INST]", "again", "[/INST]"}},
+ }
+
+ for _, tc := range cases {
+ ids, err := tokenizer.Encode(tc.input, false)
+ if err != nil {
+ t.Errorf("Failed to encode %q: %v", tc.input, err)
+ continue
+ }
+
+ decoded, err := tokenizer.Decode(ids)
+ if err != nil {
+ t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
+ continue
+ }
+
+ for _, expected := range tc.expected {
+ if !strings.Contains(decoded, expected) {
+ t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
+ }
+ }
+ }
+ })
+
+ t.Run("vocabulary_coverage", func(t *testing.T) {
+ t.Parallel()
+
+ // Tekken has a larger vocabulary, so test coverage of various token types
+ samples := []string{
+ "Hello world!",
+ "This is a test of the Tekken tokenizer.",
+ "It has a considerably larger vocabulary size.",
+ "Special characters: !@#$%^&*()",
+ "Numbers: 1234567890",
+ "Multiple languages: こんにちは 你好 안녕하세요",
+ "Code snippets: def function(): return True",
+ }
+
+ for _, sample := range samples {
+ ids, err := tokenizer.Encode(sample, false)
+ if err != nil {
+ t.Errorf("Failed to encode %q: %v", sample, err)
+ continue
+ }
+
+ decoded, err := tokenizer.Decode(ids)
+ if err != nil {
+ t.Errorf("Failed to decode tokens for %q: %v", sample, err)
+ continue
+ }
+
+ if decoded != sample {
+ t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
+ }
+ }
+ })
+
+ t.Run("splitting_behavior", func(t *testing.T) {
+ t.Parallel()
+
+ // Test the splitting behavior which might differ from SentencePiece
+ cases := map[string][]string{
+ "Hello World!": {"Hello", " World", "!"},
+ "user message": {"user", " message"},
+ "[INST]hello": {"[INST]", "hello"},
+ "hello[/INST]": {"hello", "[/INST]"},
+ }
+
+ for s, want := range cases {
+ got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
+ }
+ }
+ })
+
+ t.Run("full_chat_sequence", func(t *testing.T) {
+ t.Parallel()
+
+ // Test a complete chat sequence with Tekken's format
+ chatSequence := "[INST]user message[/INST]assistant message[INST]new user message[/INST]"
+
+ ids, err := tokenizer.Encode(chatSequence, false)
+ if err != nil {
+ t.Fatalf("Failed to encode chat sequence: %v", err)
+ }
+
+ decoded, err := tokenizer.Decode(ids)
+ if err != nil {
+ t.Fatalf("Failed to decode chat sequence tokens: %v", err)
+ }
+
+ // In Tekken, the whitespace shouldn't be added after special tokens
+ if strings.Contains(decoded, "[INST] ") {
+ t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
+ }
+
+ if strings.Contains(decoded, "[/INST] ") {
+ t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
+ }
+ })
+}
+
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))