diff --git a/api/platform/validation/machine.go b/api/platform/validation/machine.go index 37f8b2e62..069ab547e 100644 --- a/api/platform/validation/machine.go +++ b/api/platform/validation/machine.go @@ -23,6 +23,7 @@ import ( "fmt" "math" "net" + "strings" "time" "k8s.io/apimachinery/pkg/api/errors" @@ -137,9 +138,6 @@ func ValidateWorkerTimeOffset(fldPath *field.Path, worker *ssh.SSH, masters []*s func ValidateSSH(fldPath *field.Path, ip string, port int, user string, password []byte, privateKey []byte, passPhrase []byte) field.ErrorList { allErrs := field.ErrorList{} - if user != "root" { - allErrs = append(allErrs, field.Invalid(fldPath.Child("user"), user, "must be root")) - } for _, msg := range validation.IsValidIP(ip) { allErrs = append(allErrs, field.Invalid(fldPath.Child("ip"), ip, msg)) @@ -169,10 +167,13 @@ func ValidateSSH(fldPath *field.Path, ip string, port int, user string, password if err != nil { allErrs = append(allErrs, field.Invalid(fldPath, "", err.Error())) } else { - err = s.Ping() + output, err := s.CombinedOutput("whoami") if err != nil { allErrs = append(allErrs, field.Invalid(fldPath, "", err.Error())) } + if strings.TrimSpace(string(output)) != "root" { + allErrs = append(allErrs, field.Invalid(fldPath.Child("user"), user, `must be root or set sudo without password`)) + } } return allErrs diff --git a/pkg/platform/provider/baremetal/cluster/create.go b/pkg/platform/provider/baremetal/cluster/create.go index 2ee1e0441..7ca0d7066 100644 --- a/pkg/platform/provider/baremetal/cluster/create.go +++ b/pkg/platform/provider/baremetal/cluster/create.go @@ -244,7 +244,7 @@ func (p *Provider) EnsureDisableSwap(ctx context.Context, c *v1.Cluster) error { return err } - _, err = machineSSH.CombinedOutput("swapoff -a && sed -i 's/^[^#]*swap/#&/' /etc/fstab") + _, err = machineSSH.CombinedOutput(`swapoff -a && sed -i "s/^[^#]*swap/#&/" /etc/fstab`) if err != nil { return errors.Wrap(err, machine.IP) } diff --git a/pkg/platform/provider/baremetal/machine/create.go b/pkg/platform/provider/baremetal/machine/create.go index 27b8a3b16..4156c9978 100644 --- a/pkg/platform/provider/baremetal/machine/create.go +++ b/pkg/platform/provider/baremetal/machine/create.go @@ -229,7 +229,7 @@ func (p *Provider) EnsureDisableSwap(ctx context.Context, machine *platformv1.Ma return err } - _, err = machineSSH.CombinedOutput("swapoff -a && sed -i 's/^[^#]*swap/#&/' /etc/fstab") + _, err = machineSSH.CombinedOutput(`swapoff -a && sed -i "s/^[^#]*swap/#&/" /etc/fstab`) if err != nil { return err } diff --git a/pkg/platform/provider/baremetal/phases/kubelet/kubelet.go b/pkg/platform/provider/baremetal/phases/kubelet/kubelet.go index a1357d6a5..fa85ad47d 100644 --- a/pkg/platform/provider/baremetal/phases/kubelet/kubelet.go +++ b/pkg/platform/provider/baremetal/phases/kubelet/kubelet.go @@ -38,7 +38,7 @@ func Install(s ssh.Interface, version string) (err error) { for _, file := range []string{"kubelet", "kubectl"} { file = path.Join(constants.DstBinDir, file) - if _, err := s.Stat(file); err == nil { + if ok, err := s.Exist(file); err == nil && ok { backupFile, err := ssh.BackupFile(s, file) if err != nil { return fmt.Errorf("backup file %q error: %w", file, err) diff --git a/pkg/platform/provider/baremetal/preflight/checks.go b/pkg/platform/provider/baremetal/preflight/checks.go index 610106254..ee9d4e218 100644 --- a/pkg/platform/provider/baremetal/preflight/checks.go +++ b/pkg/platform/provider/baremetal/preflight/checks.go @@ -242,8 +242,7 @@ func (fac FileAvailableCheck) Name() string { // Check validates if the given file does not already exist. func (fac FileAvailableCheck) Check() (warnings, errorList []error) { - - if _, err := fac.Stat(fac.Path); err == nil { + if ok, err := fac.Exist(fac.Path); err == nil && ok { errorList = append(errorList, errors.Errorf("%s already exists", fac.Path)) } return nil, errorList @@ -365,8 +364,7 @@ func (dac DirAvailableCheck) Name() string { // Check validates if a directory does not exist or empty. func (dac DirAvailableCheck) Check() (warnings, errorList []error) { - - if _, err := dac.Stat(dac.Path); err == nil { + if ok, err := dac.Exist(dac.Path); err == nil && ok { errorList = append(errorList, errors.Errorf("%s already exists", dac.Path)) } diff --git a/pkg/util/cmdstring/cmdstring.go b/pkg/util/cmdstring/cmdstring.go index 8030dbc26..23725ecbe 100644 --- a/pkg/util/cmdstring/cmdstring.go +++ b/pkg/util/cmdstring/cmdstring.go @@ -22,7 +22,7 @@ import "fmt" // SetFileContent generates cmd for set file content. func SetFileContent(file, pattern, content string) string { - return fmt.Sprintf("grep -Pq '%s' %s && sed -i 's;%s;%s;g' %s|| echo '%s' >> %s", + return fmt.Sprintf(`grep -Pq "%s" %s && sed -i "s;%s;%s;g" %s|| echo "%s" >> %s`, pattern, file, pattern, content, file, content, file) diff --git a/pkg/util/ssh/helper.go b/pkg/util/ssh/helper.go new file mode 100644 index 000000000..35de79a37 --- /dev/null +++ b/pkg/util/ssh/helper.go @@ -0,0 +1,72 @@ +/* + * Tencent is pleased to support the open source community by making TKEStack + * available. + * + * Copyright (C) 2012-2020 Tencent. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the “License”); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an “AS IS” BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package ssh + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + + "golang.org/x/crypto/ssh" +) + +func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) { + // Create an actual signer. + buffer, err := ioutil.ReadFile(key) + if err != nil { + return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err) + } + return MakePrivateKeySigner(buffer, nil) +} + +func MakePrivateKeySigner(privateKey []byte, passPhrase []byte) (ssh.Signer, error) { + var signer ssh.Signer + var err error + if passPhrase == nil { + signer, err = ssh.ParsePrivateKey(privateKey) + } else { + signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, passPhrase) + } + if err != nil { + return nil, fmt.Errorf("error parsing SSH key: '%v'", err) + } + return signer, nil +} + +func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) { + buffer, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err) + } + keyBlock, _ := pem.Decode(buffer) + if keyBlock == nil { + return nil, fmt.Errorf("error parsing SSH key %s: 'invalid PEM format'", keyFile) + } + key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes) + if err != nil { + return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err) + } + rsaKey, ok := key.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("SSH key could not be parsed as rsa public key") + } + return rsaKey, nil +} diff --git a/pkg/util/ssh/interface.go b/pkg/util/ssh/interface.go new file mode 100644 index 000000000..f9a4ff7f7 --- /dev/null +++ b/pkg/util/ssh/interface.go @@ -0,0 +1,36 @@ +/* + * Tencent is pleased to support the open source community by making TKEStack + * available. + * + * Copyright (C) 2012-2020 Tencent. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the “License”); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an “AS IS” BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package ssh + +import "io" + +type Interface interface { + Ping() error + + CombinedOutput(cmd string) ([]byte, error) + Execf(format string, a ...interface{}) (stdout string, stderr string, exit int, err error) + Exec(cmd string) (stdout string, stderr string, exit int, err error) + + CopyFile(src, dst string) error + WriteFile(src io.Reader, dst string) error + ReadFile(filename string) ([]byte, error) + Exist(filename string) (bool, error) + + LookPath(file string) (string, error) +} diff --git a/pkg/util/ssh/helpers.go b/pkg/util/ssh/os.go similarity index 100% rename from pkg/util/ssh/helpers.go rename to pkg/util/ssh/os.go diff --git a/pkg/util/ssh/ssh.go b/pkg/util/ssh/ssh.go index 7f2b6f130..11f0d1411 100644 --- a/pkg/util/ssh/ssh.go +++ b/pkg/util/ssh/ssh.go @@ -21,9 +21,6 @@ package ssh import ( "bytes" "crypto/md5" - "crypto/rsa" - "crypto/x509" - "encoding/pem" "errors" "fmt" "io" @@ -31,6 +28,7 @@ import ( "net" "os" "path" + "strings" "time" "github.com/pkg/sftp" @@ -40,20 +38,23 @@ import ( "tkestack.io/tke/pkg/util/log" ) +const ( + tmpDir = "/tmp" +) + type SSH struct { - User string - Host string - Port int - addr string + *Config authMethods []ssh.AuthMethod dialer sshDialer - Retry int } +var _ Interface = &SSH{} + type Config struct { User string `validate:"required"` Host string `validate:"required"` Port int `validate:"required"` + Sudo bool Password string PrivateKey []byte PassPhrase []byte @@ -63,18 +64,8 @@ type Config struct { Retry int } -type Interface interface { - Ping() error - Exec(cmd string) (stdout string, stderr string, exit int, err error) - Execf(format string, a ...interface{}) (stdout string, stderr string, exit int, err error) - CombinedOutput(cmd string) ([]byte, error) - - CopyFile(src, dst string) error - WriteFile(src io.Reader, dst string) error - ReadFile(filename string) ([]byte, error) - Stat(p string) (os.FileInfo, error) - - LookPath(file string) (string, error) +func (c *Config) addr() string { + return fmt.Sprintf("%s:%d", c.Host, c.Port) } func New(c *Config) (*SSH, error) { @@ -98,20 +89,19 @@ func New(c *Config) (*SSH, error) { } authMethods = append(authMethods, ssh.PublicKeys(signer)) } - addr := fmt.Sprintf("%s:%d", c.Host, c.Port) if c.DialTimeOut == 0 { c.DialTimeOut = 5 * time.Second } + if c.User != "root" { + c.Sudo = true + } + return &SSH{ - User: c.User, - Host: c.Host, - Port: c.Port, - addr: addr, + Config: c, authMethods: authMethods, dialer: &timeoutDialer{&realSSHDialer{}, c.DialTimeOut}, - Retry: c.Retry, }, nil } @@ -137,32 +127,19 @@ func (s *SSH) Execf(format string, a ...interface{}) (stdout string, stderr stri } func (s *SSH) Exec(cmd string) (stdout string, stderr string, exit int, err error) { - log.Debugf("[%s] Exec %q", s.addr, cmd) - // Setup the config, dial the server, and open a session. - config := &ssh.ClientConfig{ - User: s.User, - Auth: s.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - client, err := s.dialer.Dial("tcp", s.addr, config) - if err != nil && s.Retry > 0 { - err = wait.Poll(5*time.Second, time.Duration(s.Retry)*5*time.Second, func() (bool, error) { - if client, err = s.dialer.Dial("tcp", s.addr, config); err != nil { - return false, err - } - return true, nil - }) + if s.Sudo { + cmd = fmt.Sprintf(`sudo bash << EOF +%s +EOF +`, cmd) } - if err != nil { - return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", s.User, s.addr, err) - } - defer client.Close() + log.Debugf("[%s] Exec %q", s.addr(), cmd) - session, err := client.NewSession() + session, closer, err := s.newSession() if err != nil { - return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", s.User, s.addr, err) + return "", "", 0, err } - defer session.Close() + defer closer() // Run the command. code := 0 @@ -180,62 +157,20 @@ func (s *SSH) Exec(cmd string) (stdout string, stderr string, exit int, err erro } else { // Some other kind of error happened (e.g. an IOError); consider the // SSH unsuccessful. - err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, s.User, s.addr, err) + err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, s.User, s.addr(), err) } } return bout.String(), berr.String(), code, err } func (s *SSH) CopyFile(src, dst string) error { - data, err := ioutil.ReadFile(src) + file, err := os.Open(src) if err != nil { return err } - needWriteFile, err := s.needWriteFile(data, dst) - if err != nil { - return err - } - if !needWriteFile { - log.Debugf("[%s] Skip copy %q because already existed", s.addr, src) - return nil - } - - log.Debugf("[%s] Copy %q to %q", s.addr, src, dst) + defer file.Close() - config := &ssh.ClientConfig{ - User: s.User, - Auth: s.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - client, err := s.dialer.Dial("tcp", s.addr, config) - if err != nil { - err = wait.Poll(5*time.Second, time.Duration(s.Retry)*5*time.Second, func() (bool, error) { - if client, err = s.dialer.Dial("tcp", s.addr, config); err != nil { - return false, err - } - return true, nil - }) - } - if err != nil { - return err - } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) - if err != nil { - return err - } - defer sftpClient.Close() - - sftpClient.MkdirAll(path.Dir(dst)) - dstFile, err := sftpClient.Create(dst) - if err != nil { - return fmt.Errorf("create file error:%s:%s", dst, err) - } - defer dstFile.Close() - - _, err = dstFile.ReadFrom(bytes.NewBuffer(data)) - return err + return s.WriteFile(file, dst) } func (s *SSH) WriteFile(src io.Reader, dst string) error { @@ -248,40 +183,46 @@ func (s *SSH) WriteFile(src io.Reader, dst string) error { return err } if !needWriteFile { - log.Debugf("[%s] Skip write %q because already existed", s.addr, dst) + log.Debugf("[%s] Skip write %q because already existed", s.addr(), dst) return nil } return s.writeFile(bytes.NewBuffer(data), dst) } -func (s *SSH) writeFile(src io.Reader, dst string) error { - log.Debugf("[%s] Write data to %q", s.addr, dst) +func (s *SSH) ReadFile(filename string) ([]byte, error) { + return s.CombinedOutput(fmt.Sprintf("cat %s", filename)) +} - config := &ssh.ClientConfig{ - User: s.User, - Auth: s.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - client, err := s.dialer.Dial("tcp", s.addr, config) +func (s *SSH) Exist(filename string) (bool, error) { + _, _, exit, err := s.Execf("ls %s", filename) if err != nil { - err = wait.Poll(5*time.Second, time.Duration(s.Retry)*5*time.Second, func() (bool, error) { - if client, err = s.dialer.Dial("tcp", s.addr, config); err != nil { - return false, err - } - return true, nil - }) + return false, fmt.Errorf("ssh exec error: %w", err) } + + return exit == 0, nil +} + +func (s *SSH) LookPath(file string) (string, error) { + data, err := s.CombinedOutput(fmt.Sprintf("which %s", file)) + return string(data), err +} + +func (s *SSH) writeFile(src io.Reader, dst string) error { + log.Debugf("[%s] Write data to %q", s.addr(), dst) + + sftpClient, closer, err := s.newSFTPClient() if err != nil { return err } - defer client.Close() + defer closer() - sftpClient, err := sftp.NewClient(client) - if err != nil { - return err + needMove := false + realDst := dst + if !strings.HasPrefix(dst, tmpDir) { + needMove = true + dst = path.Join(tmpDir, dst) } - defer sftpClient.Close() err = sftpClient.MkdirAll(path.Dir(dst)) if err != nil { @@ -294,13 +235,23 @@ func (s *SSH) writeFile(src io.Reader, dst string) error { defer dstFile.Close() _, err = dstFile.ReadFrom(src) + if err != nil { + return err + } + if needMove { + _, err = s.CombinedOutput(fmt.Sprintf("mkdir -p $(dirname %s); mv %s %s", realDst, dst, realDst)) + if err != nil { + return err + } + } + return err } func (s *SSH) needWriteFile(data []byte, dst string) (bool, error) { srcHash := md5.Sum(data) - hashFile := "/tmp" + dst + ".md5" + hashFile := tmpDir + dst + ".md5" buffer := new(bytes.Buffer) buffer.WriteString(fmt.Sprintf("%x %s\n", srcHash, dst)) err := s.writeFile(buffer, hashFile) @@ -316,77 +267,70 @@ func (s *SSH) needWriteFile(data []byte, dst string) (bool, error) { return true, nil } -func (s *SSH) Stat(p string) (os.FileInfo, error) { - config := &ssh.ClientConfig{ - User: s.User, - Auth: s.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), +func (s *SSH) newSFTPClient() (*sftp.Client, func(), error) { + client, closer, err := s.newClient() + if err != nil { + return nil, nil, err } - client, err := s.dialer.Dial("tcp", s.addr, config) + + sftpClient, err := sftp.NewClient(client) if err != nil { - err = wait.Poll(5*time.Second, time.Duration(s.Retry)*5*time.Second, func() (bool, error) { - if client, err = s.dialer.Dial("tcp", s.addr, config); err != nil { - return false, err - } - return true, nil - }) + return nil, nil, err } + + return sftpClient, + func() { + closer() + sftpClient.Close() + }, + nil +} + +// newClient returns ssh session and closer which need defer run! +func (s *SSH) newSession() (*ssh.Session, func(), error) { + client, closer, err := s.newClient() if err != nil { - return nil, err + return nil, nil, err } - defer client.Close() - sftpClient, err := sftp.NewClient(client) + session, err := client.NewSession() if err != nil { - return nil, err + return nil, nil, err } - defer sftpClient.Close() - return sftpClient.Stat(p) + return session, + func() { + closer() + session.Close() + }, + nil } -func (s *SSH) ReadFile(filename string) ([]byte, error) { +// newClient returns ssh client and closer which need defer run! +func (s *SSH) newClient() (*ssh.Client, func(), error) { config := &ssh.ClientConfig{ User: s.User, Auth: s.authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - client, err := s.dialer.Dial("tcp", s.addr, config) + client, err := s.dialer.Dial("tcp", s.addr(), config) if err != nil { err = wait.Poll(5*time.Second, time.Duration(s.Retry)*5*time.Second, func() (bool, error) { - if client, err = s.dialer.Dial("tcp", s.addr, config); err != nil { - return false, fmt.Errorf("read file %s error: %w", filename, err) + if client, err = s.dialer.Dial("tcp", s.addr(), config); err != nil { + return false, err } return true, nil }) } if err != nil { - return nil, fmt.Errorf("read file %s error: %w", filename, err) - } - defer client.Close() - - sftpClient, err := sftp.NewClient(client) - if err != nil { - return nil, fmt.Errorf("read file %s error: %w", filename, err) - } - defer sftpClient.Close() - - f, err := sftpClient.Open(filename) - if err != nil { - return nil, fmt.Errorf("read file %s error: %w", filename, err) - } - data := new(bytes.Buffer) - _, err = f.WriteTo(data) - if err != nil { - return nil, fmt.Errorf("read file %s error: %w", filename, err) + return nil, nil, err } - return data.Bytes(), nil -} - -func (s *SSH) LookPath(file string) (string, error) { - data, err := s.CombinedOutput(fmt.Sprintf("which %s", file)) - return string(data), err + return client, + func() { + client.Close() + }, + nil } // Interface to allow mocking of ssh.Dial, for testing SSH @@ -426,46 +370,3 @@ func (d *timeoutDialer) Dial(network, addr string, config *ssh.ClientConfig) (*s config.Timeout = d.timeout return d.dialer.Dial(network, addr, config) } - -func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) { - // Create an actual signer. - buffer, err := ioutil.ReadFile(key) - if err != nil { - return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err) - } - return MakePrivateKeySigner(buffer, nil) -} - -func MakePrivateKeySigner(privateKey []byte, passPhrase []byte) (ssh.Signer, error) { - var signer ssh.Signer - var err error - if passPhrase == nil { - signer, err = ssh.ParsePrivateKey(privateKey) - } else { - signer, err = ssh.ParsePrivateKeyWithPassphrase(privateKey, passPhrase) - } - if err != nil { - return nil, fmt.Errorf("error parsing SSH key: '%v'", err) - } - return signer, nil -} - -func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) { - buffer, err := ioutil.ReadFile(keyFile) - if err != nil { - return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err) - } - keyBlock, _ := pem.Decode(buffer) - if keyBlock == nil { - return nil, fmt.Errorf("error parsing SSH key %s: 'invalid PEM format'", keyFile) - } - key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes) - if err != nil { - return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err) - } - rsaKey, ok := key.(*rsa.PublicKey) - if !ok { - return nil, fmt.Errorf("SSH key could not be parsed as rsa public key") - } - return rsaKey, nil -} diff --git a/pkg/util/ssh/ssh_test.go b/pkg/util/ssh/ssh_test.go new file mode 100644 index 000000000..873fc3ddb --- /dev/null +++ b/pkg/util/ssh/ssh_test.go @@ -0,0 +1,128 @@ +/* + * Tencent is pleased to support the open source community by making TKEStack + * available. + * + * Copyright (C) 2012-2020 Tencent. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the “License”); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an “AS IS” BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package ssh_test + +import ( + "bytes" + "io/ioutil" + "os" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "tkestack.io/tke/pkg/util/ssh" + + // env for load env + _ "tkestack.io/tke/test/util/env" +) + +var s *ssh.SSH + +func init() { + port, err := strconv.Atoi(os.Getenv("SSH_PORT")) + utilruntime.Must(err) + s, _ = ssh.New(&ssh.Config{ + Host: os.Getenv("SSH_HOST"), + Port: port, + User: os.Getenv("SSH_USER"), + Password: os.Getenv("SSH_PASSWORD"), + }) +} + +func TestSudo(t *testing.T) { + output, err := s.CombinedOutput("whoami") + assert.Nil(t, err) + assert.Equal(t, "root", strings.TrimSpace(string(output))) +} + +func TestQuote(t *testing.T) { + output, err := s.CombinedOutput(`echo "a" 'b'`) + assert.Nil(t, err) + assert.Equal(t, "a b", strings.TrimSpace(string(output))) +} + +func TestWriteFile(t *testing.T) { + data := []byte("Hello") + dst := "/tmp/test" + + err := s.WriteFile(bytes.NewBuffer(data), dst) + assert.Nil(t, err) + + output, err := s.ReadFile(dst) + assert.Nil(t, err) + assert.Equal(t, data, output) +} + +func TestCoppyFile(t *testing.T) { + src := os.Args[0] + srcData, err := ioutil.ReadFile(src) + assert.Nil(t, err) + + dst := "/tmp/test" + err = s.CopyFile(src, dst) + assert.Nil(t, err) + + output, err := s.ReadFile(dst) + assert.Nil(t, err) + + assert.Equal(t, srcData, output) +} + +func TestExist(t *testing.T) { + type args struct { + filename string + } + tests := []struct { + name string + args args + want bool + wantErr bool + }{ + { + "exist", + args{ + filename: "/tmp", + }, + true, + false, + }, + { + "not exist", + args{ + filename: "/tmpfda", + }, + false, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Exist(tt.args.filename) + if (err != nil) != tt.wantErr { + t.Errorf("Exist() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Exist() got = %v, want %v", got, tt.want) + } + }) + } +}