Code Diff
diff --git a/pkg/scheduler/framework/autoscaler_contract/lister_contract_test.go b/pkg/scheduler/framework/autoscaler_contract/lister_contract_test.go
index ed53485c95e89..bcb522c9a2328 100644
--- a/pkg/scheduler/framework/autoscaler_contract/lister_contract_test.go
+++ b/pkg/scheduler/framework/autoscaler_contract/lister_contract_test.go
@@ -148,26 +148,18 @@ func (r *resourceClaimTrackerContract) GatherAllocatedState() (*schedulerapi.All
return nil, nil
}
-func (r *resourceClaimTrackerContract) SignalClaimPendingAllocation(_ types.UID, _ *resourceapi.ResourceClaim) error {
+func (r *resourceClaimTrackerContract) GetPendingAllocation(_ types.UID) *resourceapi.AllocationResult {
return nil
}
-func (r *resourceClaimTrackerContract) GetPendingAllocation(_ types.UID) (*resourceapi.AllocationResult, bool) {
- return nil, false
+func (r *resourceClaimTrackerContract) SignalClaimPendingAllocation(_ types.UID, _ *resourceapi.ResourceClaim) error {
+ return nil
}
func (r *resourceClaimTrackerContract) MaybeRemoveClaimPendingAllocation(_ types.UID, _ bool) (deleted bool) {
return false
}
-func (r *resourceClaimTrackerContract) AddSharedClaimPendingAllocation(claimUID types.UID, allocatedClaim *resourceapi.ResourceClaim) error {
- return nil
-}
-
-func (r *resourceClaimTrackerContract) RemoveSharedClaimPendingAllocation(claimUID types.UID, allocatedClaim *resourceapi.ResourceClaim) error {
- return nil
-}
-
func (r *resourceClaimTrackerContract) AssumeClaimAfterAPICall(_ *resourceapi.ResourceClaim) error {
return nil
}
diff --git a/pkg/scheduler/framework/cycle_state.go b/pkg/scheduler/framework/cycle_state.go
index 2676c263956eb..31b787253568f 100644
--- a/pkg/scheduler/framework/cycle_state.go
+++ b/pkg/scheduler/framework/cycle_state.go
@@ -41,11 +41,11 @@ type CycleState struct {
// GetParallelPreBindPlugins returns plugins that can be run in parallel with other plugins
// in the PreBind extension point.
parallelPreBindPlugins sets.Set[string]
- // isPodGroupSchedulingCycle indicates whether this cycle is a pod group scheduling cycle or not.
- // If set to false, it means that the pod referencing this CycleState either passed the pod group cycle
+ // podGroupCycleState contains the CycleState for this pod's PodGroup.
+ // If set to nil, it means that the pod referencing this CycleState either passed the pod group cycle
// or doesn't belong to any pod group.
- // This field can only be set to true when GenericWorkload feature flag is enabled.
- isPodGroupSchedulingCycle bool
+ // This field can only be non-nil when GenericWorkload feature flag is enabled.
+ podGroupCycleState fwk.PodGroupCycleState
}
// NewCycleState initializes a new CycleState and returns its pointer.
@@ -102,11 +102,15 @@ func (c *CycleState) GetParallelPreBindPlugins() sets.Set[string] {
}
func (c *CycleState) IsPodGroupSchedulingCycle() bool {
- return c.isPodGroupSchedulingCycle
+ return c.podGroupCycleState != nil
}
-func (c *CycleState) SetPodGroupSchedulingCycle(isPodGroupSchedulingCycle bool) {
- c.isPodGroupSchedulingCycle = isPodGroupSchedulingCycle
+func (c *CycleState) SetPodGroupSchedulingCycle(podGroupCycleState fwk.PodGroupCycleState) {
+ c.podGroupCycleState = podGroupCycleState
+}
+
+func (c *CycleState) GetPodGroupSchedulingCycle() fwk.PodGroupCycleState {
+ return c.podGroupCycleState
}
func (c *CycleState) SetSkipAllPostFilterPlugins(flag bool) {
@@ -135,7 +139,7 @@ func (c *CycleState) Clone() fwk.CycleState {
copy.skipScorePlugins = c.skipScorePlugins
copy.skipPreBindPlugins = c.skipPreBindPlugins
copy.parallelPreBindPlugins = c.parallelPreBindPlugins
- copy.isPodGroupSchedulingCycle = c.isPodGroupSchedulingCycle
+ copy.podGroupCycleState = c.podGroupCycleState
copy.skipAllPostFilterPlugins = c.skipAllPostFilterPlugins
return copy
diff --git a/pkg/scheduler/framework/plugins/dynamicresources/dra_manager.go b/pkg/scheduler/framework/plugins/dynamicresources/dra_manager.go
index 4687e55c4c522..a1703319f6c0c 100644
--- a/pkg/scheduler/framework/plugins/dynamicresources/dra_manager.go
+++ b/pkg/scheduler/framework/plugins/dynamicresources/dra_manager.go
@@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
+ "iter"
"slices"
"sync"
@@ -63,7 +64,7 @@ func NewDRAManager(ctx context.Context, claimsCache *assumecache.AssumeCache, re
manager := &DefaultDRAManager{
resourceClaimTracker: &claimTracker{
cache: claimsCache,
- inFlightAllocations: &sync.Map{},
+ inFlightAllocations: make(map[types.UID]inFlightAllocation),
allocatedDevices: newAllocatedDevices(logger),
logger: logger,
},
@@ -74,9 +75,6 @@ func NewDRAManager(ctx context.Context, claimsCache *assumecache.AssumeCache, re
if utilfeature.DefaultFeatureGate.Enabled(features.DRAExtendedResource) {
manager.extendedResourceCache = extendedresourcecache.NewExtendedResourceCache(logger)
}
- if utilfeature.DefaultFeatureGate.Enabled(features.GenericWorkload) {
- manager.resourceClaimTracker.inFlightAllocationSharers = &sync.Map{}
- }
pgLister := &podGroupLister{}
if utilfeature.DefaultFeatureGate.Enabled(features.DRAWorkloadResourceClaims) {
@@ -178,13 +176,16 @@ type claimTracker struct {
// - would make integration with cluster autoscaler harder because it would need
// to trigger informer callbacks.
cache *assumecache.AssumeCache
+ // inFlightMutex syncs access to inFlightAllocations.
+ inFlightMutex sync.RWMutex
// inFlightAllocations is a map from claim UUIDs to claim objects for those claims
// for which allocation was triggered during a scheduling cycle and the
// corresponding claim status update call in PreBind has not been done
- // yet. If another pod needs the claim, the pod is treated as "not
- // schedulable yet" unless the pod is a member of a PodGroup. For ungrouped
- // pods, the cluster event for the claim status update will make it
- // schedulable.
+ // yet. It also includes a reference count tracking how many actively
+ // scheduling Pods in a PodGroup are using that pending allocation. If
+ // another pod outside the PodGroup needs the claim, the pod is treated as
+ // "not schedulable yet". For those pods, the cluster event for the
+ // claim status update will make them schedulable.
//
// This mechanism avoids the following problem:
// - Pod A triggers allocation for claim X.
@@ -215,25 +216,45 @@ type claimTracker struct {
// pods is expected to be rare compared to per-pod claim, so we end up
// hitting the "multiple goroutines read, write, and overwrite entries
// for disjoint sets of keys" case that sync.Map is optimized for.
- inFlightAllocations *sync.Map
- // inFlightAllocationSharers counts the actively scheduling pods
- // sharing a given ResourceClaim.
- inFlightAllocationSharers *sync.Map
- allocatedDevices *allocatedDevices
- logger klog.Logger
+ inFlightAllocations map[types.UID]inFlightAllocation
+ allocatedDevices *allocatedDevices
+ logger klog.Logger
+}
+
+type inFlightAllocation struct {
+ claim *resourceapi.ResourceClaim
+ sharers int
}
-func (c *claimTracker) GetPendingAllocation(claimUID types.UID) (*resourceapi.AllocationResult, bool) {
- var allocation *resourceapi.AllocationResult
- claim, found := c.inFlightAllocations.Load(claimUID)
- if found && claim != nil {
- allocation = claim.(*resourceapi.ResourceClaim).Status.Allocation
+func (c *claimTracker) GetPendingAllocation(claimUID types.UID) *resourceapi.AllocationResult {
+ c.inFlightMutex.RLock()
+ defer c.inFlightMutex.RUnlock()
+
+ inFlight, found := c.inFlightAllocations[claimUID]
+ if !found || inFlight.claim == nil {
+ return nil
}
- return allocation, found
+ return inFlight.claim.Status.Allocation
}
func (c *claimTracker) SignalClaimPendingAllocation(claimUID types.UID, allocatedClaim *resourceapi.ResourceClaim) error {
- c.inFlightAllocations.Store(claimUID, allocatedClaim)
+ c.inFlightMutex.Lock()
+ defer c.inFlightMutex.Unlock()
+
+ inFlight, found := c.inFlightAllocations[claimUID]
+ if found {
+ inFlight.sharers++
+ c.inFlightAllocations[claimUID] = inFlight
+
+ claim := inFlight.claim
+ c.logger.V(5).Info("Added share for in-flight claim", "claim", klog.KObj(claim), "uid", claimUID, "version", claim.ResourceVersion, "sharers", inFlight.sharers)
+ return nil
+ }
+
+ c.inFlightAllocations[claimUID] = inFlightAllocation{
+ claim: allocatedClaim,
+ sharers: 1,
+ }
// This is the same verbosity as the corresponding log in the assume cache.
c.logger.V(5).Info("Added in-flight claim", "claim", klog.KObj(allocatedClaim), "uid", claimUID, "version", allocatedClaim.ResourceVersion)
// There's no reason to return an error in this implementation, but the error is helpful for other implementations.
@@ -242,67 +263,28 @@ func (c *claimTracker) SignalClaimPendingAllocation(claimUID types.UID, allocate
return nil
}
-func (c *claimTracker) MaybeRemoveClaimPendingAllocation(claimUID types.UID, shareable bool) (deleted bool) {
- if c.inFlightAllocationSharers != nil && shareable {
- value, ok := c.inFlightAllocationSharers.Load(claimUID)
- if ok && value.(int) > 0 {
- if loggerV := c.logger.V(5); loggerV.Enabled() {
- claim, found := c.inFlightAllocations.Load(claimUID)
- if found {
- claim := claim.(*resourceapi.ResourceClaim)
- c.logger.V(5).Info("Claim is still shared by other pods, not removing in-flight claim", "claim", klog.KObj(claim), "uid", claimUID, "version", claim.ResourceVersion)
- }
- }
- return false
- }
- }
+func (c *claimTracker) MaybeRemoveClaimPendingAllocation(claimUID types.UID, forceRemove bool) (deleted bool) {
+ c.inFlightMutex.Lock()
+ defer c.inFlightMutex.Unlock()
- claim, found := c.inFlightAllocations.LoadAndDelete(claimUID)
+ inFlight, found := c.inFlightAllocations[claimUID]
// The assume cache doesn't log this, but maybe it should.
- if found {
- claim := claim.(*resourceapi.ResourceClaim)
- c.logger.V(5).Info("Removed in-flight claim", "claim", klog.KObj(claim), "uid", claimUID, "version", claim.ResourceVersion)
- } else {
+ if !found {
c.logger.V(5).Info("Redundant remove of in-flight claim, not found", "uid", claimUID)
+ return false
}
- return found
-}
+ claim := inFlight.claim
-func (c *claimTracker) AddSharedClaimPendingAllocation(claimUID types.UID, allocatedClaim *resourceapi.ResourceClaim) error {
- newSharers := 1
- value, loaded := c.inFlightAllocationSharers.LoadOrStore(claimUID, newSharers)
- if loaded {
- oldSharers := value.(int)
- newSharers = oldSharers + 1
- swapped := c.inFlightAllocationSharers.CompareAndSwap(claimUID, oldSharers, newSharers)
- if !swapped {
- // The value must have changed since we loaded
- return fmt.Errorf("conflict adding in-flight allocation sharer for claim %s/%s, UID=%s", allocatedClaim.Namespace, allocatedClaim.Name, claimUID)
- }
+ if forceRemove || inFlight.sharers == 1 {
+ delete(c.inFlightAllocations, claimUID)
+ c.logger.V(5).Info("Removed in-flight claim", "claim", klog.KObj(claim), "uid", claimUID, "version", claim.ResourceVersion)
+ return true
}
- c.logger.V(5).Info("Added share for in-flight claim", "claim", klog.KObj(allocatedClaim), "uid", claimUID, "version", allocatedClaim.ResourceVersion, "sharers", newSharers)
- return nil
-}
-func (c *claimTracker) RemoveSharedClaimPendingAllocation(claimUID types.UID, allocatedClaim *resourceapi.ResourceClaim) error {
- value, ok := c.inFlightAllocationSharers.Load(claimUID)
- if !ok {
- return nil
- }
- oldSharers := value.(int)
- newSharers := oldSharers - 1
- var written bool
- if newSharers == 0 {
- written = c.inFlightAllocationSharers.CompareAndDelete(claimUID, oldSharers)
- } else {
- written = c.inFlightAllocationSharers.CompareAndSwap(claimUID, oldSharers, newSharers)
- }
- if !written {
- // The value must have changed since we loaded
- return fmt.Errorf("conflict removing in-flight allocation sharer for claim %s/%s, UID=%s", allocatedClaim.Namespace, allocatedClaim.Name, claimUID)
- }
- c.logger.V(5).Info("Removed share for in-flight claim", "claim", klog.KObj(allocatedClaim), "uid", claimUID, "version", allocatedClaim.ResourceVersion, "sharers", newSharers)
- return nil
+ inFlight.sharers--
+ c.inFlightAllocations[claimUID] = inFlight
+ c.logger.V(5).Info("Claim is still shared by other pods, not removing in-flight claim", "claim", klog.KObj(claim), "uid", claimUID, "version", claim.ResourceVersion, "sharers", inFlight.sharers)
+ return false
}
func (c *claimTracker) Get(namespace, claimName string) (*resourceapi.ResourceClaim, error) {
@@ -360,14 +342,13 @@ func (c *claimTracker) ListAllAllocatedDevices() (a sets.Set[structured.DeviceID
allocated, revision := c.allocatedDevices.Get()
// Whatever is in flight also has to be checked.
- c.inFlightAllocations.Range(func(key, value any) bool {
- claim := value.(*resourceapi.ResourceClaim)
+ for _, inFlight := range c.allInFlightAllocationsRLocked() {
+ claim := inFlight.claim
foreachAllocatedDevice(claim, func(deviceID structured.DeviceID) {
c.logger.V(6).Info("Device is in flight for allocation", "device", deviceID, "claim", klog.KObj(claim))
allocated.Insert(deviceID)
}, false, func(structured.SharedDeviceID) {}, func(structured.DeviceConsumedCapacity) {})
- return true
- })
+ }
if revision == c.allocatedDevices.Revision() {
// Our current result is valid, nothing changed in the meantime.
@@ -433,8 +414,8 @@ func (c *claimTracker) GatherAllocatedState() (s *structured.AllocatedState, err
}
// Whatever is in flight also has to be checked.
- c.inFlightAllocations.Range(func(key, value any) bool {
- claim := value.(*resourceapi.ResourceClaim)
+ for _, inFlight := range c.allInFlightAllocationsRLocked() {
+ claim := inFlight.claim
foreachAllocatedDevice(claim,
func(deviceID structured.DeviceID) { // dedicatedDeviceCallback
c.logger.V(6).Info("Device is in flight for allocation", "device", deviceID, "claim", klog.KObj(claim))
@@ -449,9 +430,7 @@ func (c *claimTracker) GatherAllocatedState() (s *structured.AllocatedState, err
c.logger.V(6).Info("Device is in flight for allocation", "consumed capacity", capacity, "claim", klog.KObj(claim))
aggregatedCapacity.Insert(capacity)
})
- return true
- })
-
+ }
if revision1 == c.allocatedDevices.Revision() {
// Our current result is valid, nothing changed in the meantime.
return &structured.AllocatedState{
@@ -464,6 +443,18 @@ func (c *claimTracker) GatherAllocatedState() (s *structured.AllocatedState, err
return nil, errClaimTrackerConcurrentModification
}
+func (c *claimTracker) allInFlightAllocationsRLocked() iter.Seq2[types.UID, inFlightAllocation] {
+ return func(yield func(types.UID, inFlightAllocation) bool) {
+ c.inFlightMutex.RLock()
+ defer c.inFlightMutex.RUnlock()
+ for uid, inFlight := range c.inFlightAllocations {
+ if !yield(uid, inFlight) {
+ return
+ }
+ }
+ }
+}
+
fun
... [truncated]