input handling and handler testing
This commit is contained in:
parent
c0fa2236cf
commit
922b8f2584
@ -392,6 +392,29 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch reqEmbed := req.Input.(type) {
|
||||||
|
case string:
|
||||||
|
if reqEmbed == "" {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
if reqEmbed == nil {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range reqEmbed {
|
||||||
|
if _, ok := v.(string); !ok {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
||||||
var runner *runnerRef
|
var runner *runnerRef
|
||||||
select {
|
select {
|
||||||
@ -424,10 +447,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
|
|
||||||
switch reqEmbed := req.Input.(type) {
|
switch reqEmbed := req.Input.(type) {
|
||||||
case string:
|
case string:
|
||||||
if reqEmbed == "" {
|
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@ -435,24 +454,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
|
||||||
case []any:
|
case []any:
|
||||||
if reqEmbed == nil {
|
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
reqEmbedArray := make([]string, len(reqEmbed))
|
reqEmbedArray := make([]string, len(reqEmbed))
|
||||||
for i, v := range reqEmbed {
|
for i, v := range reqEmbed {
|
||||||
if s, ok := v.(string); ok {
|
s, err := checkFit(v.(string), *req.Truncate)
|
||||||
s, err = checkFit(s, *req.Truncate)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqEmbedArray[i] = s
|
reqEmbedArray[i] = s
|
||||||
} else {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
embeddings, err = runner.llama.Embed(c.Request.Context(), reqEmbedArray)
|
||||||
default:
|
default:
|
||||||
|
@ -273,6 +273,54 @@ func Test_Routes(t *testing.T) {
|
|||||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "Embed Handler Empty Input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
embedReq := api.EmbedRequest{
|
||||||
|
Model: "t-bone",
|
||||||
|
Input: "",
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(embedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var embedResp api.EmbedResponse
|
||||||
|
err = json.Unmarshal(body, &embedResp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "t-bone", embedResp.Model)
|
||||||
|
assert.Nil(t, embedResp.Embeddings)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Embed Handler Invalid Input",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/embed",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
embedReq := api.EmbedRequest{
|
||||||
|
Model: "t-bone",
|
||||||
|
Input: 2,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(embedReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||||
|
_, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 400, resp.StatusCode)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user