input handling and handler testing

This commit is contained in:
Roy Han 2024-07-03 12:48:54 -07:00
parent c0fa2236cf
commit 922b8f2584
2 changed files with 76 additions and 19 deletions

View File

@ -392,6 +392,29 @@ func (s *Server) EmbedHandler(c *gin.Context) {
sessionDuration = req.KeepAlive.Duration
}
switch reqEmbed := req.Input.(type) {
case string:
if reqEmbed == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
case []any:
if reqEmbed == nil {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
for _, v := range reqEmbed {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
var runner *runnerRef
select {
@ -424,10 +447,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
switch reqEmbed := req.Input.(type) {
case string:
if reqEmbed == "" {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
return
}
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -435,24 +454,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
case []any:
if reqEmbed == nil {
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
return
}
reqEmbedArray := make([]string, len(reqEmbed))
for i, v := range reqEmbed {
if s, ok := v.(string); ok {
s, err = checkFit(s, *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbedArray[i] = s
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
s, err := checkFit(v.(string), *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
reqEmbedArray[i] = s
}
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
default:

View File

@ -273,6 +273,54 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy)
},
},
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json; charset=utf-8", contentType)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
require.NoError(t, err)
assert.Equal(t, "t-bone", embedResp.Model)
assert.Nil(t, embedResp.Embeddings)
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json; charset=utf-8", contentType)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
},
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())
@ -454,5 +502,5 @@ func TestNormalize(t *testing.T) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}
}