diff --git a/model/llama/model.go b/model/llama/model.go index 294661740..6efcc9bb7 100644 --- a/model/llama/model.go +++ b/model/llama/model.go @@ -35,8 +35,8 @@ func New(c ml.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Uints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: c.Uint("tokenizer.ggml.bos_token_id"), - EOS: c.Uint("tokenizer.ggml.eos_token_id"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/mllama/model.go b/model/mllama/model.go index d0c59a3e2..e5b275b0b 100644 --- a/model/mllama/model.go +++ b/model/mllama/model.go @@ -26,8 +26,8 @@ func New(c ml.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Uints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: c.Uint("tokenizer.ggml.bos_token_id"), - EOS: c.Uint("tokenizer.ggml.eos_token_id"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/process_text.go b/model/process_text.go index 1610a884d..df1e68f4c 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -21,7 +21,7 @@ const ( type TextProcessor interface { Encode(string) ([]int32, error) Decode([]int32) (string, error) - Is(uint32, Special) bool + Is(int32, Special) bool } type Vocabulary struct { @@ -30,7 +30,7 @@ type Vocabulary struct { Scores []uint32 Merges []string - BOS, EOS uint32 + BOS, EOS int32 specialOnce sync.Once special []string @@ -42,7 +42,7 @@ type Vocabulary struct { merge map[string]int32 } -func (v *Vocabulary) Is(id uint32, special Special) bool { +func (v *Vocabulary) Is(id int32, special Special) bool { switch special { case SpecialBOS: return id == v.BOS @@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { } } -func (bpe BytePairEncoding) Is(id uint32, special Special) bool { +func (bpe BytePairEncoding) Is(id int32, special Special) bool { return bpe.vocab.Is(id, special) }