just for fun
This commit is contained in:
parent
918fd32884
commit
42009d2974
256
cmd/cmd.go
256
cmd/cmd.go
@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
@ -16,6 +17,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -31,6 +33,11 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
"gonum.org/v1/gonum/mat"
|
||||||
|
"gonum.org/v1/gonum/stat"
|
||||||
|
"gonum.org/v1/plot"
|
||||||
|
"gonum.org/v1/plot/plotter"
|
||||||
|
"gonum.org/v1/plot/vg"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
"github.com/ollama/ollama/auth"
|
||||||
@ -370,6 +377,90 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return generate(cmd, opts)
|
return generate(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EmbedHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
interactive := true
|
||||||
|
|
||||||
|
opts := runOptions{
|
||||||
|
Model: args[0],
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
format, err := cmd.Flags().GetString("format")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts.Format = format
|
||||||
|
|
||||||
|
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if keepAlive != "" {
|
||||||
|
d, err := time.ParseDuration(keepAlive)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts.KeepAlive = &api.Duration{Duration: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts := args[1:]
|
||||||
|
// prepend stdin to the prompt if provided
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
|
in, err := io.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts = append([]string{string(in)}, prompts...)
|
||||||
|
opts.WordWrap = false
|
||||||
|
interactive = false
|
||||||
|
}
|
||||||
|
opts.Prompt = strings.Join(prompts, " ")
|
||||||
|
if len(prompts) > 0 {
|
||||||
|
interactive = false
|
||||||
|
}
|
||||||
|
|
||||||
|
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
|
// Fill out the rest of the options based on information about the
|
||||||
|
// model.
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := args[0]
|
||||||
|
info, err := func() (*api.ShowResponse, error) {
|
||||||
|
showReq := &api.ShowRequest{Name: name}
|
||||||
|
info, err := client.Show(cmd.Context(), showReq)
|
||||||
|
var se api.StatusError
|
||||||
|
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||||
|
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||||
|
}
|
||||||
|
return info, err
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.MultiModal = slices.Contains(info.Details.Families, "clip")
|
||||||
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
opts.Messages = append(opts.Messages, info.Messages...)
|
||||||
|
|
||||||
|
if interactive {
|
||||||
|
return generateInteractive(cmd, opts)
|
||||||
|
}
|
||||||
|
return embed(cmd, opts)
|
||||||
|
}
|
||||||
|
|
||||||
func errFromUnknownKey(unknownKeyErr error) error {
|
func errFromUnknownKey(unknownKeyErr error) error {
|
||||||
// find SSH public key in the error message
|
// find SSH public key in the error message
|
||||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||||
@ -979,6 +1070,154 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func embed(cmd *cobra.Command, opts runOptions) error {
|
||||||
|
line := opts.Prompt
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't connect to ollama server")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs := strings.Split(line, "\n\n")
|
||||||
|
|
||||||
|
req := &api.EmbedRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Input: inputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Embed(cmd.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't get embeddings")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings := resp.Embeddings
|
||||||
|
|
||||||
|
r, c := len(embeddings), len(embeddings[0])
|
||||||
|
data := make([]float64, r*c)
|
||||||
|
for i := range r {
|
||||||
|
for j := range c {
|
||||||
|
data[i*c+j] = float64(embeddings[i][j])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
X := mat.NewDense(r, c, data)
|
||||||
|
|
||||||
|
// Initialize PCA
|
||||||
|
var pca stat.PC
|
||||||
|
|
||||||
|
// Perform PCA
|
||||||
|
if !pca.PrincipalComponents(X, nil) {
|
||||||
|
return fmt.Errorf("PCA failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract principal component vectors
|
||||||
|
var vectors mat.Dense
|
||||||
|
pca.VectorsTo(&vectors)
|
||||||
|
|
||||||
|
// // Extract variances of the principal components
|
||||||
|
// var variances []float64
|
||||||
|
// variances = pca.VarsTo(variances)
|
||||||
|
|
||||||
|
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
|
||||||
|
|
||||||
|
// Perform PCA reduction
|
||||||
|
var reducedData mat.Dense
|
||||||
|
reducedData.Mul(X, W)
|
||||||
|
|
||||||
|
for i, s := range inputs {
|
||||||
|
row := reducedData.RowView(i)
|
||||||
|
fmt.Print(i+1, ". ", s, "\n")
|
||||||
|
fmt.Printf("[%v, %v]\n\n", row.AtVec(0), row.AtVec(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
|
||||||
|
for i := range len(points) {
|
||||||
|
row := reducedData.RowView(i)
|
||||||
|
points[i].X = row.AtVec(0)
|
||||||
|
points[i].Y = row.AtVec(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new plot
|
||||||
|
p := plot.New()
|
||||||
|
|
||||||
|
// Set plot title and axis labels
|
||||||
|
p.Title.Text = "Embedding Map"
|
||||||
|
|
||||||
|
// Create a scatter plot of the points
|
||||||
|
s, err := plotter.NewScatter(points)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
p.Add(s)
|
||||||
|
|
||||||
|
/// Create labels plotter and add it to the plot
|
||||||
|
|
||||||
|
labels := make([]string, reducedData.RawMatrix().Rows)
|
||||||
|
for i := range len(labels) {
|
||||||
|
labels[i] = fmt.Sprintf("%d", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// plotter := plotter
|
||||||
|
|
||||||
|
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
p.Add(l)
|
||||||
|
|
||||||
|
// Make the grid square
|
||||||
|
p.X.Min = -1
|
||||||
|
p.X.Max = 1
|
||||||
|
p.Y.Min = -1
|
||||||
|
p.Y.Max = 1
|
||||||
|
|
||||||
|
// Set the aspect ratio to be 1:1
|
||||||
|
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||||
|
{Value: -1, Label: "-1"},
|
||||||
|
{Value: -0.5, Label: "-0.5"},
|
||||||
|
{Value: 0, Label: "0"},
|
||||||
|
{Value: 0.5, Label: "0.5"},
|
||||||
|
{Value: 1, Label: "1"},
|
||||||
|
})
|
||||||
|
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||||
|
{Value: -1, Label: "-1"},
|
||||||
|
{Value: -0.5, Label: "-0.5"},
|
||||||
|
{Value: 0, Label: "0"},
|
||||||
|
{Value: 0.5, Label: "0.5"},
|
||||||
|
{Value: 1, Label: "1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Save the plot to a svg file
|
||||||
|
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.svg"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// open the plot
|
||||||
|
open := exec.Command("open", "plot.svg")
|
||||||
|
err = open.Run()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't open plot")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for Enter key press
|
||||||
|
fmt.Print("Press 'Enter' to continue")
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
_, _ = reader.ReadString('\n')
|
||||||
|
|
||||||
|
// close and delete the plot (defer this)
|
||||||
|
defer func() {
|
||||||
|
delete := exec.Command("rm", "plot.svg")
|
||||||
|
err = delete.Run()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("error: couldn't delete plot")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1247,11 +1486,26 @@ func NewCLI() *cobra.Command {
|
|||||||
RunE: RunHandler,
|
RunE: RunHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
embedCmd := &cobra.Command{
|
||||||
|
Use: "embed MODEL [PROMPT]",
|
||||||
|
Short: "Embed a model",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
PreRunE: checkServerHeartbeat,
|
||||||
|
RunE: EmbedHandler,
|
||||||
|
}
|
||||||
|
|
||||||
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||||
|
|
||||||
|
embedCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||||
|
embedCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||||
|
embedCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||||
|
embedCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||||
|
embedCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||||
|
|
||||||
serveCmd := &cobra.Command{
|
serveCmd := &cobra.Command{
|
||||||
Use: "serve",
|
Use: "serve",
|
||||||
Aliases: []string{"start"},
|
Aliases: []string{"start"},
|
||||||
@ -1326,6 +1580,7 @@ func NewCLI() *cobra.Command {
|
|||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
serveCmd,
|
serveCmd,
|
||||||
|
embedCmd,
|
||||||
} {
|
} {
|
||||||
switch cmd {
|
switch cmd {
|
||||||
case runCmd:
|
case runCmd:
|
||||||
@ -1361,6 +1616,7 @@ func NewCLI() *cobra.Command {
|
|||||||
psCmd,
|
psCmd,
|
||||||
copyCmd,
|
copyCmd,
|
||||||
deleteCmd,
|
deleteCmd,
|
||||||
|
embedCmd,
|
||||||
)
|
)
|
||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -14,11 +13,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gonum.org/v1/gonum/mat"
|
|
||||||
"gonum.org/v1/gonum/stat"
|
|
||||||
"gonum.org/v1/plot"
|
|
||||||
"gonum.org/v1/plot/plotter"
|
|
||||||
"gonum.org/v1/plot/vg"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
@ -453,140 +447,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||||
return nil
|
return nil
|
||||||
case strings.HasPrefix(line, "/embed"):
|
|
||||||
line = strings.TrimPrefix(line, "/embed")
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("error: couldn't connect to ollama server")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var strArray []string
|
|
||||||
fmt.Printf("line is %s\n", line)
|
|
||||||
err = json.Unmarshal([]byte(line), &strArray)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("error: couldn't parse input")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, s := range strArray {
|
|
||||||
fmt.Printf("strArray[%d] is %s\n", i, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &api.EmbedRequest{
|
|
||||||
Model: opts.Model,
|
|
||||||
Input: strArray,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Embed(cmd.Context(), req)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("error: couldn't get embeddings")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddings := resp.Embeddings
|
|
||||||
|
|
||||||
r, c := len(embeddings), len(embeddings[0])
|
|
||||||
data := make([]float64, r*c)
|
|
||||||
for i := 0; i < r; i++ {
|
|
||||||
for j := 0; j < c; j++ {
|
|
||||||
data[i*c+j] = float64(embeddings[i][j])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
X := mat.NewDense(r, c, data)
|
|
||||||
|
|
||||||
// Initialize PCA
|
|
||||||
var pca stat.PC
|
|
||||||
|
|
||||||
// Perform PCA
|
|
||||||
if !pca.PrincipalComponents(X, nil) {
|
|
||||||
return fmt.Errorf("PCA failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract principal component vectors
|
|
||||||
var vectors mat.Dense
|
|
||||||
pca.VectorsTo(&vectors)
|
|
||||||
|
|
||||||
// // Extract variances of the principal components
|
|
||||||
// var variances []float64
|
|
||||||
// variances = pca.VarsTo(variances)
|
|
||||||
|
|
||||||
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
|
|
||||||
|
|
||||||
// Perform PCA reduction
|
|
||||||
var reducedData mat.Dense
|
|
||||||
reducedData.Mul(X, W)
|
|
||||||
|
|
||||||
// Print the projected 2D points
|
|
||||||
fmt.Println("Reduced embeddings to 2D:")
|
|
||||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
|
||||||
row := reducedData.RowView(i)
|
|
||||||
fmt.Printf("[%v, %v]\n", row.AtVec(0), row.AtVec(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
|
|
||||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
|
||||||
row := reducedData.RowView(i)
|
|
||||||
points[i].X = row.AtVec(0)
|
|
||||||
points[i].Y = row.AtVec(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new plot
|
|
||||||
p := plot.New()
|
|
||||||
|
|
||||||
// Set plot title and axis labels
|
|
||||||
p.Title.Text = "2D Data Plot"
|
|
||||||
p.X.Label.Text = "X"
|
|
||||||
p.Y.Label.Text = "Y"
|
|
||||||
|
|
||||||
// Create a scatter plot of the points
|
|
||||||
s, err := plotter.NewScatter(points)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
p.Add(s)
|
|
||||||
|
|
||||||
/// Create labels plotter and add it to the plot
|
|
||||||
|
|
||||||
labels := make([]string, reducedData.RawMatrix().Rows)
|
|
||||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
|
||||||
labels[i] = fmt.Sprintf("%d", i+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
p.Add(l)
|
|
||||||
|
|
||||||
// Make the grid square
|
|
||||||
p.X.Min = -1
|
|
||||||
p.X.Max = 1
|
|
||||||
p.Y.Min = -1
|
|
||||||
p.Y.Max = 1
|
|
||||||
|
|
||||||
// Set the aspect ratio to be 1:1
|
|
||||||
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
|
||||||
{Value: -1, Label: "-1"},
|
|
||||||
{Value: -0.5, Label: "-0.5"},
|
|
||||||
{Value: 0, Label: "0"},
|
|
||||||
{Value: 0.5, Label: "0.5"},
|
|
||||||
{Value: 1, Label: "1"},
|
|
||||||
})
|
|
||||||
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
|
||||||
{Value: -1, Label: "-1"},
|
|
||||||
{Value: -0.5, Label: "-0.5"},
|
|
||||||
{Value: 0, Label: "0"},
|
|
||||||
{Value: 0.5, Label: "0.5"},
|
|
||||||
{Value: 1, Label: "1"},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Save the plot to a PNG file
|
|
||||||
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.png"); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case strings.HasPrefix(line, "/"):
|
case strings.HasPrefix(line, "/"):
|
||||||
args := strings.Fields(line)
|
args := strings.Fields(line)
|
||||||
isFile := false
|
isFile := false
|
||||||
|
Loading…
x
Reference in New Issue
Block a user