Add tests
This commit is contained in:
parent
6de3227841
commit
5c2f35d846
@ -7,7 +7,6 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -82,7 +81,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 999,
|
||||
"max_completion_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
@ -315,6 +314,61 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with num_ctx",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"num_ctx": 4096
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Options: map[string]any{
|
||||
"num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with max_completion_tokens < num_ctx",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_completion_tokens": 2
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Options: map[string]any{
|
||||
"num_predict": 2.0, // float because JSON doesn't distinguish between float and int
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with max_completion_tokens > num_ctx",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_completion_tokens": 4096
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Options: map[string]any{
|
||||
"num_predict": 4096.0, // float because JSON doesn't distinguish between float and int
|
||||
"num_ctx": 4096.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
@ -359,7 +413,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
t.Fatalf("requests did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||
@ -493,12 +547,14 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
if capturedRequest != nil {
|
||||
if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
@ -577,12 +633,14 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
if capturedRequest != nil {
|
||||
if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
@ -656,8 +714,8 @@ func TestListMiddleware(t *testing.T) {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
if diff := cmp.Diff(expected, actual); diff != "" {
|
||||
t.Errorf("responses did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -722,8 +780,8 @@ func TestRetrieveMiddleware(t *testing.T) {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
if diff := cmp.Diff(expected, actual); diff != "" {
|
||||
t.Errorf("responses did not match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user