diff --git a/server/routes.go b/server/routes.go index cd005fc25..0e57251d4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -9,6 +9,7 @@ import ( "io" "io/fs" "log/slog" + "math" "net" "net/http" "net/netip" @@ -309,32 +310,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { streamResponse(c, ch) } -func getDefaultSessionDuration() time.Duration { - if envconfig.KeepAlive != "" { - v, err := strconv.Atoi(envconfig.KeepAlive) - if err != nil { - d, err := time.ParseDuration(envconfig.KeepAlive) - if err != nil { - return defaultSessionDuration - } - - if d < 0 { - return time.Duration(math.MaxInt64) - } - - return d - } - - d := time.Duration(v) * time.Second - if d < 0 { - return time.Duration(math.MaxInt64) - } - return d - } - - return defaultSessionDuration -} - func (s *Server) EmbedHandler(c *gin.Context) { var req api.EmbedRequest err := c.ShouldBindJSON(&req) @@ -374,13 +349,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - var sessionDuration time.Duration - if req.KeepAlive == nil { - sessionDuration = getDefaultSessionDuration() - } else { - sessionDuration = req.KeepAlive.Duration - } - switch reqEmbed := req.Input.(type) { case string: if reqEmbed == "" { @@ -404,7 +372,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, req.KeepAlive) var runner *runnerRef select { case runner = <-rCh: