model: add a test for model forward pass during implementation
Adds a new test file to verify model forward pass behavior through JSON-specified test cases. The framework loads model files (.gguf) and their corresponding test specifications to validate expected outputs using greedy sampling.
This commit is contained in:
parent
7b5d916a9a
commit
31d04eb795
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,3 +14,6 @@ test_data
|
||||
__debug_bin*
|
||||
llama/build
|
||||
llama/vendor
|
||||
model/testdata/models/*
|
||||
!model/testdata/models/*.md
|
||||
!model/testdata/models/*.json
|
||||
|
138
model/model_external_test.go
Normal file
138
model/model_external_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Package model_test provides external tests for the model package.
|
||||
// This test file specifically tests the forward pass functionality on models.
|
||||
// It is in a separate package (model_test) to avoid import cycles while still
|
||||
// being able to test the public API of the model package.
|
||||
package model_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/sample"
|
||||
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
type modelTest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
OutputContainsOne []string `json:"output_contains_one"`
|
||||
}
|
||||
|
||||
func TestForwardSimple(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
// Read all JSON files from testdata/models
|
||||
files, err := os.ReadDir("testdata/models")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !strings.HasSuffix(file.Name(), ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonPath := filepath.Join("testdata/models", file.Name())
|
||||
ggufPath := filepath.Join("testdata/models", strings.TrimSuffix(file.Name(), ".json")+".gguf")
|
||||
|
||||
// Skip if no corresponding .gguf file exists
|
||||
if _, err := os.Stat(ggufPath); err != nil {
|
||||
t.Logf("skipping %s: no corresponding GGUF file found", file.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(jsonPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var test modelTest
|
||||
if err := json.Unmarshal(data, &test); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run(strings.TrimSuffix(file.Name(), ".json"), func(t *testing.T) {
|
||||
m, err := model.New(ggufPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
|
||||
|
||||
inputs, err := m.(model.TextProcessor).Encode(test.Prompt)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var result []string
|
||||
for len(result) < 100 { // Limit to 100 tokens max
|
||||
options := model.Options{
|
||||
Inputs: inputs,
|
||||
Positions: make([]int32, len(inputs)),
|
||||
Sequences: make([]int, len(inputs)),
|
||||
Outputs: []int32{int32(len(inputs) - 1)},
|
||||
}
|
||||
for i := range options.Positions {
|
||||
options.Positions[i] = int32(i)
|
||||
options.Sequences[i] = 0
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, m, options)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
t.Fatal(fmt.Errorf("forward pass failed: %v", err))
|
||||
}
|
||||
|
||||
f32s := modelOutput.Floats()
|
||||
logits := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
logits[i] = float64(f32)
|
||||
}
|
||||
|
||||
token, err := sample.Sample(logits, sample.Greedy())
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
t.Fatal(fmt.Errorf("sampling failed: %v", err))
|
||||
}
|
||||
|
||||
ctx.Close()
|
||||
|
||||
// Greedy sampling: take the token with the highest logit
|
||||
nextToken := int32(token[0])
|
||||
if m.(model.TextProcessor).Is(nextToken, model.SpecialEOS) {
|
||||
break
|
||||
}
|
||||
|
||||
piece, err := m.(model.TextProcessor).Decode([]int32{nextToken})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result = append(result, piece)
|
||||
output := strings.Join(result, "")
|
||||
|
||||
for _, expectedOutput := range test.OutputContainsOne {
|
||||
if strings.Contains(output, expectedOutput) {
|
||||
t.Logf("Test passed with output: %q (matched expected: %q)", output, expectedOutput)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Maintain full context by appending new token
|
||||
inputs = append(inputs, nextToken)
|
||||
}
|
||||
|
||||
t.Fatalf("Expected output containing one of %q but got: %q", test.OutputContainsOne, strings.Join(result, ""))
|
||||
})
|
||||
}
|
||||
}
|
10
model/testdata/models/README.md
vendored
Normal file
10
model/testdata/models/README.md
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
# Test Model Directory
|
||||
|
||||
This directory is used for storing model files (like `.gguf` files) that are required to run the tests in `model_external_test.go`.
|
||||
|
||||
## Usage
|
||||
|
||||
- Place any model files you need for testing in this directory
|
||||
- The test file will look for any model files here (e.g., `llama3.gguf`)
|
||||
- All non-markdown files in this directory are git-ignored to prevent large model files from being committed to the repository
|
||||
- Only `.md` files (like this README) will be tracked in git
|
7
model/testdata/models/qwen2_5.json
vendored
Normal file
7
model/testdata/models/qwen2_5.json
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"prompt": "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi<|im_end|>\n<|im_start|>assistant\n",
|
||||
"output_contains_one": [
|
||||
"Hello",
|
||||
"Hi"
|
||||
]
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user