diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 67c69ee86..b8f5f0666 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -38,7 +38,6 @@ const ( func New(c ml.Config) (model.Model, error) { m := Model{ SentencePieceModel: model.NewSentencePieceModel( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 567ad1a45..f9c53343a 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -55,7 +55,6 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c ml.Config) (model.Model, error) { m := Model{ SentencePieceModel: model.NewSentencePieceModel( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7d8b6577e..7b2b83c02 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -45,7 +45,6 @@ func newTextModel(c ml.Config) *TextModel { m := TextModel{ SentencePieceModel: model.NewSentencePieceModel( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/process_text_spm.go b/model/process_text_spm.go index 68e3ed015..c6e08dbd4 100644 --- a/model/process_text_spm.go +++ b/model/process_text_spm.go @@ -1,29 +1,23 @@ package model import ( - "iter" + "container/heap" + "fmt" "log/slog" + "strconv" "strings" - - "github.com/dlclark/regexp2" - queue "github.com/emirpasic/gods/v2/queues/priorityqueue" ) const spmWhitespaceSep = "▁" -func replaceWhitespaceBySeperator(s string) string { - return strings.ReplaceAll(s, " ", spmWhitespaceSep) -} - type SentencePieceModel struct { maxTokenLen int - pre *regexp2.Regexp vocab *Vocabulary } var _ TextProcessor = (*SentencePieceModel)(nil) -func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { +func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} @@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { return SentencePieceModel{ maxTokenLen: maxTokenLen, - pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), vocab: vocab, } } @@ -53,20 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm *SentencePieceModel) split(s string) iter.Seq[string] { - return func(yield func(string) bool) { - for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) { - if !yield(m.String()) { - break - } - } - } -} - func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { - // TODO: process special tokens concurrently id := spm.vocab.Encode(special) for i := 0; i < len(fragments); i++ { frag := fragments[i] @@ -91,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) } } - slog.Debug("fragments", "frags", fragments) var ids []int32 for _, frag := range fragments { @@ -100,105 +81,96 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) continue } - for split := range spm.split(frag.value) { - split = replaceWhitespaceBySeperator(split) + text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep) - var sb strings.Builder - sb.Write([]byte(split)) - if id := spm.vocab.Encode(sb.String()); id >= 0 { - ids = append(ids, id) - continue + if id := spm.vocab.Encode(text); id >= 0 { + ids = append(ids, id) + continue + } + + q := &queue{} + heap.Init(q) + + runes := []rune(text) + merges := make([]merge, len(runes)) + for r := range runes { + merges[r] = merge{ + p: r - 1, + n: r + 1, + runes: []rune{runes[r]}, } + } - runes := []rune(sb.String()) - pq := queue.NewWith(func(a, b any) int { - priA := a.(*candidate) - priB := b.(*candidate) - if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) { - return -1 - } - return 1 - }) - - merges := make([]merge, len(runes)) - for r := range runes { - merges[r] = merge{ - p: r - 1, - n: r + 1, - runes: []rune{runes[r]}, - } - } - - slog.Debug("tokenizer", "merges", merges) - - pairwise := func(a, b int) *candidate { - if a < 0 || b >= len(runes) { - return nil - } - - left, right := string(merges[a].runes), string(merges[b].runes) - if id := spm.vocab.Encode(left + right); id >= 0 { - return &candidate{ - a: a, - b: b, - score: spm.vocab.Scores[id], - } - } + pairwise := func(a, b int) *candidate { + if a < 0 || b >= len(runes) { return nil } - for i := range len(runes) - 1 { - if pair := pairwise(i, i+1); pair != nil { - pq.Enqueue(pair) + left, right := string(merges[a].runes), string(merges[b].runes) + if id := spm.vocab.Encode(left + right); id >= 0 { + return &candidate{ + a: a, + b: b, + score: spm.vocab.Scores[id], + size: len(left) + len(right), } } - pqv := pq.Values() - for _, v := range pqv { - e := v.(*candidate) - slog.Debug("candidate", "candidate", e) + return nil + } + + for i := range len(runes) - 1 { + if pair := pairwise(i, i+1); pair != nil { + heap.Push(q, pair) + } + } + + for q.Len() > 0 { + pair := heap.Pop(q).(*candidate) + left, right := merges[pair.a], merges[pair.b] + + if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size { + continue } - for !pq.Empty() { - v, _ := pq.Dequeue() - pair := v.(*candidate) - left, right := merges[pair.a], merges[pair.b] + merges[pair.a].runes = append(left.runes, right.runes...) + merges[pair.b].runes = nil + merges[pair.a].n = right.n + if right.n < len(merges) { + merges[right.n].p = pair.a + } - slog.Debug("pair", "left", left, "right", right) - if len(left.runes) == 0 || len(right.runes) == 0 { + if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + heap.Push(q, pair) + } + + if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + heap.Push(q, pair) + } + } + + for _, merge := range merges { + if token := string(merge.runes); token != "" { + id := spm.vocab.Encode(token) + + if id >= 0 { + ids = append(ids, id) continue } - if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 { - continue - } - - merges[pair.a].runes = append(left.runes, right.runes...) - merges[pair.b].runes = nil - merges[pair.a].n = right.n - if right.n < len(merges) { - merges[right.n].p = pair.a - } - - if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { - pq.Enqueue(pair) - } - - if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { - pq.Enqueue(pair) - } - } - - slog.Debug("merges", "merges", merges) - - for _, merge := range merges { - if len(merge.runes) > 0 { - if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { - ids = append(ids, id) + // Fallback to byte tokenization + var result []int32 + for _, b := range []byte(token) { + byteToken := fmt.Sprintf("<0x%02X>", b) + unknownID := spm.vocab.Encode(byteToken) + if unknownID >= 0 { + result = append(result, unknownID) } else { - slog.Debug("missing token", "token", string(merge.runes)) + slog.Debug("unknown byte token", "byte", b, "token", byteToken) } } + + ids = append(ids, result...) } } } @@ -229,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) type candidate struct { a, b int score float32 + size int +} + +type queue []*candidate + +func (q queue) Len() int { return len(q) } + +func (q queue) Less(i, j int) bool { + return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a) +} + +func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } + +func (q *queue) Push(x interface{}) { + item := x.(*candidate) + *q = append(*q, item) +} + +func (q *queue) Pop() interface{} { + old := *q + n := len(old) + item := old[n-1] + *q = old[0 : n-1] + return item } func (spm SentencePieceModel) Decode(ids []int32) (string, error) { @@ -236,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) { for _, id := range ids { data := spm.vocab.Decode(id) data = strings.ReplaceAll(data, spmWhitespaceSep, " ") - if _, err := sb.WriteString(data); err != nil { - return "", err + + // For tokenizers that use byte tokens like "<0xEA>" + // convert them to the partial unicode character + // so they are buffered correctly by the runner instead + // of being sent back to the api as "<0xEA>" + if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") { + byteVal, err := strconv.ParseUint(data[1:5], 0, 8) + if err != nil { + return "", fmt.Errorf("failed to parse hex byte: %v", err) + } + + if err := sb.WriteByte(byte(byteVal)); err != nil { + return "", err + } + } else { + if _, err := sb.WriteString(data); err != nil { + return "", err + } } } - slog.Debug("decoded", "ids", ids, "text", sb.String()) return sb.String(), nil } diff --git a/model/process_text_spm_test.go b/model/process_text_spm_test.go index a43004db1..4813333ee 100644 --- a/model/process_text_spm_test.go +++ b/model/process_text_spm_test.go @@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { t.Fatal(err) } - preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+` - var v Vocabulary for _, piece := range spm.GetPieces() { @@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(preTokenizer, &v) + return NewSentencePieceModel(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) { } }) } + +func TestSentencePieceModelDecodeByteTokens(t *testing.T) { + vocab := &Vocabulary{ + Values: []string{ + "normal", + "<0xEA>", + "<0x41>", + "<0xC3>", + "<0xA3>", + }, + Types: []uint32{ + TOKEN_TYPE_NORMAL, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + TOKEN_TYPE_BYTE, + }, + Scores: []float32{0, 0, 0, 0, 0}, + } + + spm := NewSentencePieceModel(vocab) + + tests := []struct { + name string + ids []int32 + expected string + }{ + { + name: "single byte token", + ids: []int32{1}, + expected: "\xea", + }, + { + name: "ASCII byte token", + ids: []int32{2}, + expected: "A", + }, + { + name: "multiple byte tokens forming UTF-8 character", + ids: []int32{3, 4}, + expected: "ã", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := spm.Decode(tt.ids) + if err != nil { + t.Errorf("failed to decode token IDs %v: %v", tt.ids, err) + } + if result != tt.expected { + t.Errorf("got %q, want %q", result, tt.expected) + } + }) + } +}