add session expiration

This commit is contained in:
Michael Yang
2023-07-19 15:00:28 -07:00
parent 3003fc03fc
commit f62a882760
3 changed files with 100 additions and 20 deletions

View File

@@ -22,16 +22,19 @@ import (
"github.com/jmorganca/ollama/llama"
)
var mu sync.Mutex
var activeSession struct {
ID int64
*llama.LLM
mu sync.Mutex
id int64
llm *llama.LLM
expireAt time.Time
expireTimer *time.Timer
}
func GenerateHandler(c *gin.Context) {
mu.Lock()
defer mu.Unlock()
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
checkpointStart := time.Now()
@@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) {
return
}
if req.SessionID == 0 || req.SessionID != activeSession.ID {
if activeSession.LLM != nil {
activeSession.Close()
activeSession.LLM = nil
if req.SessionID == 0 || req.SessionID != activeSession.id {
if activeSession.llm != nil {
activeSession.llm.Close()
activeSession.llm = nil
}
opts := api.DefaultOptions()
@@ -70,10 +73,34 @@ func GenerateHandler(c *gin.Context) {
return
}
activeSession.ID = time.Now().UnixNano()
activeSession.LLM = llm
activeSession.id = time.Now().UnixNano()
activeSession.llm = llm
}
sessionDuration := req.SessionDuration
sessionID := activeSession.id
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
if activeSession.expireTimer == nil {
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
if sessionID != activeSession.id {
return
}
if time.Now().Before(activeSession.expireAt) {
return
}
activeSession.llm.Close()
activeSession.llm = nil
activeSession.id = 0
})
}
activeSession.expireTimer.Reset(sessionDuration.Duration)
checkpointLoaded := time.Now()
prompt, err := model.Prompt(req)
@@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r api.GenerateResponse) {
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
activeSession.expireTimer.Reset(sessionDuration.Duration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.ID
r.SessionID = activeSession.id
r.SessionExpiresAt = activeSession.expireAt.UTC()
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
return
}
c.JSON(http.StatusOK, api.ListResponse{models})
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func CopyModelHandler(c *gin.Context) {