diff --git a/server/python_tools.go b/server/python_tools.go new file mode 100644 index 000000000..d669f0be8 --- /dev/null +++ b/server/python_tools.go @@ -0,0 +1,226 @@ +package server + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/ollama/ollama/api" +) + +var ( + pythonFuncRegex = regexp.MustCompile(`(\w+)\((.*?)\)`) + braces = map[rune]rune{ + '[': ']', + '{': '}', + '(': ')', + '"': '"', + '\'': '\'', + } +) + +// parsePythonValue converts a Python value string to its appropriate Go type +func parsePythonValue(value string) (any, error) { + value = strings.TrimSpace(value) + + // string + if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) || + (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) { + // Remove quotes + result := value[1 : len(value)-1] + return result, nil + } + + // bool + switch strings.ToLower(value) { + case "true": + return true, nil + case "false": + return false, nil + case "none": + return nil, nil + } + + // int + if i, err := strconv.Atoi(value); err == nil { + return i, nil + } + + // float + if f, err := strconv.ParseFloat(value, 64); err == nil { + return f, nil + } + + // list + if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") { + listStr := value[1 : len(value)-1] + var list []any + stack := []rune{} + start := 0 + + for i, char := range listStr { + if len(stack) != 0 && char == braces[stack[len(stack)-1]] { + stack = stack[:len(stack)-1] + } else if _, ok := braces[char]; ok { + stack = append(stack, char) + } + + if len(stack) == 0 && (char == ',' || i == len(listStr)-1) { + end := i + if i == len(listStr)-1 { + end = i + 1 + } + item := strings.TrimSpace(listStr[start:end]) + if val, err := parsePythonValue(item); err == nil { + list = append(list, val) + } else { + return nil, fmt.Errorf("invalid list item: %s", item) + } + start = i + 1 + } + } + return list, nil + } + + // dictionary + if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") && strings.Contains(value, ":") { + dictStr := value[1 : len(value)-1] + dict := make(map[any]any) + stack := []rune{} + start := 0 + for i, char := range dictStr { + if len(stack) != 0 && char == braces[stack[len(stack)-1]] { + stack = stack[:len(stack)-1] + } else if _, ok := braces[char]; ok { + stack = append(stack, char) + } + if len(stack) == 0 && (char == ',' || i == len(dictStr)-1) { + end := i + if i == len(dictStr)-1 { + end = i + 1 + } + item := strings.TrimSpace(dictStr[start:end]) + kv := strings.SplitN(item, ":", 2) + if len(kv) != 2 { + return nil, fmt.Errorf("invalid dictionary key-value pair: %s", item) + } + + key, err := parsePythonValue(strings.TrimSpace(kv[0])) + if err != nil { + return nil, fmt.Errorf("invalid dictionary key: %s", kv[0]) + } + + val, err := parsePythonValue(strings.TrimSpace(kv[1])) + if err != nil { + return nil, fmt.Errorf("invalid dictionary value: %s", kv[1]) + } + + dict[key] = val + start = i + 1 + } + } + return dict, nil + } + + // sets (stored as lists) + if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") { + setStr := value[1 : len(value)-1] + var list []any + stack := []rune{} + start := 0 + for i, char := range setStr { + if len(stack) != 0 && char == braces[stack[len(stack)-1]] { + stack = stack[:len(stack)-1] + } else if _, ok := braces[char]; ok { + stack = append(stack, char) + } + if len(stack) == 0 && (char == ',' || i == len(setStr)-1) { + end := i + if i == len(setStr)-1 { + end = i + 1 + } + item := strings.TrimSpace(setStr[start:end]) + if val, err := parsePythonValue(item); err == nil { + list = append(list, val) + } else { + return nil, fmt.Errorf("invalid set item: %s", item) + } + start = i + 1 + } + } + return list, nil + } + + return nil, fmt.Errorf("invalid Python value: %s", value) +} + +// parsePythonFunctionCall parses Python function calls from a string +// it supports keyword arguments, as well as multiple functions in a single string +func parsePythonFunctionCall(s string) ([]api.ToolCall, error) { + matches := pythonFuncRegex.FindAllStringSubmatchIndex(s, -1) + if len(matches) == 0 { + return nil, fmt.Errorf("no Python function calls found") + } + + var toolCalls []api.ToolCall + for _, match := range matches { + name := s[match[2]:match[3]] + args := s[match[4]:match[5]] + arguments := make(api.ToolCallFunctionArguments) + if len(args) == 0 { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: name, + }, + }) + continue + } + + start := 0 + stack := []rune{} + for i, char := range args { + if len(stack) != 0 && char == braces[stack[len(stack)-1]] { + stack = stack[:len(stack)-1] + } else if _, ok := braces[char]; ok { + stack = append(stack, char) + } + if len(stack) == 0 && (char == ',' || i == len(args)-1) { + end := i + if i == len(args)-1 { + end = i + 1 + } + kv := strings.SplitN(args[start:end], "=", 2) + if len(kv) == 2 { + key := strings.TrimSpace(kv[0]) + valueStr := strings.TrimSpace(kv[1]) + + // Parse the value into appropriate type + value, err := parsePythonValue(valueStr) + if err != nil { + return nil, fmt.Errorf("failed to parse value for key %q: %v", key, err) + } + + arguments[key] = value + } else { + return nil, fmt.Errorf("invalid argument format: %q", args[start:end]) + } + start = i + 1 + } + } + + if len(arguments) > 0 { + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: name, + Arguments: arguments, + }, + }) + } + } + + if len(toolCalls) > 0 { + return toolCalls, nil + } + return nil, fmt.Errorf("failed to parse any valid tool calls") +} diff --git a/server/python_tools_test.go b/server/python_tools_test.go new file mode 100644 index 000000000..e6908a8a2 --- /dev/null +++ b/server/python_tools_test.go @@ -0,0 +1,269 @@ +package server + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestParsePythonFunctionCall(t *testing.T) { + t1 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "San Francisco, CA", + "format": "fahrenheit", + }, + }, + } + + t2 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_forecast", + Arguments: api.ToolCallFunctionArguments{ + "days": 5, + "location": "Seattle", + }, + }, + } + + t3 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "list": []any{1, 2, 3}, + "int": -1, + "float": 1.23, + "string": "hello", + }, + }, + } + t4 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + }, + } + + cases := []struct { + name string + input string + want []api.ToolCall + err bool + }{ + { + name: "malformed function call - missing closing paren", + input: "get_current_weather(location=\"San Francisco\"", + err: true, + }, + { + name: "empty function call", + input: "get_current_weather()", + want: []api.ToolCall{t4}, + err: false, + }, + { + name: "single valid function call", + input: "get_current_weather(location=\"San Francisco, CA\", format=\"fahrenheit\")", + want: []api.ToolCall{t1}, + }, + { + name: "multiple valid function calls", + input: "get_current_weather(location=\"San Francisco, CA\", format=\"fahrenheit\") get_forecast(days=5, location=\"Seattle\")", + want: []api.ToolCall{t1, t2}, + }, + { + name: "multiple valid function calls with list", + input: "get_current_weather(list=[1,2,3], int=-1, float=1.23, string=\"hello\")", + want: []api.ToolCall{t3}, + }, + { + name: "positional arguments not supported", + input: "get_current_weather(1, 2, 3)", + err: true, + }, + { + name: "invalid argument format without equals", + input: "get_current_weather(\"San Francisco\")", + err: true, + }, + { + name: "nested lists", + input: "get_current_weather(data=[[1,2],[3,4]])", + want: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "data": []any{[]any{1, 2}, []any{3, 4}}, + }, + }, + }}, + }, + { + name: "boolean and none values", + input: "get_current_weather(active=true, enabled=false, value=None)", + want: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "active": true, + "enabled": false, + "value": nil, + }, + }, + }}, + }, + { + name: "single vs double quotes", + input: "get_current_weather(str1='single', str2=\"double\")", + want: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "str1": "single", + "str2": "double", + }, + }, + }}, + }, + { + name: "whitespace handling", + input: "get_current_weather( location = \"San Francisco\" , temp = 72 )", + want: []api.ToolCall{{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "San Francisco", + "temp": 72, + }, + }, + }}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := parsePythonFunctionCall(tt.input) + if (err != nil) != tt.err { + t.Fatalf("expected error: %v, got error: %v", tt.err, err) + } + if tt.err { + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestParsePythonValue(t *testing.T) { + cases := []struct { + name string + input string + want any + err bool + }{ + { + name: "string with double quotes", + input: "\"hello\"", + want: "hello", + }, + { + name: "string with single quotes", + input: "'world'", + want: "world", + }, + { + name: "integer", + input: "42", + want: 42, + }, + { + name: "float", + input: "3.14", + want: 3.14, + }, + { + name: "boolean true", + input: "True", + want: true, + }, + { + name: "boolean false", + input: "False", + want: false, + }, + { + name: "none/null", + input: "None", + want: nil, + }, + { + name: "simple list", + input: "[1, 2, 3]", + want: []any{1, 2, 3}, + }, + { + name: "nested list", + input: "[1, [2, 3], 4]", + want: []any{1, []any{2, 3}, 4}, + }, + { + name: "mixed type list", + input: "[1, \"two\", 3.0, true]", + want: []any{1, "two", 3.0, true}, + }, + { + name: "invalid list", + input: "[1, 2,", + want: nil, + err: true, + }, + { + name: "dictionaries", + input: "{'a': 1, 'b': 2}", + want: map[any]any{"a": 1, "b": 2}, + err: false, + }, + { + name: "int dictionary", + input: "{1: 2}", + want: map[any]any{1: 2}, + err: false, + }, + { + name: "mixed type dictionary", + input: "{'a': 1, 'b': 2.0, 'c': True}", + want: map[any]any{"a": 1, "b": 2.0, "c": true}, + err: false, + }, + { + name: "invalid dictionary - missing closing brace", + input: "{'a': 1, 'b': 2", + want: nil, + err: true, + }, + { + name: "sets", + input: "{1, 2, 3}", + want: []any{1, 2, 3}, + err: false, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := parsePythonValue(tt.input) + if (err != nil) != tt.err { + t.Fatalf("expected error: %v, got error: %v", tt.err, err) + } + if tt.err { + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +}