add sentence piece tokenizer
This commit is contained in:
parent
d231229122
commit
8cf1ea4fd8
@ -120,6 +120,15 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||||
|
r := keyValue(kv, key, &array{})
|
||||||
|
s := make([]float32, r.size)
|
||||||
|
for i := range r.size {
|
||||||
|
s[i] = float32(r.values[i].(float32))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
|
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
|
||||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||||
key = kv.Architecture() + "." + key
|
key = kv.Architecture() + "." + key
|
||||||
|
1
go.mod
1
go.mod
@ -18,6 +18,7 @@ require (
|
|||||||
github.com/agnivade/levenshtein v1.1.1
|
github.com/agnivade/levenshtein v1.1.1
|
||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||||
github.com/dlclark/regexp2 v1.11.4
|
github.com/dlclark/regexp2 v1.11.4
|
||||||
|
github.com/emirpasic/gods v1.18.1
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.6.0
|
||||||
github.com/mattn/go-runewidth v0.0.14
|
github.com/mattn/go-runewidth v0.0.14
|
||||||
|
2
go.sum
2
go.sum
@ -44,6 +44,8 @@ github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+
|
|||||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
||||||
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||||
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
|
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||||
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
|
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
|
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
|
@ -17,6 +17,7 @@ type Config interface {
|
|||||||
|
|
||||||
Strings(string, ...[]string) []string
|
Strings(string, ...[]string) []string
|
||||||
Uints(string, ...[]uint32) []uint32
|
Uints(string, ...[]uint32) []uint32
|
||||||
|
Floats(string, ...[]float32) []float32
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend interface {
|
type Backend interface {
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package gemma2
|
package gemma2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
@ -20,7 +19,7 @@ type Options struct {
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.BytePairEncoding
|
model.SentencePieceModel
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
@ -32,10 +31,11 @@ type Model struct {
|
|||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
@ -55,7 +55,7 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewCausalCache(m.Shift), kvcache.NewSWACache(slidingWindowLen, m.Shift))
|
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
@ -76,7 +76,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, uint32(headDim), opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, uint32(headDim), opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
// todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models
|
// todo: this should be 1.0/math.Sqrt(float64(headDim)) for 27B models
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
//q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||||
@ -140,8 +140,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||||
fmt.Printf("HELLO THERE!!\n")
|
|
||||||
|
|
||||||
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"fmt"
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
@ -18,6 +19,15 @@ const (
|
|||||||
SpecialEOS
|
SpecialEOS
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TOKEN_TYPE_NORMAL = iota + 1
|
||||||
|
TOKEN_TYPE_UNKNOWN
|
||||||
|
TOKEN_TYPE_CONTROL
|
||||||
|
TOKEN_TYPE_USER_DEFINED
|
||||||
|
TOKEN_TYPE_UNUSED
|
||||||
|
TOKEN_TYPE_BYTE
|
||||||
|
)
|
||||||
|
|
||||||
type TextProcessor interface {
|
type TextProcessor interface {
|
||||||
Encode(string) ([]int32, error)
|
Encode(string) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
@ -27,7 +37,7 @@ type TextProcessor interface {
|
|||||||
type Vocabulary struct {
|
type Vocabulary struct {
|
||||||
Values []string
|
Values []string
|
||||||
Types []uint32
|
Types []uint32
|
||||||
Scores []uint32
|
Scores []float32
|
||||||
Merges []string
|
Merges []string
|
||||||
|
|
||||||
BOS, EOS int32
|
BOS, EOS int32
|
||||||
@ -75,7 +85,7 @@ func (v *Vocabulary) Decode(id int32) string {
|
|||||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||||
v.specialOnce.Do(func() {
|
v.specialOnce.Do(func() {
|
||||||
for i := range v.Values {
|
for i := range v.Values {
|
||||||
if v.Types[i] == 3 {
|
if v.Types[i] == TOKEN_TYPE_CONTROL {
|
||||||
v.special = append(v.special, v.Values[i])
|
v.special = append(v.special, v.Values[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,6 +181,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
|||||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fmt.Printf("frags = %#v\n", fragments)
|
||||||
|
|
||||||
var ids []int32
|
var ids []int32
|
||||||
for _, frag := range fragments {
|
for _, frag := range fragments {
|
||||||
|
232
model/process_text_spm.go
Normal file
232
model/process_text_spm.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
//"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/dlclark/regexp2"
|
||||||
|
queue "github.com/emirpasic/gods/queues/priorityqueue"
|
||||||
|
)
|
||||||
|
|
||||||
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
|
func replaceWhitespaceBySeperator(s string) string {
|
||||||
|
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
||||||
|
}
|
||||||
|
|
||||||
|
type SentencePieceModel struct {
|
||||||
|
maxTokenLen int
|
||||||
|
pre *regexp2.Regexp
|
||||||
|
vocab *Vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||||
|
fmt.Printf("Tokens (%d): %5s %5s %5s ...\n", len(vocab.Values), vocab.Values[0], vocab.Values[1], vocab.Values[2])
|
||||||
|
fmt.Printf("Scores (%d): %0.3f %0.3f %0.3f ...\n", len(vocab.Scores), vocab.Scores[0], vocab.Scores[1], vocab.Scores[2])
|
||||||
|
fmt.Printf("Types (%d): %5d %5d %5d ...\n", len(vocab.Types), vocab.Types[0], vocab.Types[1], vocab.Types[2])
|
||||||
|
|
||||||
|
counter := map[int]int{}
|
||||||
|
var maxTokenLen int
|
||||||
|
|
||||||
|
for cnt, _ := range vocab.Types {
|
||||||
|
switch vocab.Types[cnt] {
|
||||||
|
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
||||||
|
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
counter[int(vocab.Types[cnt])] += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Normal: %d\n", counter[TOKEN_TYPE_NORMAL])
|
||||||
|
fmt.Printf("Unknown: %d\n", counter[TOKEN_TYPE_UNKNOWN])
|
||||||
|
fmt.Printf("Control: %d\n", counter[TOKEN_TYPE_CONTROL])
|
||||||
|
fmt.Printf("User Defined: %d\n", counter[TOKEN_TYPE_USER_DEFINED])
|
||||||
|
fmt.Printf("Unused: %d\n", counter[TOKEN_TYPE_UNUSED])
|
||||||
|
fmt.Printf("Byte: %d\n", counter[TOKEN_TYPE_BYTE])
|
||||||
|
fmt.Printf("Max token len: %d\n", maxTokenLen)
|
||||||
|
|
||||||
|
return SentencePieceModel{
|
||||||
|
maxTokenLen: maxTokenLen,
|
||||||
|
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
||||||
|
vocab: vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||||
|
return spm.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||||
|
return func(yield func(string) bool) {
|
||||||
|
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
||||||
|
if !yield(m.String()) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
|
||||||
|
fragments := []fragment{{value: s}}
|
||||||
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
|
// TODO: process special tokens concurrently
|
||||||
|
id := spm.vocab.Encode(special)
|
||||||
|
for i := 0; i < len(fragments); i++ {
|
||||||
|
frag := fragments[i]
|
||||||
|
if len(frag.ids) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var middle []fragment
|
||||||
|
switch i := strings.Index(frag.value, special); {
|
||||||
|
case i < 0:
|
||||||
|
middle = append(middle, frag)
|
||||||
|
case i > 0:
|
||||||
|
middle = append(middle, fragment{value: frag.value[:i]})
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
||||||
|
if rest := frag.value[i+len(special):]; rest != "" {
|
||||||
|
middle = append(middle, fragment{value: rest})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("frags = %#v\n", fragments)
|
||||||
|
|
||||||
|
var ids []int32
|
||||||
|
for _, frag := range fragments {
|
||||||
|
if len(frag.ids) > 0 {
|
||||||
|
ids = append(ids, frag.ids...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for split := range spm.split(frag.value) {
|
||||||
|
split = replaceWhitespaceBySeperator(split)
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.Write([]byte(split))
|
||||||
|
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
runes := []rune(sb.String())
|
||||||
|
pq := queue.NewWith(func(a, b any) int {
|
||||||
|
priA := a.(*candidate)
|
||||||
|
priB := b.(*candidate)
|
||||||
|
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
})
|
||||||
|
|
||||||
|
merges := make([]merge, len(runes))
|
||||||
|
for r := range runes {
|
||||||
|
merges[r] = merge{
|
||||||
|
p: r - 1,
|
||||||
|
n: r + 1,
|
||||||
|
runes: []rune{runes[r]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Printf("remaining runes = %#v\n", runes)
|
||||||
|
fmt.Printf("merges = %#v\n", merges)
|
||||||
|
|
||||||
|
pairwise := func(a, b int) *candidate {
|
||||||
|
if a < 0 || b >= len(runes) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||||
|
fmt.Printf("looking up '%s'\n", left+right)
|
||||||
|
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||||
|
return &candidate{
|
||||||
|
a: a,
|
||||||
|
b: b,
|
||||||
|
length: len(left + " " + right),
|
||||||
|
score: spm.vocab.Scores[id],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range len(runes) - 1 {
|
||||||
|
if pair := pairwise(i, i+1); pair != nil {
|
||||||
|
pq.Enqueue(pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pqv := pq.Values()
|
||||||
|
for _, v := range pqv {
|
||||||
|
e := v.(*candidate)
|
||||||
|
fmt.Printf("candidate = %#v\n", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
for !pq.Empty() {
|
||||||
|
v, _ := pq.Dequeue()
|
||||||
|
pair := v.(*candidate)
|
||||||
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
|
||||||
|
if len(left.runes) == 0 || len(right.runes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||||
|
merges[pair.b].runes = nil
|
||||||
|
merges[pair.a].n = right.n
|
||||||
|
if right.n < len(merges) {
|
||||||
|
merges[right.n].p = pair.a
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||||
|
pq.Enqueue(pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||||
|
pq.Enqueue(pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("merges = %#v\n", merges)
|
||||||
|
|
||||||
|
for _, merge := range merges {
|
||||||
|
if len(merge.runes) > 0 {
|
||||||
|
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("!!! missing token for '%s'\n", string(merge.runes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
fmt.Printf("tokens = %#v\n", ids)
|
||||||
|
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type candidate struct {
|
||||||
|
a, b int
|
||||||
|
score float32
|
||||||
|
length int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, id := range ids {
|
||||||
|
for _, r := range spm.vocab.Decode(id) {
|
||||||
|
// todo - do we need to introspect the chars here?
|
||||||
|
if err := sb.WriteByte(byte(r)); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user