Skip to content

Commit

Permalink
feat: pull existing or new model (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
sammcj authored Jul 14, 2024
1 parent ec855a7 commit c4829e0
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 26 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,13 @@ echo "alias g=gollama" >> ~/.zshrc
- `Space`: Select
- `Enter`: Run model (Ollama run)
- `i`: Inspect model
- `t`: Top (show running models) _**(Work in progress)**_
- `t`: Top (show running models)
- `D`: Delete model
- `e`: Edit model **new**
- `c`: Copy model
- `U`: Unload all models
- `p`: Pull model **new** _**(Work in progress, no progress bar yet)**_
- `p`: Pull an existing model **new**
- `g`: Pull (get) new model **new**
- `P`: Push model
- `n`: Sort by name
- `s`: Sort by size
Expand Down
173 changes: 158 additions & 15 deletions app_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/table"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"

"github.com/sammcj/gollama/logging"
)

type View int

const (
MainView View = iota
TopView
Expand All @@ -41,6 +41,46 @@ var topRunning = false

func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd

if m.pulling {
switch msg := msg.(type) {
case tea.KeyMsg:
if m.newModelPull {
switch msg.Type {
case tea.KeyEnter:
m.newModelPull = false
m.pullProgress = 0.01 // Start progress immediately
return m, tea.Batch(
m.startPullModel(m.pullInput.Value()),
m.updateProgressCmd(),
)
case tea.KeyCtrlC, tea.KeyEsc:
m.pulling = false
m.newModelPull = false
m.pullInput.Reset()
return m, nil
}
var cmd tea.Cmd
m.pullInput, cmd = m.pullInput.Update(msg)
return m, cmd
} else {
if msg.Type == tea.KeyCtrlC {
m.pulling = false
m.pullProgress = 0
return m, nil
}
}
case pullSuccessMsg:
return m.handlePullSuccessMsg(msg)
case pullErrorMsg:
return m.handlePullErrorMsg(msg)
case progressMsg:
if m.pullProgress < 1.0 {
m.pullProgress = msg.progress
return m, m.updateProgressCmd()
}
}
}
switch msg := msg.(type) {
case tea.KeyMsg:
return m.handleKeyMsg(msg)
Expand All @@ -61,7 +101,6 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.height = msg.Height
m.list.SetSize(m.width, m.height)
return m, nil

default:
m.list, cmd = m.list.Update(msg)
return m, cmd
Expand Down Expand Up @@ -101,6 +140,18 @@ func (m *AppModel) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {

// Handle other keys
switch msg.String() {
case "ctrl+c":
if m.pulling {
m.pulling = false
m.pullProgress = 0
m.pullInput.Reset()
return m, nil
}
if m.editing {
m.editing = false
return m, nil
}
return m, tea.Quit
case "q":
if m.list.FilterState() == list.FilterApplied {
logging.DebugLogger.Println("Clearing filter with 'q' key")
Expand Down Expand Up @@ -129,12 +180,6 @@ func (m *AppModel) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
} else {
return m, nil
}
case "ctrl+c":
if m.editing {
m.editing = false
return m, nil
}
return m, tea.Quit
}

if m.confirmDeletion {
Expand Down Expand Up @@ -197,6 +242,8 @@ func (m *AppModel) handleKeyMsg(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m.handlePullModelKey()
case key.Matches(msg, m.keys.RenameModel):
return m.handleRenameModelKey()
case key.Matches(msg, m.keys.PullNewModel):
return m.handlePullNewModelKey()
case key.Matches(msg, m.keys.InspectModel):
return m.handleInspectModelKey()
case key.Matches(msg, m.keys.Top):
Expand Down Expand Up @@ -313,14 +360,36 @@ func (m *AppModel) handlePushErrorMsg(msg pushErrorMsg) (tea.Model, tea.Cmd) {
}

func (m *AppModel) handlePullSuccessMsg(msg pullSuccessMsg) (tea.Model, tea.Cmd) {
m.message = fmt.Sprintf("Successfully pulled model: %s\n", msg.modelName)
return m, nil
m.pulling = false
m.newModelPull = false
m.pullProgress = 0
m.message = fmt.Sprintf("Successfully pulled model: %s", msg.modelName)
return m, tea.Batch(
m.refreshModelsAfterPull(),
func() tea.Msg {
// This will force a refresh of the main view
return tea.WindowSizeMsg{Width: m.width, Height: m.height}
},
)
}

func (m *AppModel) handlePullErrorMsg(msg pullErrorMsg) (tea.Model, tea.Cmd) {
logging.ErrorLogger.Printf("Error pulling model: %v\n", msg.err)
m.message = fmt.Sprintf("Error pulling model: %v\n", msg.err)
return m, nil
m.pulling = false
m.pullProgress = 0
m.message = fmt.Sprintf("Error pulling model: %v", msg.err)
return m, func() tea.Msg {
// This will force a refresh of the main view
return tea.WindowSizeMsg{Width: m.width, Height: m.height}
}
}

func (m *AppModel) updateProgressCmd() tea.Cmd {
return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg {
return progressMsg{
modelName: m.pullInput.Value(),
progress: m.pullProgress,
}
})
}

func (m *AppModel) handleGenericMsg(msg genericMsg) (tea.Model, tea.Cmd) {
Expand Down Expand Up @@ -552,15 +621,59 @@ func (m *AppModel) handlePushModelKey() (tea.Model, tea.Cmd) {
}

func (m *AppModel) handlePullModelKey() (tea.Model, tea.Cmd) {
// TODO: Add progress bar
logging.DebugLogger.Println("PullModel key matched")
if item, ok := m.list.SelectedItem().(Model); ok {
m.message = lipgloss.NewStyle().Foreground(lipgloss.Color("129")).Render(fmt.Sprintf("Pulling model: %s\n", item.Name))
m.pulling = true
m.pullProgress = 0
return m, m.startPullModel(item.Name)
}
return m, nil
}

func (m *AppModel) handlePullNewModelKey() (tea.Model, tea.Cmd) {
m.pullInput = textinput.New()
m.pullInput.Placeholder = "Enter model name (e.g. llama3:8b-instruct)"
m.pullInput.Focus()
m.pulling = true
m.newModelPull = true
return m, textinput.Blink
}

func (m *AppModel) updatePullInput(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.pullInput, cmd = m.pullInput.Update(msg)

switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.Type {
case tea.KeyEnter:
return m, m.startPullModel(m.pullInput.Value())
case tea.KeyEsc:
m.pulling = false
m.pullInput.Reset()
return m, nil
}
}

return m, cmd
}

func (m *AppModel) startPullNewModel(modelName string) tea.Cmd {
return func() tea.Msg {
ctx := context.Background()
req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
m.pullProgress = float64(resp.Completed) / float64(resp.Total)
return nil
})
if err != nil {
return pullErrorMsg{err}
}
return pullSuccessMsg{modelName}
}
}

func (m *AppModel) handleInspectModelKey() (tea.Model, tea.Cmd) {
logging.DebugLogger.Println("InspectModel key matched")
selectedItem := m.list.SelectedItem()
Expand Down Expand Up @@ -629,6 +742,22 @@ func (m *AppModel) View() string {
return m.filterView()
}

if m.pulling {
if m.newModelPull && m.pullProgress == 0 {
return fmt.Sprintf(
"%s\n%s",
"Enter model name to pull:",
m.pullInput.View(),
)
}
return fmt.Sprintf(
"Pulling model: %.0f%%\n%s\n%s",
m.pullProgress*100,
m.progress.ViewAs(m.pullProgress),
"Press Ctrl+C to cancel",
)
}

view := m.list.View()

if m.message != "" {
Expand Down Expand Up @@ -827,3 +956,17 @@ func (m *AppModel) printFullHelp() string {
return "\n" + t.View() + "\nPress 'q' or `esc` to return to the main view."

}

// Add this method to refresh the model list after pulling:
func (m *AppModel) refreshModelsAfterPull() tea.Cmd {
return func() tea.Msg {
ctx := context.Background()
resp, err := m.client.List(ctx)
if err != nil {
return pullErrorMsg{err}
}
m.models = parseAPIResponse(resp)
m.refreshList()
return nil
}
}
2 changes: 2 additions & 0 deletions keymap.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type KeyMap struct {
UnloadModels key.Binding
Help key.Binding
RenameModel key.Binding
PullNewModel key.Binding
SortOrder string
}

Expand All @@ -54,6 +55,7 @@ func NewKeyMap() *KeyMap {
LinkModel: key.NewBinding(key.WithKeys("l"), key.WithHelp("l", "link (L=all)")),
PushModel: key.NewBinding(key.WithKeys("P"), key.WithHelp("P", "push")),
PullModel: key.NewBinding(key.WithKeys("p"), key.WithHelp("p", "pull")),
PullNewModel: key.NewBinding(key.WithKeys("g"), key.WithHelp("g", "get")),
Quit: key.NewBinding(key.WithKeys("q")),
RunModel: key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "run")),
SortByFamily: key.NewBinding(key.WithKeys("f"), key.WithHelp("f", "^family")),
Expand Down
12 changes: 11 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/progress"
"github.com/charmbracelet/bubbles/table"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/ollama/ollama/api"
Expand Down Expand Up @@ -48,12 +49,16 @@ type AppModel struct {
altScreenActive bool
view View
showProgress bool
remoteHost bool
pullInput textinput.Model
pulling bool
pullProgress float64
newModelPull bool
}

// TODO: Refactor: we don't need unique message types for every single action
type progressMsg struct {
modelName string
progress float64
}

type runFinishedMessage struct{ err error }
Expand All @@ -78,6 +83,8 @@ type genericMsg struct {
message string
}

type View int

var Version string // Version is set by the build system

func main() {
Expand Down Expand Up @@ -199,6 +206,9 @@ func main() {
noCleanup: *noCleanupFlag,
cfg: &cfg,
progress: progress.New(progress.WithDefaultGradient()),
pullInput: textinput.New(),
pulling: false,
pullProgress: 0,
}

if *ollamaDirFlag == "" {
Expand Down
33 changes: 25 additions & 8 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,27 @@ func (m *AppModel) startPushModel(modelName string) tea.Cmd {
}

func (m *AppModel) startPullModel(modelName string) tea.Cmd {
logging.InfoLogger.Printf("Pulling model: %s\n", modelName)
return m.pullModelCmd(modelName)
return func() tea.Msg {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
if !m.pulling {
return context.Canceled
}
m.pullProgress = float64(resp.Completed) / float64(resp.Total)
return nil
})

if err == context.Canceled {
return pullErrorMsg{fmt.Errorf("pull cancelled")}
}
if err != nil {
return pullErrorMsg{err}
}
return pullSuccessMsg{modelName}
}
}

func (m *AppModel) pushModelCmd(modelName string) tea.Cmd {
Expand All @@ -114,15 +133,13 @@ func (m *AppModel) pullModelCmd(modelName string) tea.Cmd {
ctx := context.Background()
req := &api.PullRequest{Name: modelName}
err := m.client.Pull(ctx, req, func(resp api.ProgressResponse) error {
m.progress.SetPercent(float64(resp.Completed) / float64(resp.Total))
m.pullProgress = float64(resp.Completed) / float64(resp.Total)
return nil
})
// If the progress is 100%, the model has been successfully pulled, return a success message
if err == nil {
return pullSuccessMsg{modelName}
if err != nil {
return pullErrorMsg{err}
}
// If the progress is not 100%, the model has not been successfully pulled, return an error message
return pullErrorMsg{err}
return pullSuccessMsg{modelName}
}
}

Expand Down

0 comments on commit c4829e0

Please sign in to comment.