additional review comments
This commit is contained in:
parent
b27e8f3f10
commit
98272fbd58
@ -402,7 +402,10 @@ func (b *Backend) NewContext() ml.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Backend) NewContextSize(n int) ml.Context {
|
func (b *Backend) NewContextSize(n int) ml.Context {
|
||||||
n = min(n, b.maxGraphNodes)
|
if n > b.maxGraphNodes {
|
||||||
|
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
|
||||||
|
}
|
||||||
|
|
||||||
return &Context{
|
return &Context{
|
||||||
b: b,
|
b: b,
|
||||||
maxGraphNodes: n,
|
maxGraphNodes: n,
|
||||||
@ -534,7 +537,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
|||||||
panic("unsupported dtype")
|
panic("unsupported dtype")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(shape) < 1 {
|
if len(shape) < 1 || shape[0] == 0 {
|
||||||
var shape C.int64_t = 0
|
var shape C.int64_t = 0
|
||||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||||
} else if len(shape) > 4 {
|
} else if len(shape) > 4 {
|
||||||
@ -565,6 +568,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|||||||
|
|
||||||
func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||||
n := len(s)
|
n := len(s)
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range shape {
|
for _, v := range shape {
|
||||||
n /= v
|
n /= v
|
||||||
}
|
}
|
||||||
@ -577,22 +585,28 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||||
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
|
if err := checkShape(s, shape...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t := c.newTensor(ml.DTypeF32, shape)
|
t := c.newTensor(ml.DTypeF32, shape)
|
||||||
|
if len(s) > 0 {
|
||||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||||
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
|
if err := checkShape(s, shape...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t := c.newTensor(ml.DTypeI32, shape)
|
t := c.newTensor(ml.DTypeI32, shape)
|
||||||
|
if len(s) > 0 {
|
||||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ type TextSelfAttention struct {
|
|||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the causal cache, which are just the self attention layers
|
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
|
||||||
|
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
|
Loading…
x
Reference in New Issue
Block a user