From 7fa9694359ed908e9193c3780930c5d1605a74c1 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 18 Feb 2025 14:10:11 -0800 Subject: [PATCH] model: add a test for model forward pass during implementation Adds a new test file to verify model forward pass behavior through JSON-specified test cases. The framework loads model files (.gguf) and their corresponding test specifications to validate expected outputs using greedy sampling. --- .gitignore | 3 + model/model_external_test.go | 138 +++++++++++++++++++++++++++++ model/testdata/models/README.md | 10 +++ model/testdata/models/qwen2_5.json | 7 ++ 4 files changed, 158 insertions(+) create mode 100644 model/model_external_test.go create mode 100644 model/testdata/models/README.md create mode 100644 model/testdata/models/qwen2_5.json diff --git a/.gitignore b/.gitignore index 551abec87..a9172c7f1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ test_data __debug_bin* llama/build llama/vendor +model/testdata/models/* +!model/testdata/models/*.md +!model/testdata/models/*.json diff --git a/model/model_external_test.go b/model/model_external_test.go new file mode 100644 index 000000000..dc8f6fb61 --- /dev/null +++ b/model/model_external_test.go @@ -0,0 +1,138 @@ +// Package model_test provides external tests for the model package. +// This test file specifically tests the forward pass functionality on models. +// It is in a separate package (model_test) to avoid import cycles while still +// being able to test the public API of the model package. +package model_test + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/sample" + + _ "github.com/ollama/ollama/model/models" +) + +type modelTest struct { + Prompt string `json:"prompt"` + OutputContainsOne []string `json:"output_contains_one"` +} + +func TestForwardSimple(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // Read all JSON files from testdata/models + files, err := os.ReadDir("testdata/models") + if err != nil { + t.Fatal(err) + } + + for _, file := range files { + if !strings.HasSuffix(file.Name(), ".json") { + continue + } + + jsonPath := filepath.Join("testdata/models", file.Name()) + ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf") + + // Skip if no corresponding .gguf file exists + if _, err := os.Stat(ggufPath); err != nil { + t.Logf("skipping %s: no corresponding GGUF file found", file.Name()) + continue + } + + data, err := os.ReadFile(jsonPath) + if err != nil { + t.Fatal(err) + } + + var test modelTest + if err := json.Unmarshal(data, &test); err != nil { + t.Fatal(err) + } + + t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) { + m, err := model.New(ggufPath) + if err != nil { + t.Fatal(err) + } + + m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048) + + inputs, err := m.(model.TextProcessor).Encode(test.Prompt) + if err != nil { + t.Fatal(err) + } + + var result []string + for len(result) < 100 { // Limit to 100 tokens max + options := model.Options{ + Inputs: inputs, + Positions: make([]int32, len(inputs)), + Sequences: make([]int, len(inputs)), + Outputs: []int32{int32(len(inputs) - 1)}, + } + for i := range options.Positions { + options.Positions[i] = int32(i) + options.Sequences[i] = 0 + } + + ctx := m.Backend().NewContext() + + modelOutput, err := model.Forward(ctx, m, options) + if err != nil { + ctx.Close() + t.Fatal(fmt.Errorf("forward pass failed: %v", err)) + } + + f32s := modelOutput.Floats() + logits := make([]float64, len(f32s)) + for i, f32 := range f32s { + logits[i] = float64(f32) + } + + token, err := sample.Sample(logits, sample.Greedy()) + if err != nil { + ctx.Close() + t.Fatal(fmt.Errorf("sampling failed: %v", err)) + } + + ctx.Close() + + // Greedy sampling: take the token with the highest logit + nextToken := int32(token[0]) + if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) { + break + } + + piece, err := m.(model.TextProcessor).Decode([]int32{nextToken}) + if err != nil { + t.Fatal(err) + } + + result = append(result, piece) + output := strings.Join(result, "") + + for _, expectedOutput := range test.OutputContainsOne { + if strings.Contains(output, expectedOutput) { + t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput) + return + } + } + + // Maintain full context by appending new token + inputs = append(inputs, nextToken) + } + + t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, "")) + }) + } +} diff --git a/model/testdata/models/README.md b/model/testdata/models/README.md new file mode 100644 index 000000000..3c6308d1a --- /dev/null +++ b/model/testdata/models/README.md @@ -0,0 +1,10 @@ +# Test Model Directory + +This directory is used for storing model files (like `.gguf` files) that are required to run the tests in `model_external_test.go`. + +## Usage + +- Place any model files you need for testing in this directory +- The test file will look for any model files here (e.g., `llama3.gguf`) +- All non-markdown files in this directory are git-ignored to prevent large model files from being committed to the repository +- Only `.md` files (like this README) will be tracked in git diff --git a/model/testdata/models/qwen2_5.json b/model/testdata/models/qwen2_5.json new file mode 100644 index 000000000..182003bc2 --- /dev/null +++ b/model/testdata/models/qwen2_5.json @@ -0,0 +1,7 @@ +{ + "prompt": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n", + "output_contains_one": [ + "Hello", + "Hi" + ] +}