diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 19a2ab8c4..47a88043e 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -13,9 +13,9 @@ import ( ) type Options struct { - hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads, headDim int + eps, ropeBase, ropeScale float32 + ropeDim uint32 } type Model struct { @@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( + // TODO: need to set this in the conversion for mistral: + // tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+ c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), @@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), @@ -75,24 +78,36 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - headDim := opts.hiddenSize / opts.numHeads ropeType := uint32(0) + // Get head dimension - use explicit value if available, otherwise calculate + headDim := opts.headDim + if headDim == 0 { + headDim = opts.hiddenSize / opts.numHeads + } + // Query projection and reshape q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + // Key projection and reshape k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + // Value projection and reshape v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + // Attention computation scaleFactor := 1.0 / math.Sqrt(float64(headDim)) kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + // Reshape attention output for final projection + outputDim := headDim * opts.numHeads + kqv = kqv.Reshape(ctx, outputDim, batchSize) + + // Apply output projection return sa.Output.Forward(ctx, kqv) } diff --git a/model/process_text_test.go b/model/process_text_test.go index f48303212..8654f6d27 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -209,6 +209,326 @@ func TestLlama(t *testing.T) { }) } +// tekken loads the Tekken tokenizer for testing +func tekken(t testing.TB) TextProcessor { + t.Helper() + + // Load tokenizer config from mistral-small + tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json") + configFile, err := os.Open(tokenizerConfigPath) + if err != nil { + t.Fatal(err) + } + defer configFile.Close() + + var config struct { + AddBosToken bool `json:"add_bos_token"` + AddEosToken bool `json:"add_eos_token"` + BosToken struct { + Content string `json:"content"` + } `json:"bos_token"` + EosToken struct { + Content string `json:"content"` + } `json:"eos_token"` + } + if err := json.NewDecoder(configFile).Decode(&config); err != nil { + t.Fatal(err) + } + + // Load tokenizer.json which contains the vocabulary and other settings + tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json") + tokenizerFile, err := os.Open(tokenizerJsonPath) + if err != nil { + t.Fatal(err) + } + defer tokenizerFile.Close() + + var tokenizerData struct { + Model struct { + Type string `json:"type"` + Vocab map[string]int32 `json:"vocab"` + Merges []string `json:"merges"` + } `json:"model"` + AddedTokens []struct { + Id int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } `json:"added_tokens"` + PreTokenizer struct { + Type string `json:"type"` + Pretokenizers []struct { + Type string `json:"type"` + Pattern struct { + String string `json:"String"` + } `json:"pattern"` + Behavior string `json:"behavior"` + } `json:"pretokenizers"` + } `json:"pre_tokenizer"` + } + if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil { + t.Fatal(err) + } + + // Extract the pattern from pre_tokenizer if available + var pattern string + if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 { + pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String + } + + // Combine regular vocab and added tokens + vocab := tokenizerData.Model.Vocab + + // Add special tokens from added_tokens + for _, token := range tokenizerData.AddedTokens { + vocab[token.Content] = token.Id + } + + // Create vocabulary arrays + maxId := int32(-1) + for _, id := range vocab { + if id > maxId { + maxId = id + } + } + + vocabSize := int(maxId + 1) + types := make([]uint32, vocabSize) + tokens := make([]string, vocabSize) + scores := make([]float32, vocabSize) + + for token, id := range vocab { + tokens[id] = token + types[id] = TOKEN_TYPE_NORMAL + + // Assign appropriate token types for special tokens + if token == "" { + types[id] = TOKEN_TYPE_CONTROL + } else if token == "" { + types[id] = TOKEN_TYPE_CONTROL + } else if token == "[INST]" || token == "[/INST]" { + types[id] = TOKEN_TYPE_CONTROL + } + } + + // In Tekken, we don't need to load merges separately as they're part of the model + var merges []string + + // Create vocabulary object + vocabObj := &Vocabulary{ + Values: tokens, + Types: types, + Scores: scores, + Merges: merges, + BOS: vocab[config.BosToken.Content], + EOS: vocab[config.EosToken.Content], + AddBOS: config.AddBosToken, + AddEOS: config.AddEosToken, + } + + // Use pattern from tokenizer.json if available + if pattern != "" { + // Ensure pattern has proper escaping for Go regexp + pattern = strings.ReplaceAll(pattern, "p{", "\\p{") + return NewBytePairEncoding(pattern, vocabObj) + } + + // Fallback pattern if not found + return NewBytePairEncoding( + `\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`, + vocabObj, + ) +} + +func TestTekken(t *testing.T) { + // Skip if the test data isn't available + if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) { + t.Skip("Mistral-small test data not available") + } + + tokenizer := tekken(t) + + t.Run("whitespace_handling", func(t *testing.T) { + t.Parallel() + + // The key difference from SentencePiece is that Tekken doesn't prepend whitespace + cases := []struct { + input string + expected string + }{ + {" hello", " hello"}, + {"hello ", "hello "}, + {"hello world", "hello world"}, + {" hello world ", " hello world "}, + } + + for _, tc := range cases { + ids, err := tokenizer.Encode(tc.input, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", tc.input, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) + continue + } + + if decoded != tc.expected { + t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected) + } + } + }) + + t.Run("chat_templates", func(t *testing.T) { + t.Parallel() + + // Test the Tekken chat template format which doesn't have spaces after special tokens + templates := []struct { + input string + expectSpace bool // whether we expect a space after special tokens + }{ + {"[INST]user message[/INST]", false}, + {"[INST] user message[/INST]", true}, + {"[INST]user message [/INST]", true}, + } + + for _, tc := range templates { + ids, err := tokenizer.Encode(tc.input, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", tc.input, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) + continue + } + + // Check if there's a space after special tokens + hasSpaceAfterINST := strings.Contains(decoded, "[INST] ") + + if hasSpaceAfterINST != tc.expectSpace { + t.Errorf("Chat template space handling: got space=%v, want space=%v for %q", + hasSpaceAfterINST, tc.expectSpace, tc.input) + } + } + }) + + t.Run("special_tokens", func(t *testing.T) { + t.Parallel() + + // Test how Tekken handles special tokens + cases := []struct { + input string + expected []string // We'll check if these tokens are in the decoded output + }{ + {"[INST]hello[/INST]", []string{"", "[INST]", "hello", "[/INST]"}}, + {"[INST]hello[/INST]", []string{"[INST]", "hello", "[/INST]", ""}}, + {"[INST]hello[/INST][INST]again[/INST]", []string{"", "[INST]", "hello", "[/INST]", "", "[INST]", "again", "[/INST]"}}, + } + + for _, tc := range cases { + ids, err := tokenizer.Encode(tc.input, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", tc.input, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) + continue + } + + for _, expected := range tc.expected { + if !strings.Contains(decoded, expected) { + t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded) + } + } + } + }) + + t.Run("vocabulary_coverage", func(t *testing.T) { + t.Parallel() + + // Tekken has a larger vocabulary, so test coverage of various token types + samples := []string{ + "Hello world!", + "This is a test of the Tekken tokenizer.", + "It has a considerably larger vocabulary size.", + "Special characters: !@#$%^&*()", + "Numbers: 1234567890", + "Multiple languages: こんにちは 你好 안녕하세요", + "Code snippets: def function(): return True", + } + + for _, sample := range samples { + ids, err := tokenizer.Encode(sample, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", sample, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", sample, err) + continue + } + + if decoded != sample { + t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample) + } + } + }) + + t.Run("splitting_behavior", func(t *testing.T) { + t.Parallel() + + // Test the splitting behavior which might differ from SentencePiece + cases := map[string][]string{ + "Hello World!": {"Hello", " World", "!"}, + "user message": {"user", " message"}, + "[INST]hello": {"[INST]", "hello"}, + "hello[/INST]": {"hello", "[/INST]"}, + } + + for s, want := range cases { + got := slices.Collect(tokenizer.(*BytePairEncoding).split(s)) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Splitting behavior no match (-want +got):\n%s", diff) + } + } + }) + + t.Run("full_chat_sequence", func(t *testing.T) { + t.Parallel() + + // Test a complete chat sequence with Tekken's format + chatSequence := "[INST]user message[/INST]assistant message[INST]new user message[/INST]" + + ids, err := tokenizer.Encode(chatSequence, false) + if err != nil { + t.Fatalf("Failed to encode chat sequence: %v", err) + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Fatalf("Failed to decode chat sequence tokens: %v", err) + } + + // In Tekken, the whitespace shouldn't be added after special tokens + if strings.Contains(decoded, "[INST] ") { + t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded) + } + + if strings.Contains(decoded, "[/INST] ") { + t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded) + } + }) +} + func BenchmarkBytePairEncoding(b *testing.B) { tokenizer := llama(b) bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))