more unittests

This commit is contained in:
Patrick Devine 2024-07-05 23:11:20 -07:00
parent 2e055e3af8
commit 5d4a331de3
2 changed files with 23 additions and 7 deletions

View File

@ -16,7 +16,7 @@ func (p *adapter) KV(t *Tokenizer) llm.KV {
// todo - need a way to pass these in
kv := llm.KV{
"r": uint32(8),
"alpha": uint32(16),
"alpha": uint32(160),
}
return kv
}

View File

@ -135,15 +135,15 @@ func TestConvertNPZ(t *testing.T) {
for _, fn := range cases {
ts, err := parseNPZ(filepath.Join("testdata", fn))
assert.Nil(t, err)
assert.Equal(t, len(ts), 16*2*2) // 16 layers, 2 tensors, 2 loras
assert.Equal(t, 16*2*2, len(ts)) // 16 layers, 2 tensors, 2 loras
a := adapter{}
for _, m := range ts {
at := m.(adapterTensor)
assert.Equal(t, at.path, filepath.Join("testdata", fn))
assert.Equal(t, at.dtype, "F32") // only float32s supported
assert.Equal(t, len(at.tensorBase.shape), 2)
assert.Equal(t, filepath.Join("testdata", fn), at.path)
assert.Equal(t, "F32", at.dtype) // only float32s supported
assert.Equal(t, 2, len(at.tensorBase.shape))
}
var ws io.WriteSeeker = &memWriter{}
@ -152,10 +152,26 @@ func TestConvertNPZ(t *testing.T) {
mw := ws.(*memWriter)
slog.Info(fmt.Sprintf("buffer len = %d", len(mw.buf)))
assert.NotEqual(t, 0, len(mw.buf))
rs := bytes.NewReader(mw.buf)
ggml, _, err := llm.DecodeGGML(rs, len(mw.buf))
assert.Nil(t, err)
assert.NotNil(t, ggml)
assert.Nil(t, err, "decode ggml failed")
assert.NotNil(t, ggml, "ggml was empty")
kv := ggml.KV()
assert.NotNil(t, kv, "lora KVs not found")
r, ok := kv["r"]
assert.Equal(t, true, ok, "lora rank not set")
assert.Equal(t, uint32(8), r, "lora rank was incorrect")
alpha, ok := kv["alpha"]
assert.Equal(t, true, ok, "lora alpha not set")
assert.Equal(t, uint32(160), alpha, "lora alpha value was incorrect")
gts := ggml.Tensors()
assert.NotNil(t, gts, "no tensors found")
assert.Equal(t, len(ts), len(gts.Items))
}
}