From 8eb648e05c1b5ad3ce8387c095f1ac517ab2618a Mon Sep 17 00:00:00 2001 From: "maofeng.huang" Date: Fri, 7 Mar 2025 16:23:15 +0800 Subject: [PATCH] Fix the issue of handling policy msg --- pkg/ixdcgm/policy.go | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/pkg/ixdcgm/policy.go b/pkg/ixdcgm/policy.go index e1efae7..3d77d87 100644 --- a/pkg/ixdcgm/policy.go +++ b/pkg/ixdcgm/policy.go @@ -40,6 +40,8 @@ import ( "github.com/creasty/defaults" ) +const PolicyChanCapMultiplier = 3 + // At least one policy must be enabled. type PolicyConditionParams struct { // DbePolicyEnabled indicates whether the DbePolicy is enabled. Default is false (disabled). @@ -133,7 +135,6 @@ var ( // callbacks maps PolicyViolation channels with policy // captures C callback() value for each violation condition callbacks map[string]chan PolicyViolation - conChanCap = 5 conChanLcks map[string]*sync.Mutex // paramMap maps C.dcgmPolicy_t.parms index and limits @@ -143,14 +144,15 @@ var ( registerCh = make(chan struct{}) ) -func makePolicyChannels() { +func makePolicyChannels(gpuCnt int) { + policyChanCap := PolicyChanCapMultiplier * (gpuCnt + 1) policyChanOnce.Do(func() { callbacks = make(map[string]chan PolicyViolation) - callbacks["dbe"] = make(chan PolicyViolation, conChanCap) - callbacks["pcie"] = make(chan PolicyViolation, conChanCap) - callbacks["maxrtpg"] = make(chan PolicyViolation, conChanCap) - callbacks["thermal"] = make(chan PolicyViolation, conChanCap) - callbacks["power"] = make(chan PolicyViolation, conChanCap) + callbacks["dbe"] = make(chan PolicyViolation, policyChanCap) + callbacks["pcie"] = make(chan PolicyViolation, policyChanCap) + callbacks["maxrtpg"] = make(chan PolicyViolation, policyChanCap) + callbacks["thermal"] = make(chan PolicyViolation, policyChanCap) + callbacks["power"] = make(chan PolicyViolation, policyChanCap) conChanLcks = make(map[string]*sync.Mutex) conChanLcks["dbe"] = &sync.Mutex{} @@ -275,15 +277,22 @@ func registerPolicyForGpus(ctx context.Context, params *PolicyConditionParams, g // registerPolicy sets GPU usage and error policies and notifies in case of any violations on GPUs within a specific group func registerPolicy(ctx context.Context, groupId GroupHandle, params *PolicyConditionParams) (<-chan PolicyViolation, error) { + var err error if params == nil { return nil, fmt.Errorf("PolicyConditionParams is required") } - if err := validatePolicy(params); err != nil { + if err = validatePolicy(params); err != nil { return nil, err } + grpInfo, err := GetGroupInfo(groupId) + if err != nil { + return nil, fmt.Errorf("Error getting group info for group %v: %v", groupId, err) + } + gpuCnt := len(grpInfo.EntityList) + // init policy globals for internal API - makePolicyChannels() + makePolicyChannels(gpuCnt) makePolicyParamsMap(params) // make a list of policy conditions for setting their parameters @@ -318,7 +327,6 @@ func registerPolicy(ctx context.Context, groupId GroupHandle, params *PolicyCond condition |= C.DCGM_POLICY_COND_POWER } - var err error if err = setPolicy(groupId, condition, paramKeys); err != nil { return nil, err } @@ -333,7 +341,7 @@ func registerPolicy(ctx context.Context, groupId GroupHandle, params *PolicyCond return nil, &DcgmError{msg: C.GoString(C.errorString(result)), Code: result} } - vioChanCap := conTypes * 2 + vioChanCap := conTypes * (gpuCnt + 1) violation := make(chan PolicyViolation, vioChanCap) go func() { @@ -368,7 +376,7 @@ func registerPolicy(ctx context.Context, groupId GroupHandle, params *PolicyCond } }() - return violation, err + return violation, nil } func unregisterPolicy(groupId GroupHandle, condition C.dcgmPolicyCondition_t) { @@ -484,10 +492,10 @@ func writeToCallbacks(con string, vioErr PolicyViolation) { conChanLcks[con].Lock() defer conChanLcks[con].Unlock() - if len(callbacks[con]) == conChanCap { + if len(callbacks[con]) == cap(callbacks[con]) { log.Printf("Error: The channel of %s condition is already full. New messages will be discarded.\n", con) return - } else if len(callbacks[con]) == conChanCap-1 { + } else if len(callbacks[con]) == cap(callbacks[con])-1 { log.Printf("Warning: The channel of %s condition is almost full. Please read it as soon as possible.\n", con) } callbacks[con] <- vioErr -- Gitee