WIP working with other stuff

This commit is contained in:
ParthSareen 2025-01-29 10:28:22 -08:00
parent e93db4d20e
commit 73098a2973
5 changed files with 377 additions and 111 deletions

171
sample/decode.go Normal file
View File

@ -0,0 +1,171 @@
package sample
import (
"bytes"
"encoding/json"
"errors"
)
// Schema holds a JSON schema.
type Schema struct {
// Name is the name of the property. For the parent/root property, this
// is "root". For child properties, this is the name of the property.
Name string `json:"-"`
// Type is the type of the property.
//
// TODO: Union types (e.g. make this a []string).
Type string
// PrefixItems is a list of schemas for each item in a tuple. By
// default, the tuple is "closed." unless Items is set to true or a
// valid Schema.
PrefixItems []*Schema
// Items is the schema for each item in a list.
//
// If it is missing, or its JSON value is "null" or "false", it is nil.
// If the JSON value is "true", it is set to the empty Schema. If the
// JSON value is an object, it will be decoded as a Schema.
Items *Schema
// MinItems specifies the minimum number of items allowed in a list.
MinItems int
// MaxItems specifies the maximum number of items allowed in a list.
MaxItems int
// Properties is the schema for each property of an object.
Properties []*Schema
// Format is the format of the property. This is used to validate the
// property against a specific format.
//
// It is the callers responsibility to validate the property against
// the format.
Format string
// Minimum specifies the minimum value for numeric properties.
Minimum float64
// Maximum specifies the maximum value for numeric properties.
Maximum float64
// Enum is a list of valid values for the property.
Enum []json.RawMessage
}
func (s *Schema) UnmarshalJSON(data []byte) error {
type S Schema
w := struct {
Properties props
Items items
*S
}{
S: (*S)(s),
}
if err := json.Unmarshal(data, &w); err != nil {
return err
}
if w.Items.set {
s.Items = &w.Items.Schema
}
s.Properties = w.Properties
return nil
}
type items struct {
Schema
set bool
}
func (s *items) UnmarshalJSON(data []byte) error {
switch b := data[0]; b {
case 't':
*s = items{set: true}
case '{':
type I items
if err := json.Unmarshal(data, (*I)(s)); err != nil {
return err
}
s.set = true
case 'n', 'f':
default:
return errors.New("invalid Items")
}
return nil
}
// EffectiveType returns the effective type of the schema. If the Type field is
// not empty, it is returned; otherwise:
//
// - If the schema has both Properties and Items, it returns an empty string.
// - If the schema has Properties, it returns "object".
// - If the schema has Items, it returns "array".
// - If the schema has neither Properties nor Items, it returns "value".
//
// The returned string is never empty.
func (d *Schema) EffectiveType() string {
if d.Type == "" {
if len(d.Properties) > 0 {
return "object"
}
if len(d.PrefixItems) > 0 || d.Items != nil {
return "array"
}
return "value"
}
return d.Type
}
// props is an ordered list of properties. The order of the properties
// is the order in which they were defined in the schema.
type props []*Schema
var _ json.Unmarshaler = (*props)(nil)
func (v *props) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if data[0] != '{' {
return errors.New("expected object")
}
d := json.NewDecoder(bytes.NewReader(data))
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
// llama.cpp, ignore unknown fields, which could be lead to unexpected
// behavior for clients of this package, since they may not be aware
// that "additionalFields", "itemsPrefix", etc, are being ignored.
//
// For now, just do what llama.cpp does.
t, err := d.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("expected object")
}
for d.More() {
// Use the first token (map key) as the property name, then
// decode the rest of the object fields into a Schema and
// append.
t, err := d.Token()
if err != nil {
return err
}
if t == json.Delim('}') {
return nil
}
s := &Schema{
Name: t.(string),
}
if err := d.Decode(s); err != nil {
return err
}
*v = append(*v, s)
}
return nil
}

View File

