From 1e545ea7a061df0bd357c2fc60f1031151081c3c Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 17 Dec 2024 15:31:02 -0800 Subject: [PATCH] Add caching for model loading --- server/model_loader.go | 63 ++++++++++++++++++++++++++++++++++++++++++ server/routes.go | 48 +++----------------------------- 2 files changed, 67 insertions(+), 44 deletions(-) create mode 100644 server/model_loader.go diff --git a/server/model_loader.go b/server/model_loader.go new file mode 100644 index 000000000..ff8e5fda9 --- /dev/null +++ b/server/model_loader.go @@ -0,0 +1,63 @@ +package server + +import ( + "fmt" + "sync" + + "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/types/model" +) + +type loadedModel struct { + model *llama.Model + modelPath string +} + +// modelCache stores loaded models keyed by their full path and params hash +var modelCache sync.Map // map[string]*loadedModel + +func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { + modelName := model.ParseName(name) + if !modelName.IsValid() { + return nil, fmt.Errorf("invalid model name: %s", modelName) + } + + modelPath, err := GetModel(modelName.String()) + if err != nil { + return nil, fmt.Errorf("model not found: %s", modelName) + } + + // Create cache key from model path and params hash + cacheKey := fmt.Sprintf("%s-%+v", modelPath.ModelPath, params) + if cached, ok := modelCache.Load(cacheKey); ok { + return cached.(*loadedModel), nil + } + + // Evict existing model if any + evictExistingModel() + + model, err := llama.LoadModelFromFile(modelPath.ModelPath, params) + if err != nil { + return nil, fmt.Errorf("failed to load model: %v", err) + } + + loaded := &loadedModel{ + model: model, + modelPath: modelPath.ModelPath, + } + modelCache.Store(cacheKey, loaded) + + return loaded, nil +} + +// evictExistingModel removes any currently loaded model from the cache +// Currently only supports a single model in cache at a time +// TODO: Add proper cache eviction policy (LRU/size/TTL based) +func evictExistingModel() { + modelCache.Range(func(key, value any) bool { + if cached, ok := modelCache.LoadAndDelete(key); ok { + llama.FreeModel(cached.(*loadedModel).model) + } + return true + }) +} diff --git a/server/routes.go b/server/routes.go index 6788b6d9a..5b73a8d18 100644 --- a/server/routes.go +++ b/server/routes.go @@ -575,36 +575,16 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) { return } - name := model.ParseName(req.Model) - if !name.IsValid() { - http.Error(w, fmt.Sprintf("model name `%q` is invalid", req.Model), http.StatusBadRequest) - return - } - name, err := getExistingName(name) - if err != nil { - http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound) - return - } - - // Get local model path - modelPath, err := GetModel(name.String()) - if err != nil { - http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound) - return - } - - model, err := llama.LoadModelFromFile(modelPath.ModelPath, llama.ModelParams{ + loadedModel, err := LoadModel(req.Model, llama.ModelParams{ VocabOnly: true, - UseMmap: true, }) if err != nil { http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError) return } - defer llama.FreeModel(model) // Tokenize the text - tokens, err := model.Tokenize(req.Text, false, true) + tokens, err := loadedModel.model.Tokenize(req.Text, false, true) if err != nil { http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError) return @@ -645,37 +625,17 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) { return } - name := model.ParseName(req.Model) - if !name.IsValid() { - http.Error(w, fmt.Sprintf("model name `%q` is invalid", req.Model), http.StatusBadRequest) - return - } - name, err := getExistingName(name) - if err != nil { - http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound) - return - } - - // Get local model path - modelPath, err := GetModel(name.String()) - if err != nil { - http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound) - return - } - - model, err := llama.LoadModelFromFile(modelPath.ModelPath, llama.ModelParams{ + loadedModel, err := LoadModel(req.Model, llama.ModelParams{ VocabOnly: true, - UseMmap: true, }) if err != nil { http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError) return } - defer llama.FreeModel(model) var text string for _, token := range req.Tokens { - text += model.TokenToPiece(token) + text += loadedModel.model.TokenToPiece(token) } w.Header().Set("Content-Type", "application/json")