WIP
This commit is contained in:
parent
1e545ea7a0
commit
11acb85ff3
@ -9,14 +9,18 @@ import (
|
||||
)
|
||||
|
||||
type loadedModel struct {
|
||||
model *llama.Model
|
||||
model llama.Model
|
||||
modelPath string
|
||||
}
|
||||
|
||||
type modelLoader struct {
|
||||
cache sync.Map
|
||||
}
|
||||
|
||||
// modelCache stores loaded models keyed by their full path and params hash
|
||||
var modelCache sync.Map // map[string]*loadedModel
|
||||
|
||||
func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
func (ml *modelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
modelName := model.ParseName(name)
|
||||
if !modelName.IsValid() {
|
||||
return nil, fmt.Errorf("invalid model name: %s", modelName)
|
||||
@ -34,7 +38,7 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
}
|
||||
|
||||
// Evict existing model if any
|
||||
evictExistingModel()
|
||||
ml.evictExistingModel()
|
||||
|
||||
model, err := llama.LoadModelFromFile(modelPath.ModelPath, params)
|
||||
if err != nil {
|
||||
@ -42,7 +46,7 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
}
|
||||
|
||||
loaded := &loadedModel{
|
||||
model: model,
|
||||
model: *model,
|
||||
modelPath: modelPath.ModelPath,
|
||||
}
|
||||
modelCache.Store(cacheKey, loaded)
|
||||
@ -53,10 +57,10 @@ func LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
// evictExistingModel removes any currently loaded model from the cache
|
||||
// Currently only supports a single model in cache at a time
|
||||
// TODO: Add proper cache eviction policy (LRU/size/TTL based)
|
||||
func evictExistingModel() {
|
||||
modelCache.Range(func(key, value any) bool {
|
||||
if cached, ok := modelCache.LoadAndDelete(key); ok {
|
||||
llama.FreeModel(cached.(*loadedModel).model)
|
||||
func (ml *modelLoader) evictExistingModel() {
|
||||
ml.cache.Range(func(key, value any) bool {
|
||||
if cached, ok := ml.cache.LoadAndDelete(key); ok {
|
||||
llama.FreeModel(&cached.(*loadedModel).model)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
@ -47,6 +47,7 @@ var mode string = gin.DebugMode
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
ml modelLoader
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -575,7 +576,7 @@ func (s *Server) TokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
loadedModel, err := LoadModel(req.Model, llama.ModelParams{
|
||||
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
|
||||
VocabOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
@ -625,7 +626,7 @@ func (s *Server) DetokenizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
loadedModel, err := LoadModel(req.Model, llama.ModelParams{
|
||||
loadedModel, err := s.ml.LoadModel(req.Model, llama.ModelParams{
|
||||
VocabOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
|
144
server/routes_tokenize_test.go
Normal file
144
server/routes_tokenize_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type mockModelLoader struct {
|
||||
LoadModelFn func(string, llama.ModelParams) (*loadedModel, error)
|
||||
}
|
||||
|
||||
func (ml *mockModelLoader) LoadModel(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
if ml.LoadModelFn != nil {
|
||||
return ml.LoadModelFn(name, params)
|
||||
}
|
||||
|
||||
return &loadedModel{
|
||||
model: mockModel{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockModel struct {
|
||||
llama.Model
|
||||
TokenizeFn func(text string, addBos bool, addEos bool) ([]int, error)
|
||||
TokenToPieceFn func(token int) string
|
||||
}
|
||||
|
||||
func (m *mockModel) Tokenize(text string, addBos bool, addEos bool) ([]int, error) {
|
||||
return []int{1, 2, 3}, nil
|
||||
}
|
||||
|
||||
func (m *mockModel) TokenToPiece(token int) string {
|
||||
return fmt.Sprint(token)
|
||||
}
|
||||
|
||||
func TestTokenizeHandler(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockLoader := mockModelLoader{
|
||||
LoadModelFn: func(name string, params llama.ModelParams) (*loadedModel, error) {
|
||||
return &loadedModel{
|
||||
model: mockModel{},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mockRunner{}),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mockRunner{},
|
||||
}
|
||||
},
|
||||
},
|
||||
ml: mockLoader,
|
||||
}
|
||||
|
||||
t.Run("method not allowed", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing body", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), nil)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing text", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Model: "test",
|
||||
})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Text: "test text",
|
||||
})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model not found", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Model: "nonexistent",
|
||||
Text: "test text",
|
||||
})
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful tokenization", func(t *testing.T) {
|
||||
w := createRequest(t, gin.WrapF(s.TokenizeHandler), api.TokenizeRequest{
|
||||
Model: "test",
|
||||
Text: "test text",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
var resp api.TokenizeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedTokens := []int{0, 1}
|
||||
if len(resp.Tokens) != len(expectedTokens) {
|
||||
t.Errorf("expected %d tokens, got %d", len(expectedTokens), len(resp.Tokens))
|
||||
}
|
||||
for i, token := range resp.Tokens {
|
||||
if token != expectedTokens[i] {
|
||||
t.Errorf("expected token %d at position %d, got %d", expectedTokens[i], i, token)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user