diff --git a/server/routes.go b/server/routes.go index 9017f84ea..4611ff084 100644 --- a/server/routes.go +++ b/server/routes.go @@ -392,6 +392,29 @@ func (s *Server) EmbedHandler(c *gin.Context) { sessionDuration = req.KeepAlive.Duration } + switch reqEmbed := req.Input.(type) { + case string: + if reqEmbed == "" { + c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) + return + } + case []any: + if reqEmbed == nil { + c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}}) + return + } + + for _, v := range reqEmbed { + if _, ok := v.(string); !ok { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + } + default: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + return + } + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) var runner *runnerRef select { @@ -424,10 +447,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { switch reqEmbed := req.Input.(type) { case string: - if reqEmbed == "" { - c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}}) - return - } reqEmbed, err = checkFit(reqEmbed, *req.Truncate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -435,24 +454,14 @@ func (s *Server) EmbedHandler(c *gin.Context) { } embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed}) case []any: - if reqEmbed == nil { - c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}}) - return - } - reqEmbedArray := make([]string, len(reqEmbed)) for i, v := range reqEmbed { - if s, ok := v.(string); ok { - s, err = checkFit(s, *req.Truncate) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - reqEmbedArray[i] = s - } else { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) + s, err := checkFit(v.(string), *req.Truncate) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + reqEmbedArray[i] = s } embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray) default: diff --git a/server/routes_test.go b/server/routes_test.go index 7c3def336..8666311cf 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -273,6 +273,54 @@ func Test_Routes(t *testing.T) { assert.Equal(t, "library", retrieveResp.OwnedBy) }, }, + { + Name: "Embed Handler Empty Input", + Method: http.MethodPost, + Path: "/api/embed", + Setup: func(t *testing.T, req *http.Request) { + embedReq := api.EmbedRequest{ + Model: "t-bone", + Input: "", + } + jsonData, err := json.Marshal(embedReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, "application/json; charset=utf-8", contentType) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var embedResp api.EmbedResponse + err = json.Unmarshal(body, &embedResp) + require.NoError(t, err) + + assert.Equal(t, "t-bone", embedResp.Model) + assert.Nil(t, embedResp.Embeddings) + }, + }, + { + Name: "Embed Handler Invalid Input", + Method: http.MethodPost, + Path: "/api/embed", + Setup: func(t *testing.T, req *http.Request) { + embedReq := api.EmbedRequest{ + Model: "t-bone", + Input: 2, + } + jsonData, err := json.Marshal(embedReq) + require.NoError(t, err) + req.Body = io.NopCloser(bytes.NewReader(jsonData)) + }, + Expected: func(t *testing.T, resp *http.Response) { + contentType := resp.Header.Get("Content-Type") + assert.Equal(t, "application/json; charset=utf-8", contentType) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir()) @@ -454,5 +502,5 @@ func TestNormalize(t *testing.T) { t.Errorf("Vector %v is not normalized", tc.input) } }) - } + } }