diff --git a/llm/server.go b/llm/server.go index 09690a5ff..9553ba8f0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) return s.llamaModel.Tokenize(content, false, true) } if s.textProcessor != nil { - tokens, err := s.textProcessor.Encode(content) + tokens, err := s.textProcessor.Encode(content, false) if err != nil { return nil, err } diff --git a/model/process_text.go b/model/process_text.go index 7083f36fd..bfb0a5f20 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -19,7 +19,7 @@ const ( ) type TextProcessor interface { - Encode(string) ([]int32, error) + Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool } @@ -144,7 +144,7 @@ type merge struct { runes []rune } -func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { +func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range bpe.vocab.SpecialVocabulary() { // TODO: process special tokens concurrently @@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { } } - if len(ids) > 0 { + if addSpecial && len(ids) > 0 { if bpe.vocab.AddBOS { if ids[0] == bpe.vocab.BOS { slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) diff --git a/model/process_text_test.go b/model/process_text_test.go index cad1f94ff..f48303212 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -74,7 +74,7 @@ func TestLlama(t *testing.T) { t.Run("simple", func(t *testing.T) { t.Parallel() - ids, err := tokenizer.Encode("hello world") + ids, err := tokenizer.Encode("hello world", true) if err != nil { t.Error(err) } @@ -92,7 +92,7 @@ func TestLlama(t *testing.T) { t.Errorf("got %q, want hello world", s) } - ids, err = tokenizer.Encode("hello <|end_of_text|>") + ids, err = tokenizer.Encode("hello <|end_of_text|>", true) if err != nil { t.Error(err) } @@ -126,7 +126,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Error(err) } @@ -152,7 +152,7 @@ func TestLlama(t *testing.T) { } for _, want := range cases { - ids, err := tokenizer.Encode(want) + ids, err := tokenizer.Encode(want, true) if err != nil { t.Error(err) } @@ -176,7 +176,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() for range b.N { - _, err := tokenizer.Encode(string(bts)) + _, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } @@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { - ids, err := tokenizer.Encode(string(bts)) + ids, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1a4bbf19e..9ba6563f0 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { for i, part := range parts { // text - tokenize - tokens, err := s.model.(model.TextProcessor).Encode(part) + tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { return nil, err }