saved state
This commit is contained in:
parent
c56a8b7749
commit
b973dedb4b
4
go.mod
4
go.mod
@ -24,8 +24,8 @@ require (
|
|||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
gonum.org/v1/gonum v0.15.0
|
|
||||||
golang.org/x/tools v0.28.0
|
golang.org/x/tools v0.28.0
|
||||||
|
gonum.org/v1/gonum v0.15.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@ -72,7 +72,7 @@ require (
|
|||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0
|
golang.org/x/crypto v0.31.0
|
||||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
|
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
|
||||||
golang.org/x/net v0.25.0 // indirect
|
golang.org/x/net v0.32.0 // indirect
|
||||||
golang.org/x/sys v0.28.0
|
golang.org/x/sys v0.28.0
|
||||||
golang.org/x/term v0.27.0
|
golang.org/x/term v0.27.0
|
||||||
golang.org/x/text v0.21.0
|
golang.org/x/text v0.21.0
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/cache"
|
"github.com/ollama/ollama/cache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
@ -27,6 +28,7 @@ var args struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func temp() error {
|
func temp() error {
|
||||||
|
start := time.Now()
|
||||||
flag.IntVar(&args.n, "n", 10, "number of samples")
|
flag.IntVar(&args.n, "n", 10, "number of samples")
|
||||||
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
|
||||||
flag.StringVar(&args.image, "image", "", "path to image file")
|
flag.StringVar(&args.image, "image", "", "path to image file")
|
||||||
@ -104,9 +106,11 @@ func temp() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||||
var stringBuffer string
|
|
||||||
var offset int
|
var offset int
|
||||||
|
var stringBuffer string
|
||||||
|
var firstTokenTime time.Duration
|
||||||
for range args.n {
|
for range args.n {
|
||||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -118,15 +122,21 @@ func temp() error {
|
|||||||
for i, f32 := range f32s {
|
for i, f32 := range f32s {
|
||||||
f64s[i] = float64(f32)
|
f64s[i] = float64(f32)
|
||||||
}
|
}
|
||||||
|
sampleTime := time.Now()
|
||||||
|
samplers := []sample.Sampler{
|
||||||
|
pushdownSampler,
|
||||||
|
// sample.Weighed(),
|
||||||
|
// sample.TopP(0.9),
|
||||||
|
// sample.Weighed(),
|
||||||
|
sample.Greedy(),
|
||||||
|
}
|
||||||
|
|
||||||
// do sampling
|
f64s, err = sample.Sample(f64s, samplers...)
|
||||||
// []ints back
|
|
||||||
// ints map to sampled logits
|
|
||||||
f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
finishTime := time.Now()
|
||||||
|
fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
||||||
|
|
||||||
var outputIDs []int32
|
var outputIDs []int32
|
||||||
for _, f64 := range f64s {
|
for _, f64 := range f64s {
|
||||||
@ -134,7 +144,6 @@ func temp() error {
|
|||||||
outputIDs = append(outputIDs, int32(f64))
|
outputIDs = append(outputIDs, int32(f64))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pdaSampler.UpdateState(outputIDs)
|
|
||||||
|
|
||||||
if len(outputIDs) == 0 {
|
if len(outputIDs) == 0 {
|
||||||
break
|
break
|
||||||
@ -147,14 +156,29 @@ func temp() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// fmt.Print(s)
|
if firstTokenTime == 0 {
|
||||||
|
firstTokenTime = time.Since(start)
|
||||||
|
fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Printf("--- token: %q\n", s)
|
||||||
|
// fmt.Printf("--- outputIDs: %v\n", outputIDs)
|
||||||
stringBuffer += s
|
stringBuffer += s
|
||||||
fmt.Println("--- stringBuffer", stringBuffer)
|
fmt.Println("--- stringBuffer", stringBuffer)
|
||||||
|
|
||||||
|
err = pushdownSampler.UpdateState(outputIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
inputIDs = append(inputIDs, outputIDs...)
|
inputIDs = append(inputIDs, outputIDs...)
|
||||||
if args.cache {
|
if args.cache {
|
||||||
offset = len(inputIDs) - 1
|
offset = len(inputIDs) - 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fmt.Println("\n------ Output: ------")
|
||||||
|
fmt.Println(stringBuffer)
|
||||||
|
fmt.Println("--------------------")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ type TextProcessor interface {
|
|||||||
Encode(string) ([]int32, error)
|
Encode(string) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(uint32, Special) bool
|
Is(uint32, Special) bool
|
||||||
|
|
||||||
GetVocabulary() *Vocabulary
|
GetVocabulary() *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,16 +100,16 @@ func (v *Vocabulary) Merge(left, right string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) GetVocabulary() *Vocabulary {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
type BytePairEncoding struct {
|
type BytePairEncoding struct {
|
||||||
Pretokenizer string
|
Pretokenizer string
|
||||||
|
|
||||||
*Vocabulary
|
*Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) GetVocabulary() *Vocabulary {
|
|
||||||
return bpe.Vocabulary
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bpe BytePairEncoding) split(s string) ([]string, error) {
|
func (bpe BytePairEncoding) split(s string) ([]string, error) {
|
||||||
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
|
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -44,8 +44,6 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
// consider adding a node to just point to values, could be good to compute that
|
// consider adding a node to just point to values, could be good to compute that
|
||||||
// mask rather than many different nodes
|
// mask rather than many different nodes
|
||||||
|
|
||||||
// Connect nodes
|
|
||||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
|
||||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||||
|
|
||||||
@ -161,6 +159,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
|||||||
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: tough life fr. plz fix.
|
||||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||||
|
|
||||||
vocab := proc.GetVocabulary()
|
vocab := proc.GetVocabulary()
|
||||||
@ -176,33 +175,42 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
for _, node := range stateToNodeMap {
|
for _, node := range stateToNodeMap {
|
||||||
for i := range vocab.Values {
|
err = createMask(node, proc, decodedToks, vocab)
|
||||||
token := decodedToks[i]
|
if err != nil {
|
||||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
return err
|
||||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
valid := true
|
|
||||||
curNode := node
|
|
||||||
consumedSpecialRunes := make(map[rune]bool)
|
|
||||||
for _, r := range token {
|
|
||||||
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !valid {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if valid {
|
|
||||||
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// garbage interface plz fix
|
func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
|
||||||
|
for i := range vocab.Values {
|
||||||
|
token := decodedToks[i]
|
||||||
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
|
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
valid := true
|
||||||
|
curNode := node
|
||||||
|
consumedSpecialRunes := make(map[rune]bool)
|
||||||
|
var err error
|
||||||
|
for _, r := range token {
|
||||||
|
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: garbage interface plz fix
|
||||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||||
if consumedSpecialRunes[r] {
|
if consumedSpecialRunes[r] {
|
||||||
return false, nil, nil
|
return false, nil, nil
|
||||||
|
@ -52,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: need to add resampling logic if the first sample was not good
|
||||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||||
// fmt.Println(">>> sample:", s.curNode.State)
|
// fmt.Println(">>> sample:", s.curNode.State)
|
||||||
switch s.curNode.State {
|
switch s.curNode.State {
|
||||||
@ -156,8 +157,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
// fmt.Println("pushing [ brace stack", r)
|
// fmt.Println("pushing [ brace stack", r)
|
||||||
}
|
}
|
||||||
if r == rune('}') {
|
if r == rune('}') {
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
|
}
|
||||||
top := s.braceStack[len(s.braceStack)-1]
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
if len(s.braceStack) == 0 || top != rune('{') {
|
if top != rune('{') {
|
||||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
@ -165,8 +169,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r == rune(']') {
|
if r == rune(']') {
|
||||||
|
if len(s.braceStack) == 0 {
|
||||||
|
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||||
|
}
|
||||||
top := s.braceStack[len(s.braceStack)-1]
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
if len(s.braceStack) == 0 || top != rune('[') {
|
if top != rune('[') {
|
||||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
@ -194,6 +201,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||||
|
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
|
||||||
|
// Should be possible through bitwise ops as well
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
_, exists := node.MaskTokenIDToNode[int32(i)]
|
_, exists := node.MaskTokenIDToNode[int32(i)]
|
||||||
if !exists {
|
if !exists {
|
||||||
|
@ -165,11 +165,12 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
|||||||
if len(logitsCopy) == 0 {
|
if len(logitsCopy) == 0 {
|
||||||
return nil, errors.New("no valid tokens found")
|
return nil, errors.New("no valid tokens found")
|
||||||
}
|
}
|
||||||
logitsCopy, err := computeSoftmax(logitsCopy)
|
|
||||||
|
softmax, err := computeSoftmax(logitsCopy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
w := sampleuv.NewWeighted(logitsCopy, nil)
|
w := sampleuv.NewWeighted(softmax, nil)
|
||||||
if v, ok := w.Take(); ok {
|
if v, ok := w.Take(); ok {
|
||||||
// returns the token ID
|
// returns the token ID
|
||||||
return []float64{float64(indices[v])}, nil
|
return []float64{float64(indices[v])}, nil
|
||||||
|
@ -3,10 +3,52 @@ package sample
|
|||||||
import "github.com/ollama/ollama/model"
|
import "github.com/ollama/ollama/model"
|
||||||
|
|
||||||
type StructuredOutput struct {
|
type StructuredOutput struct {
|
||||||
schema *Schema
|
schema *Schema
|
||||||
|
stateToNodeMap map[JSONState]*PDANode
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput {
|
||||||
|
_, stateToNodeMap, err := BuildGraph(proc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StructuredOutput{
|
||||||
|
schema: schema,
|
||||||
|
stateToNodeMap: stateToNodeMap,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode {
|
||||||
|
|
||||||
|
schemaType := so.schema.EffectiveType()
|
||||||
|
switch schemaType {
|
||||||
|
case "object":
|
||||||
|
// each prop is a key
|
||||||
|
// prevState := StateInObjectKey
|
||||||
|
for _, prop := range so.schema.Properties {
|
||||||
|
// name of key
|
||||||
|
name := prop.Name
|
||||||
|
prevState := StateInObjectKey
|
||||||
|
for i, r := range name {
|
||||||
|
newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune
|
||||||
|
|
||||||
|
// Create new node for this state if it doesn't exist
|
||||||
|
if _, exists := so.stateToNodeMap[newState]; !exists {
|
||||||
|
so.stateToNodeMap[newState] = &PDANode{
|
||||||
|
State: newState,
|
||||||
|
TransitionEdges: make(map[rune]*PDANode),
|
||||||
|
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect previous state to this state via the rune
|
||||||
|
so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
|
||||||
|
prevState = newState
|
||||||
|
}
|
||||||
|
// type of value
|
||||||
|
// propType := prop.Type
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user