Skip to content

Commit

Permalink
context support for command cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
antonsergeyev committed Oct 1, 2021
1 parent bce9642 commit 87a38e1
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 7 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ You can find the docs at [go docs](https://pkg.go.dev/github.com/melbahja/goph).
- Supports connections with **ssh agent** (Unix systems only).
- Supports adding new hosts to **known_hosts file**.
- Supports **file system operations** like: `Open, Create, Chmod...`
- Supports **context.Context** for command cancellation.

## 📄  Usage

Expand Down Expand Up @@ -119,6 +120,14 @@ err := client.Download("/path/to/remote/file", "/path/to/local/file")
out, err := client.Run("bash -c 'printenv'")
```

#### ☛ Execute Bash Command with timeout:
```go
context, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
// will send SIGINT and return error after 1 second
out, err := client.RunContext(ctx, "sleep 5")
```

#### ☛ Execute Bash Command With Env Variables:
```go
out, err := client.Run(`env MYVAR="MY VALUE" bash -c 'echo $MYVAR;'`)
Expand All @@ -132,6 +141,9 @@ out, err := client.Run(`env MYVAR="MY VALUE" bash -c 'echo $MYVAR;'`)
// Get new `Goph.Cmd`
cmd, err := client.Command("ls", "-alh", "/tmp")

// or with context:
// cmd, err := client.CommandContext(ctx, "ls", "-alh", "/tmp")

if err != nil {
// handle the error!
}
Expand Down
24 changes: 24 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package goph

import (
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -106,6 +107,16 @@ func (c Client) Run(cmd string) ([]byte, error) {
return sess.CombinedOutput(cmd)
}

// Run starts a new SSH session with context and runs the cmd. It returns CombinedOutput and err if any.
func (c Client) RunContext(ctx context.Context, name string) ([]byte, error) {
cmd, err := c.CommandContext(ctx, name)
if err != nil {
return nil, err
}

return cmd.CombinedOutput()
}

// Command returns new Cmd and error if any.
func (c Client) Command(name string, args ...string) (*Cmd, error) {

Expand All @@ -122,9 +133,22 @@ func (c Client) Command(name string, args ...string) (*Cmd, error) {
Path: name,
Args: args,
Session: sess,
Context: context.Background(),
}, nil
}

// Command returns new Cmd with context and error, if any.
func (c Client) CommandContext(ctx context.Context, name string, args ...string) (*Cmd, error) {
cmd, err := c.Command(name, args...)
if err != nil {
return cmd, err
}

cmd.Context = ctx

return cmd, nil
}

// NewSftp returns new sftp client and error if any.
func (c Client) NewSftp(opts ...sftp.ClientOption) (*sftp.Client, error) {
return sftp.NewClient(c.Client, opts...)
Expand Down
49 changes: 44 additions & 5 deletions cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
package goph

import (
"context"
"fmt"
"strings"

"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"strings"
)

// Cmd it's like os/exec.Cmd but for ssh session.
Expand All @@ -25,30 +25,44 @@ type Cmd struct {

// SSH session.
*ssh.Session

// Context for cancellation
Context context.Context
}

// CombinedOutput runs cmd on the remote host and returns its combined stdout and stderr.
func (c *Cmd) CombinedOutput() ([]byte, error) {
if err := c.init(); err != nil {
return nil, errors.Wrap(err, "cmd init")
}
return c.Session.CombinedOutput(c.String())

return c.runWithContext(func() ([]byte, error) {
return c.Session.CombinedOutput(c.String())
})
}

// Output runs cmd on the remote host and returns its stdout.
func (c *Cmd) Output() ([]byte, error) {
if err := c.init(); err != nil {
return nil, errors.Wrap(err, "cmd init")
}
return c.Session.Output(c.String())

return c.runWithContext(func() ([]byte, error) {
return c.Session.Output(c.String())
})
}

// Run runs cmd on the remote host.
func (c *Cmd) Run() error {
if err := c.init(); err != nil {
return errors.Wrap(err, "cmd init")
}
return c.Session.Run(c.String())

_, err := c.runWithContext(func() ([]byte, error) {
return nil, c.Session.Run(c.String())
})

return err
}

// Start runs the command on the remote host.
Expand Down Expand Up @@ -78,3 +92,28 @@ func (c *Cmd) init() (err error) {

return nil
}

// Executes the given callback within session. Sends SIGINT when the context is canceled.
func (c *Cmd) runWithContext(callback func() ([]byte, error)) ([]byte, error) {
type commandOutput struct {
output []byte
err error
}
outputChan := make(chan commandOutput)
go func() {
output, err := callback()
outputChan <- commandOutput{
output: output,
err: err,
}
}()

select {
case <-c.Context.Done():
_ = c.Session.Signal(ssh.SIGINT)

return nil, c.Context.Err()
case result := <-outputChan:
return result.output, result.err
}
}
17 changes: 15 additions & 2 deletions examples/goph/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package main

import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"log"
"net"
"os"
"strings"
"time"

"github.com/melbahja/goph"
"github.com/pkg/sftp"
Expand All @@ -28,8 +30,10 @@ import (
// Run command and auth with private key and passphrase:
// > go run main.go --ip 192.168.122.102 --passphrase --cmd ls
//
// Run a command and interrupt it after 1 second:
// > go run main.go --ip 192.168.122.102 --cmd "sleep 10" --timeout=1s
//
// You can test with the interactive mode without passing --cmd falg.
// You can test with the interactive mode without passing --cmd flag.
//

var (
Expand All @@ -43,6 +47,7 @@ var (
cmd string
pass bool
passphrase bool
timeout time.Duration
agent bool
sftpc *sftp.Client
)
Expand All @@ -57,6 +62,7 @@ func init() {
flag.BoolVar(&pass, "pass", false, "ask for ssh password instead of private key.")
flag.BoolVar(&agent, "agent", false, "use ssh agent for authentication (unix systems only).")
flag.BoolVar(&passphrase, "passphrase", false, "ask for private key passphrase.")
flag.DurationVar(&timeout, "timeout", 0, "interrupt a command with SIGINT after a given timeout (0 means no timeout)")
}

func VerifyHost(host string, remote net.Addr, key ssh.PublicKey) error {
Expand Down Expand Up @@ -135,8 +141,15 @@ func main() {

// If the cmd flag exists
if cmd != "" {
ctx := context.Background()
// create a context with timeout, if supplied in the argumetns
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

out, err := client.Run(cmd)
out, err := client.RunContext(ctx, cmd)

fmt.Println(string(out), err)
return
Expand Down

0 comments on commit 87a38e1

Please sign in to comment.