ml: update Dump to handle precision

This commit is contained in:
Michael Yang 2025-02-10 16:50:49 -08:00
parent c4f127ee6d
commit 95eb87a052

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
"strconv"
"strings" "strings"
) )
@ -126,15 +127,19 @@ func Dump(t Tensor, opts ...DumpOptions) string {
switch t.DType() { switch t.DType() {
case DTypeF32: case DTypeF32:
return dump[[]float32](t, opts[0]) return dump[[]float32](t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeI32: case DTypeI32:
return dump[[]int32](t, opts[0]) return dump[[]int32](t, opts[0].Items, func(i int32) string {
return strconv.FormatInt(int64(i), 10)
})
default: default:
return "<unsupported>" return "<unsupported>"
} }
} }
func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string { func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
bts := t.Bytes() bts := t.Bytes()
if bts == nil { if bts == nil {
return "<nil>" return "<nil>"
@ -154,10 +159,10 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
fmt.Fprint(&sb, "[") fmt.Fprint(&sb, "[")
defer func() { fmt.Fprint(&sb, "]") }() defer func() { fmt.Fprint(&sb, "]") }()
for i := int64(0); i < dims[0]; i++ { for i := int64(0); i < dims[0]; i++ {
if i >= opts.Items && i < dims[0]-opts.Items { if i >= items && i < dims[0]-items {
fmt.Fprint(&sb, "..., ") fmt.Fprint(&sb, "..., ")
// skip to next printable element // skip to next printable element
skip := dims[0] - 2*opts.Items skip := dims[0] - 2*items
if len(dims) > 1 { if len(dims) > 1 {
stride += mul(append(dims[1:], skip)...) stride += mul(append(dims[1:], skip)...)
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix) fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
@ -170,7 +175,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
} }
} else { } else {
fmt.Fprint(&sb, s[stride+i]) fmt.Fprint(&sb, fn(s[stride+i]))
if i < dims[0]-1 { if i < dims[0]-1 {
fmt.Fprint(&sb, ", ") fmt.Fprint(&sb, ", ")
} }