From f33ccd5d27f521dac79bba0312371414a0b3bc08 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 11 Mar 2025 16:06:06 -0700 Subject: [PATCH] ggml: Use pointer receivers for Context Context is currently mixed between pointer and value receivers. Change this to be all pointer receivers so don't have to reason about whether the things we are updating in the struct will be retained. --- ml/backend/ggml/ggml.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 5af39d75f..727a7b8e5 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -484,7 +484,7 @@ type Context struct { maxGraphNodes int } -func (c Context) Input() ml.Context { +func (c *Context) Input() ml.Context { if c.b.input != nil { return &Context{ b: c.b, @@ -494,10 +494,10 @@ func (c Context) Input() ml.Context { } } - return &c + return c } -func (c Context) Layer(i int) ml.Context { +func (c *Context) Layer(i int) ml.Context { if buft, ok := c.b.layers[i]; ok { return &Context{ b: c.b, @@ -507,7 +507,7 @@ func (c Context) Layer(i int) ml.Context { } } - return &c + return c } func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { @@ -522,7 +522,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { return c } -func (c Context) Compute(tensors ...ml.Tensor) { +func (c *Context) Compute(tensors ...ml.Tensor) { C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph) C.ggml_backend_sched_reset(c.b.sched) @@ -541,7 +541,7 @@ func (c Context) Compute(tensors ...ml.Tensor) { } } -func (c Context) Reserve() error { +func (c *Context) Reserve() error { if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { C.ggml_backend_sched_reset(c.b.sched) return errors.New("failed to reserve graph") @@ -559,7 +559,7 @@ func (c Context) Reserve() error { return nil } -func (c Context) MaxGraphNodes() int { +func (c *Context) MaxGraphNodes() int { return c.maxGraphNodes } @@ -576,7 +576,7 @@ func pad(length, pad C.size_t) C.size_t { return ((length + pad - 1) / pad) * pad } -func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { +func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { if c.buft == nil { panic("set Input or Layer before creating tensors") } @@ -621,7 +621,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) { return &Tensor{b: c.b, t: t}, nil } -func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { +func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { t, err := c.newTensor(dtype, shape) if err != nil { panic(err) @@ -630,7 +630,7 @@ func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { return t } -func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { t, err := c.newTensor(dtype, shape) if err != nil { panic(err) @@ -658,7 +658,7 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error { return nil } -func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { +func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { if err := checkShape(s, shape...); err != nil { return nil, err } @@ -675,7 +675,7 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { return t, nil } -func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { +func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { if err := checkShape(s, shape...); err != nil { return nil, err }