This commit is contained in:
Patrick Devine 2025-02-12 11:19:30 -08:00
parent 10e06d0a45
commit 035e69799e
2 changed files with 13 additions and 24 deletions

View File

@ -1,11 +1,9 @@
package model package model
import ( import (
"fmt"
"iter" "iter"
"log/slog" "log/slog"
"strings" "strings"
//"unicode/utf8"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/queues/priorityqueue" queue "github.com/emirpasic/gods/queues/priorityqueue"
@ -24,9 +22,7 @@ type SentencePieceModel struct {
} }
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
fmt.Printf("Tokens (%d): %5s %5s %5s ...\n", len(vocab.Values), vocab.Values[0], vocab.Values[1], vocab.Values[2]) slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:3], "scores", vocab.Scores[:3], "types", vocab.Types[:3])
fmt.Printf("Scores (%d): %0.3f %0.3f %0.3f ...\n", len(vocab.Scores), vocab.Scores[0], vocab.Scores[1], vocab.Scores[2])
fmt.Printf("Types (%d): %5d %5d %5d ...\n", len(vocab.Types), vocab.Types[0], vocab.Types[1], vocab.Types[2])
counter := map[int]int{} counter := map[int]int{}
var maxTokenLen int var maxTokenLen int
@ -41,13 +37,9 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
} }
} }
fmt.Printf("Normal: %d\n", counter[TOKEN_TYPE_NORMAL]) slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
fmt.Printf("Unknown: %d\n", counter[TOKEN_TYPE_UNKNOWN]) "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
fmt.Printf("Control: %d\n", counter[TOKEN_TYPE_CONTROL]) "max token len", maxTokenLen)
fmt.Printf("User Defined: %d\n", counter[TOKEN_TYPE_USER_DEFINED])
fmt.Printf("Unused: %d\n", counter[TOKEN_TYPE_UNUSED])
fmt.Printf("Byte: %d\n", counter[TOKEN_TYPE_BYTE])
fmt.Printf("Max token len: %d\n", maxTokenLen)
return SentencePieceModel{ return SentencePieceModel{
maxTokenLen: maxTokenLen, maxTokenLen: maxTokenLen,
@ -98,7 +90,7 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
} }
} }
fmt.Printf("frags = %#v\n", fragments) slog.Debug("fragments", "frags", fragments)
var ids []int32 var ids []int32
for _, frag := range fragments { for _, frag := range fragments {
@ -135,8 +127,6 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
runes: []rune{runes[r]}, runes: []rune{runes[r]},
} }
} }
fmt.Printf("remaining runes = %#v\n", runes)
fmt.Printf("merges = %#v\n", merges)
pairwise := func(a, b int) *candidate { pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) { if a < 0 || b >= len(runes) {
@ -144,7 +134,6 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
} }
left, right := string(merges[a].runes), string(merges[b].runes) left, right := string(merges[a].runes), string(merges[b].runes)
fmt.Printf("looking up '%s'\n", left+right)
if id := spm.vocab.Encode(left + right); id >= 0 { if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{ return &candidate{
a: a, a: a,
@ -165,7 +154,7 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
pqv := pq.Values() pqv := pq.Values()
for _, v := range pqv { for _, v := range pqv {
e := v.(*candidate) e := v.(*candidate)
fmt.Printf("candidate = %#v\n", e) slog.Debug("candidate", "candidate", e)
} }
for !pq.Empty() { for !pq.Empty() {
@ -193,21 +182,21 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
} }
} }
fmt.Printf("merges = %#v\n", merges) slog.Debug("merges", "merges", merges)
for _, merge := range merges { for _, merge := range merges {
if len(merge.runes) > 0 { if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
} else { } else {
fmt.Printf("!!! missing token for '%s'\n", string(merge.runes)) slog.Debug("missing token", "token", string(merge.runes))
} }
} }
} }
} }
} }
fmt.Printf("tokens = %#v\n", ids) slog.Debug("encoded", "ids", ids)
return ids, nil return ids, nil
} }