diff --git a/pkg/ixdcgm/policy.go b/pkg/ixdcgm/policy.go index ada6c23fe7daba7cead8b4a230c4c1b3ea5be188..e1efae7ddf1f8f1d37048fce370aa29058bbd2e6 100644 --- a/pkg/ixdcgm/policy.go +++ b/pkg/ixdcgm/policy.go @@ -132,7 +132,9 @@ var ( // callbacks maps PolicyViolation channels with policy // captures C callback() value for each violation condition - callbacks map[string]chan PolicyViolation + callbacks map[string]chan PolicyViolation + conChanCap = 5 + conChanLcks map[string]*sync.Mutex // paramMap maps C.dcgmPolicy_t.parms index and limits // to be used in setPolicy() for setting user selected policies @@ -144,11 +146,18 @@ var ( func makePolicyChannels() { policyChanOnce.Do(func() { callbacks = make(map[string]chan PolicyViolation) - callbacks["dbe"] = make(chan PolicyViolation, 1) - callbacks["pcie"] = make(chan PolicyViolation, 1) - callbacks["maxrtpg"] = make(chan PolicyViolation, 1) - callbacks["thermal"] = make(chan PolicyViolation, 1) - callbacks["power"] = make(chan PolicyViolation, 1) + callbacks["dbe"] = make(chan PolicyViolation, conChanCap) + callbacks["pcie"] = make(chan PolicyViolation, conChanCap) + callbacks["maxrtpg"] = make(chan PolicyViolation, conChanCap) + callbacks["thermal"] = make(chan PolicyViolation, conChanCap) + callbacks["power"] = make(chan PolicyViolation, conChanCap) + + conChanLcks = make(map[string]*sync.Mutex) + conChanLcks["dbe"] = &sync.Mutex{} + conChanLcks["pcie"] = &sync.Mutex{} + conChanLcks["maxrtpg"] = &sync.Mutex{} + conChanLcks["thermal"] = &sync.Mutex{} + conChanLcks["power"] = &sync.Mutex{} }) } @@ -314,18 +323,18 @@ func registerPolicy(ctx context.Context, groupId GroupHandle, params *PolicyCond return nil, err } + log.Println("Listening for violations...") result := C.dcgmPolicyRegister(handle.handle, groupId.handle, C.dcgmPolicyCondition_t(condition), C.fpRecvUpdates(C.violationNotify), C.fpRecvUpdates(C.voidCallback), ) - if err = errorString(result); err != nil { return nil, &DcgmError{msg: C.GoString(C.errorString(result)), Code: result} } - log.Println("Listening for violations...") - violation := make(chan PolicyViolation, conTypes) + vioChanCap := conTypes * 2 + violation := make(chan PolicyViolation, vioChanCap) go func() { defer func() { @@ -335,6 +344,13 @@ func registerPolicy(ctx context.Context, groupId GroupHandle, params *PolicyCond close(registerCh) }() for { + if len(violation) == vioChanCap { + log.Println("Error: The violation channel is already full. New messages will be discarded.") + continue + } else if len(violation) == vioChanCap-1 { + log.Println("Warning: The violation channel is almost full. Please read it as soon as possible.") + } + select { case dbe := <-callbacks["dbe"]: violation <- dbe @@ -396,6 +412,7 @@ func VoidCallback(data unsafe.Pointer) int { // //export ViolationRegistration func ViolationRegistration(data unsafe.Pointer) int { + // log.Println("A policy violation is coming ...") var con policyCondition var timestamp time.Time var val interface{} @@ -450,15 +467,28 @@ func ViolationRegistration(data unsafe.Pointer) int { switch con { case DbePolicy: - callbacks["dbe"] <- err + writeToCallbacks("dbe", err) case PCIePolicy: - callbacks["pcie"] <- err + writeToCallbacks("pcie", err) case MaxRtPgPolicy: - callbacks["maxrtpg"] <- err + writeToCallbacks("maxrtpg", err) case ThermalPolicy: - callbacks["thermal"] <- err + writeToCallbacks("thermal", err) case PowerPolicy: - callbacks["power"] <- err + writeToCallbacks("power", err) } return 0 } + +func writeToCallbacks(con string, vioErr PolicyViolation) { + conChanLcks[con].Lock() + defer conChanLcks[con].Unlock() + + if len(callbacks[con]) == conChanCap { + log.Printf("Error: The channel of %s condition is already full. New messages will be discarded.\n", con) + return + } else if len(callbacks[con]) == conChanCap-1 { + log.Printf("Warning: The channel of %s condition is almost full. Please read it as soon as possible.\n", con) + } + callbacks[con] <- vioErr +} diff --git a/samples/policy/main.go b/samples/policy/main.go index 9a81a958d0b9b6774a8ffc0a4b4918c82ee2b7a0..633d9b251758e77409e4997cbb0d87649b944451 100644 --- a/samples/policy/main.go +++ b/samples/policy/main.go @@ -47,20 +47,20 @@ func main() { PCIePolicyEnabled: true, ThermalPolicyEnabled: true, ThermalPolicyThreshold: 60, // °C + PowerPolicyEnabled: true, + PowerPolicyThreshold: 250, // W } // Monitor policy violations for all GPUs + // Note: if you want to monitor policy violations for special GPUs (e.g., gpuId0 and gpuId1), + // use the api: ixdcgm.ListenForPolicyViolationsForGPUs(ctx, params, gpuId0, gpuId1) ch, err := ixdcgm.ListenForPolicyViolationsForAllGPUs(ctx, params) - - // If you want to monitor policy violations for particular GPUs (e.g., gpuId0 and gpuId1), - // use the following code: - // ch, err := ixdcgm.ListenForPolicyViolationsForGPUs(ctx, params, 0, 1) - if err != nil { fmt.Printf("Failed to monitor policy violations, err: %v", err) return } + // Read the policy violations from the channel as soon as possible. for { select { case pe := <-ch: