From f7d64856d5802ea2ece3abad7f421de4bb38bbdf Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 2 Jul 2024 10:41:31 -0700 Subject: [PATCH] start tests --- api/client.go | 20 ++++++++++++ api/client_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++ cmd/cmd.go | 22 +++++++++++++ 3 files changed, 121 insertions(+) diff --git a/api/client.go b/api/client.go index c59fbc423..8ad68795f 100644 --- a/api/client.go +++ b/api/client.go @@ -383,3 +383,23 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } + +// IsLocal checks whether the client is connecting to a local server. +func (c *Client) IsLocal() bool { + // Resolve the host to an IP address and check if the IP is local + // Currently, only checks if it is localhost or loopback + host, _, err := net.SplitHostPort(c.base.Host) + if err != nil { + host = c.base.Host + } + + if host == "" || host == "localhost" { + return true + } + + if ip := net.ParseIP(host); ip != nil { + return ip.IsLoopback() + } + + return false +} diff --git a/api/client_test.go b/api/client_test.go index fe9fd74f7..92bdb04f8 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -1,6 +1,8 @@ package api import ( + "net/http" + "net/url" "testing" "github.com/ollama/ollama/envconfig" @@ -46,3 +48,80 @@ func TestClientFromEnvironment(t *testing.T) { }) } } + +// Test function +func TestIsLocal(t *testing.T) { + type test struct { + client *Client + want bool + err error + } + + tests := map[string]test{ + "localhost": { + client: func() *Client { + baseURL, _ := url.Parse("http://localhost:1234") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: true, + err: nil, + }, + "127.0.0.1": { + client: func() *Client { + baseURL, _ := url.Parse("http://127.0.0.1:1234") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: true, + err: nil, + }, + "example.com": { + client: func() *Client { + baseURL, _ := url.Parse("http://example.com:1111") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: false, + err: nil, + }, + "8.8.8.8": { + client: func() *Client { + baseURL, _ := url.Parse("http://8.8.8.8:1234") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: false, + err: nil, + }, + "empty host with port": { + client: func() *Client { + baseURL, _ := url.Parse("http://:1234") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: true, + err: nil, + }, + "empty host without port": { + client: func() *Client { + baseURL, _ := url.Parse("http://") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: true, + err: nil, + }, + "remote host without port": { + client: func() *Client { + baseURL, _ := url.Parse("http://example.com") + return &Client{base: baseURL, http: &http.Client{}} + }(), + want: false, + err: nil, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + got := tc.client.IsLocal() + if got != tc.want { + t.Errorf("test %s failed: got %v, want %v", name, got, tc.want) + } + }) + } +} diff --git a/cmd/cmd.go b/cmd/cmd.go index c7cf0581a..a7c9cb54d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -277,12 +277,34 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er } digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) + + // Here, we want to check if the server is local + // If true, call, createBlobLocal + // This should find the model directory, copy blob over, and return the digest + // If this fails, just upload it + // If this is successful, return the digest + + // Resolve server to IP + // Check if server is local + if client.IsLocal() { + err := createBlobLocal(cmd, client, digest) + if err == nil { + return digest, nil + } + } + if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { return "", err } return digest, nil } +func createBlobLocal(cmd *cobra.Command, client *api.Client, digest string) error { + // This function should be called if the server is local + // It should find the model directory, copy the blob over, and return the digest + +} + func RunHandler(cmd *cobra.Command, args []string) error { interactive := true