diff --git a/llama/runner/README.md b/llama/runner/README.md index 703b1dd12..74e163b1d 100644 --- a/llama/runner/README.md +++ b/llama/runner/README.md @@ -1,5 +1,7 @@ # `runner` +> Note: this is a work in progress + A minimial runner for loading a model and running inference via a http web server. ``` @@ -13,3 +15,12 @@ curl -X POST -H "Content-Type: application/json" -d '{"prompt": "hi"}' http://lo ``` ### Embeddings + +``` +curl -X POST -H "Content-Type: application/json" -d '{"prompt": "turn me into an embedding"}' http://localhost:8080/embeddings +``` + +### TODO + +- [ ] Parallization +- [ ] More tests diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 97129fef8..6fad0d952 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -55,7 +55,7 @@ func (s *Sequence) prompt() bool { return s.nPast < len(s.tokens)-1 } -func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { +func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence { tokens, err := s.lc.Model().Tokenize(prompt, false, true) if err != nil { panic(err) @@ -148,8 +148,10 @@ func (s *Server) run(ctx context.Context) { continue } + hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict + // if past the num predict limit - if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx { + if hitLimit || seq.nPast > s.numCtx { seq.doneReason = "limit" close(seq.responses) s.lc.KvCacheSeqRm(i, 0, -1) @@ -317,7 +319,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { samplingParams.Seed = uint32(req.Seed) samplingParams.Grammar = req.Grammar - seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false) + seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false) // TODO (jmorganca): add to sequence queue instead of // failing if a slot isn't available @@ -368,7 +370,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - seq := s.NewSequence(req.Prompt, nil, nil, true) + seq := s.NewSequence(req.Prompt, 0, nil, nil, true) s.mu.Lock() for i, sq := range s.seqs { @@ -413,7 +415,7 @@ func main() { ppath := flag.String("projector", "", "Path to projector binary file") parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously") batchSize := flag.Int("batch-size", 512, "Batch size") - nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") + nGpuLayers := flag.Int("num-gpu", 0, "Number of layers to offload to GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU") flashAttention := flag.Bool("flash-attention", false, "Enable flash attention") numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")