Skip to content

Commit

Permalink
fix: Authentication for plugins. Fixes #8144 (#8147)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Collins <[email protected]>
  • Loading branch information
alexec committed Mar 17, 2022
1 parent d4b1afe commit 0e707cd
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 110 deletions.
73 changes: 66 additions & 7 deletions cmd/argoexec/commands/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/client-go/kubernetes"
restclient "k8s.io/client-go/rest"

Expand All @@ -20,9 +22,59 @@ import (
)

func NewAgentCommand() *cobra.Command {
return &cobra.Command{
cmd := cobra.Command{
Use: "agent",
SilenceUsage: true, // this prevents confusing usage message being printed on error
}
cmd.AddCommand(NewAgentInitCommand())
cmd.AddCommand(NewAgentMainCommand())
return &cmd
}

func NewAgentInitCommand() *cobra.Command {
return &cobra.Command{
Use: "init",
Run: func(cmd *cobra.Command, args []string) {
for _, name := range getPluginNames() {
filename := tokenFilename(name)
log.WithField("plugin", name).
WithField("filename", filename).
Info("creating token file for plugin")
if err := os.Mkdir(filepath.Dir(filename), 0o770); err != nil {
log.Fatal(err)
}
token := rand.String(32) // this could have 26^32 ~= 2 x 10^45 possible values, not guessable in reasonable time
if err := os.WriteFile(filename, []byte(token), 0o440); err != nil {
log.Fatal(err)
}
}
},
}
}

func tokenFilename(name string) string {
return filepath.Join(common.VarRunArgoPath, name, "token")
}

func getPluginNames() []string {
var names []string
if err := json.Unmarshal([]byte(os.Getenv(common.EnvVarPluginNames)), &names); err != nil {
log.Fatal(err)
}
return names
}

func getPluginAddresses() []string {
var addresses []string
if err := json.Unmarshal([]byte(os.Getenv(common.EnvVarPluginAddresses)), &addresses); err != nil {
log.Fatal(err)
}
return addresses
}

func NewAgentMainCommand() *cobra.Command {
return &cobra.Command{
Use: "main",
RunE: func(cmd *cobra.Command, args []string) error {
return initAgentExecutor().Agent(context.Background())
},
Expand Down Expand Up @@ -52,13 +104,20 @@ func initAgentExecutor() *executor.AgentExecutor {
log.Fatalf("Unable to determine workflow name from environment variable %s", common.EnvVarWorkflowName)
}

var addresses []string
if err := json.Unmarshal([]byte(os.Getenv(common.EnvVarPluginAddresses)), &addresses); err != nil {
log.Fatal(err)
}
addresses := getPluginAddresses()
names := getPluginNames()
var plugins []executorplugins.TemplateExecutor
for _, address := range addresses {
plugins = append(plugins, rpc.New(address))
for i, address := range addresses {
name := names[i]
filename := tokenFilename(name)
log.WithField("plugin", name).
WithField("filename", filename).
Info("loading token file for plugin")
data, err := os.ReadFile(filename)
if err != nil {
log.Fatal(err)
}
plugins = append(plugins, rpc.New(address, string(data)))
}

return executor.NewAgentExecutor(clientSet, restClient, config, namespace, workflowName, plugins)
Expand Down
2 changes: 1 addition & 1 deletion cmd/argoexec/commands/emissary.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
)

var (
varRunArgo = "/var/run/argo"
varRunArgo = common.VarRunArgoPath
containerName = os.Getenv(common.EnvVarContainerName)
includeScriptOutput = os.Getenv(common.EnvVarIncludeScriptOutput) == "true" // capture stdout/combined
template = &wfv1.Template{}
Expand Down
18 changes: 14 additions & 4 deletions docs/executor_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ We'll need to create a script that starts a HTTP server. Save this as `server.py
```python
import json
from http.server import BaseHTTPRequestHandler, HTTPServer
from pprint import pprint

with open("/var/run/argo/token") as f:
token = f.read().strip()


class Plugin(BaseHTTPRequestHandler):
Expand All @@ -97,16 +99,23 @@ class Plugin(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(json.dumps(reply).encode("UTF-8"))

def forbidden(self):
self.send_response(403)
self.end_headers()

def unsupported(self):
self.send_response(404)
self.end_headers()

def do_POST(self):
if self.path == '/api/v1/template.execute':
if self.headers.get("Authorization") != "Bearer " + token:
self.forbidden()
elif self.path == '/api/v1/template.execute':
args = self.args()
pprint(args)
if 'hello' in args['template'].get('plugin', {}):
self.reply({'node': {'phase': 'Succeeded', 'message': 'Hello template!'}})
self.reply(
{'node': {'phase': 'Succeeded', 'message': 'Hello template!',
'outputs': {'parameters': [{'name': 'foo', 'value': 'bar'}]}}})
else:
self.reply({})
else:
Expand All @@ -125,6 +134,7 @@ Some things to note here:

* You only need to implement the calls you need. Return 404 and it won't be called again.
* The path is the RPC method name.
* You should check that the `Authorization` header contains the same value as `/var/run/argo/token`. Return 403 if not
* The request body contains the template's input parameters.
* The response body may contain the node's result, including the phase (e.g. "Succeeded" or "Failed") and a message.
* If the response is `{}`, then the plugin is saying it cannot execute the plugin template, e.g. it is a Slack plugin,
Expand Down
23 changes: 21 additions & 2 deletions test/e2e/manifests/plugins/hello-executor-plugin-configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,33 @@ data:
- |
import json
from http.server import BaseHTTPRequestHandler, HTTPServer
with open("/var/run/argo/token") as f:
token = f.read().strip()
class Plugin(BaseHTTPRequestHandler):
def args(self):
return json.loads(self.rfile.read(int(self.headers.get('Content-Length'))))
def reply(self, reply):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps(reply).encode("UTF-8"))
def forbidden(self):
self.send_response(403)
self.end_headers()
def unsupported(self):
self.send_response(404)
self.end_headers()
def do_POST(self):
if self.path == '/api/v1/template.execute':
if self.headers.get("Authorization") != "Bearer " + token:
self.forbidden()
elif self.path == '/api/v1/template.execute':
args = self.args()
if 'hello' in args['template'].get('plugin', {}):
self.reply(
Expand All @@ -27,6 +42,8 @@ data:
self.reply({})
else:
self.unsupported()
if __name__ == '__main__':
httpd = HTTPServer(('', 4355), Plugin)
httpd.serve_forever()
Expand Down Expand Up @@ -57,7 +74,9 @@ metadata:
annotations:
workflows.argoproj.io/description: |
This is the "hello world" plugin.
Example:
```yaml
apiVersion: argoproj.io/v1alpha1
kind: Workflow
Expand All @@ -74,4 +93,4 @@ metadata:
creationTimestamp: null
labels:
workflows.argoproj.io/configmap-type: ExecutorPlugin
name: hello-executor-plugin
name: hello-executor-plugin
7 changes: 6 additions & 1 deletion workflow/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ const (
EnvVarNodeID = "ARGO_NODE_ID"
// EnvVarPluginAddresses is a list of plugin addresses
EnvVarPluginAddresses = "ARGO_PLUGIN_ADDRESSES"
// EnvVarPluginNames is a list of plugin names
EnvVarPluginNames = "ARGO_PLUGIN_NAMES"
// EnvVarContainerName container the container's name for the current pod
EnvVarContainerName = "ARGO_CONTAINER_NAME"
// EnvVarDeadline is the deadline for the pod
Expand Down Expand Up @@ -233,8 +235,11 @@ const (
// CACertificatesVolumeMountName is the name of the secret that contains the CA certificates.
CACertificatesVolumeMountName = "argo-workflows-agent-ca-certificates"

// VarRunArgoPath is the standard path for the shared volume
VarRunArgoPath = "/var/run/argo"

// ArgoProgressPath defines the path to a file used for self reporting progress
ArgoProgressPath = "/var/run/argo/progress"
ArgoProgressPath = VarRunArgoPath + "/progress"

// ErrDeadlineExceeded is the pod status reason when exceed deadline
ErrDeadlineExceeded = "DeadlineExceeded"
Expand Down
Loading

0 comments on commit 0e707cd

Please sign in to comment.