Add tests

This commit is contained in:
ParthSareen 2025-01-30 13:16:15 -08:00
parent 6de3227841
commit 5c2f35d846

View File

@ -7,7 +7,6 @@ import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
@ -82,7 +81,7 @@ func TestChatMiddleware(t *testing.T) {
{"role": "user", "content": "Hello"}
],
"stream": true,
"max_tokens": 999,
"max_completion_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
@ -315,6 +314,61 @@ func TestChatMiddleware(t *testing.T) {
Stream: &True,
},
},
{
name: "chat handler with num_ctx",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"num_ctx": 4096
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with max_completion_tokens < num_ctx",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 2
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_predict": 2.0, // float because JSON doesn't distinguish between float and int
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with max_completion_tokens > num_ctx",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 4096
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_predict": 4096.0, // float because JSON doesn't distinguish between float and int
"num_ctx": 4096.0,
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler error forwarding",
body: `{
@ -359,7 +413,7 @@ func TestChatMiddleware(t *testing.T) {
return
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
@ -493,12 +547,14 @@ func TestCompletionsMiddleware(t *testing.T) {
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
if capturedRequest != nil {
if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match (-want +got):\n%s", diff)
}
capturedRequest = nil
@ -577,12 +633,14 @@ func TestEmbeddingsMiddleware(t *testing.T) {
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
if capturedRequest != nil {
if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match (-want +got):\n%s", diff)
}
capturedRequest = nil
@ -656,8 +714,8 @@ func TestListMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match (-want +got):\n%s", diff)
}
}
}
@ -722,8 +780,8 @@ func TestRetrieveMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match (-want +got):\n%s", diff)
}
}
}