wip
This commit is contained in:
parent
5ec6bb52a0
commit
4450f871db
@ -582,16 +582,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
// jsonSampler = nil
|
// jsonSampler = nil
|
||||||
// pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil)
|
pythonSampler := &sample.PythonSampler{}
|
||||||
// pythonSampler := &sample.PythonSampler{}
|
functions := []sample.PythonFunction{
|
||||||
// functions := []sample.PythonFunction{
|
{
|
||||||
// {
|
Name: "add_two_strings",
|
||||||
// Name: "add_two_strings",
|
Args: []string{"s1", "s2"},
|
||||||
// Args: []string{"s1", "s2"},
|
Types: []string{"string", "string"},
|
||||||
// Types: []string{"string", "string"},
|
},
|
||||||
// },
|
}
|
||||||
// }
|
pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
||||||
// pythonSampler.Init(functions, s.model.(model.TextProcessor))
|
|
||||||
sampler := sample.NewSampler(
|
sampler := sample.NewSampler(
|
||||||
req.Options.Temperature,
|
req.Options.Temperature,
|
||||||
req.Options.TopK,
|
req.Options.TopK,
|
||||||
@ -600,7 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
grammar,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
pythonSampler,
|
||||||
|
// nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
|
@ -108,10 +108,13 @@ type PythonSampler struct {
|
|||||||
proc model.TextProcessor
|
proc model.TextProcessor
|
||||||
decodedToks []string
|
decodedToks []string
|
||||||
curNode *Node
|
curNode *Node
|
||||||
|
completed int
|
||||||
|
functions []PythonFunction
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
|
||||||
s.proc = proc
|
s.proc = proc
|
||||||
|
s.functions = functions
|
||||||
decodedToks := make([]string, len(proc.Vocab().Values))
|
decodedToks := make([]string, len(proc.Vocab().Values))
|
||||||
for i := range proc.Vocab().Values {
|
for i := range proc.Vocab().Values {
|
||||||
token, err := proc.Decode([]int32{int32(i)})
|
token, err := proc.Decode([]int32{int32(i)})
|
||||||
@ -194,7 +197,7 @@ func (s *PythonSampler) BuildGraph() error {
|
|||||||
|
|
||||||
// String end
|
// String end
|
||||||
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
|
||||||
s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
// Number
|
// Number
|
||||||
for _, r := range validNumberRunes {
|
for _, r := range validNumberRunes {
|
||||||
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
|
||||||
@ -237,6 +240,16 @@ func (s *PythonSampler) UpdateState(token int32) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("invalid token: %q", mappedString)
|
return fmt.Errorf("invalid token: %q", mappedString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if mappedString == "\"" {
|
||||||
|
if s.curNode.State == PStateInStringEnd {
|
||||||
|
s.completed++
|
||||||
|
}
|
||||||
|
if s.completed == len(s.functions) {
|
||||||
|
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
|
||||||
|
s.CreateMask(s.curNode)
|
||||||
|
}
|
||||||
|
}
|
||||||
s.curNode = nextNode
|
s.curNode = nextNode
|
||||||
fmt.Println("curNode", s.curNode.State)
|
fmt.Println("curNode", s.curNode.State)
|
||||||
for r, node := range s.curNode.TransitionEdges {
|
for r, node := range s.curNode.TransitionEdges {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user