more unittests
This commit is contained in:
parent
2e055e3af8
commit
5d4a331de3
@ -16,7 +16,7 @@ func (p *adapter) KV(t *Tokenizer) llm.KV {
|
|||||||
// todo - need a way to pass these in
|
// todo - need a way to pass these in
|
||||||
kv := llm.KV{
|
kv := llm.KV{
|
||||||
"r": uint32(8),
|
"r": uint32(8),
|
||||||
"alpha": uint32(16),
|
"alpha": uint32(160),
|
||||||
}
|
}
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
@ -135,15 +135,15 @@ func TestConvertNPZ(t *testing.T) {
|
|||||||
for _, fn := range cases {
|
for _, fn := range cases {
|
||||||
ts, err := parseNPZ(filepath.Join("testdata", fn))
|
ts, err := parseNPZ(filepath.Join("testdata", fn))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, len(ts), 16*2*2) // 16 layers, 2 tensors, 2 loras
|
assert.Equal(t, 16*2*2, len(ts)) // 16 layers, 2 tensors, 2 loras
|
||||||
|
|
||||||
a := adapter{}
|
a := adapter{}
|
||||||
|
|
||||||
for _, m := range ts {
|
for _, m := range ts {
|
||||||
at := m.(adapterTensor)
|
at := m.(adapterTensor)
|
||||||
assert.Equal(t, at.path, filepath.Join("testdata", fn))
|
assert.Equal(t, filepath.Join("testdata", fn), at.path)
|
||||||
assert.Equal(t, at.dtype, "F32") // only float32s supported
|
assert.Equal(t, "F32", at.dtype) // only float32s supported
|
||||||
assert.Equal(t, len(at.tensorBase.shape), 2)
|
assert.Equal(t, 2, len(at.tensorBase.shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
var ws io.WriteSeeker = &memWriter{}
|
var ws io.WriteSeeker = &memWriter{}
|
||||||
@ -152,10 +152,26 @@ func TestConvertNPZ(t *testing.T) {
|
|||||||
|
|
||||||
mw := ws.(*memWriter)
|
mw := ws.(*memWriter)
|
||||||
slog.Info(fmt.Sprintf("buffer len = %d", len(mw.buf)))
|
slog.Info(fmt.Sprintf("buffer len = %d", len(mw.buf)))
|
||||||
|
assert.NotEqual(t, 0, len(mw.buf))
|
||||||
rs := bytes.NewReader(mw.buf)
|
rs := bytes.NewReader(mw.buf)
|
||||||
ggml, _, err := llm.DecodeGGML(rs, len(mw.buf))
|
ggml, _, err := llm.DecodeGGML(rs, len(mw.buf))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err, "decode ggml failed")
|
||||||
assert.NotNil(t, ggml)
|
assert.NotNil(t, ggml, "ggml was empty")
|
||||||
|
|
||||||
|
kv := ggml.KV()
|
||||||
|
assert.NotNil(t, kv, "lora KVs not found")
|
||||||
|
|
||||||
|
r, ok := kv["r"]
|
||||||
|
assert.Equal(t, true, ok, "lora rank not set")
|
||||||
|
assert.Equal(t, uint32(8), r, "lora rank was incorrect")
|
||||||
|
|
||||||
|
alpha, ok := kv["alpha"]
|
||||||
|
assert.Equal(t, true, ok, "lora alpha not set")
|
||||||
|
assert.Equal(t, uint32(160), alpha, "lora alpha value was incorrect")
|
||||||
|
|
||||||
|
gts := ggml.Tensors()
|
||||||
|
assert.NotNil(t, gts, "no tensors found")
|
||||||
|
assert.Equal(t, len(ts), len(gts.Items))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user