just for fun
This commit is contained in:
parent
918fd32884
commit
42009d2974
256
cmd/cmd.go
256
cmd/cmd.go
@ -2,6 +2,7 @@ package cmd
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
@ -16,6 +17,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@ -31,6 +33,11 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
"gonum.org/v1/gonum/stat"
|
||||
"gonum.org/v1/plot"
|
||||
"gonum.org/v1/plot/plotter"
|
||||
"gonum.org/v1/plot/vg"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
@ -370,6 +377,90 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func EmbedHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if keepAlive != "" {
|
||||
d, err := time.ParseDuration(keepAlive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.KeepAlive = &api.Duration{Duration: d}
|
||||
}
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
// Fill out the rest of the options based on information about the
|
||||
// model.
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
info, err := func() (*api.ShowResponse, error) {
|
||||
showReq := &api.ShowRequest{Name: name}
|
||||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
}
|
||||
return info, err
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Details.Families, "clip")
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
opts.Messages = append(opts.Messages, info.Messages...)
|
||||
|
||||
if interactive {
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return embed(cmd, opts)
|
||||
}
|
||||
|
||||
func errFromUnknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
@ -979,6 +1070,154 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func embed(cmd *cobra.Command, opts runOptions) error {
|
||||
line := opts.Prompt
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
inputs := strings.Split(line, "\n\n")
|
||||
|
||||
req := &api.EmbedRequest{
|
||||
Model: opts.Model,
|
||||
Input: inputs,
|
||||
}
|
||||
|
||||
resp, err := client.Embed(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get embeddings")
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings := resp.Embeddings
|
||||
|
||||
r, c := len(embeddings), len(embeddings[0])
|
||||
data := make([]float64, r*c)
|
||||
for i := range r {
|
||||
for j := range c {
|
||||
data[i*c+j] = float64(embeddings[i][j])
|
||||
}
|
||||
}
|
||||
|
||||
X := mat.NewDense(r, c, data)
|
||||
|
||||
// Initialize PCA
|
||||
var pca stat.PC
|
||||
|
||||
// Perform PCA
|
||||
if !pca.PrincipalComponents(X, nil) {
|
||||
return fmt.Errorf("PCA failed")
|
||||
}
|
||||
|
||||
// Extract principal component vectors
|
||||
var vectors mat.Dense
|
||||
pca.VectorsTo(&vectors)
|
||||
|
||||
// // Extract variances of the principal components
|
||||
// var variances []float64
|
||||
// variances = pca.VarsTo(variances)
|
||||
|
||||
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
|
||||
|
||||
// Perform PCA reduction
|
||||
var reducedData mat.Dense
|
||||
reducedData.Mul(X, W)
|
||||
|
||||
for i, s := range inputs {
|
||||
row := reducedData.RowView(i)
|
||||
fmt.Print(i+1, ". ", s, "\n")
|
||||
fmt.Printf("[%v, %v]\n\n", row.AtVec(0), row.AtVec(1))
|
||||
}
|
||||
|
||||
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
|
||||
for i := range len(points) {
|
||||
row := reducedData.RowView(i)
|
||||
points[i].X = row.AtVec(0)
|
||||
points[i].Y = row.AtVec(1)
|
||||
}
|
||||
|
||||
// Create a new plot
|
||||
p := plot.New()
|
||||
|
||||
// Set plot title and axis labels
|
||||
p.Title.Text = "Embedding Map"
|
||||
|
||||
// Create a scatter plot of the points
|
||||
s, err := plotter.NewScatter(points)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Add(s)
|
||||
|
||||
/// Create labels plotter and add it to the plot
|
||||
|
||||
labels := make([]string, reducedData.RawMatrix().Rows)
|
||||
for i := range len(labels) {
|
||||
labels[i] = fmt.Sprintf("%d", i+1)
|
||||
}
|
||||
|
||||
// plotter := plotter
|
||||
|
||||
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Add(l)
|
||||
|
||||
// Make the grid square
|
||||
p.X.Min = -1
|
||||
p.X.Max = 1
|
||||
p.Y.Min = -1
|
||||
p.Y.Max = 1
|
||||
|
||||
// Set the aspect ratio to be 1:1
|
||||
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||
{Value: -1, Label: "-1"},
|
||||
{Value: -0.5, Label: "-0.5"},
|
||||
{Value: 0, Label: "0"},
|
||||
{Value: 0.5, Label: "0.5"},
|
||||
{Value: 1, Label: "1"},
|
||||
})
|
||||
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||
{Value: -1, Label: "-1"},
|
||||
{Value: -0.5, Label: "-0.5"},
|
||||
{Value: 0, Label: "0"},
|
||||
{Value: 0.5, Label: "0.5"},
|
||||
{Value: 1, Label: "1"},
|
||||
})
|
||||
|
||||
// Save the plot to a svg file
|
||||
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.svg"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// open the plot
|
||||
open := exec.Command("open", "plot.svg")
|
||||
err = open.Run()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't open plot")
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for Enter key press
|
||||
fmt.Print("Press 'Enter' to continue")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
_, _ = reader.ReadString('\n')
|
||||
|
||||
// close and delete the plot (defer this)
|
||||
defer func() {
|
||||
delete := exec.Command("rm", "plot.svg")
|
||||
err = delete.Run()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't delete plot")
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@ -1247,11 +1486,26 @@ func NewCLI() *cobra.Command {
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
embedCmd := &cobra.Command{
|
||||
Use: "embed MODEL [PROMPT]",
|
||||
Short: "Embed a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: EmbedHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
|
||||
embedCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||
embedCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
embedCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
embedCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
embedCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
|
||||
serveCmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Aliases: []string{"start"},
|
||||
@ -1326,6 +1580,7 @@ func NewCLI() *cobra.Command {
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
serveCmd,
|
||||
embedCmd,
|
||||
} {
|
||||
switch cmd {
|
||||
case runCmd:
|
||||
@ -1361,6 +1616,7 @@ func NewCLI() *cobra.Command {
|
||||
psCmd,
|
||||
copyCmd,
|
||||
deleteCmd,
|
||||
embedCmd,
|
||||
)
|
||||
|
||||
return rootCmd
|
||||
|
@ -1,7 +1,6 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -14,11 +13,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
"gonum.org/v1/gonum/stat"
|
||||
"gonum.org/v1/plot"
|
||||
"gonum.org/v1/plot/plotter"
|
||||
"gonum.org/v1/plot/vg"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@ -453,140 +447,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/embed"):
|
||||
line = strings.TrimPrefix(line, "/embed")
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't connect to ollama server")
|
||||
return err
|
||||
}
|
||||
|
||||
var strArray []string
|
||||
fmt.Printf("line is %s\n", line)
|
||||
err = json.Unmarshal([]byte(line), &strArray)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't parse input")
|
||||
return err
|
||||
}
|
||||
|
||||
for i, s := range strArray {
|
||||
fmt.Printf("strArray[%d] is %s\n", i, s)
|
||||
}
|
||||
|
||||
req := &api.EmbedRequest{
|
||||
Model: opts.Model,
|
||||
Input: strArray,
|
||||
}
|
||||
|
||||
resp, err := client.Embed(cmd.Context(), req)
|
||||
if err != nil {
|
||||
fmt.Println("error: couldn't get embeddings")
|
||||
return err
|
||||
}
|
||||
|
||||
embeddings := resp.Embeddings
|
||||
|
||||
r, c := len(embeddings), len(embeddings[0])
|
||||
data := make([]float64, r*c)
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
data[i*c+j] = float64(embeddings[i][j])
|
||||
}
|
||||
}
|
||||
|
||||
X := mat.NewDense(r, c, data)
|
||||
|
||||
// Initialize PCA
|
||||
var pca stat.PC
|
||||
|
||||
// Perform PCA
|
||||
if !pca.PrincipalComponents(X, nil) {
|
||||
return fmt.Errorf("PCA failed")
|
||||
}
|
||||
|
||||
// Extract principal component vectors
|
||||
var vectors mat.Dense
|
||||
pca.VectorsTo(&vectors)
|
||||
|
||||
// // Extract variances of the principal components
|
||||
// var variances []float64
|
||||
// variances = pca.VarsTo(variances)
|
||||
|
||||
W := vectors.Slice(0, c, 0, 2).(*mat.Dense)
|
||||
|
||||
// Perform PCA reduction
|
||||
var reducedData mat.Dense
|
||||
reducedData.Mul(X, W)
|
||||
|
||||
// Print the projected 2D points
|
||||
fmt.Println("Reduced embeddings to 2D:")
|
||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
||||
row := reducedData.RowView(i)
|
||||
fmt.Printf("[%v, %v]\n", row.AtVec(0), row.AtVec(1))
|
||||
}
|
||||
|
||||
points := make(plotter.XYs, reducedData.RawMatrix().Rows)
|
||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
||||
row := reducedData.RowView(i)
|
||||
points[i].X = row.AtVec(0)
|
||||
points[i].Y = row.AtVec(1)
|
||||
}
|
||||
|
||||
// Create a new plot
|
||||
p := plot.New()
|
||||
|
||||
// Set plot title and axis labels
|
||||
p.Title.Text = "2D Data Plot"
|
||||
p.X.Label.Text = "X"
|
||||
p.Y.Label.Text = "Y"
|
||||
|
||||
// Create a scatter plot of the points
|
||||
s, err := plotter.NewScatter(points)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Add(s)
|
||||
|
||||
/// Create labels plotter and add it to the plot
|
||||
|
||||
labels := make([]string, reducedData.RawMatrix().Rows)
|
||||
for i := 0; i < reducedData.RawMatrix().Rows; i++ {
|
||||
labels[i] = fmt.Sprintf("%d", i+1)
|
||||
}
|
||||
|
||||
l, err := plotter.NewLabels(plotter.XYLabels{XYs: points, Labels: labels})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
p.Add(l)
|
||||
|
||||
// Make the grid square
|
||||
p.X.Min = -1
|
||||
p.X.Max = 1
|
||||
p.Y.Min = -1
|
||||
p.Y.Max = 1
|
||||
|
||||
// Set the aspect ratio to be 1:1
|
||||
p.X.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||
{Value: -1, Label: "-1"},
|
||||
{Value: -0.5, Label: "-0.5"},
|
||||
{Value: 0, Label: "0"},
|
||||
{Value: 0.5, Label: "0.5"},
|
||||
{Value: 1, Label: "1"},
|
||||
})
|
||||
p.Y.Tick.Marker = plot.ConstantTicks([]plot.Tick{
|
||||
{Value: -1, Label: "-1"},
|
||||
{Value: -0.5, Label: "-0.5"},
|
||||
{Value: 0, Label: "0"},
|
||||
{Value: 0.5, Label: "0.5"},
|
||||
{Value: 1, Label: "1"},
|
||||
})
|
||||
|
||||
// Save the plot to a PNG file
|
||||
if err := p.Save(6*vg.Inch, 6*vg.Inch, "plot.png"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
case strings.HasPrefix(line, "/"):
|
||||
args := strings.Fields(line)
|
||||
isFile := false
|
||||
|
Loading…
x
Reference in New Issue
Block a user