@ -30,7 +30,10 @@ const (
StateInList
StateInListComma
StateListEnd
StateInValue
StateInValueEnd
StateInListEnd
StateInListObjectEnd
StateInNewline
StateInNumber
StateInNumberEnd
@ -38,6 +41,7 @@ const (
StateInObjectKeyEnd
StateTerminate
StateInObjectEnd
StateTransitioningToTerminate
)
func (s JSONState) String() string {
@ -76,6 +80,8 @@ func (s JSONState) String() string {
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInListComma:
return "StateInListComma"
case StateListEnd:

View File

@ -6,7 +6,9 @@ import (
"github.com/ollama/ollama/model"
)
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
// TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
@ -34,6 +36,7 @@ func NewPDANode(state JSONState) *PDANode {
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
stateToNodeMap := make(map[JSONState]*PDANode)
// TODO: make this a loop
startNode := NewPDANode(StateStart)
stateToNodeMap[StateStart] = startNode
@ -95,6 +98,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
intNode := NewPDANode(StateInInt)
stateToNodeMap[StateInInt] = intNode
listObjEndNode := NewPDANode(StateInListObjectEnd)
stateToNodeMap[StateInListObjectEnd] = listObjEndNode
// TODO:
// consider adding a node to just point to values, could be good to compute that
// mask rather than many different nodes
@ -105,108 +111,84 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
objNode.TransitionEdges['"'] = objKeyNode
objNode.TransitionEdges['\n'] = newlineNode
// objNode.TransitionEdges['\t'] = tabNode
objNode.TransitionEdges[' '] = spaceObjNode
//new line
newlineNode.TransitionEdges['"'] = objKeyNode
newlineNode.TransitionEdges['\t'] = tabNode
tabNode.TransitionEdges['"'] = objKeyNode
// tabNode.TransitionEdges['\t'] = tabNode
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
objKeyNode.TransitionEdges['"'] = objKeyEndNode
objKeyEndNode.TransitionEdges[':'] = colonNode
objEndNode.TransitionEdges[' '] = spaceNode
objEndNode.TransitionEdges[','] = commaNode
objEndNode.TransitionEdges['}'] = objEndNode
// where values should be
// this could be combined but the probs might change, we're alr doing a skip ahead
colonNode.TransitionEdges[' '] = spaceNode
colonNode.TransitionEdges['['] = listNode
colonNode.TransitionEdges['{'] = objNode
addValueConnections(colonNode, stateToNodeMap)
// Leads to a value
spaceNode.TransitionEdges['"'] = stringNode
spaceNode.TransitionEdges['['] = listNode
spaceNode.TransitionEdges['{'] = objNode
for _, r := range validNumberRunes {
spaceNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
spaceNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
spaceNode.TransitionEdges[r] = nullNode
}
addValueConnections(spaceNode, stateToNodeMap)
// Values
// string node
stringNode.TransitionEdges[rune(-1)] = stringNode
stringNode.TransitionEdges['"'] = stringEndNode
stringEndNode.TransitionEdges[','] = commaNode
stringEndNode.TransitionEdges['}'] = objEndNode
stringEndNode.TransitionEdges[']'] = listEndNode
// String end node
addEnds(stringEndNode, stateToNodeMap)
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
numberNode.TransitionEdges[r] = numberNode
}
numberNode.TransitionEdges[','] = commaNode
numberNode.TransitionEdges['}'] = objEndNode
numberNode.TransitionEdges[']'] = listEndNode
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
// list node
listNode.TransitionEdges[','] = commaNode
listNode.TransitionEdges['"'] = stringNode
// squash states to a value
for _, r := range validNumberRunes {
listNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listNode.TransitionEdges[r] = nullNode
}
// null node
for _, r := range validNullRunes {
nullNode.TransitionEdges[r] = nullNode
}
nullNode.TransitionEdges[','] = commaNode
nullNode.TransitionEdges['}'] = objEndNode
nullNode.TransitionEdges[']'] = listEndNode
// list comma
// should point to values
listCommaNode.TransitionEdges['"'] = stringNode
listCommaNode.TransitionEdges[' '] = listCommaNode
listCommaNode.TransitionEdges['{'] = objNode
listCommaNode.TransitionEdges['\n'] = newlineNode
for _, r := range validNumberRunes {
listCommaNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listCommaNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listCommaNode.TransitionEdges[r] = nullNode
}
addEnds(numberNode, stateToNodeMap)
// bool node
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
boolNode.TransitionEdges['}'] = objEndNode
boolNode.TransitionEdges[']'] = listEndNode
boolNode.TransitionEdges[','] = commaNode
addEnds(boolNode, stateToNodeMap)
// list node
listNode.TransitionEdges[','] = commaNode
listNode.TransitionEdges['{'] = objNode
listNode.TransitionEdges[' '] = listNode
listNode.TransitionEdges['\n'] = listNode
addValueConnections(listNode, stateToNodeMap)
// null node
for _, r := range validNullRunes {
nullNode.TransitionEdges[r] = nullNode
}
addEnds(nullNode, stateToNodeMap)
// list comma
// should point to values
listCommaNode.TransitionEdges[' '] = listCommaNode
listCommaNode.TransitionEdges['{'] = objNode
listCommaNode.TransitionEdges['\n'] = newlineNode
addValueConnections(listCommaNode, stateToNodeMap)
// list object end
listObjEndNode.TransitionEdges[','] = listCommaNode
listObjEndNode.TransitionEdges[']'] = listEndNode
// bool node
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
addEnds(boolNode, stateToNodeMap)
listEndNode.TransitionEdges['}'] = objEndNode
listEndNode.TransitionEdges[','] = commaNode
@ -218,10 +200,27 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
commaNode.TransitionEdges[' '] = spaceObjNode
spaceObjNode.TransitionEdges['"'] = objKeyNode
spaceObjNode.TransitionEdges['\n'] = newlineNode
return startNode, stateToNodeMap, nil
}
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
}
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
}
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
vocab := proc.GetVocabulary()
@ -240,7 +239,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
for i := range vocab.Values {
token := decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
valid := true
@ -263,6 +262,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
return nil
}
// garbage interface plz fix
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
if consumedSpecialRunes[r] {
return false, nil, nil
@ -281,7 +281,6 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
if curNode.State == nextNode.State {
return false, nil, nil
}
// fmt.Println("special rune", r, "consumed")
consumedSpecialRunes[r] = true
}
return true, nextNode, nil

