ml: update Dump to handle precision
This commit is contained in:
parent
c4f127ee6d
commit
95eb87a052
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -126,15 +127,19 @@ func Dump(t Tensor, opts ...DumpOptions) string {
|
|||||||
|
|
||||||
switch t.DType() {
|
switch t.DType() {
|
||||||
case DTypeF32:
|
case DTypeF32:
|
||||||
return dump[[]float32](t, opts[0])
|
return dump[[]float32](t, opts[0].Items, func(f float32) string {
|
||||||
|
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||||
|
})
|
||||||
case DTypeI32:
|
case DTypeI32:
|
||||||
return dump[[]int32](t, opts[0])
|
return dump[[]int32](t, opts[0].Items, func(i int32) string {
|
||||||
|
return strconv.FormatInt(int64(i), 10)
|
||||||
|
})
|
||||||
default:
|
default:
|
||||||
return "<unsupported>"
|
return "<unsupported>"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
|
||||||
bts := t.Bytes()
|
bts := t.Bytes()
|
||||||
if bts == nil {
|
if bts == nil {
|
||||||
return "<nil>"
|
return "<nil>"
|
||||||
@ -154,10 +159,10 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
|||||||
fmt.Fprint(&sb, "[")
|
fmt.Fprint(&sb, "[")
|
||||||
defer func() { fmt.Fprint(&sb, "]") }()
|
defer func() { fmt.Fprint(&sb, "]") }()
|
||||||
for i := int64(0); i < dims[0]; i++ {
|
for i := int64(0); i < dims[0]; i++ {
|
||||||
if i >= opts.Items && i < dims[0]-opts.Items {
|
if i >= items && i < dims[0]-items {
|
||||||
fmt.Fprint(&sb, "..., ")
|
fmt.Fprint(&sb, "..., ")
|
||||||
// skip to next printable element
|
// skip to next printable element
|
||||||
skip := dims[0] - 2*opts.Items
|
skip := dims[0] - 2*items
|
||||||
if len(dims) > 1 {
|
if len(dims) > 1 {
|
||||||
stride += mul(append(dims[1:], skip)...)
|
stride += mul(append(dims[1:], skip)...)
|
||||||
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
|
||||||
@ -170,7 +175,7 @@ func dump[S ~[]E, E number](t Tensor, opts DumpOptions) string {
|
|||||||
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprint(&sb, s[stride+i])
|
fmt.Fprint(&sb, fn(s[stride+i]))
|
||||||
if i < dims[0]-1 {
|
if i < dims[0]-1 {
|
||||||
fmt.Fprint(&sb, ", ")
|
fmt.Fprint(&sb, ", ")
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user