diff --git a/sample/decode.go b/sample/decode.go new file mode 100644 index 000000000..528cd0938 --- /dev/null +++ b/sample/decode.go @@ -0,0 +1,171 @@ +package sample + +import ( + "bytes" + "encoding/json" + "errors" +) + +// Schema holds a JSON schema. +type Schema struct { + // Name is the name of the property. For the parent/root property, this + // is "root". For child properties, this is the name of the property. + Name string `json:"-"` + + // Type is the type of the property. + // + // TODO: Union types (e.g. make this a []string). + Type string + + // PrefixItems is a list of schemas for each item in a tuple. By + // default, the tuple is "closed." unless Items is set to true or a + // valid Schema. + PrefixItems []*Schema + + // Items is the schema for each item in a list. + // + // If it is missing, or its JSON value is "null" or "false", it is nil. + // If the JSON value is "true", it is set to the empty Schema. If the + // JSON value is an object, it will be decoded as a Schema. + Items *Schema + + // MinItems specifies the minimum number of items allowed in a list. + MinItems int + + // MaxItems specifies the maximum number of items allowed in a list. + MaxItems int + + // Properties is the schema for each property of an object. + Properties []*Schema + + // Format is the format of the property. This is used to validate the + // property against a specific format. + // + // It is the callers responsibility to validate the property against + // the format. + Format string + + // Minimum specifies the minimum value for numeric properties. + Minimum float64 + + // Maximum specifies the maximum value for numeric properties. + Maximum float64 + + // Enum is a list of valid values for the property. + Enum []json.RawMessage +} + +func (s *Schema) UnmarshalJSON(data []byte) error { + type S Schema + w := struct { + Properties props + Items items + *S + }{ + S: (*S)(s), + } + if err := json.Unmarshal(data, &w); err != nil { + return err + } + if w.Items.set { + s.Items = &w.Items.Schema + } + s.Properties = w.Properties + return nil +} + +type items struct { + Schema + set bool +} + +func (s *items) UnmarshalJSON(data []byte) error { + switch b := data[0]; b { + case 't': + *s = items{set: true} + case '{': + type I items + if err := json.Unmarshal(data, (*I)(s)); err != nil { + return err + } + s.set = true + case 'n', 'f': + default: + return errors.New("invalid Items") + } + return nil +} + +// EffectiveType returns the effective type of the schema. If the Type field is +// not empty, it is returned; otherwise: +// +// - If the schema has both Properties and Items, it returns an empty string. +// - If the schema has Properties, it returns "object". +// - If the schema has Items, it returns "array". +// - If the schema has neither Properties nor Items, it returns "value". +// +// The returned string is never empty. +func (d *Schema) EffectiveType() string { + if d.Type == "" { + if len(d.Properties) > 0 { + return "object" + } + if len(d.PrefixItems) > 0 || d.Items != nil { + return "array" + } + return "value" + } + return d.Type +} + +// props is an ordered list of properties. The order of the properties +// is the order in which they were defined in the schema. +type props []*Schema + +var _ json.Unmarshaler = (*props)(nil) + +func (v *props) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return nil + } + if data[0] != '{' { + return errors.New("expected object") + } + + d := json.NewDecoder(bytes.NewReader(data)) + + // TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like + // llama.cpp, ignore unknown fields, which could be lead to unexpected + // behavior for clients of this package, since they may not be aware + // that "additionalFields", "itemsPrefix", etc, are being ignored. + // + // For now, just do what llama.cpp does. + + t, err := d.Token() + if err != nil { + return err + } + if t != json.Delim('{') { + return errors.New("expected object") + } + for d.More() { + // Use the first token (map key) as the property name, then + // decode the rest of the object fields into a Schema and + // append. + t, err := d.Token() + if err != nil { + return err + } + if t == json.Delim('}') { + return nil + } + s := &Schema{ + Name: t.(string), + } + if err := d.Decode(s); err != nil { + return err + } + *v = append(*v, s) + } + return nil +} diff --git a/sample/fast_json.go b/sample/fast_json.go index 6104b8b3e..ee65d6f6f 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -30,7 +30,10 @@ const ( StateInList StateInListComma StateListEnd + StateInValue + StateInValueEnd StateInListEnd + StateInListObjectEnd StateInNewline StateInNumber StateInNumberEnd @@ -38,6 +41,7 @@ const ( StateInObjectKeyEnd StateTerminate StateInObjectEnd + StateTransitioningToTerminate ) func (s JSONState) String() string { @@ -76,6 +80,8 @@ func (s JSONState) String() string { return "StateInObjSpace" case StateInList: return "StateInList" + case StateInListObjectEnd: + return "StateInListObjectEnd" case StateInListComma: return "StateInListComma" case StateListEnd: diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index d58f23cc4..85c5f35da 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -6,7 +6,9 @@ import ( "github.com/ollama/ollama/model" ) -var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','} +// TODO: / should be valid but an escape character + +var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} @@ -34,6 +36,7 @@ func NewPDANode(state JSONState) *PDANode { func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) { stateToNodeMap := make(map[JSONState]*PDANode) + // TODO: make this a loop startNode := NewPDANode(StateStart) stateToNodeMap[StateStart] = startNode @@ -95,6 +98,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err intNode := NewPDANode(StateInInt) stateToNodeMap[StateInInt] = intNode + listObjEndNode := NewPDANode(StateInListObjectEnd) + stateToNodeMap[StateInListObjectEnd] = listObjEndNode + // TODO: // consider adding a node to just point to values, could be good to compute that // mask rather than many different nodes @@ -105,108 +111,84 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err objNode.TransitionEdges['"'] = objKeyNode objNode.TransitionEdges['\n'] = newlineNode - // objNode.TransitionEdges['\t'] = tabNode + objNode.TransitionEdges[' '] = spaceObjNode + //new line newlineNode.TransitionEdges['"'] = objKeyNode newlineNode.TransitionEdges['\t'] = tabNode tabNode.TransitionEdges['"'] = objKeyNode - // tabNode.TransitionEdges['\t'] = tabNode objKeyNode.TransitionEdges[rune(-1)] = objKeyNode objKeyNode.TransitionEdges['"'] = objKeyEndNode objKeyEndNode.TransitionEdges[':'] = colonNode - objEndNode.TransitionEdges[' '] = spaceNode + + objEndNode.TransitionEdges[','] = commaNode + objEndNode.TransitionEdges['}'] = objEndNode // where values should be // this could be combined but the probs might change, we're alr doing a skip ahead colonNode.TransitionEdges[' '] = spaceNode + colonNode.TransitionEdges['['] = listNode + colonNode.TransitionEdges['{'] = objNode + addValueConnections(colonNode, stateToNodeMap) // Leads to a value - spaceNode.TransitionEdges['"'] = stringNode spaceNode.TransitionEdges['['] = listNode spaceNode.TransitionEdges['{'] = objNode - - for _, r := range validNumberRunes { - spaceNode.TransitionEdges[r] = numberNode - } - for _, r := range validBoolRunes { - spaceNode.TransitionEdges[r] = boolNode - } - - for _, r := range validNullRunes { - spaceNode.TransitionEdges[r] = nullNode - } + addValueConnections(spaceNode, stateToNodeMap) // Values // string node stringNode.TransitionEdges[rune(-1)] = stringNode stringNode.TransitionEdges['"'] = stringEndNode - stringEndNode.TransitionEdges[','] = commaNode - stringEndNode.TransitionEdges['}'] = objEndNode - stringEndNode.TransitionEdges[']'] = listEndNode + // String end node + addEnds(stringEndNode, stateToNodeMap) // TODO: add counters for allowable number of decimals, e, E, etc // number node for _, r := range validNumberRunes { numberNode.TransitionEdges[r] = numberNode } - numberNode.TransitionEdges[','] = commaNode - numberNode.TransitionEdges['}'] = objEndNode - numberNode.TransitionEdges[']'] = listEndNode - - for _, r := range validBoolRunes { - boolNode.TransitionEdges[r] = boolNode - } - - // list node - listNode.TransitionEdges[','] = commaNode - listNode.TransitionEdges['"'] = stringNode - // squash states to a value - for _, r := range validNumberRunes { - listNode.TransitionEdges[r] = numberNode - } - for _, r := range validBoolRunes { - listNode.TransitionEdges[r] = boolNode - } - for _, r := range validNullRunes { - listNode.TransitionEdges[r] = nullNode - } - - // null node - for _, r := range validNullRunes { - nullNode.TransitionEdges[r] = nullNode - } - nullNode.TransitionEdges[','] = commaNode - nullNode.TransitionEdges['}'] = objEndNode - nullNode.TransitionEdges[']'] = listEndNode - - // list comma - // should point to values - listCommaNode.TransitionEdges['"'] = stringNode - listCommaNode.TransitionEdges[' '] = listCommaNode - listCommaNode.TransitionEdges['{'] = objNode - listCommaNode.TransitionEdges['\n'] = newlineNode - - for _, r := range validNumberRunes { - listCommaNode.TransitionEdges[r] = numberNode - } - for _, r := range validBoolRunes { - listCommaNode.TransitionEdges[r] = boolNode - } - for _, r := range validNullRunes { - listCommaNode.TransitionEdges[r] = nullNode - } + addEnds(numberNode, stateToNodeMap) // bool node for _, r := range validBoolRunes { boolNode.TransitionEdges[r] = boolNode } - boolNode.TransitionEdges['}'] = objEndNode - boolNode.TransitionEdges[']'] = listEndNode - boolNode.TransitionEdges[','] = commaNode + addEnds(boolNode, stateToNodeMap) + + // list node + listNode.TransitionEdges[','] = commaNode + listNode.TransitionEdges['{'] = objNode + listNode.TransitionEdges[' '] = listNode + listNode.TransitionEdges['\n'] = listNode + addValueConnections(listNode, stateToNodeMap) + + // null node + for _, r := range validNullRunes { + nullNode.TransitionEdges[r] = nullNode + } + addEnds(nullNode, stateToNodeMap) + + // list comma + // should point to values + listCommaNode.TransitionEdges[' '] = listCommaNode + listCommaNode.TransitionEdges['{'] = objNode + listCommaNode.TransitionEdges['\n'] = newlineNode + addValueConnections(listCommaNode, stateToNodeMap) + + // list object end + listObjEndNode.TransitionEdges[','] = listCommaNode + listObjEndNode.TransitionEdges[']'] = listEndNode + + // bool node + for _, r := range validBoolRunes { + boolNode.TransitionEdges[r] = boolNode + } + addEnds(boolNode, stateToNodeMap) listEndNode.TransitionEdges['}'] = objEndNode listEndNode.TransitionEdges[','] = commaNode @@ -218,10 +200,27 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err commaNode.TransitionEdges[' '] = spaceObjNode spaceObjNode.TransitionEdges['"'] = objKeyNode + spaceObjNode.TransitionEdges['\n'] = newlineNode return startNode, stateToNodeMap, nil } +func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { + node.TransitionEdges[','] = stateToNodeMap[StateInComma] + node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + node.TransitionEdges[']'] = stateToNodeMap[StateListEnd] +} + +func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { + node.TransitionEdges['"'] = stateToNodeMap[StateInString] + for _, r := range validNumberRunes { + node.TransitionEdges[r] = stateToNodeMap[StateInNumber] + } + node.TransitionEdges['t'] = stateToNodeMap[StateInBool] + node.TransitionEdges['f'] = stateToNodeMap[StateInBool] + node.TransitionEdges['n'] = stateToNodeMap[StateInNull] +} + func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { vocab := proc.GetVocabulary() @@ -240,7 +239,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex for i := range vocab.Values { token := decodedToks[i] // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON - if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" { + if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { continue } valid := true @@ -263,6 +262,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex return nil } +// garbage interface plz fix func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) { if consumedSpecialRunes[r] { return false, nil, nil @@ -281,7 +281,6 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) ( if curNode.State == nextNode.State { return false, nil, nil } - // fmt.Println("special rune", r, "consumed") consumedSpecialRunes[r] = true } return true, nextNode, nil diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index a97b5f29a..cb467e81c 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -9,6 +9,8 @@ import ( "github.com/ollama/ollama/model" ) +// TODO: safety in case of invalid json +// TODO: interfaces to cleanup with return values type PushdownSampler struct { // stateful curNode *PDANode @@ -18,6 +20,7 @@ type PushdownSampler struct { stateCounter uint32 } +// graph should be built once and reused per tokenizer func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { start := time.Now() @@ -39,14 +42,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) fmt.Printf("Graph build time = %v\n", time.Since(start)) - // for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode { - // token, err := proc.Decode([]int32{int32(id)}) - // if err != nil { - // panic(err) - // } - // fmt.Println("id", id, "node", node, "token", token) - // } - // time.Sleep(10 * time.Second) + return &PushdownSampler{ curNode: startNode, proc: proc, @@ -57,9 +53,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { } func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { - fmt.Println("sample:", s.curNode.State) - + // fmt.Println(">>> sample:", s.curNode.State) switch s.curNode.State { + case StateInString: + return s.maskLogits(logits, s.curNode) + case StateInObjectEnd: // force finish if no braces left if len(s.braceStack) == 0 { @@ -73,24 +71,24 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } return logits, nil } - valid, err := s.proc.Encode("}") + + peek := s.braceStack[len(s.braceStack)-1] + if peek == rune('[') { + s.curNode = s.stateToNodeMap[StateInListObjectEnd] + // fmt.Println("switching to list object end", s.curNode.State) + } + + logits, err := s.maskLogits(logits, s.curNode) if err != nil { return nil, err } - for i := range logits { - for _, token := range valid { - if i != int(token) { - logits[i] = math.NaN() - } - } - } return logits, nil case StateInComma: peek := s.braceStack[len(s.braceStack)-1] if peek == rune('[') { s.curNode = s.stateToNodeMap[StateInListComma] - fmt.Println("switching to list comma", s.curNode.State) + // fmt.Println("switching to list comma", s.curNode.State) } logits, err := s.maskLogits(logits, s.curNode) if err != nil { @@ -109,7 +107,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { return logits, nil default: - fmt.Println("masking logits current state", s.curNode.State) + // fmt.Println("masking logits current state", s.curNode.State) logits, err := s.maskLogits(logits, s.curNode) if err != nil { return nil, err @@ -119,54 +117,48 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { - fmt.Println("update state", s.curNode.State) - - // TODO: need to handle end states and entering object case, and list case - if s.curNode.State == StateInObjectEnd { - fmt.Println("in object end") - if len(s.braceStack) > 0 { - s.braceStack = s.braceStack[:len(s.braceStack)-1] - return nil - } - s.curNode = NewPDANode(StateTerminate) - // TODO: return here? - } - // need this cause there could be multiple transitions + // fmt.Println("update state", s.curNode.State) mappedString, err := s.proc.Decode(tokenSlice) if err != nil { return err } - // TODO: should force closing for all braces + + // TODO: should force closing for all braces - not doing square yet for _, r := range mappedString { if r == rune('{') { s.braceStack = append(s.braceStack, r) + // fmt.Println("pushing { brace stack", r) } if r == rune('[') { s.braceStack = append(s.braceStack, r) + // fmt.Println("pushing [ brace stack", r) } if r == rune('}') { - if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') { - return fmt.Errorf("unmatched closing brace") + top := s.braceStack[len(s.braceStack)-1] + if len(s.braceStack) == 0 || top != rune('{') { + return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') } s.braceStack = s.braceStack[:len(s.braceStack)-1] - fmt.Println("popping brace stack", s.braceStack) + // fmt.Println("popping { brace stack", top) } if r == rune(']') { - if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') { - return fmt.Errorf("unmatched closing brace") + top := s.braceStack[len(s.braceStack)-1] + if len(s.braceStack) == 0 || top != rune('[') { + return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') } s.braceStack = s.braceStack[:len(s.braceStack)-1] - fmt.Println("popping brace stack", s.braceStack) + // fmt.Println("popping [ brace stack", top) } } + for _, tokenID := range tokenSlice { // transition to the next node nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID] if !ok { return fmt.Errorf("invalid token: %q", mappedString) } - fmt.Println("transitioning to", nextNodeState) + // fmt.Println("transitioning to", nextNodeState) // TODO: add a penalty for staying in the same state too long if nextNodeState == s.curNode.State { diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go new file mode 100644 index 000000000..a02ad9fc2 --- /dev/null +++ b/sample/structured_outputs.go @@ -0,0 +1,98 @@ +package sample + +import "github.com/ollama/ollama/model" + +type StructuredOutput struct { + schema *Schema +} + +func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode { + // _, stateToNodeMap, err := BuildGraph(proc) + // if err != nil { + // panic(err) + // } + + return nil +} + +// func constrainGraph(graph *PDANode, schema *Schema) *PDANode { +// // If no schema constraints, return original graph node +// if schema == nil { +// return graph +// } + +// // Create a new node with same state +// constrainedNode := NewPDANode(graph.State) + +// // Copy over existing transitions and masks +// constrainedNode.TransitionEdges = make(map[rune]*PDANode) +// for r, node := range graph.TransitionEdges { +// constrainedNode.TransitionEdges[r] = node +// } +// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode + +// // Apply schema constraints based on type +// switch schema.EffectiveType() { +// case "object": +// // Only allow defined property names in object keys +// if graph.State == StateInObjectKey { +// // TODO: Add property name validation +// } + +// // Constrain property values based on schema +// if graph.State == StateInColon || graph.State == StateInSpace { +// // Clear transitions to only allow valid types +// constrainedNode.TransitionEdges = make(map[rune]*PDANode) + +// // Add transitions based on property schemas +// for _, prop := range schema.Properties { +// switch prop.EffectiveType() { +// case "object": +// if objNode, ok := graph.TransitionEdges['{']; ok { +// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop) +// } +// case "array": +// if arrNode, ok := graph.TransitionEdges['[']; ok { +// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop) +// } +// case "string": +// if strNode, ok := graph.TransitionEdges['"']; ok { +// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop) +// } +// case "number": +// for _, r := range validNumberRunes { +// if numNode, ok := graph.TransitionEdges[r]; ok { +// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop) +// } +// } +// case "integer": +// for _, r := range validIntRunes { +// if intNode, ok := graph.TransitionEdges[r]; ok { +// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop) +// } +// } +// case "boolean": +// for _, r := range []rune{'t', 'f'} { +// if boolNode, ok := graph.TransitionEdges[r]; ok { +// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop) +// } +// } +// case "null": +// if nullNode, ok := graph.TransitionEdges['n']; ok { +// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop) +// } +// } +// } +// } + +// case "array": +// // Constrain array items based on schema +// if schema.Items != nil { +// for r, node := range graph.TransitionEdges { +// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items) +// } +// } +// } + +// return constrainedNode +// }