diff --git a/model/process_text.go b/model/process_text.go index 01af65b62..017e0c0b9 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -32,6 +32,8 @@ type TextProcessor interface { Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool + + Vocab() *Vocabulary } type Vocabulary struct { diff --git a/model/process_text_spm.go b/model/process_text_spm.go index 68e3ed015..8f96f11f1 100644 --- a/model/process_text_spm.go +++ b/model/process_text_spm.go @@ -49,6 +49,10 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { } } +func (spm SentencePieceModel) Vocab() *Vocabulary { + return spm.vocab +} + func (spm SentencePieceModel) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) }