Code Diff
diff --git a/pkg/kubelet/nodeshutdown/nodeshutdown_manager.go b/pkg/kubelet/nodeshutdown/nodeshutdown_manager.go
index fa36cd11c1455..2a03130985f81 100644
--- a/pkg/kubelet/nodeshutdown/nodeshutdown_manager.go
+++ b/pkg/kubelet/nodeshutdown/nodeshutdown_manager.go
@@ -126,9 +126,15 @@ func newPodManager(conf *Config) *podManager {
}
// killPods terminates pods by priority.
-func (m *podManager) killPods(activePods []*v1.Pod) error {
+func (m *podManager) killPods(ctx context.Context, activePods []*v1.Pod) error {
groups := groupByPriority(m.shutdownGracePeriodByPodPriority, activePods)
for _, group := range groups {
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
// If there are no pods in a particular range,
// then do not wait for pods in that priority range.
if len(group.Pods) == 0 {
@@ -176,9 +182,9 @@ func (m *podManager) killPods(activePods []*v1.Pod) error {
// to terminate before proceeding to the next group.
var groupTerminationWaitDuration = time.Duration(group.ShutdownGracePeriodSeconds) * time.Second
var (
- doneCh = make(chan struct{})
- timer = m.clock.NewTimer(groupTerminationWaitDuration)
- ctx, ctxCancel = context.WithTimeout(context.Background(), groupTerminationWaitDuration)
+ doneCh = make(chan struct{})
+ timer = m.clock.NewTimer(groupTerminationWaitDuration)
+ groupCtx, ctxCancel = context.WithTimeout(ctx, groupTerminationWaitDuration)
)
go func() {
defer close(doneCh)
@@ -188,7 +194,9 @@ func (m *podManager) killPods(activePods []*v1.Pod) error {
// let's wait until all the volumes are unmounted from all the pods before
// continuing to the next group. This is done so that the CSI Driver (assuming
// that it's part of the highest group) has a chance to perform unmounts.
- if err := m.volumeManager.WaitForAllPodsUnmount(ctx, group.Pods); err != nil {
+ // The wait is derived from the kubelet context so it can stop early if the
+ // parent kubelet context is cancelled while shutdown processing is in progress.
+ if err := m.volumeManager.WaitForAllPodsUnmount(groupCtx, group.Pods); err != nil {
var podIdentifiers []string
for _, pod := range group.Pods {
podIdentifiers = append(podIdentifiers, fmt.Sprintf("%s/%s", pod.Namespace, pod.Name))
diff --git a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux.go b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux.go
index 25e17774bef96..8ef18464c2484 100644
--- a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux.go
+++ b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux.go
@@ -288,9 +288,12 @@ func (m *managerImpl) start(ctx context.Context) (chan struct{}, error) {
if isShuttingDown {
// Update node status and ready condition
- go m.syncNodeStatus(ctx)
+ nodeStatusCtx := klog.NewContext(ctx, m.logger)
+ go m.syncNodeStatus(nodeStatusCtx)
- m.processShutdownEvent()
+ if err := m.processShutdownEvent(ctx); err != nil {
+ m.logger.Error(err, "Shutdown manager failed to process shutdown event")
+ }
} else {
_ = m.acquireInhibitLock()
}
@@ -323,7 +326,7 @@ func (m *managerImpl) ShutdownStatus() error {
return nil
}
-func (m *managerImpl) processShutdownEvent() error {
+func (m *managerImpl) processShutdownEvent(ctx context.Context) error {
m.logger.V(1).Info("Shutdown manager processing shutdown event")
activePods := m.getPods()
@@ -356,5 +359,5 @@ func (m *managerImpl) processShutdownEvent() error {
}()
}
- return m.podManager.killPods(activePods)
+ return m.podManager.killPods(ctx, activePods)
}
diff --git a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux_test.go b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux_test.go
index d6b15f12efc69..85996f12af1a6 100644
--- a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux_test.go
+++ b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_linux_test.go
@@ -46,6 +46,7 @@ import (
"k8s.io/kubernetes/pkg/kubelet/eviction"
"k8s.io/kubernetes/pkg/kubelet/nodeshutdown/systemd"
"k8s.io/kubernetes/pkg/kubelet/volumemanager"
+ testutilsktesting "k8s.io/kubernetes/test/utils/ktesting"
"k8s.io/utils/clock"
testingclock "k8s.io/utils/clock/testing"
)
@@ -597,6 +598,8 @@ func TestStartDoesNotReconnectAfterContextCancel(t *testing.T) {
}
func Test_managerImpl_processShutdownEvent(t *testing.T) {
+ tCtx := testutilsktesting.Init(t)
+
var (
fakeRecorder = &record.FakeRecorder{}
fakeVolumeManager = volumemanager.NewFakeVolumeManager([]v1.UniqueVolumeName{}, 0, nil, false)
@@ -661,6 +664,7 @@ func Test_managerImpl_processShutdownEvent(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ // Use a buffered logger because this test asserts log output.
logger := ktesting.NewLogger(t,
ktesting.NewConfig(
ktesting.BufferLogs(true),
@@ -684,8 +688,11 @@ func Test_managerImpl_processShutdownEvent(t *testing.T) {
clock: tt.fields.clock,
},
}
- if err := m.processShutdownEvent(); (err != nil) != tt.wantErr {
- t.Errorf("managerImpl.processShutdownEvent() error = %v, wantErr %v", err, tt.wantErr)
+ err := m.processShutdownEvent(tCtx)
+ if tt.wantErr {
+ require.Error(t, err, "managerImpl.processShutdownEvent() should return an error")
+ } else {
+ require.NoError(t, err, "managerImpl.processShutdownEvent() should not return an error")
}
underlier, ok := logger.GetSink().(ktesting.Underlier)
@@ -704,6 +711,8 @@ func Test_managerImpl_processShutdownEvent(t *testing.T) {
}
func Test_processShutdownEvent_VolumeUnmountTimeout(t *testing.T) {
+ tCtx := testutilsktesting.Init(t)
+
var (
fakeRecorder = &record.FakeRecorder{}
syncNodeStatus = func(context.Context) {}
@@ -718,6 +727,7 @@ func Test_processShutdownEvent_VolumeUnmountTimeout(t *testing.T) {
// for volume unmount operations that take longer than the allowed grace period.
fmt.Errorf("unmount timeout"), false,
)
+ // Use a buffered logger because this test asserts log output.
logger := ktesting.NewLogger(t, ktesting.NewConfig(ktesting.BufferLogs(true)))
m := &managerImpl{
logger: logger,
@@ -747,7 +757,7 @@ func Test_processShutdownEvent_VolumeUnmountTimeout(t *testing.T) {
}
start := fakeclock.Now()
- err := m.processShutdownEvent()
+ err := m.processShutdownEvent(tCtx)
end := fakeclock.Now()
require.NoError(t, err, "managerImpl.processShutdownEvent() should not return an error")
diff --git a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows.go b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows.go
index a47de9386d814..c0e7fcc61f1f6 100644
--- a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows.go
+++ b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows.go
@@ -66,6 +66,15 @@ type managerImpl struct {
storage storage
}
+type preShutdownHandler struct {
+ manager *managerImpl
+}
+
+func (h *preShutdownHandler) ProcessShutdownEvent() error {
+ // Windows SCM preshutdown callback has no context, so start a root context at this callback boundary.
+ return h.manager.ProcessShutdownEvent(context.Background())
+}
+
// NewManager returns a new node shutdown manager.
func NewManager(conf *Config) Manager {
if !utilfeature.DefaultFeatureGate.Enabled(features.WindowsGracefulNodeShutdown) {
@@ -145,7 +154,7 @@ func (m *managerImpl) Start(_ context.Context) error {
return err
}
- service.SetPreShutdownHandler(m)
+ service.SetPreShutdownHandler(&preShutdownHandler{manager: m})
m.setMetrics()
@@ -238,7 +247,7 @@ func (m *managerImpl) ShutdownStatus() error {
return nil
}
-func (m *managerImpl) ProcessShutdownEvent() error {
+func (m *managerImpl) ProcessShutdownEvent(ctx context.Context) error {
m.logger.V(1).Info("Shutdown manager detected new preshutdown event", "event", "preshutdown")
m.recorder.Event(m.nodeRef, v1.EventTypeNormal, kubeletevents.NodeShutdown, "Shutdown manager detected preshutdown event")
@@ -247,7 +256,8 @@ func (m *managerImpl) ProcessShutdownEvent() error {
m.nodeShuttingDownNow = true
m.nodeShuttingDownMutex.Unlock()
- go m.syncNodeStatus(context.TODO())
+ nodeStatusCtx := klog.NewContext(ctx, m.logger)
+ go m.syncNodeStatus(nodeStatusCtx)
m.logger.V(1).Info("Shutdown manager processing preshutdown event")
activePods := m.getPods()
@@ -280,7 +290,7 @@ func (m *managerImpl) ProcessShutdownEvent() error {
}()
}
- return m.podManager.killPods(activePods)
+ return m.podManager.killPods(ctx, activePods)
}
func (m *managerImpl) periodRequested() time.Duration {
diff --git a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows_test.go b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows_test.go
index fc278c5d4d3ce..8fe1335423234 100644
--- a/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows_test.go
+++ b/pkg/kubelet/nodeshutdown/nodeshutdown_manager_windows_test.go
@@ -28,6 +28,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
utilfeature "k8s.io/apiserver/pkg/util/feature"
@@ -39,6 +40,7 @@ import (
kubeletconfig "k8s.io/kubernetes/pkg/kubelet/apis/config"
"k8s.io/kubernetes/pkg/kubelet/eviction"
"k8s.io/kubernetes/pkg/kubelet/volumemanager"
+ testutilsktesting "k8s.io/kubernetes/test/utils/ktesting"
"k8s.io/utils/clock"
testingclock "k8s.io/utils/clock/testing"
)
@@ -226,8 +228,11 @@ func Test_managerImpl_ProcessShutdownEvent(t *testing.T) {
},
}
+ tCtx := testutilsktesting.Init(t)
+
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ // Use a buffered logger because this test asserts log output.
logger := ktesting.NewLogger(t,
ktesting.NewConfig(
ktesting.BufferLogs(true),
@@ -249,8 +254,11 @@ func Test_managerImpl_ProcessShutdownEvent(t *testing.T) {
clock: tt.fields.clock,
},
}
- if err := m.ProcessShutdownEvent(); (err != nil) != tt.wantErr {
- t.Errorf("managerImpl.processShutdownEvent() error = %v, wantErr %v", err, tt.wantErr)
+ err := m.ProcessShutdownEvent(tCtx)
+ if tt.wantErr {
+ require.Error(t, err, "managerImpl.processShutdownEvent() should return an error")
+ } else {
+ require.NoError(t, err, "managerImpl.processShutdownEvent() should not return an error")
}
underlier, ok := logger.GetSink().(ktesting.Underlier)