Compare commits

...

9 Commits

Author SHA1 Message Date
Patrick Devine
b7349a4efd more linter feeding 2025-02-18 13:32:58 -08:00
Patrick Devine
4cda3e3622 feed the linter 2025-02-18 13:16:43 -08:00
Patrick Devine
95fbf1da12 fix causal test 2025-02-18 13:02:44 -08:00
Patrick Devine
83d1a1ab55 cleanup 2025-02-18 12:47:34 -08:00
Patrick Devine
035e69799e clean up 2025-02-18 12:40:12 -08:00
Patrick Devine
10e06d0a45 gemma2 ftw 2025-02-18 12:40:02 -08:00
Patrick Devine
8cf1ea4fd8 add sentence piece tokenizer 2025-02-18 12:39:45 -08:00
Patrick Devine
d231229122 cache is king 2025-02-18 12:39:27 -08:00
Patrick Devine
fad98fabab gemma2 impl 2025-02-18 12:39:17 -08:00
12 changed files with 455 additions and 14 deletions

View File

@ -120,6 +120,15 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s
}
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key

1
go.mod
View File

@ -18,6 +18,7 @@ require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods v1.18.1
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14

2
go.sum
View File

@ -44,6 +44,8 @@ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=

View File

@ -434,7 +434,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}

View File

@ -17,6 +17,7 @@ type Config interface {
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
}
type Backend interface {
@ -76,7 +77,7 @@ type Tensor interface {
Scale(ctx Context, s float64) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor

View File

@ -596,10 +596,13 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
const (
ropeTypeNorm C.int = iota
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{}
}
@ -613,8 +616,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor

View File

@ -0,0 +1,193 @@
package gemma2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
attnLogitSoftcap float32
finalLogitSoftcap float32
}
type Model struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"` // is this supposed to be root means square?
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
*Options
}
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
// todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("gemma2", New)
}

View File

@ -67,14 +67,15 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -99,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, uint32(0), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {

View File

@ -19,14 +19,15 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -52,7 +53,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
type TextMLP struct {

View File

@ -1,6 +1,7 @@
package models
import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@ -18,6 +18,15 @@ const (
SpecialEOS
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(string) ([]int32, error)
Decode([]int32) (string, error)
@ -27,7 +36,7 @@ type TextProcessor interface {
type Vocabulary struct {
Values []string
Types []uint32
Scores []uint32
Scores []float32
Merges []string
BOS, EOS int32
@ -75,7 +84,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == 3 {
if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}

220
model/process_text_spm.go Normal file
View File

@ -0,0 +1,220 @@
package model
import (
"iter"
"log/slog"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/queues/priorityqueue"
)
const spmWhitespaceSep = "▁"
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:3], "scores", vocab.Scores[:3], "types", vocab.Types[:3])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
slog.Debug("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return 1
}
return -1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
length: len(left + " " + right),
score: spm.vocab.Scores[id],
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
slog.Debug("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
slog.Debug("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
slog.Debug("missing token", "token", string(merge.runes))
}
}
}
}
}
slog.Debug("encoded", "ids", ids)
return ids, nil
}
type candidate struct {
a, b int
score float32
length int
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}