merge conflicts

This commit is contained in:
Roy Han 2024-07-09 14:00:13 -07:00
parent 786848dfd3
commit b686ac144c

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
@ -309,17 +310,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
var runner *runnerRef if err != nil {
select { handleScheduleError(c, req.Model, err)
case runner = <-rCh:
case err = <-eCh:
handleErrorResponse(c, err)
return return
} }
checkFit := func(s string, truncate bool) (string, error) { checkFit := func(s string, truncate bool) (string, error) {
tokens, err := runner.llama.Tokenize(c.Request.Context(), s) tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return "", err return "", err
@ -328,7 +326,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
if len(tokens) > opts.NumCtx { if len(tokens) > opts.NumCtx {
if truncate { if truncate {
tokens = tokens[:opts.NumCtx] tokens = tokens[:opts.NumCtx]
return runner.llama.Detokenize(c.Request.Context(), tokens) return r.Detokenize(c.Request.Context(), tokens)
} else { } else {
return "", fmt.Errorf("input length exceeds maximum context length") return "", fmt.Errorf("input length exceeds maximum context length")
} }
@ -346,7 +344,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed})
case []any: case []any:
reqEmbedArray := make([]string, len(reqEmbed)) reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed { for i, v := range reqEmbed {
@ -357,7 +355,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
reqEmbedArray[i] = s reqEmbedArray[i] = s
} }
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray) embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray)
default: default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
} }
@ -418,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := runner.llama.Embed(c.Request.Context(), []string{req.Prompt}) embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))