ml/backend/ggml: fix rms norm
This commit is contained in:
parent
5d81c1a184
commit
2192a28eed
@ -485,7 +485,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||||
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user