diff --git a/llama/llama.go b/llama/llama.go index 15e719798..13334b045 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -449,9 +449,24 @@ type Model struct { c *C.struct_llama_model } +func (m *Model) Detokenize(tokens []int) (string, error) { + var text string + for _, token := range tokens { + piece := m.TokenToPiece(token) + if piece == "" { + return "", fmt.Errorf("failed to convert token %d to piece", token) + } + text += piece + } + return text, nil +} + func (m *Model) TokenToPiece(token int) string { tokenLen := 12 buf := make([]byte, tokenLen) + if token > m.NumVocab() { + return "" + } tokenLen = int(C.llama_token_to_piece( m.c, C.int32_t(token), diff --git a/server/routes.go b/server/routes.go index 1552798bb..a3aba7655 100644 --- a/server/routes.go +++ b/server/routes.go @@ -634,9 +634,10 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) { return } - var text string - for _, token := range req.Tokens { - text += loadedModel.model.TokenToPiece(token) + text, err := loadedModel.model.Detokenize(req.Tokens) + if err != nil { + http.Error(w, fmt.Sprintf("failed to detokenize text: %v", err), http.StatusInternalServerError) + return } w.Header().Set("Content-Type", "application/json")