Skip to content

Commit

Permalink
fix: persist default provider deletion (#1288)
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Dagelic <[email protected]>
  • Loading branch information
idagelic authored Nov 1, 2024
1 parent c0b2bad commit 05aa014
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 45 deletions.
4 changes: 2 additions & 2 deletions pkg/api/controllers/provider/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ func InstallProvider(ctx *gin.Context) {
}
}

downloadPath, err := server.ProviderManager.DownloadProvider(ctx.Request.Context(), req.DownloadUrls, req.Name, true)
downloadPath, err := server.ProviderManager.DownloadProvider(ctx.Request.Context(), req.DownloadUrls, req.Name)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to download provider: %w", err))
return
}

err = server.ProviderManager.RegisterProvider(downloadPath)
err = server.ProviderManager.RegisterProvider(downloadPath, true)
if err != nil {
ctx.AbortWithError(http.StatusInternalServerError, fmt.Errorf("failed to register provider: %w", err))
return
Expand Down
22 changes: 22 additions & 0 deletions pkg/provider/manager/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 Daytona Platforms Inc.
// SPDX-License-Identifier: Apache-2.0

package manager

import "fmt"

func IsProviderAlreadyDownloaded(err error, name string) bool {
return err.Error() == providerAlreadyDownloadedError(name).Error()
}

func IsNoPluginFound(err error, dir string) bool {
return err.Error() == noPluginFoundError(dir).Error()
}

func providerAlreadyDownloadedError(name string) error {
return fmt.Errorf("provider %s already installed", name)
}

func noPluginFoundError(dir string) error {
return fmt.Errorf("no plugin found in %s", dir)
}
7 changes: 2 additions & 5 deletions pkg/provider/manager/installer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,14 @@ func (m *ProviderManager) GetProvidersManifest() (*ProvidersManifest, error) {
return &manifest, nil
}

func (m *ProviderManager) DownloadProvider(ctx context.Context, downloadUrls map[os.OperatingSystem]string, providerName string, throwIfPresent bool) (string, error) {
func (m *ProviderManager) DownloadProvider(ctx context.Context, downloadUrls map[os.OperatingSystem]string, providerName string) (string, error) {
downloadPath := filepath.Join(m.baseDir, providerName, providerName)
if runtime.GOOS == "windows" {
downloadPath += ".exe"
}

if _, err := goos.Stat(downloadPath); err == nil {
if throwIfPresent {
return "", fmt.Errorf("provider %s already downloaded", providerName)
}
return "", nil
return "", providerAlreadyDownloadedError(providerName)
}

log.Info("Downloading " + providerName)
Expand Down
87 changes: 59 additions & 28 deletions pkg/provider/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package manager
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
Expand All @@ -22,6 +23,8 @@ import (
log "github.com/sirupsen/logrus"
)

const INITIAL_SETUP_LOCK_FILE_NAME = "initial-setup.lock"

type pluginRef struct {
client *plugin.Client
path string
Expand All @@ -35,11 +38,11 @@ var ProviderHandshakeConfig = plugin.HandshakeConfig{
}

type IProviderManager interface {
DownloadProvider(ctx context.Context, downloadUrls map[os_util.OperatingSystem]string, providerName string, throwIfPresent bool) (string, error)
DownloadProvider(ctx context.Context, downloadUrls map[os_util.OperatingSystem]string, providerName string) (string, error)
GetProvider(name string) (*Provider, error)
GetProviders() map[string]Provider
GetProvidersManifest() (*ProvidersManifest, error)
RegisterProvider(pluginPath string) error
RegisterProvider(pluginPath string, manualInstall bool) error
TerminateProviderProcesses(providersBasePath string) error
UninstallProvider(name string) error
Purge() error
Expand Down Expand Up @@ -129,44 +132,48 @@ func (m *ProviderManager) GetProviders() map[string]Provider {
return providers
}

func (m *ProviderManager) RegisterProvider(pluginPath string) error {
func (m *ProviderManager) RegisterProvider(pluginPath string, manualInstall bool) error {
pluginRef, err := m.initializeProvider(pluginPath)
if err != nil {
return err
}

m.pluginRefs[pluginRef.name] = pluginRef

p, err := m.dispenseProvider(pluginRef.client, pluginRef.name)
if err != nil {
return err
}

existingTargets, err := m.providerTargetService.Map()
if err != nil {
return errors.New("failed to get targets: " + err.Error())
}

presetTargets, err := (*p).GetPresetTargets()
if err != nil {
return errors.New("failed to get preset targets: " + err.Error())
}
lockFilePath := filepath.Join(pluginRef.path, INITIAL_SETUP_LOCK_FILE_NAME)
_, err = os.Stat(lockFilePath)
if os.IsNotExist(err) || manualInstall {
p, err := m.GetProvider(pluginRef.name)
if err != nil {
return fmt.Errorf("failed to get provider: %w", err)
}

log.Info("Setting preset targets")
for _, target := range *presetTargets {
if _, ok := existingTargets[target.Name]; ok {
log.Infof("Target %s already exists. Skipping...", target.Name)
continue
existingTargets, err := m.providerTargetService.Map()
if err != nil {
return errors.New("failed to get targets: " + err.Error())
}

err := m.providerTargetService.Save(&target)
presetTargets, err := (*p).GetPresetTargets()
if err != nil {
log.Errorf("Failed to set target %s: %s", target.Name, err)
} else {
log.Infof("Target %s set", target.Name)
return errors.New("failed to get preset targets: " + err.Error())
}

log.Infof("Setting preset targets for %s", pluginRef.name)
for _, target := range *presetTargets {
if _, ok := existingTargets[target.Name]; ok {
log.Infof("Target %s already exists. Skipping...", target.Name)
continue
}

err := m.providerTargetService.Save(&target)
if err != nil {
log.Errorf("Failed to set target %s: %s", target.Name, err)
} else {
log.Infof("Target %s set", target.Name)
}
}
log.Infof("Preset targets set for %s", pluginRef.name)
}
log.Info("Preset targets set")

log.Infof("Provider %s initialized", pluginRef.name)

Expand All @@ -180,11 +187,35 @@ func (m *ProviderManager) UninstallProvider(name string) error {
}
pluginRef.client.Kill()

err := os.RemoveAll(pluginRef.path)
lockFileExisted := false
lockFilePath := filepath.Join(pluginRef.path, INITIAL_SETUP_LOCK_FILE_NAME)
_, err := os.Stat(lockFilePath)
if err == nil {
lockFileExisted = true
}

err = os.RemoveAll(pluginRef.path)
if err != nil {
return errors.New("failed to remove provider: " + err.Error())
}

if lockFileExisted {
// After clearing up the contents, remake the directory and add a lock file that
// will be used to ensure that the provider is not reinstalled automatically
err = os.MkdirAll(pluginRef.path, os.ModePerm)
if err != nil {
return err
}

lockFilePath := filepath.Join(pluginRef.path, INITIAL_SETUP_LOCK_FILE_NAME)

file, err := os.Create(lockFilePath)
if err != nil {
return err
}
defer file.Close()
}

delete(m.pluginRefs, name)

return nil
Expand Down
47 changes: 37 additions & 10 deletions pkg/server/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"path/filepath"

"github.com/daytonaio/daytona/pkg/provider/manager"
log "github.com/sirupsen/logrus"
)

Expand All @@ -22,9 +23,19 @@ func (s *Server) downloadDefaultProviders() error {

log.Info("Downloading default providers")
for providerName, provider := range defaultProviders {
_, err = s.ProviderManager.DownloadProvider(context.Background(), provider.DownloadUrls, providerName, false)
lockFilePath := filepath.Join(s.config.ProvidersDir, providerName, manager.INITIAL_SETUP_LOCK_FILE_NAME)

_, err := os.Stat(lockFilePath)
if err == nil {
continue
}

_, err = s.ProviderManager.DownloadProvider(context.Background(), provider.DownloadUrls, providerName)
if err != nil {
log.Error(err)
if !manager.IsProviderAlreadyDownloaded(err, providerName) {
log.Error(err)
}
continue
}
}

Expand All @@ -41,7 +52,7 @@ func (s *Server) registerProviders() error {
return err
}

files, err := os.ReadDir(s.config.ProvidersDir)
directoryEntries, err := os.ReadDir(s.config.ProvidersDir)
if err != nil {
if os.IsNotExist(err) {
log.Info("No providers found")
Expand All @@ -50,22 +61,38 @@ func (s *Server) registerProviders() error {
return err
}

for _, file := range files {
if file.IsDir() {
pluginPath, err := s.getPluginPath(filepath.Join(s.config.ProvidersDir, file.Name()))
for _, entry := range directoryEntries {
if entry.IsDir() {
providerDir := filepath.Join(s.config.ProvidersDir, entry.Name())

pluginPath, err := s.getPluginPath(providerDir)
if err != nil {
log.Error(err)
if !manager.IsNoPluginFound(err, providerDir) {
log.Error(err)
}
continue
}

err = s.ProviderManager.RegisterProvider(pluginPath)
err = s.ProviderManager.RegisterProvider(pluginPath, false)
if err != nil {
log.Error(err)
continue
}

// Lock the initial setup
lockFilePath := filepath.Join(s.config.ProvidersDir, entry.Name(), manager.INITIAL_SETUP_LOCK_FILE_NAME)

_, err = os.Stat(lockFilePath)
if err != nil {
file, err := os.Create(lockFilePath)
if err != nil {
return err
}
defer file.Close()
}

// Check for updates
provider, err := s.ProviderManager.GetProvider(file.Name())
provider, err := s.ProviderManager.GetProvider(entry.Name())
if err != nil {
log.Error(err)
continue
Expand Down Expand Up @@ -95,7 +122,7 @@ func (s *Server) getPluginPath(dir string) (string, error) {
}

for _, file := range files {
if !file.IsDir() {
if !file.IsDir() && file.Name() != manager.INITIAL_SETUP_LOCK_FILE_NAME {
return filepath.Join(dir, file.Name()), nil
}
}
Expand Down

0 comments on commit 05aa014

Please sign in to comment.