additional review comments

This commit is contained in:
Jesse Gross 2025-03-07 11:19:03 -08:00 committed by Michael Yang
parent b27e8f3f10
commit 98272fbd58
2 changed files with 32 additions and 16 deletions

View File

@ -402,7 +402,10 @@ func (b *Backend) NewContext() ml.Context {
} }
func (b *Backend) NewContextSize(n int) ml.Context { func (b *Backend) NewContextSize(n int) ml.Context {
n = min(n, b.maxGraphNodes) if n > b.maxGraphNodes {
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
}
return &Context{ return &Context{
b: b, b: b,
maxGraphNodes: n, maxGraphNodes: n,
@ -534,7 +537,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
panic("unsupported dtype") panic("unsupported dtype")
} }
if len(shape) < 1 { if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0 var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)} return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 { } else if len(shape) > 4 {
@ -565,6 +568,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func checkShape[S ~[]E, E any](s S, shape ...int) error { func checkShape[S ~[]E, E any](s S, shape ...int) error {
n := len(s) n := len(s)
if n == 0 {
return nil
}
for _, v := range shape { for _, v := range shape {
n /= v n /= v
} }
@ -577,22 +585,28 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
} }
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil && len(shape) > 0 { if err := checkShape(s, shape...); err != nil {
return nil, err return nil, err
} }
t := c.newTensor(ml.DTypeF32, shape) t := c.newTensor(ml.DTypeF32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil return t, nil
} }
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil && len(shape) > 0 { if err := checkShape(s, shape...); err != nil {
return nil, err return nil, err
} }
t := c.newTensor(ml.DTypeI32, shape) t := c.newTensor(ml.DTypeI32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil return t, nil
} }

View File

@ -10,10 +10,11 @@ import (
) )
type TextSelfAttention struct { type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the causal cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
return key, nil
} }
type TextMLP struct { type TextMLP struct {
@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
} }
type TextModelOptions struct { type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32