From 5c2f35d846a462c87155be512e652d774ddddb62 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 30 Jan 2025 13:16:15 -0800 Subject: [PATCH] Add tests --- openai/openai_test.go | 88 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 15 deletions(-) diff --git a/openai/openai_test.go b/openai/openai_test.go index d8c821d39..32f680505 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "net/http/httptest" - "reflect" "strings" "testing" "time" @@ -82,7 +81,7 @@ func TestChatMiddleware(t *testing.T) { {"role": "user", "content": "Hello"} ], "stream": true, - "max_tokens": 999, + "max_completion_tokens": 999, "seed": 123, "stop": ["\n", "stop"], "temperature": 3.0, @@ -315,6 +314,61 @@ func TestChatMiddleware(t *testing.T) { Stream: &True, }, }, + { + name: "chat handler with num_ctx", + body: `{ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "num_ctx": 4096 + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Options: map[string]any{ + "num_ctx": 4096.0, // float because JSON doesn't distinguish between float and int + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with max_completion_tokens < num_ctx", + body: `{ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "max_completion_tokens": 2 + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Options: map[string]any{ + "num_predict": 2.0, // float because JSON doesn't distinguish between float and int + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with max_completion_tokens > num_ctx", + body: `{ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "max_completion_tokens": 4096 + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Options: map[string]any{ + "num_predict": 4096.0, // float because JSON doesn't distinguish between float and int + "num_ctx": 4096.0, + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, { name: "chat handler error forwarding", body: `{ @@ -359,7 +413,7 @@ func TestChatMiddleware(t *testing.T) { return } if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" { - t.Fatalf("requests did not match: %+v", diff) + t.Fatalf("requests did not match (-want +got):\n%s", diff) } if diff := cmp.Diff(tc.err, errResp); diff != "" { t.Fatalf("errors did not match for %s:\n%s", tc.name, diff) @@ -493,12 +547,14 @@ func TestCompletionsMiddleware(t *testing.T) { } } - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + if capturedRequest != nil { + if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" { + t.Fatalf("requests did not match (-want +got):\n%s", diff) + } } - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match (-want +got):\n%s", diff) } capturedRequest = nil @@ -577,12 +633,14 @@ func TestEmbeddingsMiddleware(t *testing.T) { } } - if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { - t.Fatal("requests did not match") + if capturedRequest != nil { + if diff := cmp.Diff(tc.req, *capturedRequest); diff != "" { + t.Fatalf("requests did not match (-want +got):\n%s", diff) + } } - if !reflect.DeepEqual(tc.err, errResp) { - t.Fatal("errors did not match") + if diff := cmp.Diff(tc.err, errResp); diff != "" { + t.Fatalf("errors did not match (-want +got):\n%s", diff) } capturedRequest = nil @@ -656,8 +714,8 @@ func TestListMiddleware(t *testing.T) { t.Fatalf("failed to unmarshal actual response: %v", err) } - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Errorf("responses did not match (-want +got):\n%s", diff) } } } @@ -722,8 +780,8 @@ func TestRetrieveMiddleware(t *testing.T) { t.Fatalf("failed to unmarshal actual response: %v", err) } - if !reflect.DeepEqual(expected, actual) { - t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + if diff := cmp.Diff(expected, actual); diff != "" { + t.Errorf("responses did not match (-want +got):\n%s", diff) } } }