draft: mlx
This commit is contained in:
parent
67bcb55941
commit
6f09d63862
@ -2,33 +2,9 @@ cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(Ollama C CXX)
|
||||
|
||||
include(CheckLanguage)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
set(CMAKE_BUILD_TYPE Release)
|
||||
set(BUILD_SHARED_LIBS ON)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_CCACHE ON)
|
||||
set(GGML_SCHED_MAX_COPIES 4)
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_LLAMAFILE ON)
|
||||
|
||||
add_compile_definitions(GGML_BUILD)
|
||||
add_compile_definitions(GGML_SHARED)
|
||||
add_compile_definitions(GGML_BACKEND_DL)
|
||||
add_compile_definitions(GGML_BACKEND_SHARED)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||
|
||||
function(set_target_output_directory _target)
|
||||
if(TARGET ${_target})
|
||||
set_target_properties(${_target} PROPERTIES
|
||||
@ -39,30 +15,8 @@ function(set_target_output_directory _target)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src EXCLUDE_FROM_ALL)
|
||||
set_target_output_directory(ggml-base)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||
set_target_output_directory(ggml-cpu)
|
||||
|
||||
find_package(BLAS)
|
||||
if(NOT BLAS_VENDOR)
|
||||
set(GGML_BLAS_VENDOR "Generic")
|
||||
else()
|
||||
set(GGML_BLAS_VENDOR ${BLAS_VENDOR})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-blas)
|
||||
set_target_output_directory(ggml-blas)
|
||||
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
set_target_output_directory(ggml-cuda)
|
||||
endif()
|
||||
|
||||
check_language(HIP)
|
||||
if(CMAKE_HIP_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
set_target_output_directory(ggml-hip)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/mlx)
|
||||
endif()
|
||||
|
@ -35,7 +35,7 @@ func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
|
||||
}
|
||||
|
||||
func NewBackend(f *os.File) (Backend, error) {
|
||||
if backend, ok := backends["ggml"]; ok {
|
||||
if backend, ok := backends["mlx"]; ok {
|
||||
return backend(f)
|
||||
}
|
||||
|
||||
@ -67,8 +67,8 @@ type Tensor interface {
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
Norm(ctx Context, eps float32) Tensor
|
||||
RMSNorm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
@ -2,4 +2,5 @@ package backend
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/ml/backend/ggml"
|
||||
_ "github.com/ollama/ollama/ml/backend/mlx"
|
||||
)
|
||||
|
51
ml/backend/ggml/CMakeLists.txt
Normal file
51
ml/backend/ggml/CMakeLists.txt
Normal file
@ -0,0 +1,51 @@
|
||||
include(CheckLanguage)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_CCACHE ON)
|
||||
set(GGML_SCHED_MAX_COPIES 4)
|
||||
set(GGML_CPU_ALL_VARIANTS ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_LLAMAFILE ON)
|
||||
|
||||
add_compile_definitions(GGML_BUILD)
|
||||
add_compile_definitions(GGML_SHARED)
|
||||
add_compile_definitions(GGML_BACKEND_DL)
|
||||
add_compile_definitions(GGML_BACKEND_SHARED)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cpu)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cpu/amx)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src EXCLUDE_FROM_ALL)
|
||||
set_target_output_directory(ggml-base)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cpu)
|
||||
set_target_output_directory(ggml-cpu)
|
||||
|
||||
find_package(BLAS)
|
||||
if(NOT BLAS_VENDOR)
|
||||
set(GGML_BLAS_VENDOR "Generic")
|
||||
else()
|
||||
set(GGML_BLAS_VENDOR ${BLAS_VENDOR})
|
||||
endif()
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-blas)
|
||||
set_target_output_directory(ggml-blas)
|
||||
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-cuda)
|
||||
set_target_output_directory(ggml-cuda)
|
||||
endif()
|
||||
|
||||
check_language(HIP)
|
||||
if(CMAKE_HIP_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ggml/src/ggml-hip)
|
||||
set_target_output_directory(ggml-hip)
|
||||
endif()
|
@ -423,16 +423,28 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
t: C.ggml_norm(ctx.(*Context).ctx, t.t, (C.float)(eps)),
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
r := (&Tensor{
|
||||
t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||
}).Mul(ctx, w)
|
||||
|
||||
if b != nil {
|
||||
r = r.Add(ctx, b)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, eps float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
r := (&Tensor{
|
||||
t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||
}).Mul(ctx, w)
|
||||
|
||||
if b != nil {
|
||||
r = r.Add(ctx, b)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
|
26
ml/backend/mlx/CMakeLists.txt
Normal file
26
ml/backend/mlx/CMakeLists.txt
Normal file
@ -0,0 +1,26 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
set(MLX_BUILD_SAFETENSORS OFF)
|
||||
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(NOT MLX_METAL_VERSION)
|
||||
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG v0.1.0)
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
364
ml/backend/mlx/mlx.go
Normal file
364
ml/backend/mlx/mlx.go
Normal file
@ -0,0 +1,364 @@
|
||||
package mlx
|
||||
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../../../build/_deps/mlx-c-src
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/../../../build/lib -lmlxc -lmlx
|
||||
// #cgo LDFLAGS: -framework Accelerate
|
||||
// #cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../build/lib
|
||||
// #include <stdlib.h>
|
||||
// #include "mlx/c/array.h"
|
||||
// #include "mlx/c/fast.h"
|
||||
// #include "mlx/c/ops.h"
|
||||
// #include "mlx/c/stream.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func init() {
|
||||
ml.RegisterBackend("mlx", New)
|
||||
}
|
||||
|
||||
func New(r *os.File) (ml.Backend, error) {
|
||||
meta, n, err := fs.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tensors := make(map[string]*Array, len(meta.Tensors().Items()))
|
||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
|
||||
|
||||
stream := C.mlx_default_cpu_stream_new()
|
||||
|
||||
var g errgroup.Group
|
||||
var mu sync.Mutex
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
g.Go(func() error {
|
||||
var b bytes.Buffer
|
||||
n, err := io.Copy(&b, io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if n != int64(t.Size()) {
|
||||
return fmt.Errorf("expected %d bytes, got %d", t.Size(), n)
|
||||
}
|
||||
|
||||
cbytes := C.CBytes(b.Bytes())
|
||||
defer C.free(cbytes)
|
||||
|
||||
shape := make([]C.int, len(t.Shape))
|
||||
for i, dim := range t.Shape {
|
||||
shape[i] = C.int(dim)
|
||||
}
|
||||
|
||||
var dtype C.mlx_dtype
|
||||
switch t.Kind {
|
||||
case 0:
|
||||
dtype = C.MLX_FLOAT32
|
||||
case 1:
|
||||
dtype = C.MLX_FLOAT16
|
||||
default:
|
||||
return fmt.Errorf("unsupported dtype %d", t.Kind)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
var a C.mlx_array
|
||||
C.mlx_transpose_all(
|
||||
&a,
|
||||
C.mlx_array_new_data(
|
||||
cbytes,
|
||||
(*C.int)(&shape[0]),
|
||||
C.int(len(shape)),
|
||||
dtype,
|
||||
),
|
||||
stream,
|
||||
)
|
||||
|
||||
tensors[t.Name] = &Array{
|
||||
name: t.Name,
|
||||
a: a,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
tensors: tensors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Backend struct {
|
||||
meta *fs.GGML
|
||||
tensors map[string]*Array
|
||||
}
|
||||
|
||||
// Config implements ml.Backend.
|
||||
func (b *Backend) Config() ml.Config {
|
||||
return b.meta.KV()
|
||||
}
|
||||
|
||||
// Get implements ml.Backend.
|
||||
func (b *Backend) Get(name string) ml.Tensor {
|
||||
if a, ok := b.tensors[name]; ok {
|
||||
return a
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) NewContext() ml.Context {
|
||||
return &Context{
|
||||
stream: C.mlx_default_cpu_stream_new(),
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
stream C.mlx_stream
|
||||
}
|
||||
|
||||
// Close implements ml.Context.
|
||||
func (c *Context) Close() error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Compute implements ml.Context.
|
||||
func (c *Context) Compute(ml.Tensor) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Forward implements ml.Context.
|
||||
func (c *Context) Forward(ml.Tensor) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// FromFloatSlice implements ml.Context.
|
||||
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// FromIntSlice implements ml.Context.
|
||||
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
cshape := make([]C.int, len(shape))
|
||||
for i, dim := range shape {
|
||||
cshape[i] = C.int(dim)
|
||||
}
|
||||
|
||||
return &Array{
|
||||
a: C.mlx_array_new_data(
|
||||
unsafe.Pointer(&s[0]),
|
||||
(*C.int)(&cshape[0]),
|
||||
C.int(len(cshape)),
|
||||
C.MLX_INT32,
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Zeros implements ml.Context.
|
||||
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
type Array struct {
|
||||
name string
|
||||
a C.mlx_array
|
||||
}
|
||||
|
||||
func (a *Array) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", a.name),
|
||||
slog.Any("shape", a.Shape()),
|
||||
)
|
||||
}
|
||||
|
||||
// Add implements ml.Tensor.
|
||||
func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Bytes implements ml.Tensor.
|
||||
func (a *Array) Bytes() []byte {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Concat implements ml.Tensor.
|
||||
func (a *Array) Concat(ctx ml.Context, a2 ml.Tensor, dim int) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Contiguous implements ml.Tensor.
|
||||
func (a *Array) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Conv2D implements ml.Tensor.
|
||||
func (a *Array) Conv2D(ctx ml.Context, weight ml.Tensor, s0 int, s1 int, p0 int, p1 int, d0 int, d1 int) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Copy implements ml.Tensor.
|
||||
func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// DType implements ml.Tensor.
|
||||
func (a *Array) DType() ml.DType {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Dim implements ml.Tensor.
|
||||
func (a *Array) Dim(n int) int64 {
|
||||
return int64(C.mlx_array_dim(a.a, C.int(n)))
|
||||
}
|
||||
|
||||
// Floats implements ml.Tensor.
|
||||
func (a *Array) Floats() []float32 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GELU implements ml.Tensor.
|
||||
func (a *Array) GELU(ctx ml.Context) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Mul implements ml.Tensor.
|
||||
func (a *Array) Mul(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Mulmat implements ml.Tensor.
|
||||
func (a *Array) Mulmat(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
|
||||
slog.Info("mulmat", "a", a, "a2", a2)
|
||||
var r C.mlx_array
|
||||
C.mlx_matmul(&r, a2.(*Array).a, a.Permute(1, 0, 2, 3), ctx.(*Context).stream)
|
||||
return &Array{a: r}
|
||||
}
|
||||
|
||||
// LayerNorm implements ml.Tensor.
|
||||
func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
var r C.mlx_array
|
||||
C.mlx_fast_layer_norm(
|
||||
&r,
|
||||
a.a,
|
||||
w.(*Array).a,
|
||||
b.(*Array).a,
|
||||
C.float(eps),
|
||||
ctx.(*Context).stream,
|
||||
)
|
||||
return &Array{a: r}
|
||||
}
|
||||
|
||||
// Pad implements ml.Tensor.
|
||||
func (a *Array) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Permute implements ml.Tensor.
|
||||
func (a *Array) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// RMSNorm implements ml.Tensor.
|
||||
func (a *Array) RMSNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
var r C.mlx_array
|
||||
C.mlx_fast_rms_norm(
|
||||
&r,
|
||||
a.a,
|
||||
w.(*Array).a,
|
||||
C.float(eps),
|
||||
ctx.(*Context).stream,
|
||||
)
|
||||
return &Array{a: r}
|
||||
}
|
||||
|
||||
// Reshape implements ml.Tensor.
|
||||
func (a *Array) Reshape(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
cshape := make([]C.int, len(shape))
|
||||
for i, dim := range shape {
|
||||
cshape[i] = C.int(dim)
|
||||
}
|
||||
|
||||
var r C.mlx_array
|
||||
C.mlx_reshape(&r, a.a, (*C.int)(&cshape[0]), C.size_t(len(cshape)), ctx.(*Context).stream)
|
||||
return &Array{a: r}
|
||||
}
|
||||
|
||||
// Rope implements ml.Tensor.
|
||||
func (a *Array) Rope(ctx ml.Context, positionIDs ml.Tensor, ropeFactors ml.Tensor, dim uint32, base float32, scale float32) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Rows implements ml.Tensor.
|
||||
func (a *Array) Rows(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
|
||||
var r C.mlx_array
|
||||
slog.Info("rows", "a", a, "a2", a2)
|
||||
C.mlx_take(&r, a.a, a2.(*Array).a, 0, ctx.(*Context).stream)
|
||||
return &Array{a: r}
|
||||
}
|
||||
|
||||
// SILU implements ml.Tensor.
|
||||
func (a *Array) SILU(ctx ml.Context) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Scale implements ml.Tensor.
|
||||
func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Shape implements ml.Tensor.
|
||||
func (a *Array) Shape() []int64 {
|
||||
shape := make([]int64, C.mlx_array_ndim(a.a))
|
||||
for i := range shape {
|
||||
shape[i] = int64(C.mlx_array_dim(a.a, C.int(i)))
|
||||
}
|
||||
|
||||
return shape
|
||||
}
|
||||
|
||||
// Softmax implements ml.Tensor.
|
||||
func (a *Array) Softmax(ctx ml.Context) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Stack implements ml.Tensor.
|
||||
func (a *Array) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Stride implements ml.Tensor.
|
||||
func (a *Array) Stride(n int) int64 {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Tanh implements ml.Tensor.
|
||||
func (a *Array) Tanh(ctx ml.Context) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Unpad implements ml.Tensor.
|
||||
func (a *Array) Unpad(ctx ml.Context, shape ...int64) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// View implements ml.Tensor.
|
||||
func (a *Array) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
panic("unimplemented")
|
||||
}
|
@ -10,12 +10,7 @@ type LayerNorm struct {
|
||||
}
|
||||
|
||||
func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
t = t.Norm(ctx, eps).Mul(ctx, m.Weight)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
@ -24,10 +19,5 @@ type RMSNorm struct {
|
||||
}
|
||||
|
||||
func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
|
||||
t = t.RMSNorm(ctx, eps).Mul(ctx, m.Weight)
|
||||
if m.Bias != nil {
|
||||
t = t.Add(ctx, m.Bias)
|
||||
}
|
||||
|
||||
return t
|
||||
return t.RMSNorm(ctx, m.Weight, m.Bias, eps)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
@ -52,16 +53,16 @@ type SelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
batchSize := hiddenState.Dim(0)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
// q = q.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
// k = k.Rope(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@ -72,6 +73,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
slog.Info("self attention", "q", q, "k", k, "v", v)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
kq = kq.Softmax(ctx)
|
||||
@ -127,6 +130,8 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
slog.Info("breakpoint", "inputs", inputs, "positions", positions, "hiddenState", hiddenState)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
|
||||
}
|
||||
|
50
model/mllama/bench_test.go.1
Normal file
50
model/mllama/bench_test.go.1
Normal file
@ -0,0 +1,50 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func BenchmarkProcessText(b *testing.B) {
|
||||
ours, err := model.New(filepath.Join("testdata", "model.bin"))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
b.Skip("no model.bin")
|
||||
} else if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
var ids []int32
|
||||
|
||||
b.Run("encode", func(b *testing.B) {
|
||||
txt, err := os.ReadFile(filepath.Join("..", "testdata", "war-and-peace.txt"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
b.StartTimer()
|
||||
|
||||
ids, err = ours.(model.TextProcessor).Encode(string(txt))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("decode", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
b.StartTimer()
|
||||
|
||||
_, err = ours.(model.TextProcessor).Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
63845
model/testdata/war-and-peace.txt
vendored
Normal file
63845
model/testdata/war-and-peace.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user