input handling and handler testing
This commit is contained in:
parent
c0fa2236cf
commit
922b8f2584
@ -392,6 +392,29 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
switch reqEmbed := req.Input.(type) {
|
||||
case string:
|
||||
if reqEmbed == "" {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
case []any:
|
||||
if reqEmbed == nil {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range reqEmbed {
|
||||
if _, ok := v.(string); !ok {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
}
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
return
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
@ -424,10 +447,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
|
||||
switch reqEmbed := req.Input.(type) {
|
||||
case string:
|
||||
if reqEmbed == "" {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@ -435,24 +454,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||
case []any:
|
||||
if reqEmbed == nil {
|
||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
||||
return
|
||||
}
|
||||
|
||||
reqEmbedArray := make([]string, len(reqEmbed))
|
||||
for i, v := range reqEmbed {
|
||||
if s, ok := v.(string); ok {
|
||||
s, err = checkFit(s, *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
reqEmbedArray[i] = s
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||
s, err := checkFit(v.(string), *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
reqEmbedArray[i] = s
|
||||
}
|
||||
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
||||
default:
|
||||
|
@ -273,6 +273,54 @@ func Test_Routes(t *testing.T) {
|
||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Embed Handler Empty Input",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/embed",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
embedReq := api.EmbedRequest{
|
||||
Model: "t-bone",
|
||||
Input: "",
|
||||
}
|
||||
jsonData, err := json.Marshal(embedReq)
|
||||
require.NoError(t, err)
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var embedResp api.EmbedResponse
|
||||
err = json.Unmarshal(body, &embedResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "t-bone", embedResp.Model)
|
||||
assert.Nil(t, embedResp.Embeddings)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Embed Handler Invalid Input",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/embed",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
embedReq := api.EmbedRequest{
|
||||
Model: "t-bone",
|
||||
Input: 2,
|
||||
}
|
||||
jsonData, err := json.Marshal(embedReq)
|
||||
require.NoError(t, err)
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
@ -454,5 +502,5 @@ func TestNormalize(t *testing.T) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user