Add tests

This commit is contained in:
ParthSareen 2025-01-30 13:16:15 -08:00
parent 6de3227841
commit 5c2f35d846

View File

@ -7,7 +7,6 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -82,7 +81,7 @@ func TestChatMiddleware(t *testing.T) {
{"role": "user", "content": "Hello"} {"role": "user", "content": "Hello"}
], ],
"stream": true, "stream": true,
"max_tokens": 999, "max_completion_tokens": 999,
"seed": 123, "seed": 123,
"stop": ["\n", "stop"], "stop": ["\n", "stop"],
"temperature": 3.0, "temperature": 3.0,
@ -315,6 +314,61 @@ func TestChatMiddleware(t *testing.T) {
Stream: &True, Stream: &True,
}, },
}, },
{
name: "chat handler with num_ctx",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"num_ctx": 4096
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with max_completion_tokens < num_ctx",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 2
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_predict": 2.0, // float because JSON doesn't distinguish between float and int
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with max_completion_tokens > num_ctx",
body: `{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 4096
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Options: map[string]any{
"num_predict": 4096.0, // float because JSON doesn't distinguish between float and int
"num_ctx": 4096.0,
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{ {
name: "chat handler error forwarding", name: "chat handler error forwarding",
body: `{ body: `{
@ -359,7 +413,7 @@ func TestChatMiddleware(t *testing.T) {
return return
} }
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff) t.Fatalf("requests did not match (-want +got):\n%s", diff)
} }
if diff := cmp.Diff(tc.err, errResp); diff != "" { if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
@ -493,12 +547,14 @@ func TestCompletionsMiddleware(t *testing.T) {
} }
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { if capturedRequest != nil {
t.Fatal("requests did not match") if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
} }
if !reflect.DeepEqual(tc.err, errResp) { if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatal("errors did not match") t.Fatalf("errors did not match (-want +got):\n%s", diff)
} }
capturedRequest = nil capturedRequest = nil
@ -577,12 +633,14 @@ func TestEmbeddingsMiddleware(t *testing.T) {
} }
} }
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { if capturedRequest != nil {
t.Fatal("requests did not match") if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" {
t.Fatalf("requests did not match (-want +got):\n%s", diff)
}
} }
if !reflect.DeepEqual(tc.err, errResp) { if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatal("errors did not match") t.Fatalf("errors did not match (-want +got):\n%s", diff)
} }
capturedRequest = nil capturedRequest = nil
@ -656,8 +714,8 @@ func TestListMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err) t.Fatalf("failed to unmarshal actual response: %v", err)
} }
if !reflect.DeepEqual(expected, actual) { if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) t.Errorf("responses did not match (-want +got):\n%s", diff)
} }
} }
} }
@ -722,8 +780,8 @@ func TestRetrieveMiddleware(t *testing.T) {
t.Fatalf("failed to unmarshal actual response: %v", err) t.Fatalf("failed to unmarshal actual response: %v", err)
} }
if !reflect.DeepEqual(expected, actual) { if diff := cmp.Diff(expected, actual); diff != "" {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) t.Errorf("responses did not match (-want +got):\n%s", diff)
} }
} }
} }