Skip to content

Commit

Permalink
fix: Fixed memoization is unchecked after mutex synchronization. Fixes
Browse files Browse the repository at this point in the history
…#11219 (#11456)

Signed-off-by: shmruin <[email protected]>
  • Loading branch information
shmruin committed Aug 7, 2023
1 parent 545bf38 commit 143d0f5
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 46 deletions.
128 changes: 82 additions & 46 deletions workflow/controller/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1812,52 +1812,6 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

// If memoization is on, check if node output exists in cache
if node == nil && processedTmpl.Memoize != nil {
memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name)
if memoizationCache == nil {
err := fmt.Errorf("cache could not be found or created")
woc.log.WithFields(log.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

entry, err := memoizationCache.Load(ctx, processedTmpl.Memoize.Key)
if err != nil {
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

hit := entry.Hit()
var outputs *wfv1.Outputs
if processedTmpl.Memoize.MaxAge != "" {
maxAge, err := time.ParseDuration(processedTmpl.Memoize.MaxAge)
if err != nil {
err := fmt.Errorf("invalid maxAge: %s", err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}
maxAgeOutputs, ok := entry.GetOutputsWithMaxAge(maxAge)
if !ok {
// The outputs are expired, so this cache entry is not hit
hit = false
}
outputs = maxAgeOutputs
} else {
outputs = entry.GetOutputs()
}

memoizationStatus := &wfv1.MemoizationStatus{
Hit: hit,
Key: processedTmpl.Memoize.Key,
CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name,
}
if hit {
node = woc.initializeCacheHitNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, outputs, memoizationStatus)
} else {
node = woc.initializeCacheNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, memoizationStatus)
}
woc.wf.Status.Nodes.Set(node.ID, *node)
woc.updated = true
}

if node != nil {
if node.Fulfilled() {
woc.controller.syncManager.Release(woc.wf, node.ID, processedTmpl.Synchronization)
Expand Down Expand Up @@ -1905,6 +1859,8 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
return node, err
}

unlockedNode := false

if processedTmpl.Synchronization != nil {
lockAcquired, wfUpdated, msg, err := woc.controller.syncManager.TryAcquire(woc.wf, woc.wf.NodeID(nodeName), processedTmpl.Synchronization)
if err != nil {
Expand All @@ -1920,6 +1876,7 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
// unexpected behavior and is a bug.
panic("bug: GetLockName should not return an error after a call to TryAcquire")
}
woc.log.Infof("Could not acquire lock named: %s", lockName)
return woc.markNodeWaitingForLock(node.Name, lockName.EncodeName())
} else {
woc.log.Infof("Node %s acquired synchronization lock", nodeName)
Expand All @@ -1930,10 +1887,71 @@ func (woc *wfOperationCtx) executeTemplate(ctx context.Context, nodeName string,
return nil, err
}
}
// Set this value to check that this node is using synchronization, and has acquired the lock
unlockedNode = true
}

woc.updated = woc.updated || wfUpdated
}

