Working e2e logits

This commit is contained in:
ParthSareen 2025-01-03 16:06:28 -08:00
parent c92d418a7c
commit f9928b677f
3 changed files with 7 additions and 9 deletions

View File

@ -310,12 +310,10 @@ func flushPending(seq *Sequence) bool {
// Add logits if requested and available
if seq.returnLogits && seq.logits != nil {
slog.Info("returning logits - flushPending")
resp.Logits = seq.logits
seq.logits = nil
}
slog.Info("returning logits - flushPending", "logits", resp.Logits[0])
select {
case seq.responses <- resp:
return true
@ -503,9 +501,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
// Before sampling:
if seq.returnLogits { // New flag we need to add to Sequence struct
slog.Info("returning logits")
seq.logits = s.lc.GetLogits() // Using our new GetLogits() method
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
slog.Info("copying logits")
copy(seq.logits, logits)
slog.Info("copying logits success")
}
// Then sample token
@ -728,7 +728,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
close(seq.quit)
return
case content, ok := <-seq.responses:
slog.Info("logits in last chan", "content", content.Logits[0])
if ok {
slog.Info("content", "content", content.Content)
if err := json.NewEncoder(w).Encode(&content); err != nil {

View File

@ -633,7 +633,8 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
ws ::= ([ \t\n] ws)?
`
const maxBufferSize = 512 * format.KiloByte
// TODO: change back to 512 * format.KiloByte
const maxBufferSize = 2048 * format.KiloByte
type ImageData struct {
Data []byte `json:"data"`

View File

@ -1543,8 +1543,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
slog.Debug("chat request", "images", len(images), "prompt", prompt)
slog.Info("chat request", "return_logits", req.ReturnLogits)
ch := make(chan any)
go func() {
defer close(ch)