From 5d4a331de3b8ec554b5f0a21a30b00600f946e80 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 5 Jul 2024 23:11:20 -0700 Subject: [PATCH] more unittests --- convert/convert_adapter.go | 2 +- convert/convert_test.go | 28 ++++++++++++++++++++++------ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/convert/convert_adapter.go b/convert/convert_adapter.go index 6f27fd445..f74829991 100644 --- a/convert/convert_adapter.go +++ b/convert/convert_adapter.go @@ -16,7 +16,7 @@ func (p *adapter) KV(t *Tokenizer) llm.KV { // todo - need a way to pass these in kv := llm.KV{ "r": uint32(8), - "alpha": uint32(16), + "alpha": uint32(160), } return kv } diff --git a/convert/convert_test.go b/convert/convert_test.go index c4fd5dbd5..89c9b544d 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -135,15 +135,15 @@ func TestConvertNPZ(t *testing.T) { for _, fn := range cases { ts, err := parseNPZ(filepath.Join("testdata", fn)) assert.Nil(t, err) - assert.Equal(t, len(ts), 16*2*2) // 16 layers, 2 tensors, 2 loras + assert.Equal(t, 16*2*2, len(ts)) // 16 layers, 2 tensors, 2 loras a := adapter{} for _, m := range ts { at := m.(adapterTensor) - assert.Equal(t, at.path, filepath.Join("testdata", fn)) - assert.Equal(t, at.dtype, "F32") // only float32s supported - assert.Equal(t, len(at.tensorBase.shape), 2) + assert.Equal(t, filepath.Join("testdata", fn), at.path) + assert.Equal(t, "F32", at.dtype) // only float32s supported + assert.Equal(t, 2, len(at.tensorBase.shape)) } var ws io.WriteSeeker = &memWriter{} @@ -152,10 +152,26 @@ func TestConvertNPZ(t *testing.T) { mw := ws.(*memWriter) slog.Info(fmt.Sprintf("buffer len = %d", len(mw.buf))) + assert.NotEqual(t, 0, len(mw.buf)) rs := bytes.NewReader(mw.buf) ggml, _, err := llm.DecodeGGML(rs, len(mw.buf)) - assert.Nil(t, err) - assert.NotNil(t, ggml) + assert.Nil(t, err, "decode ggml failed") + assert.NotNil(t, ggml, "ggml was empty") + + kv := ggml.KV() + assert.NotNil(t, kv, "lora KVs not found") + + r, ok := kv["r"] + assert.Equal(t, true, ok, "lora rank not set") + assert.Equal(t, uint32(8), r, "lora rank was incorrect") + + alpha, ok := kv["alpha"] + assert.Equal(t, true, ok, "lora alpha not set") + assert.Equal(t, uint32(160), alpha, "lora alpha value was incorrect") + + gts := ggml.Tensors() + assert.NotNil(t, gts, "no tensors found") + assert.Equal(t, len(ts), len(gts.Items)) } }