// Check memoization cache if the node is about to be created, or was created in the past but is only now allowed to run due to acquiring a lock
if processedTmpl.Memoize != nil {
if node == nil || unlockedNode {
memoizationCache := woc.controller.cacheFactory.GetCache(controllercache.ConfigMapCache, processedTmpl.Memoize.Cache.ConfigMap.Name)
if memoizationCache == nil {
err := fmt.Errorf("cache could not be found or created")
woc.log.WithFields(log.Fields{"cacheName": processedTmpl.Memoize.Cache.ConfigMap.Name}).WithError(err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

entry, err := memoizationCache.Load(ctx, processedTmpl.Memoize.Key)
if err != nil {
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}

hit := entry.Hit()
var outputs *wfv1.Outputs
if processedTmpl.Memoize.MaxAge != "" {
maxAge, err := time.ParseDuration(processedTmpl.Memoize.MaxAge)
if err != nil {
err := fmt.Errorf("invalid maxAge: %s", err)
return woc.initializeNodeOrMarkError(node, nodeName, templateScope, orgTmpl, opts.boundaryID, err), err
}
maxAgeOutputs, ok := entry.GetOutputsWithMaxAge(maxAge)
if !ok {
// The outputs are expired, so this cache entry is not hit
hit = false
}
outputs = maxAgeOutputs
} else {
outputs = entry.GetOutputs()
}

memoizationStatus := &wfv1.MemoizationStatus{
Hit: hit,
Key: processedTmpl.Memoize.Key,
CacheName: processedTmpl.Memoize.Cache.ConfigMap.Name,
}
if hit {
if node == nil {
node = woc.initializeCacheHitNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, outputs, memoizationStatus)
} else {
woc.log.Infof("Node %s is using mutex with memoize. Cache is hit.", nodeName)
woc.updateAsCacheHitNode(node, outputs, memoizationStatus)
}
} else {
if node == nil {
node = woc.initializeCacheNode(nodeName, processedTmpl, templateScope, orgTmpl, opts.boundaryID, memoizationStatus)
} else {
woc.log.Infof("Node %s is using mutex with memoize. Cache is NOT hit", nodeName)
woc.updateAsCacheNode(node, memoizationStatus)
}
}
woc.wf.Status.Nodes.Set(node.ID, *node)
woc.updated = true
}
}

// If the user has specified retries, node becomes a special retry node.
// This node acts as a parent of all retries that will be done for
// the container. The status of this node should be "Success" if any
Expand Down Expand Up @@ -2395,6 +2413,24 @@ func (woc *wfOperationCtx) initializeNode(nodeName string, nodeType wfv1.NodeTyp
return &node
}

// Update a node status with cache status
func (woc *wfOperationCtx) updateAsCacheNode(node *wfv1.NodeStatus, memStat *wfv1.MemoizationStatus) {
node.MemoizationStatus = memStat

woc.wf.Status.Nodes.Set(node.ID, *node)
woc.updated = true
}

// Update a node status that has been cached and marked as finished
func (woc *wfOperationCtx) updateAsCacheHitNode(node *wfv1.NodeStatus, outputs *wfv1.Outputs, memStat *wfv1.MemoizationStatus, message ...string) {
node.Phase = wfv1.NodeSucceeded
node.Outputs = outputs
node.FinishedAt = metav1.Time{Time: time.Now().UTC()}

woc.updateAsCacheNode(node, memStat)
woc.log.Infof("%s node %v updated %s%s", node.Type, node.ID, node.Phase, message)
}

// markNodePhase marks a node with the given phase, creating the node if necessary and handles timestamps
func (woc *wfOperationCtx) markNodePhase(nodeName string, phase wfv1.NodePhase, message ...string) *wfv1.NodeStatus {
node, err := woc.wf.GetNodeByName(nodeName)
Expand Down
91 changes: 91 additions & 0 deletions workflow/controller/operator_concurrency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package controller
import (
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"strings"
"testing"

"github.com/stretchr/testify/assert"
apiv1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
Expand Down Expand Up @@ -1046,3 +1048,92 @@ func TestSynchronizationForPendingShuttingdownWfs(t *testing.T) {
assert.Equal(t, wfv1.WorkflowRunning, wocTwo.execWf.Status.Phase)
})
}

func TestWorkflowMemoizationWithMutex(t *testing.T) {
wf := wfv1.MustUnmarshalWorkflow(`apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: example-steps-simple
namespace: default
spec:
entrypoint: main
templates:
- name: main
steps:
- - name: job-1
template: sleep
arguments:
parameters:
- name: sleep_duration
value: 10
- name: job-2
template: sleep
arguments:
parameters:
- name: sleep_duration
value: 5
- name: sleep
synchronization:
mutex:
name: mutex-example-steps-simple
inputs:
parameters:
- name: sleep_duration
script:
image: alpine:latest
command: [/bin/sh]
source: |
echo "Sleeping for {{ inputs.parameters.sleep_duration }}"
sleep {{ inputs.parameters.sleep_duration }}
memoize:
key: "memo-key-1"
cache:
configMap:
name: cache-example-steps-simple
`)
cancel, controller := newController(wf)
defer cancel()

ctx := context.Background()

woc := newWorkflowOperationCtx(wf, controller)
woc.operate(ctx)

holdingJobs := make(map[string]string)
for _, node := range woc.wf.Status.Nodes {
holdingJobs[node.ID] = node.DisplayName
}

// Check initial status: job-1 acquired the lock
job1AcquiredLock := false
if woc.wf.Status.Synchronization != nil && woc.wf.Status.Synchronization.Mutex != nil {
for _, holding := range woc.wf.Status.Synchronization.Mutex.Holding {
if holdingJobs[holding.Holder] == "job-1" {
fmt.Println("acquired: ", holding.Holder)
job1AcquiredLock = true
}
}
}
assert.True(t, job1AcquiredLock)

// Make job-1's pod succeed
makePodsPhase(ctx, woc, apiv1.PodSucceeded, func(pod *apiv1.Pod) {
if pod.ObjectMeta.Name == "job-1" {
pod.Status.Phase = apiv1.PodSucceeded
}
})
woc.operate(ctx)

// Check final status: both job-1 and job-2 succeeded, job-2 simply hit the cache
for _, node := range woc.wf.Status.Nodes {
switch node.DisplayName {
case "job-1":
assert.Equal(t, wfv1.NodeSucceeded, node.Phase)
assert.False(t, node.MemoizationStatus.Hit)
case "job-2":
assert.Equal(t, wfv1.NodeSucceeded, node.Phase)
assert.True(t, node.MemoizationStatus.Hit)
}
}
}

0 comments on commit 143d0f5

Please sign in to comment.