WIP working with other stuff
This commit is contained in:
parent
e93db4d20e
commit
73098a2973
171
sample/decode.go
Normal file
171
sample/decode.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Schema holds a JSON schema.
|
||||||
|
type Schema struct {
|
||||||
|
// Name is the name of the property. For the parent/root property, this
|
||||||
|
// is "root". For child properties, this is the name of the property.
|
||||||
|
Name string `json:"-"`
|
||||||
|
|
||||||
|
// Type is the type of the property.
|
||||||
|
//
|
||||||
|
// TODO: Union types (e.g. make this a []string).
|
||||||
|
Type string
|
||||||
|
|
||||||
|
// PrefixItems is a list of schemas for each item in a tuple. By
|
||||||
|
// default, the tuple is "closed." unless Items is set to true or a
|
||||||
|
// valid Schema.
|
||||||
|
PrefixItems []*Schema
|
||||||
|
|
||||||
|
// Items is the schema for each item in a list.
|
||||||
|
//
|
||||||
|
// If it is missing, or its JSON value is "null" or "false", it is nil.
|
||||||
|
// If the JSON value is "true", it is set to the empty Schema. If the
|
||||||
|
// JSON value is an object, it will be decoded as a Schema.
|
||||||
|
Items *Schema
|
||||||
|
|
||||||
|
// MinItems specifies the minimum number of items allowed in a list.
|
||||||
|
MinItems int
|
||||||
|
|
||||||
|
// MaxItems specifies the maximum number of items allowed in a list.
|
||||||
|
MaxItems int
|
||||||
|
|
||||||
|
// Properties is the schema for each property of an object.
|
||||||
|
Properties []*Schema
|
||||||
|
|
||||||
|
// Format is the format of the property. This is used to validate the
|
||||||
|
// property against a specific format.
|
||||||
|
//
|
||||||
|
// It is the callers responsibility to validate the property against
|
||||||
|
// the format.
|
||||||
|
Format string
|
||||||
|
|
||||||
|
// Minimum specifies the minimum value for numeric properties.
|
||||||
|
Minimum float64
|
||||||
|
|
||||||
|
// Maximum specifies the maximum value for numeric properties.
|
||||||
|
Maximum float64
|
||||||
|
|
||||||
|
// Enum is a list of valid values for the property.
|
||||||
|
Enum []json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Schema) UnmarshalJSON(data []byte) error {
|
||||||
|
type S Schema
|
||||||
|
w := struct {
|
||||||
|
Properties props
|
||||||
|
Items items
|
||||||
|
*S
|
||||||
|
}{
|
||||||
|
S: (*S)(s),
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &w); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if w.Items.set {
|
||||||
|
s.Items = &w.Items.Schema
|
||||||
|
}
|
||||||
|
s.Properties = w.Properties
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type items struct {
|
||||||
|
Schema
|
||||||
|
set bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *items) UnmarshalJSON(data []byte) error {
|
||||||
|
switch b := data[0]; b {
|
||||||
|
case 't':
|
||||||
|
*s = items{set: true}
|
||||||
|
case '{':
|
||||||
|
type I items
|
||||||
|
if err := json.Unmarshal(data, (*I)(s)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.set = true
|
||||||
|
case 'n', 'f':
|
||||||
|
default:
|
||||||
|
return errors.New("invalid Items")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EffectiveType returns the effective type of the schema. If the Type field is
|
||||||
|
// not empty, it is returned; otherwise:
|
||||||
|
//
|
||||||
|
// - If the schema has both Properties and Items, it returns an empty string.
|
||||||
|
// - If the schema has Properties, it returns "object".
|
||||||
|
// - If the schema has Items, it returns "array".
|
||||||
|
// - If the schema has neither Properties nor Items, it returns "value".
|
||||||
|
//
|
||||||
|
// The returned string is never empty.
|
||||||
|
func (d *Schema) EffectiveType() string {
|
||||||
|
if d.Type == "" {
|
||||||
|
if len(d.Properties) > 0 {
|
||||||
|
return "object"
|
||||||
|
}
|
||||||
|
if len(d.PrefixItems) > 0 || d.Items != nil {
|
||||||
|
return "array"
|
||||||
|
}
|
||||||
|
return "value"
|
||||||
|
}
|
||||||
|
return d.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// props is an ordered list of properties. The order of the properties
|
||||||
|
// is the order in which they were defined in the schema.
|
||||||
|
type props []*Schema
|
||||||
|
|
||||||
|
var _ json.Unmarshaler = (*props)(nil)
|
||||||
|
|
||||||
|
func (v *props) UnmarshalJSON(data []byte) error {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data[0] != '{' {
|
||||||
|
return errors.New("expected object")
|
||||||
|
}
|
||||||
|
|
||||||
|
d := json.NewDecoder(bytes.NewReader(data))
|
||||||
|
|
||||||
|
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
|
||||||
|
// llama.cpp, ignore unknown fields, which could be lead to unexpected
|
||||||
|
// behavior for clients of this package, since they may not be aware
|
||||||
|
// that "additionalFields", "itemsPrefix", etc, are being ignored.
|
||||||
|
//
|
||||||
|
// For now, just do what llama.cpp does.
|
||||||
|
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if t != json.Delim('{') {
|
||||||
|
return errors.New("expected object")
|
||||||
|
}
|
||||||
|
for d.More() {
|
||||||
|
// Use the first token (map key) as the property name, then
|
||||||
|
// decode the rest of the object fields into a Schema and
|
||||||
|
// append.
|
||||||
|
t, err := d.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if t == json.Delim('}') {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := &Schema{
|
||||||
|
Name: t.(string),
|
||||||
|
}
|
||||||
|
if err := d.Decode(s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*v = append(*v, s)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -30,7 +30,10 @@ const (
|
|||||||
StateInList
|
StateInList
|
||||||
StateInListComma
|
StateInListComma
|
||||||
StateListEnd
|
StateListEnd
|
||||||
|
StateInValue
|
||||||
|
StateInValueEnd
|
||||||
StateInListEnd
|
StateInListEnd
|
||||||
|
StateInListObjectEnd
|
||||||
StateInNewline
|
StateInNewline
|
||||||
StateInNumber
|
StateInNumber
|
||||||
StateInNumberEnd
|
StateInNumberEnd
|
||||||
@ -38,6 +41,7 @@ const (
|
|||||||
StateInObjectKeyEnd
|
StateInObjectKeyEnd
|
||||||
StateTerminate
|
StateTerminate
|
||||||
StateInObjectEnd
|
StateInObjectEnd
|
||||||
|
StateTransitioningToTerminate
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s JSONState) String() string {
|
func (s JSONState) String() string {
|
||||||
@ -76,6 +80,8 @@ func (s JSONState) String() string {
|
|||||||
return "StateInObjSpace"
|
return "StateInObjSpace"
|
||||||
case StateInList:
|
case StateInList:
|
||||||
return "StateInList"
|
return "StateInList"
|
||||||
|
case StateInListObjectEnd:
|
||||||
|
return "StateInListObjectEnd"
|
||||||
case StateInListComma:
|
case StateInListComma:
|
||||||
return "StateInListComma"
|
return "StateInListComma"
|
||||||
case StateListEnd:
|
case StateListEnd:
|
||||||
|
@ -6,7 +6,9 @@ import (
|
|||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
// TODO: / should be valid but an escape character
|
||||||
|
|
||||||
|
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||||
|
|
||||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||||
@ -34,6 +36,7 @@ func NewPDANode(state JSONState) *PDANode {
|
|||||||
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||||
stateToNodeMap := make(map[JSONState]*PDANode)
|
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||||
|
|
||||||
|
// TODO: make this a loop
|
||||||
startNode := NewPDANode(StateStart)
|
startNode := NewPDANode(StateStart)
|
||||||
stateToNodeMap[StateStart] = startNode
|
stateToNodeMap[StateStart] = startNode
|
||||||
|
|
||||||
@ -95,6 +98,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
intNode := NewPDANode(StateInInt)
|
intNode := NewPDANode(StateInInt)
|
||||||
stateToNodeMap[StateInInt] = intNode
|
stateToNodeMap[StateInInt] = intNode
|
||||||
|
|
||||||
|
listObjEndNode := NewPDANode(StateInListObjectEnd)
|
||||||
|
stateToNodeMap[StateInListObjectEnd] = listObjEndNode
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
// consider adding a node to just point to values, could be good to compute that
|
// consider adding a node to just point to values, could be good to compute that
|
||||||
// mask rather than many different nodes
|
// mask rather than many different nodes
|
||||||
@ -105,108 +111,84 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
|
|
||||||
objNode.TransitionEdges['"'] = objKeyNode
|
objNode.TransitionEdges['"'] = objKeyNode
|
||||||
objNode.TransitionEdges['\n'] = newlineNode
|
objNode.TransitionEdges['\n'] = newlineNode
|
||||||
// objNode.TransitionEdges['\t'] = tabNode
|
objNode.TransitionEdges[' '] = spaceObjNode
|
||||||
|
|
||||||
|
//new line
|
||||||
newlineNode.TransitionEdges['"'] = objKeyNode
|
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||||
newlineNode.TransitionEdges['\t'] = tabNode
|
newlineNode.TransitionEdges['\t'] = tabNode
|
||||||
|
|
||||||
tabNode.TransitionEdges['"'] = objKeyNode
|
tabNode.TransitionEdges['"'] = objKeyNode
|
||||||
// tabNode.TransitionEdges['\t'] = tabNode
|
|
||||||
|
|
||||||
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||||
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||||
|
|
||||||
objKeyEndNode.TransitionEdges[':'] = colonNode
|
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||||
objEndNode.TransitionEdges[' '] = spaceNode
|
|
||||||
|
objEndNode.TransitionEdges[','] = commaNode
|
||||||
|
objEndNode.TransitionEdges['}'] = objEndNode
|
||||||
|
|
||||||
// where values should be
|
// where values should be
|
||||||
// this could be combined but the probs might change, we're alr doing a skip ahead
|
// this could be combined but the probs might change, we're alr doing a skip ahead
|
||||||
colonNode.TransitionEdges[' '] = spaceNode
|
colonNode.TransitionEdges[' '] = spaceNode
|
||||||
|
colonNode.TransitionEdges['['] = listNode
|
||||||
|
colonNode.TransitionEdges['{'] = objNode
|
||||||
|
addValueConnections(colonNode, stateToNodeMap)
|
||||||
|
|
||||||
// Leads to a value
|
// Leads to a value
|
||||||
spaceNode.TransitionEdges['"'] = stringNode
|
|
||||||
spaceNode.TransitionEdges['['] = listNode
|
spaceNode.TransitionEdges['['] = listNode
|
||||||
spaceNode.TransitionEdges['{'] = objNode
|
spaceNode.TransitionEdges['{'] = objNode
|
||||||
|
addValueConnections(spaceNode, stateToNodeMap)
|
||||||
for _, r := range validNumberRunes {
|
|
||||||
spaceNode.TransitionEdges[r] = numberNode
|
|
||||||
}
|
|
||||||
for _, r := range validBoolRunes {
|
|
||||||
spaceNode.TransitionEdges[r] = boolNode
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range validNullRunes {
|
|
||||||
spaceNode.TransitionEdges[r] = nullNode
|
|
||||||
}
|
|
||||||
|
|
||||||
// Values
|
// Values
|
||||||
// string node
|
// string node
|
||||||
stringNode.TransitionEdges[rune(-1)] = stringNode
|
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||||
stringNode.TransitionEdges['"'] = stringEndNode
|
stringNode.TransitionEdges['"'] = stringEndNode
|
||||||
|
|
||||||
stringEndNode.TransitionEdges[','] = commaNode
|
// String end node
|
||||||
stringEndNode.TransitionEdges['}'] = objEndNode
|
addEnds(stringEndNode, stateToNodeMap)
|
||||||
stringEndNode.TransitionEdges[']'] = listEndNode
|
|
||||||
|
|
||||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||||
// number node
|
// number node
|
||||||
for _, r := range validNumberRunes {
|
for _, r := range validNumberRunes {
|
||||||
numberNode.TransitionEdges[r] = numberNode
|
numberNode.TransitionEdges[r] = numberNode
|
||||||
}
|
}
|
||||||
numberNode.TransitionEdges[','] = commaNode
|
addEnds(numberNode, stateToNodeMap)
|
||||||
numberNode.TransitionEdges['}'] = objEndNode
|
|
||||||
numberNode.TransitionEdges[']'] = listEndNode
|
|
||||||
|
|
||||||
for _, r := range validBoolRunes {
|
|
||||||
boolNode.TransitionEdges[r] = boolNode
|
|
||||||
}
|
|
||||||
|
|
||||||
// list node
|
|
||||||
listNode.TransitionEdges[','] = commaNode
|
|
||||||
listNode.TransitionEdges['"'] = stringNode
|
|
||||||
// squash states to a value
|
|
||||||
for _, r := range validNumberRunes {
|
|
||||||
listNode.TransitionEdges[r] = numberNode
|
|
||||||
}
|
|
||||||
for _, r := range validBoolRunes {
|
|
||||||
listNode.TransitionEdges[r] = boolNode
|
|
||||||
}
|
|
||||||
for _, r := range validNullRunes {
|
|
||||||
listNode.TransitionEdges[r] = nullNode
|
|
||||||
}
|
|
||||||
|
|
||||||
// null node
|
|
||||||
for _, r := range validNullRunes {
|
|
||||||
nullNode.TransitionEdges[r] = nullNode
|
|
||||||
}
|
|
||||||
nullNode.TransitionEdges[','] = commaNode
|
|
||||||
nullNode.TransitionEdges['}'] = objEndNode
|
|
||||||
nullNode.TransitionEdges[']'] = listEndNode
|
|
||||||
|
|
||||||
// list comma
|
|
||||||
// should point to values
|
|
||||||
listCommaNode.TransitionEdges['"'] = stringNode
|
|
||||||
listCommaNode.TransitionEdges[' '] = listCommaNode
|
|
||||||
listCommaNode.TransitionEdges['{'] = objNode
|
|
||||||
listCommaNode.TransitionEdges['\n'] = newlineNode
|
|
||||||
|
|
||||||
for _, r := range validNumberRunes {
|
|
||||||
listCommaNode.TransitionEdges[r] = numberNode
|
|
||||||
}
|
|
||||||
for _, r := range validBoolRunes {
|
|
||||||
listCommaNode.TransitionEdges[r] = boolNode
|
|
||||||
}
|
|
||||||
for _, r := range validNullRunes {
|
|
||||||
listCommaNode.TransitionEdges[r] = nullNode
|
|
||||||
}
|
|
||||||
|
|
||||||
// bool node
|
// bool node
|
||||||
for _, r := range validBoolRunes {
|
for _, r := range validBoolRunes {
|
||||||
boolNode.TransitionEdges[r] = boolNode
|
boolNode.TransitionEdges[r] = boolNode
|
||||||
}
|
}
|
||||||
boolNode.TransitionEdges['}'] = objEndNode
|
addEnds(boolNode, stateToNodeMap)
|
||||||
boolNode.TransitionEdges[']'] = listEndNode
|
|
||||||
boolNode.TransitionEdges[','] = commaNode
|
// list node
|
||||||
|
listNode.TransitionEdges[','] = commaNode
|
||||||
|
listNode.TransitionEdges['{'] = objNode
|
||||||
|
listNode.TransitionEdges[' '] = listNode
|
||||||
|
listNode.TransitionEdges['\n'] = listNode
|
||||||
|
addValueConnections(listNode, stateToNodeMap)
|
||||||
|
|
||||||
|
// null node
|
||||||
|
for _, r := range validNullRunes {
|
||||||
|
nullNode.TransitionEdges[r] = nullNode
|
||||||
|
}
|
||||||
|
addEnds(nullNode, stateToNodeMap)
|
||||||
|
|
||||||
|
// list comma
|
||||||
|
// should point to values
|
||||||
|
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||||
|
listCommaNode.TransitionEdges['{'] = objNode
|
||||||
|
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
addValueConnections(listCommaNode, stateToNodeMap)
|
||||||
|
|
||||||
|
// list object end
|
||||||
|
listObjEndNode.TransitionEdges[','] = listCommaNode
|
||||||
|
listObjEndNode.TransitionEdges[']'] = listEndNode
|
||||||
|
|
||||||
|
// bool node
|
||||||
|
for _, r := range validBoolRunes {
|
||||||
|
boolNode.TransitionEdges[r] = boolNode
|
||||||
|
}
|
||||||
|
addEnds(boolNode, stateToNodeMap)
|
||||||
|
|
||||||
listEndNode.TransitionEdges['}'] = objEndNode
|
listEndNode.TransitionEdges['}'] = objEndNode
|
||||||
listEndNode.TransitionEdges[','] = commaNode
|
listEndNode.TransitionEdges[','] = commaNode
|
||||||
@ -218,10 +200,27 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
|||||||
commaNode.TransitionEdges[' '] = spaceObjNode
|
commaNode.TransitionEdges[' '] = spaceObjNode
|
||||||
|
|
||||||
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
||||||
|
spaceObjNode.TransitionEdges['\n'] = newlineNode
|
||||||
|
|
||||||
return startNode, stateToNodeMap, nil
|
return startNode, stateToNodeMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||||
|
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||||
|
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||||
|
node.TransitionEdges[']'] = stateToNodeMap[StateListEnd]
|
||||||
|
}
|
||||||
|
|
||||||
|
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||||
|
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||||
|
for _, r := range validNumberRunes {
|
||||||
|
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||||
|
}
|
||||||
|
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||||
|
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||||
|
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||||
|
}
|
||||||
|
|
||||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||||
|
|
||||||
vocab := proc.GetVocabulary()
|
vocab := proc.GetVocabulary()
|
||||||
@ -240,7 +239,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
|||||||
for i := range vocab.Values {
|
for i := range vocab.Values {
|
||||||
token := decodedToks[i]
|
token := decodedToks[i]
|
||||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
valid := true
|
valid := true
|
||||||
@ -263,6 +262,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// garbage interface plz fix
|
||||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||||
if consumedSpecialRunes[r] {
|
if consumedSpecialRunes[r] {
|
||||||
return false, nil, nil
|
return false, nil, nil
|
||||||
@ -281,7 +281,6 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
|
|||||||
if curNode.State == nextNode.State {
|
if curNode.State == nextNode.State {
|
||||||
return false, nil, nil
|
return false, nil, nil
|
||||||
}
|
}
|
||||||
// fmt.Println("special rune", r, "consumed")
|
|
||||||
consumedSpecialRunes[r] = true
|
consumedSpecialRunes[r] = true
|
||||||
}
|
}
|
||||||
return true, nextNode, nil
|
return true, nextNode, nil
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO: safety in case of invalid json
|
||||||
|
// TODO: interfaces to cleanup with return values
|
||||||
type PushdownSampler struct {
|
type PushdownSampler struct {
|
||||||
// stateful
|
// stateful
|
||||||
curNode *PDANode
|
curNode *PDANode
|
||||||
@ -18,6 +20,7 @@ type PushdownSampler struct {
|
|||||||
stateCounter uint32
|
stateCounter uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// graph should be built once and reused per tokenizer
|
||||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
@ -39,14 +42,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||||
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
|
||||||
// token, err := proc.Decode([]int32{int32(id)})
|
|
||||||
// if err != nil {
|
|
||||||
// panic(err)
|
|
||||||
// }
|
|
||||||
// fmt.Println("id", id, "node", node, "token", token)
|
|
||||||
// }
|
|
||||||
// time.Sleep(10 * time.Second)
|
|
||||||
return &PushdownSampler{
|
return &PushdownSampler{
|
||||||
curNode: startNode,
|
curNode: startNode,
|
||||||
proc: proc,
|
proc: proc,
|
||||||
@ -57,9 +53,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||||
fmt.Println("sample:", s.curNode.State)
|
// fmt.Println(">>> sample:", s.curNode.State)
|
||||||
|
|
||||||
switch s.curNode.State {
|
switch s.curNode.State {
|
||||||
|
case StateInString:
|
||||||
|
return s.maskLogits(logits, s.curNode)
|
||||||
|
|
||||||
case StateInObjectEnd:
|
case StateInObjectEnd:
|
||||||
// force finish if no braces left
|
// force finish if no braces left
|
||||||
if len(s.braceStack) == 0 {
|
if len(s.braceStack) == 0 {
|
||||||
@ -73,24 +71,24 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
return logits, nil
|
return logits, nil
|
||||||
}
|
}
|
||||||
valid, err := s.proc.Encode("}")
|
|
||||||
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
|
if peek == rune('[') {
|
||||||
|
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
||||||
|
// fmt.Println("switching to list object end", s.curNode.State)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for i := range logits {
|
|
||||||
for _, token := range valid {
|
|
||||||
if i != int(token) {
|
|
||||||
logits[i] = math.NaN()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return logits, nil
|
return logits, nil
|
||||||
|
|
||||||
case StateInComma:
|
case StateInComma:
|
||||||
peek := s.braceStack[len(s.braceStack)-1]
|
peek := s.braceStack[len(s.braceStack)-1]
|
||||||
if peek == rune('[') {
|
if peek == rune('[') {
|
||||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||||
fmt.Println("switching to list comma", s.curNode.State)
|
// fmt.Println("switching to list comma", s.curNode.State)
|
||||||
}
|
}
|
||||||
logits, err := s.maskLogits(logits, s.curNode)
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -109,7 +107,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
return logits, nil
|
return logits, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
fmt.Println("masking logits current state", s.curNode.State)
|
// fmt.Println("masking logits current state", s.curNode.State)
|
||||||
logits, err := s.maskLogits(logits, s.curNode)
|
logits, err := s.maskLogits(logits, s.curNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -119,54 +117,48 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||||
fmt.Println("update state", s.curNode.State)
|
// fmt.Println("update state", s.curNode.State)
|
||||||
|
|
||||||
// TODO: need to handle end states and entering object case, and list case
|
|
||||||
if s.curNode.State == StateInObjectEnd {
|
|
||||||
fmt.Println("in object end")
|
|
||||||
if len(s.braceStack) > 0 {
|
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s.curNode = NewPDANode(StateTerminate)
|
|
||||||
// TODO: return here?
|
|
||||||
}
|
|
||||||
// need this cause there could be multiple transitions
|
|
||||||
mappedString, err := s.proc.Decode(tokenSlice)
|
mappedString, err := s.proc.Decode(tokenSlice)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: should force closing for all braces
|
|
||||||
|
// TODO: should force closing for all braces - not doing square yet
|
||||||
for _, r := range mappedString {
|
for _, r := range mappedString {
|
||||||
if r == rune('{') {
|
if r == rune('{') {
|
||||||
s.braceStack = append(s.braceStack, r)
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
// fmt.Println("pushing { brace stack", r)
|
||||||
}
|
}
|
||||||
if r == rune('[') {
|
if r == rune('[') {
|
||||||
s.braceStack = append(s.braceStack, r)
|
s.braceStack = append(s.braceStack, r)
|
||||||
|
// fmt.Println("pushing [ brace stack", r)
|
||||||
}
|
}
|
||||||
if r == rune('}') {
|
if r == rune('}') {
|
||||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
return fmt.Errorf("unmatched closing brace")
|
if len(s.braceStack) == 0 || top != rune('{') {
|
||||||
|
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
fmt.Println("popping brace stack", s.braceStack)
|
// fmt.Println("popping { brace stack", top)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r == rune(']') {
|
if r == rune(']') {
|
||||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
|
top := s.braceStack[len(s.braceStack)-1]
|
||||||
return fmt.Errorf("unmatched closing brace")
|
if len(s.braceStack) == 0 || top != rune('[') {
|
||||||
|
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||||
}
|
}
|
||||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||||
fmt.Println("popping brace stack", s.braceStack)
|
// fmt.Println("popping [ brace stack", top)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tokenID := range tokenSlice {
|
for _, tokenID := range tokenSlice {
|
||||||
// transition to the next node
|
// transition to the next node
|
||||||
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid token: %q", mappedString)
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
}
|
}
|
||||||
fmt.Println("transitioning to", nextNodeState)
|
// fmt.Println("transitioning to", nextNodeState)
|
||||||
|
|
||||||
// TODO: add a penalty for staying in the same state too long
|
// TODO: add a penalty for staying in the same state too long
|
||||||
if nextNodeState == s.curNode.State {
|
if nextNodeState == s.curNode.State {
|
||||||
|
98
sample/structured_outputs.go
Normal file
98
sample/structured_outputs.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package sample
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/model"
|
||||||
|
|
||||||
|
type StructuredOutput struct {
|
||||||
|
schema *Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
||||||
|
// _, stateToNodeMap, err := BuildGraph(proc)
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
|
||||||
|
// // If no schema constraints, return original graph node
|
||||||
|
// if schema == nil {
|
||||||
|
// return graph
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Create a new node with same state
|
||||||
|
// constrainedNode := NewPDANode(graph.State)
|
||||||
|
|
||||||
|
// // Copy over existing transitions and masks
|
||||||
|
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
||||||
|
// for r, node := range graph.TransitionEdges {
|
||||||
|
// constrainedNode.TransitionEdges[r] = node
|
||||||
|
// }
|
||||||
|
// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
|
||||||
|
|
||||||
|
// // Apply schema constraints based on type
|
||||||
|
// switch schema.EffectiveType() {
|
||||||
|
// case "object":
|
||||||
|
// // Only allow defined property names in object keys
|
||||||
|
// if graph.State == StateInObjectKey {
|
||||||
|
// // TODO: Add property name validation
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Constrain property values based on schema
|
||||||
|
// if graph.State == StateInColon || graph.State == StateInSpace {
|
||||||
|
// // Clear transitions to only allow valid types
|
||||||
|
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
||||||
|
|
||||||
|
// // Add transitions based on property schemas
|
||||||
|
// for _, prop := range schema.Properties {
|
||||||
|
// switch prop.EffectiveType() {
|
||||||
|
// case "object":
|
||||||
|
// if objNode, ok := graph.TransitionEdges['{']; ok {
|
||||||
|
// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
|
||||||
|
// }
|
||||||
|
// case "array":
|
||||||
|
// if arrNode, ok := graph.TransitionEdges['[']; ok {
|
||||||
|
// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
|
||||||
|
// }
|
||||||
|
// case "string":
|
||||||
|
// if strNode, ok := graph.TransitionEdges['"']; ok {
|
||||||
|
// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
|
||||||
|
// }
|
||||||
|
// case "number":
|
||||||
|
// for _, r := range validNumberRunes {
|
||||||
|
// if numNode, ok := graph.TransitionEdges[r]; ok {
|
||||||
|
// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// case "integer":
|
||||||
|
// for _, r := range validIntRunes {
|
||||||
|
// if intNode, ok := graph.TransitionEdges[r]; ok {
|
||||||
|
// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// case "boolean":
|
||||||
|
// for _, r := range []rune{'t', 'f'} {
|
||||||
|
// if boolNode, ok := graph.TransitionEdges[r]; ok {
|
||||||
|
// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// case "null":
|
||||||
|
// if nullNode, ok := graph.TransitionEdges['n']; ok {
|
||||||
|
// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// case "array":
|
||||||
|
// // Constrain array items based on schema
|
||||||
|
// if schema.Items != nil {
|
||||||
|
// for r, node := range graph.TransitionEdges {
|
||||||
|
// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return constrainedNode
|
||||||
|
// }
|
Loading…
x
Reference in New Issue
Block a user