From 11acb85ff34eedd260b5b9835a6ea6b224448ce6 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 17 Dec 2024 16:54:47 -0800 Subject: [PATCH] WIP --- server/model_loader.go | 20 +++-- server/routes.go | 5 +- server/routes_tokenize_test.go | 144 +++++++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 10 deletions(-) create mode 100644 server/routes_tokenize_test.go diff --git a/server/model_loader.go b/server/model_loader.go index ff8e5fda9..a2e2bfcc9 100644 --- a/server/model_loader.go +++ b/server/model_loader.go @@ -9,14 +9,18 @@ import ( ) type loadedModel struct { - model *llama.Model + model llama.Model modelPath string } +type modelLoader struct { + cache sync.Map +} + // modelCache stores loaded models keyed by their full path and params hash var modelCache sync.Map // map[string]*loadedModel -func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { +func (ml *modelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { modelName := model.ParseName(name) if !modelName.IsValid() { return nil, fmt.Errorf("invalid model name: %s", modelName) @@ -34,7 +38,7 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { } // Evict existing model if any - evictExistingModel() + ml.evictExistingModel() model, err := llama.LoadModelFromFile(modelPath.ModelPath, params) if err != nil { @@ -42,7 +46,7 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { } loaded := &loadedModel{ - model: model, + model: *model, modelPath: modelPath.ModelPath, } modelCache.Store(cacheKey, loaded) @@ -53,10 +57,10 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { // evictExistingModel removes any currently loaded model from the cache // Currently only supports a single model in cache at a time // TODO: Add proper cache eviction policy (LRU/size/TTL based) -func evictExistingModel() { - modelCache.Range(func(key, value any) bool { - if cached, ok := modelCache.LoadAndDelete(key); ok { - llama.FreeModel(cached.(*loadedModel).model) +func (ml *modelLoader) evictExistingModel() { + ml.cache.Range(func(key, value any) bool { + if cached, ok := ml.cache.LoadAndDelete(key); ok { + llama.FreeModel(&cached.(*loadedModel).model) } return true }) diff --git a/server/routes.go b/server/routes.go index 5b73a8d18..1552798bb 100644 --- a/server/routes.go +++ b/server/routes.go @@ -47,6 +47,7 @@ var mode string = gin.DebugMode type Server struct { addr net.Addr sched *Scheduler + ml modelLoader } func init() { @@ -575,7 +576,7 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) { return } - loadedModel, err := LoadModel(req.Model, llama.ModelParams{ + loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{ VocabOnly: true, }) if err != nil { @@ -625,7 +626,7 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) { return } - loadedModel, err := LoadModel(req.Model, llama.ModelParams{ + loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{ VocabOnly: true, }) if err != nil { diff --git a/server/routes_tokenize_test.go b/server/routes_tokenize_test.go new file mode 100644 index 000000000..86f82851d --- /dev/null +++ b/server/routes_tokenize_test.go @@ -0,0 +1,144 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/discover" + "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llm" +) + +type mockModelLoader struct { + LoadModelFn func(string, llama.ModelParams) (*loadedModel, error) +} + +func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) { + if ml.LoadModelFn != nil { + return ml.LoadModelFn(name, params) + } + + return &loadedModel{ + model: mockModel{}, + }, nil +} + +type mockModel struct { + llama.Model + TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error) + TokenToPieceFn func(token int) string +} + +func (m *mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) { + return []int{1, 2, 3}, nil +} + +func (m *mockModel) TokenToPiece(token int) string { + return fmt.Sprint(token) +} + +func TestTokenizeHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + + mockLoader := mockModelLoader{ + LoadModelFn: func(name string, params llama.ModelParams) (*loadedModel, error) { + return &loadedModel{ + model: mockModel{}, + }, nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mockRunner{}), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) { + time.Sleep(time.Millisecond) + req.successCh <- &runnerRef{ + llama: &mockRunner{}, + } + }, + }, + ml: mockLoader, + } + + t.Run("method not allowed", func(t *testing.T) { + w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil) + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) + } + }) + + t.Run("missing body", func(t *testing.T) { + w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("missing text", func(t *testing.T) { + w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{ + Model: "test", + }) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("missing model", func(t *testing.T) { + w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{ + Text: "test text", + }) + if w.Code != http.StatusBadRequest { + t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + } + }) + + t.Run("model not found", func(t *testing.T) { + w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{ + Model: "nonexistent", + Text: "test text", + }) + if w.Code != http.StatusInternalServerError { + t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code) + } + }) + + t.Run("successful tokenization", func(t *testing.T) { + w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{ + Model: "test", + Text: "test text", + }) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d, got %d", http.StatusOK, w.Code) + } + + var resp api.TokenizeResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatal(err) + } + + expectedTokens := []int{0, 1} + if len(resp.Tokens) != len(expectedTokens) { + t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens)) + } + for i, token := range resp.Tokens { + if token != expectedTokens[i] { + t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token) + } + } + }) +}