wip: apply gbnf vocab to logits
This commit is contained in:
parent
05a01fdecb
commit
81888abbe4
135
llama/grammar.go
Normal file
135
llama/grammar.go
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
package llama
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CFLAGS: -std=c11
|
||||||
|
#cgo CXXFLAGS: -std=c++17
|
||||||
|
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/include
|
||||||
|
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/common
|
||||||
|
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/src
|
||||||
|
#cgo CPPFLAGS: -I${SRCDIR}
|
||||||
|
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "llama.h"
|
||||||
|
#include "grammar_ext.h"
|
||||||
|
|
||||||
|
// Helper function to handle Go string arrays to C
|
||||||
|
static char** makeCharArray(int size) {
|
||||||
|
return (char**)malloc(size * sizeof(char*));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void setArrayString(char** a, int i, const char* s) {
|
||||||
|
a[i] = (char*)s;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void freeCharArray(char** a, int size) {
|
||||||
|
free(a);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"runtime"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Grammar represents the interface for grammar-based sampling
|
||||||
|
type Grammar interface {
|
||||||
|
Apply(logits []float32) ([]float32, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CGrammar is a wrapper around the C++ grammar implementation
|
||||||
|
type CGrammar struct {
|
||||||
|
grammar *C.struct_llama_grammar
|
||||||
|
model *C.struct_llama_model
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGrammarWithTokens creates a new grammar using a custom vocabulary defined by tokens
|
||||||
|
func NewGrammarWithTokens(grammarStr, grammarRoot string, tokens []string) (Grammar, error) {
|
||||||
|
if grammarStr == "" {
|
||||||
|
return nil, errors.New("empty grammar string")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return nil, errors.New("empty token list")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create C array of strings for tokens
|
||||||
|
cTokens := C.makeCharArray(C.int(len(tokens)))
|
||||||
|
defer C.freeCharArray(cTokens, C.int(len(tokens)))
|
||||||
|
|
||||||
|
// Convert Go strings to C strings and set them in the array
|
||||||
|
cStrings := make([]*C.char, len(tokens))
|
||||||
|
for i, token := range tokens {
|
||||||
|
cStrings[i] = C.CString(token)
|
||||||
|
C.setArrayString(cTokens, C.int(i), cStrings[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create vocabulary from tokens
|
||||||
|
cVocab := C.vocab_bridge_from_tokens((**C.char)(unsafe.Pointer(cTokens)), C.int(len(tokens)))
|
||||||
|
|
||||||
|
// Free the C strings after creating the vocab
|
||||||
|
for _, str := range cStrings {
|
||||||
|
C.free(unsafe.Pointer(str))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cVocab == nil {
|
||||||
|
return nil, errors.New("failed to create vocabulary from tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure to free the vocabulary when we're done
|
||||||
|
defer C.vocab_bridge_free(cVocab)
|
||||||
|
|
||||||
|
cGrammarStr := C.CString(grammarStr)
|
||||||
|
defer C.free(unsafe.Pointer(cGrammarStr))
|
||||||
|
|
||||||
|
cGrammarRoot := C.CString(grammarRoot)
|
||||||
|
defer C.free(unsafe.Pointer(cGrammarRoot))
|
||||||
|
|
||||||
|
// Create grammar using our C wrapper function with the correct signature
|
||||||
|
grammar := C.grammar_create_from_string(cVocab, cGrammarStr, cGrammarRoot)
|
||||||
|
if grammar == nil {
|
||||||
|
return nil, errors.New("failed to initialize grammar")
|
||||||
|
}
|
||||||
|
|
||||||
|
cg := &CGrammar{
|
||||||
|
grammar: grammar,
|
||||||
|
closed: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up finalizer to free resources when the object is garbage collected
|
||||||
|
runtime.SetFinalizer(cg, func(g *CGrammar) {
|
||||||
|
g.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return cg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies grammar constraints to logits
|
||||||
|
func (g *CGrammar) Apply(logits []float32) ([]float32, error) {
|
||||||
|
if g.closed || g.grammar == nil {
|
||||||
|
return nil, errors.New("grammar not initialized or already closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a copy of logits to modify
|
||||||
|
result := make([]float32, len(logits))
|
||||||
|
copy(result, logits)
|
||||||
|
|
||||||
|
// Apply grammar constraints using our C wrapper function
|
||||||
|
C.grammar_apply_to_logits(g.grammar, (*C.float)(&result[0]), C.int(len(result)))
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases resources associated with the grammar
|
||||||
|
func (g *CGrammar) Close() error {
|
||||||
|
if !g.closed && g.grammar != nil {
|
||||||
|
C.grammar_free(g.grammar)
|
||||||
|
g.grammar = nil
|
||||||
|
g.closed = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
83
llama/grammar_ext.cpp
vendored
Normal file
83
llama/grammar_ext.cpp
vendored
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
#include <stdlib.h>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
#include "llama-sampling.h"
|
||||||
|
#include "llama-grammar.h"
|
||||||
|
#include "llama-vocab.h"
|
||||||
|
#include "grammar_ext.h"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root) {
|
||||||
|
try {
|
||||||
|
// Initialize grammar sampler directly with the model
|
||||||
|
struct llama_sampler* sampler = llama_sampler_init_grammar(vocab, grammar_str, grammar_root);
|
||||||
|
if (!sampler) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cast the sampler to a grammar and return it
|
||||||
|
return (struct llama_grammar*)sampler;
|
||||||
|
} catch (const std::exception &err) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits) {
|
||||||
|
if (!grammar || !logits || n_logits <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token data array for the grammar application
|
||||||
|
llama_token_data* token_data = (llama_token_data*)malloc(n_logits * sizeof(llama_token_data));
|
||||||
|
if (!token_data) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize token data from logits
|
||||||
|
for (int i = 0; i < n_logits; i++) {
|
||||||
|
token_data[i].id = i;
|
||||||
|
token_data[i].logit = logits[i];
|
||||||
|
token_data[i].p = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token data array structure
|
||||||
|
llama_token_data_array arr = {
|
||||||
|
.data = token_data,
|
||||||
|
.size = (size_t)n_logits,
|
||||||
|
.sorted = false,
|
||||||
|
.selected = -1
|
||||||
|
};
|
||||||
|
|
||||||
|
// Apply grammar constraints to the token data array
|
||||||
|
llama_grammar_apply_impl(*grammar, &arr);
|
||||||
|
|
||||||
|
// Copy back the modified logits
|
||||||
|
for (int i = 0; i < n_logits; i++) {
|
||||||
|
logits[i] = token_data[i].logit;
|
||||||
|
}
|
||||||
|
|
||||||
|
free(token_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
void grammar_free(struct llama_grammar* grammar) {
|
||||||
|
if (grammar) {
|
||||||
|
// Free the grammar as a sampler
|
||||||
|
llama_sampler_free((struct llama_sampler*)grammar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens) {
|
||||||
|
// Call the C++ function from llama-vocab.cpp
|
||||||
|
return llama_vocab_from_tokens(tokens, n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
void vocab_bridge_free(struct llama_vocab* vocab) {
|
||||||
|
// Call the C++ function from llama-vocab.cpp
|
||||||
|
llama_vocab_free(vocab);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
33
llama/grammar_ext.h
vendored
Normal file
33
llama/grammar_ext.h
vendored
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#ifndef GRAMMAR_EXT_H
|
||||||
|
#define GRAMMAR_EXT_H
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Forward declarations
|
||||||
|
struct llama_grammar;
|
||||||
|
struct llama_vocab;
|
||||||
|
|
||||||
|
// Create a new grammar from a string (returns a grammar implemented as a sampler)
|
||||||
|
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root);
|
||||||
|
|
||||||
|
// Apply grammar constraints to logits
|
||||||
|
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits);
|
||||||
|
|
||||||
|
// Free grammar resources (frees the underlying sampler)
|
||||||
|
void grammar_free(struct llama_grammar* grammar);
|
||||||
|
|
||||||
|
// C wrapper for llama_vocab_from_tokens
|
||||||
|
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens);
|
||||||
|
|
||||||
|
// C wrapper for llama_vocab_free
|
||||||
|
void vocab_bridge_free(struct llama_vocab* vocab);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // GRAMMAR_EXT_H
|
@ -18,6 +18,7 @@ package llama
|
|||||||
|
|
||||||
#include "mllama.h"
|
#include "mllama.h"
|
||||||
#include "sampling_ext.h"
|
#include "sampling_ext.h"
|
||||||
|
#include "grammar_ext.h"
|
||||||
|
|
||||||
extern bool llamaProgressCallback(float progress, void *user_data);
|
extern bool llamaProgressCallback(float progress, void *user_data);
|
||||||
extern void llamaLog(int level, char* text, void* user_data);
|
extern void llamaLog(int level, char* text, void* user_data);
|
||||||
|
117
llama/patches/0019-expose-llama_vocab-from-tokens.patch
Normal file
117
llama/patches/0019-expose-llama_vocab-from-tokens.patch
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
From 668a974433edccf2c5fcc2192c39aed601e575f2 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Bruce MacDonald <brucewmacdonald@gmail.com>
|
||||||
|
Date: Thu, 6 Mar 2025 21:07:06 -0800
|
||||||
|
Subject: [PATCH] expose llama_vocab from tokens
|
||||||
|
|
||||||
|
---
|
||||||
|
llama/llama.cpp/src/llama-vocab.cpp | 73 +++++++++++++++++++++++++++++
|
||||||
|
llama/llama.cpp/src/llama-vocab.h | 11 ++++-
|
||||||
|
2 files changed, 83 insertions(+), 1 deletion(-)
|
||||||
|
|
||||||
|
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
|
||||||
|
index c7ff28be..ad6e7ad8 100644
|
||||||
|
--- a/llama/llama.cpp/src/llama-vocab.cpp
|
||||||
|
+++ b/llama/llama.cpp/src/llama-vocab.cpp
|
||||||
|
@@ -3253,3 +3253,76 @@ int32_t llama_detokenize(
|
||||||
|
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
||||||
|
}
|
||||||
|
|
||||||
|
+struct llama_vocab *llama_vocab_from_tokens(const char **tokens, int n_tokens)
|
||||||
|
+{
|
||||||
|
+ if (!tokens || n_tokens <= 0)
|
||||||
|
+ {
|
||||||
|
+ return nullptr;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ try
|
||||||
|
+ {
|
||||||
|
+ // Create a new vocabulary instance
|
||||||
|
+ llama_vocab *vocab = new llama_vocab();
|
||||||
|
+ vocab->pimpl = std::make_unique<llama_vocab::impl>(*vocab);
|
||||||
|
+
|
||||||
|
+ // Resize the token data vectors
|
||||||
|
+ vocab->pimpl->id_to_token.resize(n_tokens);
|
||||||
|
+
|
||||||
|
+ // Create mappings for all tokens
|
||||||
|
+ for (int i = 0; i < n_tokens; i++)
|
||||||
|
+ {
|
||||||
|
+ std::string word = tokens[i];
|
||||||
|
+ if (word.empty())
|
||||||
|
+ {
|
||||||
|
+ word = "[EMPTY_" + std::to_string(i) + "]";
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Add to token mappings
|
||||||
|
+ vocab->pimpl->token_to_id[word] = i;
|
||||||
|
+
|
||||||
|
+ // Set up token data
|
||||||
|
+ auto &token_data = vocab->pimpl->id_to_token[i];
|
||||||
|
+ token_data.text = std::move(word);
|
||||||
|
+ token_data.score = 0.0f; // Default score
|
||||||
|
+ token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
|
||||||
|
+
|
||||||
|
+ // Detect special tokens
|
||||||
|
+ if (word == "<s>" || word == "<bos>")
|
||||||
|
+ {
|
||||||
|
+ vocab->pimpl->special_bos_id = i;
|
||||||
|
+ }
|
||||||
|
+ else if (word == "</s>" || word == "<eos>" || word == "<|endoftext|>")
|
||||||
|
+ {
|
||||||
|
+ vocab->pimpl->special_eos_id = i;
|
||||||
|
+ vocab->pimpl->special_eog_ids.insert(i);
|
||||||
|
+ }
|
||||||
|
+ else if (word == "<unk>")
|
||||||
|
+ {
|
||||||
|
+ vocab->pimpl->special_unk_id = i;
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ // Initialize the token-to-piece cache
|
||||||
|
+ vocab->pimpl->cache_token_to_piece.resize(n_tokens);
|
||||||
|
+ for (int i = 0; i < n_tokens; i++)
|
||||||
|
+ {
|
||||||
|
+ vocab->pimpl->cache_token_to_piece[i] = vocab->pimpl->id_to_token[i].text;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ return vocab;
|
||||||
|
+ }
|
||||||
|
+ catch (const std::exception &err)
|
||||||
|
+ {
|
||||||
|
+ return nullptr;
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
+// Helper function to free the vocab
|
||||||
|
+void llama_vocab_free(struct llama_vocab *vocab)
|
||||||
|
+{
|
||||||
|
+ if (vocab)
|
||||||
|
+ {
|
||||||
|
+ delete vocab;
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
\ No newline at end of file
|
||||||
|
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
|
||||||
|
index 5ce35521..eceb28f3 100644
|
||||||
|
--- a/llama/llama.cpp/src/llama-vocab.h
|
||||||
|
+++ b/llama/llama.cpp/src/llama-vocab.h
|
||||||
|
@@ -119,7 +119,16 @@ struct llama_vocab {
|
||||||
|
|
||||||
|
void print_info() const;
|
||||||
|
|
||||||
|
-private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
|
};
|
||||||
|
+
|
||||||
|
+// Create a vocabulary from an array of token strings
|
||||||
|
+// tokens: Array of token strings
|
||||||
|
+// n_tokens: Number of tokens in the array
|
||||||
|
+// Returns: A new llama_vocab instance, or nullptr on failure
|
||||||
|
+// The caller is responsible for freeing the vocabulary using llama_vocab_free
|
||||||
|
+LLAMA_API struct llama_vocab * llama_vocab_from_tokens(const char ** tokens, int n_tokens);
|
||||||
|
+
|
||||||
|
+// Free a vocabulary created with llama_vocab_from_tokens
|
||||||
|
+LLAMA_API void llama_vocab_free(struct llama_vocab * vocab);
|
||||||
|
--
|
||||||
|
2.39.3 (Apple Git-145)
|
||||||
|
|
@ -428,7 +428,8 @@ func (s *Server) processBatch() error {
|
|||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(options.Outputs)
|
vocabSize := len(logits) / len(options.Outputs)
|
||||||
|
// TODO: need access to vocab to apply grammar
|
||||||
|
// token = sampler.Grammar.Apply(logits)
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sample token: %w", err)
|
return fmt.Errorf("failed to sample token: %w", err)
|
||||||
@ -575,6 +576,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: if grammar is provided, load it
|
||||||
|
// if req.Grammar != "" {
|
||||||
|
// grammar := llama.NewGrammarWithTokens(req.Grammar, "root", s.model.Vocabulary)
|
||||||
|
// }
|
||||||
|
// defer grammar.Close()
|
||||||
|
// sampler := sample.WithGrammar(sample.Greedy(), grammar)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Stop,
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
"gonum.org/v1/gonum/stat/sampleuv"
|
"gonum.org/v1/gonum/stat/sampleuv"
|
||||||
)
|
)
|
||||||
@ -57,12 +58,24 @@ func (s weighted) Sample(logits []float32) (int32, error) {
|
|||||||
return -1, errors.New("weighted sampler failed, no valid token found")
|
return -1, errors.New("weighted sampler failed, no valid token found")
|
||||||
}
|
}
|
||||||
|
|
||||||
type greedy struct{}
|
type greedy struct {
|
||||||
|
grammar llama.Grammar
|
||||||
|
}
|
||||||
|
|
||||||
func Greedy() Sampler {
|
func Greedy() Sampler {
|
||||||
return greedy{}
|
return greedy{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithGrammar(s Sampler, grammar llama.Grammar) Sampler {
|
||||||
|
switch t := s.(type) {
|
||||||
|
case greedy:
|
||||||
|
t.grammar = grammar
|
||||||
|
return t
|
||||||
|
default:
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sample returns the index of the maximum value in logits.
|
// Sample returns the index of the maximum value in logits.
|
||||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user