Add caching for model loading
This commit is contained in:
parent
e679885733
commit
1e545ea7a0
63
server/model_loader.go
Normal file
63
server/model_loader.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type loadedModel struct {
|
||||||
|
model *llama.Model
|
||||||
|
modelPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelCache stores loaded models keyed by their full path and params hash
|
||||||
|
var modelCache sync.Map // map[string]*loadedModel
|
||||||
|
|
||||||
|
func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||||
|
modelName := model.ParseName(name)
|
||||||
|
if !modelName.IsValid() {
|
||||||
|
return nil, fmt.Errorf("invalid model name: %s", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelPath, err := GetModel(modelName.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %s", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create cache key from model path and params hash
|
||||||
|
cacheKey := fmt.Sprintf("%s-%+v", modelPath.ModelPath, params)
|
||||||
|
if cached, ok := modelCache.Load(cacheKey); ok {
|
||||||
|
return cached.(*loadedModel), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict existing model if any
|
||||||
|
evictExistingModel()
|
||||||
|
|
||||||
|
model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load model: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded := &loadedModel{
|
||||||
|
model: model,
|
||||||
|
modelPath: modelPath.ModelPath,
|
||||||
|
}
|
||||||
|
modelCache.Store(cacheKey, loaded)
|
||||||
|
|
||||||
|
return loaded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictExistingModel removes any currently loaded model from the cache
|
||||||
|
// Currently only supports a single model in cache at a time
|
||||||
|
// TODO: Add proper cache eviction policy (LRU/size/TTL based)
|
||||||
|
func evictExistingModel() {
|
||||||
|
modelCache.Range(func(key, value any) bool {
|
||||||
|
if cached, ok := modelCache.LoadAndDelete(key); ok {
|
||||||
|
llama.FreeModel(cached.(*loadedModel).model)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
@ -575,36 +575,16 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
loadedModel, err := LoadModel(req.Model, llama.ModelParams{
|
||||||
if !name.IsValid() {
|
|
||||||
http.Error(w, fmt.Sprintf("model name `%q` is invalid", req.Model), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
name, err := getExistingName(name)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get local model path
|
|
||||||
modelPath, err := GetModel(name.String())
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
model, err := llama.LoadModelFromFile(modelPath.ModelPath, llama.ModelParams{
|
|
||||||
VocabOnly: true,
|
VocabOnly: true,
|
||||||
UseMmap: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer llama.FreeModel(model)
|
|
||||||
|
|
||||||
// Tokenize the text
|
// Tokenize the text
|
||||||
tokens, err := model.Tokenize(req.Text, false, true)
|
tokens, err := loadedModel.model.Tokenize(req.Text, false, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to tokenize text: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@ -645,37 +625,17 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name := model.ParseName(req.Model)
|
loadedModel, err := LoadModel(req.Model, llama.ModelParams{
|
||||||
if !name.IsValid() {
|
|
||||||
http.Error(w, fmt.Sprintf("model name `%q` is invalid", req.Model), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
name, err := getExistingName(name)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get local model path
|
|
||||||
modelPath, err := GetModel(name.String())
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("model `%s` not found", req.Model), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
model, err := llama.LoadModelFromFile(modelPath.ModelPath, llama.ModelParams{
|
|
||||||
VocabOnly: true,
|
VocabOnly: true,
|
||||||
UseMmap: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer llama.FreeModel(model)
|
|
||||||
|
|
||||||
var text string
|
var text string
|
||||||
for _, token := range req.Tokens {
|
for _, token := range req.Tokens {
|
||||||
text += model.TokenToPiece(token)
|
text += loadedModel.model.TokenToPiece(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user