diff --git a/README.md b/README.md index 7d1f38d..79fd7fd 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ The application allows users to interactively select models, sort, filter, edit, - [go install (recommended)](#go-install-recommended) - [curl](#curl) - [Manually](#manually) + - [if "command not found: gollama"](#if-command-not-found-gollama) - [Usage](#usage) - [Key Bindings](#key-bindings) - [Top](#top) @@ -110,11 +111,11 @@ echo "alias g=gollama" >> ~/.zshrc - `i`: Inspect model - `t`: Top (show running models) - `D`: Delete model -- `e`: Edit model **new** +- `e`: Edit model - `c`: Copy model - `U`: Unload all models -- `p`: Pull an existing model **new** -- `g`: Pull (get) new model **new** +- `p`: Pull an existing model +- `ctrl+p`: Pull (get) new model - `P`: Push model - `n`: Sort by name - `s`: Sort by size @@ -159,7 +160,7 @@ Note: Requires Admin privileges if you're running Windows. - `-u`: Unload all running models - `-v`: Print the version and exit - `-h`, or `--host`: Specify the host for the Ollama API -- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API) **new** +- `-H`: Shortcut for `-h http://localhost:11434` (connect to local Ollama API) - `--vram`: Estimate vRAM usage for a model. Accepts: - Ollama models (e.g. `llama3.1:8b-instruct-q6_K`, `qwen2:14b-q4_0`) - HuggingFace models (e.g. `NousResearch/Hermes-2-Theta-Llama-3-8B`) diff --git a/app_model.go b/app_model.go index 587c69e..ce6b8ac 100644 --- a/app_model.go +++ b/app_model.go @@ -87,9 +87,17 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m.handlePullErrorMsg(msg) case progressMsg: if m.pullProgress < 1.0 { - m.pullProgress = msg.progress - return m, m.updateProgressCmd() + return m, tea.Batch( + m.updateProgressCmd(), + func() tea.Msg { + return progressMsg{ + modelName: msg.modelName, + progress: m.pullProgress, + } + }, + ) } + return m, nil } } switch msg := msg.(type) { @@ -412,7 +420,7 @@ func (m *AppModel) handlePullErrorMsg(msg pullErrorMsg) (tea.Model, tea.Cmd) { } func (m *AppModel) updateProgressCmd() tea.Cmd { - return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg { + return tea.Tick(time.Second, func(t time.Time) tea.Msg { return progressMsg{ modelName: m.pullInput.Value(), progress: m.pullProgress, diff --git a/keymap.go b/keymap.go index 6d55314..9a2bc24 100644 --- a/keymap.go +++ b/keymap.go @@ -57,7 +57,7 @@ func NewKeyMap() *KeyMap { LinkModel: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "link (L=all)")), PushModel: key.NewBinding(key.WithKeys("P"), key.WithHelp("P", "push")), PullModel: key.NewBinding(key.WithKeys("p"), key.WithHelp("p", "pull")), - PullNewModel: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "get")), + PullNewModel: key.NewBinding(key.WithKeys("ctrl+p"), key.WithHelp("ctrl+p", "pull new model")), Quit: key.NewBinding(key.WithKeys("q")), RunModel: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "run")), SortByFamily: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "^family")), diff --git a/operations.go b/operations.go index 9e23d37..277aebe 100644 --- a/operations.go +++ b/operations.go @@ -96,22 +96,54 @@ func (m *AppModel) startPullModel(modelName string) tea.Cmd { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - req := &api.PullRequest{Name: modelName} - err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error { - if !m.pulling { - return context.Canceled + progressChan := make(chan float64) + errChan := make(chan error) + + go func() { + req := &api.PullRequest{Name: modelName} + err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error { + if !m.pulling { + return context.Canceled + } + progress := float64(resp.Completed) / float64(resp.Total) + m.pullProgress = progress + progressChan <- progress + return nil + }) + + if err == context.Canceled { + errChan <- fmt.Errorf("pull cancelled") + return } - m.pullProgress = float64(resp.Completed) / float64(resp.Total) - return nil - }) + if err != nil { + errChan <- err + return + } + close(progressChan) + }() - if err == context.Canceled { - return pullErrorMsg{fmt.Errorf("pull cancelled")} - } - if err != nil { - return pullErrorMsg{err} + // Start a ticker to send progress updates + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case err := <-errChan: + if err != nil { + return pullErrorMsg{err} + } + return pullSuccessMsg{modelName} + case <-ticker.C: + return progressMsg{ + modelName: modelName, + progress: m.pullProgress, + } + case progress := <-progressChan: + if progress >= 1.0 { + return pullSuccessMsg{modelName} + } + } } - return pullSuccessMsg{modelName} } }