View File

@ -9,6 +9,8 @@ import (
"github.com/ollama/ollama/model"
)
// TODO: safety in case of invalid json
// TODO: interfaces to cleanup with return values
type PushdownSampler struct {
// stateful
curNode *PDANode
@ -18,6 +20,7 @@ type PushdownSampler struct {
stateCounter uint32
}
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
start := time.Now()
@ -39,14 +42,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
// token, err := proc.Decode([]int32{int32(id)})
// if err != nil {
// panic(err)
// }
// fmt.Println("id", id, "node", node, "token", token)
// }
// time.Sleep(10 * time.Second)
return &PushdownSampler{
curNode: startNode,
proc: proc,
@ -57,9 +53,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
}
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
fmt.Println("sample:", s.curNode.State)
// fmt.Println(">>> sample:", s.curNode.State)
switch s.curNode.State {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
@ -73,24 +71,24 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
}
return logits, nil
}
valid, err := s.proc.Encode("}")
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
// fmt.Println("switching to list object end", s.curNode.State)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
for i := range logits {
for _, token := range valid {
if i != int(token) {
logits[i] = math.NaN()
}
}
}
return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
fmt.Println("switching to list comma", s.curNode.State)
// fmt.Println("switching to list comma", s.curNode.State)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
@ -109,7 +107,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
// fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
@ -119,54 +117,48 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("update state", s.curNode.State)
// TODO: need to handle end states and entering object case, and list case
if s.curNode.State == StateInObjectEnd {
fmt.Println("in object end")
if len(s.braceStack) > 0 {
s.braceStack = s.braceStack[:len(s.braceStack)-1]
return nil
}
s.curNode = NewPDANode(StateTerminate)
// TODO: return here?
}
// need this cause there could be multiple transitions
// fmt.Println("update state", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
// TODO: should force closing for all braces
// TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString {
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
// fmt.Println("pushing { brace stack", r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
// fmt.Println("pushing [ brace stack", r)
}
if r == rune('}') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
return fmt.Errorf("unmatched closing brace")
top := s.braceStack[len(s.braceStack)-1]
if len(s.braceStack) == 0 || top != rune('{') {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
fmt.Println("popping brace stack", s.braceStack)
// fmt.Println("popping { brace stack", top)
}
if r == rune(']') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
return fmt.Errorf("unmatched closing brace")
top := s.braceStack[len(s.braceStack)-1]
if len(s.braceStack) == 0 || top != rune('[') {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
fmt.Println("popping brace stack", s.braceStack)
// fmt.Println("popping [ brace stack", top)
}
}
for _, tokenID := range tokenSlice {
// transition to the next node
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNodeState)
// fmt.Println("transitioning to", nextNodeState)
// TODO: add a penalty for staying in the same state too long
if nextNodeState == s.curNode.State {

View File

@ -0,0 +1,98 @@
package sample
import "github.com/ollama/ollama/model"
type StructuredOutput struct {
schema *Schema
}
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
// _, stateToNodeMap, err := BuildGraph(proc)
// if err != nil {
// panic(err)
// }
return nil
}
// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
// // If no schema constraints, return original graph node
// if schema == nil {
// return graph
// }
// // Create a new node with same state
// constrainedNode := NewPDANode(graph.State)
// // Copy over existing transitions and masks
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
// for r, node := range graph.TransitionEdges {
// constrainedNode.TransitionEdges[r] = node
// }
// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
// // Apply schema constraints based on type
// switch schema.EffectiveType() {
// case "object":
// // Only allow defined property names in object keys
// if graph.State == StateInObjectKey {
// // TODO: Add property name validation
// }
// // Constrain property values based on schema
// if graph.State == StateInColon || graph.State == StateInSpace {
// // Clear transitions to only allow valid types
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
// // Add transitions based on property schemas
// for _, prop := range schema.Properties {
// switch prop.EffectiveType() {
// case "object":
// if objNode, ok := graph.TransitionEdges['{']; ok {
// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
// }
// case "array":
// if arrNode, ok := graph.TransitionEdges['[']; ok {
// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
// }
// case "string":
// if strNode, ok := graph.TransitionEdges['"']; ok {
// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
// }
// case "number":
// for _, r := range validNumberRunes {
// if numNode, ok := graph.TransitionEdges[r]; ok {
// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
// }
// }
// case "integer":
// for _, r := range validIntRunes {
// if intNode, ok := graph.TransitionEdges[r]; ok {
// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
// }
// }
// case "boolean":
// for _, r := range []rune{'t', 'f'} {
// if boolNode, ok := graph.TransitionEdges[r]; ok {
// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
// }
// }
// case "null":
// if nullNode, ok := graph.TransitionEdges['n']; ok {
// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
// }
// }
// }
// }
// case "array":
// // Constrain array items based on schema
// if schema.Items != nil {
// for r, node := range graph.TransitionEdges {
// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
// }
// }
// }
// return constrainedNode
// }