From 8cf1ea4fd830f991c0ef216615b2a88030304fa7 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sun, 9 Feb 2025 16:52:17 -0800 Subject: [PATCH] add sentence piece tokenizer --- fs/ggml/ggml.go | 9 ++ go.mod | 1 + go.sum | 2 + ml/backend.go | 1 + model/models/gemma2/model.go | 12 +- model/process_text.go | 15 ++- model/process_text_spm.go | 232 +++++++++++++++++++++++++++++++++++ 7 files changed, 263 insertions(+), 9 deletions(-) create mode 100644 model/process_text_spm.go diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 90d1d4406..10e81ac71 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -120,6 +120,15 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { return s } +func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { + r := keyValue(kv, key, &array{}) + s := make([]float32, r.size) + for i := range r.size { + s[i] = float32(r.values[i].(float32)) + } + return s +} + func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { key = kv.Architecture() + "." + key diff --git a/go.mod b/go.mod index 1c99c0946..d4782dc98 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/agnivade/levenshtein v1.1.1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/dlclark/regexp2 v1.11.4 + github.com/emirpasic/gods v1.18.1 github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/google/go-cmp v0.6.0 github.com/mattn/go-runewidth v0.0.14 diff --git a/go.sum b/go.sum index 8eb8d84ab..122a4610c 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU= github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= diff --git a/ml/backend.go b/ml/backend.go index aebf86f76..c27a89a12 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -17,6 +17,7 @@ type Config interface { Strings(string, ...[]string) []string Uints(string, ...[]uint32) []uint32 + Floats(string, ...[]float32) []float32 } type Backend interface { diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index f8ba82110..4e8fe6d66 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -1,7 +1,6 @@ package gemma2 import ( - "fmt" "math" "github.com/ollama/ollama/kvcache" @@ -20,7 +19,7 @@ type Options struct { type Model struct { model.Base - model.BytePairEncoding + model.SentencePieceModel TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -32,10 +31,11 @@ type Model struct { func New(c ml.Config) (model.Model, error) { m := Model{ - BytePairEncoding: model.NewBytePairEncoding( + SentencePieceModel: model.NewSentencePieceModel( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Uints("tokenizer.ggml.token_type"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), @@ -55,7 +55,7 @@ func New(c ml.Config) (model.Model, error) { } slidingWindowLen := int32(c.Uint("attention.sliding_window")) - m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift)) + m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) return &m, nil } @@ -76,7 +76,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q = q.RoPE(ctx, positionIDs, opts.RopeFactors, uint32(headDim), opts.ropeBase, opts.ropeScale) // todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models - q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) + //q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) @@ -140,8 +140,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach } func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { - fmt.Printf("HELLO THERE!!\n") - inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err diff --git a/model/process_text.go b/model/process_text.go index df1e68f4c..6c60bf2e0 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -2,6 +2,7 @@ package model import ( "cmp" + "fmt" "iter" "log/slog" "strings" @@ -18,6 +19,15 @@ const ( SpecialEOS ) +const ( + TOKEN_TYPE_NORMAL = iota + 1 + TOKEN_TYPE_UNKNOWN + TOKEN_TYPE_CONTROL + TOKEN_TYPE_USER_DEFINED + TOKEN_TYPE_UNUSED + TOKEN_TYPE_BYTE +) + type TextProcessor interface { Encode(string) ([]int32, error) Decode([]int32) (string, error) @@ -27,7 +37,7 @@ type TextProcessor interface { type Vocabulary struct { Values []string Types []uint32 - Scores []uint32 + Scores []float32 Merges []string BOS, EOS int32 @@ -75,7 +85,7 @@ func (v *Vocabulary) Decode(id int32) string { func (v *Vocabulary) SpecialVocabulary() []string { v.specialOnce.Do(func() { for i := range v.Values { - if v.Types[i] == 3 { + if v.Types[i] == TOKEN_TYPE_CONTROL { v.special = append(v.special, v.Values[i]) } } @@ -171,6 +181,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) } } + fmt.Printf("frags = %#v\n", fragments) var ids []int32 for _, frag := range fragments { diff --git a/model/process_text_spm.go b/model/process_text_spm.go new file mode 100644 index 000000000..5fd19bb96 --- /dev/null +++ b/model/process_text_spm.go @@ -0,0 +1,232 @@ +package model + +import ( + "fmt" + "iter" + "strings" + //"unicode/utf8" + + "github.com/dlclark/regexp2" + queue "github.com/emirpasic/gods/queues/priorityqueue" +) + +const spmWhitespaceSep = "▁" + +func replaceWhitespaceBySeperator(s string) string { + return strings.ReplaceAll(s, " ", spmWhitespaceSep) +} + +type SentencePieceModel struct { + maxTokenLen int + pre *regexp2.Regexp + vocab *Vocabulary +} + +func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { + fmt.Printf("Tokens (%d): %5s %5s %5s ...\n", len(vocab.Values), vocab.Values[0], vocab.Values[1], vocab.Values[2]) + fmt.Printf("Scores (%d): %0.3f %0.3f %0.3f ...\n", len(vocab.Scores), vocab.Scores[0], vocab.Scores[1], vocab.Scores[2]) + fmt.Printf("Types (%d): %5d %5d %5d ...\n", len(vocab.Types), vocab.Types[0], vocab.Types[1], vocab.Types[2]) + + counter := map[int]int{} + var maxTokenLen int + + for cnt, _ := range vocab.Types { + switch vocab.Types[cnt] { + case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED: + maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt])) + fallthrough + default: + counter[int(vocab.Types[cnt])] += 1 + } + } + + fmt.Printf("Normal: %d\n", counter[TOKEN_TYPE_NORMAL]) + fmt.Printf("Unknown: %d\n", counter[TOKEN_TYPE_UNKNOWN]) + fmt.Printf("Control: %d\n", counter[TOKEN_TYPE_CONTROL]) + fmt.Printf("User Defined: %d\n", counter[TOKEN_TYPE_USER_DEFINED]) + fmt.Printf("Unused: %d\n", counter[TOKEN_TYPE_UNUSED]) + fmt.Printf("Byte: %d\n", counter[TOKEN_TYPE_BYTE]) + fmt.Printf("Max token len: %d\n", maxTokenLen) + + return SentencePieceModel{ + maxTokenLen: maxTokenLen, + pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), + vocab: vocab, + } +} + +func (spm SentencePieceModel) Is(id int32, special Special) bool { + return spm.vocab.Is(id, special) +} + +func (spm *SentencePieceModel) split(s string) iter.Seq[string] { + return func(yield func(string) bool) { + for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) { + if !yield(m.String()) { + break + } + } + } +} + +func (spm SentencePieceModel) Encode(s string) ([]int32, error) { + fragments := []fragment{{value: s}} + for _, special := range spm.vocab.SpecialVocabulary() { + // TODO: process special tokens concurrently + id := spm.vocab.Encode(special) + for i := 0; i < len(fragments); i++ { + frag := fragments[i] + if len(frag.ids) > 0 { + continue + } + + var middle []fragment + switch i := strings.Index(frag.value, special); { + case i < 0: + middle = append(middle, frag) + case i > 0: + middle = append(middle, fragment{value: frag.value[:i]}) + fallthrough + default: + middle = append(middle, fragment{value: special, ids: []int32{id}}) + if rest := frag.value[i+len(special):]; rest != "" { + middle = append(middle, fragment{value: rest}) + } + } + + fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) + } + } + fmt.Printf("frags = %#v\n", fragments) + + var ids []int32 + for _, frag := range fragments { + if len(frag.ids) > 0 { + ids = append(ids, frag.ids...) + continue + } + + for split := range spm.split(frag.value) { + split = replaceWhitespaceBySeperator(split) + + var sb strings.Builder + sb.Write([]byte(split)) + if id := spm.vocab.Encode(sb.String()); id >= 0 { + ids = append(ids, id) + continue + } + + runes := []rune(sb.String()) + pq := queue.NewWith(func(a, b any) int { + priA := a.(*candidate) + priB := b.(*candidate) + if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) { + return 1 + } + return -1 + }) + + merges := make([]merge, len(runes)) + for r := range runes { + merges[r] = merge{ + p: r - 1, + n: r + 1, + runes: []rune{runes[r]}, + } + } + fmt.Printf("remaining runes = %#v\n", runes) + fmt.Printf("merges = %#v\n", merges) + + pairwise := func(a, b int) *candidate { + if a < 0 || b >= len(runes) { + return nil + } + + left, right := string(merges[a].runes), string(merges[b].runes) + fmt.Printf("looking up '%s'\n", left+right) + if id := spm.vocab.Encode(left + right); id >= 0 { + return &candidate{ + a: a, + b: b, + length: len(left + " " + right), + score: spm.vocab.Scores[id], + } + } + return nil + } + + for i := range len(runes) - 1 { + if pair := pairwise(i, i+1); pair != nil { + pq.Enqueue(pair) + } + } + + pqv := pq.Values() + for _, v := range pqv { + e := v.(*candidate) + fmt.Printf("candidate = %#v\n", e) + } + + for !pq.Empty() { + v, _ := pq.Dequeue() + pair := v.(*candidate) + left, right := merges[pair.a], merges[pair.b] + + if len(left.runes) == 0 || len(right.runes) == 0 { + continue + } + + merges[pair.a].runes = append(left.runes, right.runes...) + merges[pair.b].runes = nil + merges[pair.a].n = right.n + if right.n < len(merges) { + merges[right.n].p = pair.a + } + + if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + pq.Enqueue(pair) + } + + if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + pq.Enqueue(pair) + } + } + + fmt.Printf("merges = %#v\n", merges) + + for _, merge := range merges { + if len(merge.runes) > 0 { + if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { + ids = append(ids, id) + } else { + fmt.Printf("!!! missing token for '%s'\n", string(merge.runes)) + } + } + } + } + + } + fmt.Printf("tokens = %#v\n", ids) + + return ids, nil +} + +type candidate struct { + a, b int + score float32 + length int +} + +func (spm SentencePieceModel) Decode(ids []int32) (string, error) { + var sb strings.Builder + for _, id := range ids { + for _, r := range spm.vocab.Decode(id) { + // todo - do we need to introspect the chars here? + if err := sb.WriteByte(byte(r)); err != nil { + return "", err + } + } + } + + return sb.String(), nil +}