diff --git a/server/routes.go b/server/routes.go index 7a79856b0..a975e94fe 100644 --- a/server/routes.go +++ b/server/routes.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "io/fs" "log/slog" "net" "net/http" @@ -309,17 +310,14 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) - var runner *runnerRef - select { - case runner = <-rCh: - case err = <-eCh: - handleErrorResponse(c, err) + r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) + if err != nil { + handleScheduleError(c, req.Model, err) return } checkFit := func(s string, truncate bool) (string, error) { - tokens, err := runner.llama.Tokenize(c.Request.Context(), s) + tokens, err := r.Tokenize(c.Request.Context(), s) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return "", err @@ -328,7 +326,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { if len(tokens) > opts.NumCtx { if truncate { tokens = tokens[:opts.NumCtx] - return runner.llama.Detokenize(c.Request.Context(), tokens) + return r.Detokenize(c.Request.Context(), tokens) } else { return "", fmt.Errorf("input length exceeds maximum context length") } @@ -346,7 +344,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) + embeddings, err = r.Embed(c.Request.Context(), []string{reqEmbed}) case []any: reqEmbedArray := make([]string, len(reqEmbed)) for i, v := range reqEmbed { @@ -357,7 +355,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { } reqEmbedArray[i] = s } - embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray) + embeddings, err = r.Embed(c.Request.Context(), reqEmbedArray) default: c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) } @@ -418,7 +416,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - embedding, err := runner.llama.Embed(c.Request.Context(), []string{req.Prompt}) + embedding, err := r.Embed(c.Request.Context(), []string{req.Prompt}) if err != nil { slog.Info(fmt.Sprintf("embedding generation failed: %v", err))