ggml-backend: Ensure data is available after async computation

We need to sync before retrieving data after async computation.
It is also important to ensure that the Go buffer is not moved by
the GC across function calls so we do a synchronous copy.
This commit is contained in:
Jesse Gross 2025-02-05 13:18:36 -08:00 committed by Jesse Gross
parent 01d9a46854
commit 60830695c2

View File

@ -9,8 +9,6 @@ package ggml
import "C" import "C"
import ( import (
"bytes"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@ -245,12 +243,17 @@ func (c *Context) Forward(t ml.Tensor) {
func (c *Context) Compute(tensors ...ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
for _, t := range tensors { needSync := true
if C.ggml_nbytes(t.(*Tensor).t) != 0 { sync := func() {
backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) if needSync {
C.ggml_backend_sched_synchronize(c.sched)
needSync = false
}
}
t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) for _, t := range tensors {
C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) if C.ggml_nbytes(t.(*Tensor).t) > 0 {
t.(*Tensor).sync = sync
} }
} }
} }
@ -330,7 +333,7 @@ func (c *Context) Close() {
type Tensor struct { type Tensor struct {
t *C.struct_ggml_tensor t *C.struct_ggml_tensor
data []byte sync func()
} }
func (t *Tensor) LogValue() slog.Value { func (t *Tensor) LogValue() slog.Value {
@ -358,14 +361,23 @@ func (t *Tensor) Shape() []int {
return shape return shape
} }
func (t *Tensor) Bytes() []byte { func (t *Tensor) Bytes() (data []byte) {
return t.data if t.sync != nil {
data = make([]byte, C.ggml_nbytes(t.t))
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
}
return
} }
func (t *Tensor) Floats() (f32s []float32) { func (t *Tensor) Floats() (data []float32) {
if t.data != nil { if t.sync != nil {
f32s = make([]float32, C.ggml_nelements(t.t)) data = make([]float32, C.ggml_nelements(t.t))
_ = binary.Read(bytes.NewReader(t.data), binary.LittleEndian, f32s)
t.sync()
C.ggml_backend_tensor_get(t.t, unsafe.Pointer(&data[0]), 0, C.ggml_nbytes(t.t))
} }
return return