From 69cc5795a73339429a97bc66f7513d327c186a9c Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 14 Aug 2024 10:35:49 -0700 Subject: [PATCH] runner.go: Shift context window when KV cache space is exceeded Currently, once the KV cache is full, text generation stops. Instead, we should shift out the oldest context so that new generation can continue based on more recent context. This uses the algorithm from llama.cpp that is currently used by Ollama with the server.cpp code. There are others but they are never turned on through Ollama, so this restores parity. The algorithm is: - Retain a configurable number of tokens at the beginning (for things like beginning of sequence tokens - Drop the oldest half of the remaining tokens - Shift the remaining new tokens to the back of the cache --- llama/llama.go | 14 ++++++++ llama/runner/runner.go | 80 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 98f864383..b169cf51a 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -157,6 +157,10 @@ func (c *Context) SampleTokenGreedy(logits []float32) int { })) } +func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) { + C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta)) +} + func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool { return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1))) } @@ -191,6 +195,16 @@ func (m *Model) TokenIsEog(token int) bool { return bool(C.llama_token_is_eog(m.c, C.llama_token(token))) } +func (m *Model) ShouldAddBOSToken() bool { + addBos := int(C.llama_add_bos_token(m.c)) + + if addBos != -1 { + return addBos != 0 + } else { + return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM + } +} + func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error { cLoraPath := C.CString(loraPath) defer C.free(unsafe.Pointer(cLoraPath)) diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 43a70f307..264252b29 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -49,6 +49,9 @@ type Sequence struct { // stop sequences stop []string + // number of tokens to keep at the beginning when shifting context window + numKeep int + // true if an embedding are to be returned instead of text generation embeddingOnly bool @@ -61,22 +64,38 @@ type Sequence struct { n_prompt_tokens int } -func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { +type NewSequenceParams struct { + numPredict int + stop []string + numKeep int + samplingParams *llama.SamplingParams + embedding bool +} + +func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence { tokens, err := s.lc.Model().Tokenize(prompt, true, true) if err != nil { panic(err) } - // truncate to last n tokens - // TODO: this shouldn't happen and will severely impact generation - // quality. instead we should ensure to cut prompt in the API. + if params.numKeep < 0 { + params.numKeep = len(tokens) + } + // Subtracting 4 ensures that at least 1 token can be discarded during shift + params.numKeep = min(params.numKeep, s.numCtx-4) + params.numKeep += s.bosToken + + // truncate to fit in context window if len(tokens) > s.numCtx { - tokens = tokens[:s.numCtx] + slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep) + newTokens := tokens[:params.numKeep] + newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...) + tokens = newTokens } var sc *llama.SamplingContext - if params != nil { - sc = llama.NewSamplingContext(*params) + if params.samplingParams != nil { + sc = llama.NewSamplingContext(*params.samplingParams) for _, t := range tokens { sc.Accept(s.lc, t, false) } @@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param return &Sequence{ tokens: tokens, n_prompt_tokens: len(tokens), - numPredict: numPredict, + numPredict: params.numPredict, responses: make(chan string, 1), embedding: make(chan []float32, 1), samplingCtx: sc, - embeddingOnly: embedding, - stop: stop, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, } } @@ -111,6 +131,9 @@ type Server struct { // context window size numCtx int + // does this model require a beginning of sequence token? + bosToken int + mu sync.Mutex cond *sync.Cond @@ -129,6 +152,21 @@ func (s *Server) allNil() bool { return true } +func (s *Server) shiftContext(seqIndex int) { + seq := s.seqs[seqIndex] + + numLeft := seq.nPast - seq.numKeep + numDiscard := numLeft / 2 + + slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast, + "numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard) + + s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard) + s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard) + + seq.nPast -= numDiscard +} + func (s *Server) run(ctx context.Context) { // TODO - should this be n_ctx / parallel like the old server.cpp setup? batch := llama.NewBatch(s.batchSize, 0, s.parallel) @@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) { continue } - hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict - // if past the num predict limit - if hitLimit || seq.nPast > s.numCtx { + if seq.numPredict > 0 && seq.numPredicted > seq.numPredict { seq.doneReason = "limit" close(seq.responses) s.lc.KvCacheSeqRm(i, 0, -1) @@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) { continue } + if seq.nPast+len(seq.tokens) > s.numCtx { + s.shiftContext(i) + } + if seq.t_start_process_prompt.IsZero() { seq.t_start_process_prompt = time.Now() } @@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar - seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false) + seq := s.NewSequence(req.Prompt, NewSequenceParams{ + numPredict: req.NumPredict, + stop: req.Stop, + numKeep: req.NumKeep, + samplingParams: &samplingParams, + embedding: false, + }) // TODO (jmorganca): add to sequence queue instead of // failing if a slot isn't available @@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embeddings := make([][]float32, len(req.Content)) var processed int for i, content := range req.Content { - seqs[i] = s.NewSequence(content, 0, nil, nil, true) + seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true}) } // TODO - refactor to go routines to add seq's and drain the responses @@ -563,6 +609,10 @@ func main() { ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention) server.lc = llama.NewContextWithModel(server.model, ctxParams) + if server.model.ShouldAddBOSToken() { + server.bosToken = 1 + } + if *ppath != "" { server.cc = llama.NewClipContext(*ppath) }