package org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.nvgpu;

import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import com.sun.jna.ptr.IntByReference;
import com.sun.jna.ptr.PointerByReference;
import java.util.HashMap;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.GpuHealth;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.GpuMemoryStat;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.nvgpu.NvmlMemoryt;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.nvgpu.NvmlPciInfo;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.nvgpu.NvmlProcessInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/linux/sharedresource/gpu/nvgpu/NvGpuDiscoverer.class */
public class NvGpuDiscoverer implements SharedGpuDiscoverer {
    private static final Logger LOG = LoggerFactory.getLogger(NvGpuDiscoverer.class);
    private static final int CHAR_BUFFER_SIZE = 128;
    private static final String DEFAULT_VENDOR_NAME = "nvgpu";
    private NvmlLibrary nvmlHandle;

    public NvGpuDiscoverer(Library library) {
        try {
            NvmlLibrary nvmlLibrary = (NvmlLibrary) NvmlLibrary.class.cast(library);
            int nvmlInit = nvmlLibrary.nvmlInit();
            if (nvmlInit != NvmlGpuAPIException.SUCCESS) {
                LOG.warn("NvGpu discoverer init failed, ret is {}", Integer.valueOf(nvmlInit));
            } else {
                this.nvmlHandle = nvmlLibrary;
            }
        } catch (Exception | UnsatisfiedLinkError e) {
            LOG.warn("NvGpu discoverer init failed, {}", e.getMessage());
        }
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public String getVendor() {
        return DEFAULT_VENDOR_NAME;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public String getPciBusID(int i) {
        NvmlPciInfo nvmlPciInfo = null;
        try {
            nvmlPciInfo = deviceGetPciInfo(i);
        } catch (NvmlGpuAPIException e) {
            LOG.error("Get pci info, Error Message: {}", e.getMessage());
        }
        if (nvmlPciInfo == null || nvmlPciInfo.busId.length == 0) {
            return "";
        }
        int i2 = 0;
        while (i2 < nvmlPciInfo.busId.length && nvmlPciInfo.busId[i2] != 0) {
            i2++;
        }
        return new String(nvmlPciInfo.busId).substring(0, i2);
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public GpuMemoryStat getGpuMemoryStat(int i) {
        GpuMemoryStat gpuMemoryStat = null;
        try {
            gpuMemoryStat = getGpuMemoryStatByIndex(i);
            return gpuMemoryStat;
        } catch (NvmlGpuAPIException e) {
            LOG.error("Unable to get stat for GPU idx {}, Error Message: {}", Integer.valueOf(i), e.getMessage());
            return gpuMemoryStat;
        }
    }

    private int getCount() {
        int i = 0;
        try {
            i = deviceGetCount();
        } catch (NvmlGpuAPIException e) {
            LOG.error("Get device count, Error Message: {}", e.getMessage());
        }
        return i;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public int[] getDeviceIndexList() {
        int count = getCount();
        if (count == 0) {
            return new int[0];
        }
        int[] iArr = new int[count];
        for (int i = 0; i < count; i++) {
            iArr[i] = i;
        }
        return iArr;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public GpuHealth getHealth(int i) {
        return GpuHealth.HEALTH_BUTT;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public String getModelName(int i) {
        String str = "";
        try {
            str = deviceGetName(i);
        } catch (NvmlGpuAPIException e) {
            LOG.error("Unable to get name for GPU idx {}, Error Message: {}", Integer.valueOf(i), e.getMessage());
        }
        return str;
    }

    @Override // org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.sharedresource.gpu.SharedGpuDiscoverer
    public String getUUID(int i) {
        String str = "";
        try {
            str = deviceGetUUIDByIndex(i);
        } catch (NvmlGpuAPIException e) {
            LOG.error("Unable to get uuid for GPU idx {}, Error Message: {}", Integer.valueOf(i), e.getMessage());
        }
        return str;
    }

    private NvmlProcessInfo[] deviceGetComputeRunningProcesses(int i) throws NvmlGpuAPIException {
        Pointer deviceGetHandleByIndex = deviceGetHandleByIndex(i);
        IntByReference intByReference = new IntByReference();
        NvmlProcessInfo.ByReference byReference = new NvmlProcessInfo.ByReference();
        int nvmlDeviceGetComputeRunningProcesses = this.nvmlHandle.nvmlDeviceGetComputeRunningProcesses(deviceGetHandleByIndex, intByReference, byReference);
        if (nvmlDeviceGetComputeRunningProcesses != NvmlGpuAPIException.SUCCESS && nvmlDeviceGetComputeRunningProcesses != NvmlGpuAPIException.INSUFFICIENT_SIZE) {
            throw new NvmlGpuAPIException(nvmlDeviceGetComputeRunningProcesses);
        }
        if (intByReference.getValue() <= 0) {
            return new NvmlProcessInfo[0];
        }
        NvmlProcessInfo.ByReference[] byReferenceArr = (NvmlProcessInfo.ByReference[]) byReference.toArray(intByReference.getValue());
        int nvmlDeviceGetComputeRunningProcesses2 = this.nvmlHandle.nvmlDeviceGetComputeRunningProcesses(deviceGetHandleByIndex, intByReference, byReferenceArr[0]);
        if (nvmlDeviceGetComputeRunningProcesses2 != NvmlGpuAPIException.SUCCESS) {
            throw new NvmlGpuAPIException(nvmlDeviceGetComputeRunningProcesses2);
        }
        return byReferenceArr;
    }

    private GpuMemoryStat getGpuMemoryStatByIndex(int i) throws NvmlGpuAPIException {
        HashMap hashMap = new HashMap();
        for (NvmlProcessInfo nvmlProcessInfo : deviceGetComputeRunningProcesses(i)) {
            hashMap.put(Integer.toString(nvmlProcessInfo.pid), Long.valueOf(nvmlProcessInfo.usedGpuMemory));
        }
        Pointer deviceGetHandleByIndex = deviceGetHandleByIndex(i);
        NvmlMemoryt.ByReference byReference = new NvmlMemoryt.ByReference();
        int nvmlDeviceGetMemoryInfo = this.nvmlHandle.nvmlDeviceGetMemoryInfo(deviceGetHandleByIndex, byReference);
        if (nvmlDeviceGetMemoryInfo != NvmlGpuAPIException.SUCCESS) {
            throw new NvmlGpuAPIException(nvmlDeviceGetMemoryInfo);
        }
        GpuMemoryStat gpuMemoryStat = new GpuMemoryStat();
        gpuMemoryStat.setUsage(hashMap);
        gpuMemoryStat.setFreeMem(Long.valueOf(byReference.free));
        gpuMemoryStat.setUsedMem(Long.valueOf(byReference.used));
        gpuMemoryStat.setTotalMem(Long.valueOf(byReference.total));
        gpuMemoryStat.setSupportProcessUsage(true);
        return gpuMemoryStat;
    }

    private Pointer deviceGetHandleByIndex(int i) throws NvmlGpuAPIException {
        PointerByReference pointerByReference = new PointerByReference();
        int nvmlDeviceGetHandleByIndex = this.nvmlHandle.nvmlDeviceGetHandleByIndex(i, pointerByReference);
        if (nvmlDeviceGetHandleByIndex != NvmlGpuAPIException.SUCCESS) {
            throw new NvmlGpuAPIException(nvmlDeviceGetHandleByIndex);
        }
        if (pointerByReference.getValue() == null) {
            throw new NvmlGpuAPIException("Unable to retrieve device pointer for gpu #" + i);
        }
        return pointerByReference.getValue();
    }

    private NvmlPciInfo deviceGetPciInfo(int i) throws NvmlGpuAPIException {
        if (this.nvmlHandle == null) {
            return null;
        }
        Pointer deviceGetHandleByIndex = deviceGetHandleByIndex(i);
        NvmlPciInfo.ByReference byReference = new NvmlPciInfo.ByReference();
        int nvmlDeviceGetPciInfo = this.nvmlHandle.nvmlDeviceGetPciInfo(deviceGetHandleByIndex, byReference);
        if (nvmlDeviceGetPciInfo != NvmlGpuAPIException.SUCCESS) {
            throw new NvmlGpuAPIException(nvmlDeviceGetPciInfo);
        }
        return byReference;
    }

    private int deviceGetCount() throws NvmlGpuAPIException {
        if (this.nvmlHandle == null) {
            return 0;
        }
        IntByReference intByReference = new IntByReference();
        int nvmlDeviceGetCount = this.nvmlHandle.nvmlDeviceGetCount(intByReference);
        if (nvmlDeviceGetCount != NvmlGpuAPIException.SUCCESS || intByReference.getValue() < 0) {
            throw new NvmlGpuAPIException(nvmlDeviceGetCount);
        }
        return intByReference.getValue();
    }

    private String deviceGetName(int i) throws NvmlGpuAPIException {
        if (this.nvmlHandle == null) {
            return "";
        }
        byte[] bArr = new byte[128];
        int nvmlDeviceGetName = this.nvmlHandle.nvmlDeviceGetName(deviceGetHandleByIndex(i), bArr, 128L);
        if (nvmlDeviceGetName != NvmlGpuAPIException.SUCCESS) {
            throw new NvmlGpuAPIException(nvmlDeviceGetName);
        }
        return Native.toString(bArr);
    }

    private String deviceGetUUIDByIndex(int i) throws NvmlGpuAPIException {
        if (this.nvmlHandle == null) {
            return "";
        }
        byte[] bArr = new byte[128];
        int nvmlDeviceGetUUID = this.nvmlHandle.nvmlDeviceGetUUID(deviceGetHandleByIndex(i), bArr, 128);
        if (nvmlDeviceGetUUID != NvmlGpuAPIException.SUCCESS) {
            throw new NvmlGpuAPIException(nvmlDeviceGetUUID);
        }
        return Native.toString(bArr);
    }
}
