package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.metric.Dimension;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.util.cuda.CudaUtils;
import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Iterator;
import okhttp3.internal.ws.RealWebSocket;
import org.apache.tika.metadata.MachineMetadata;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:lib/api-0.31.1.jar:ai/djl/training/listener/MemoryTrainingListener.class */
public class MemoryTrainingListener extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) MemoryTrainingListener.class);
    private String outputDir;

    public MemoryTrainingListener() {
    }

    public MemoryTrainingListener(String str) {
        this.outputDir = str;
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        collectMemoryInfo(trainer.getMetrics());
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        collectMemoryInfo(trainer.getMetrics());
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onTrainingEnd(Trainer trainer) {
        dumpMemoryInfo(trainer.getMetrics(), this.outputDir);
    }

    public static void collectMemoryInfo(Metrics metrics) {
        if (metrics == null || !Boolean.getBoolean("collect-memory")) {
            return;
        }
        MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
        MemoryUsage heapMemoryUsage = memoryMXBean.getHeapMemoryUsage();
        MemoryUsage nonHeapMemoryUsage = memoryMXBean.getNonHeapMemoryUsage();
        long used = heapMemoryUsage.getUsed();
        long used2 = nonHeapMemoryUsage.getUsed();
        getProcessInfo(metrics);
        metrics.addMetric("Heap", Long.valueOf(used), Unit.BYTES, new Dimension[0]);
        metrics.addMetric("NonHeap", Long.valueOf(used2), Unit.BYTES, new Dimension[0]);
        int gpuCount = CudaUtils.getGpuCount();
        for (int i = 0; i < gpuCount; i++) {
            metrics.addMetric("GPU-" + i, Long.valueOf(CudaUtils.getGpuMemory(Device.gpu(i)).getCommitted()), Unit.BYTES, new Dimension[0]);
        }
    }

    public static void dumpMemoryInfo(Metrics metrics, String str) {
        if (metrics == null || str == null) {
            return;
        }
        try {
            Path path = Paths.get(str, new String[0]);
            Files.createDirectories(path, new FileAttribute[0]);
            BufferedWriter newBufferedWriter = Files.newBufferedWriter(path.resolve("memory.log"), StandardOpenOption.CREATE, StandardOpenOption.APPEND);
            try {
                ArrayList arrayList = new ArrayList();
                arrayList.addAll(metrics.getMetric("Heap"));
                arrayList.addAll(metrics.getMetric("NonHeap"));
                arrayList.addAll(metrics.getMetric(Device.Type.CPU));
                arrayList.addAll(metrics.getMetric("rss"));
                int gpuCount = CudaUtils.getGpuCount();
                for (int i = 0; i < gpuCount; i++) {
                    arrayList.addAll(metrics.getMetric("GPU-" + i));
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    newBufferedWriter.append((CharSequence) ((Metric) it.next()).toString());
                    newBufferedWriter.newLine();
                }
                if (newBufferedWriter != null) {
                    newBufferedWriter.close();
                }
            } finally {
            }
        } catch (IOException e) {
            logger.error("Failed dump memory log", (Throwable) e);
        }
    }

    private static void getProcessInfo(Metrics metrics) {
        if (System.getProperty("os.name").startsWith(MachineMetadata.PLATFORM_LINUX) || System.getProperty("os.name").startsWith("Mac")) {
            String str = "ps -o %cpu= -o rss= -p " + ManagementFactory.getRuntimeMXBean().getName().split("@")[0];
            try {
                InputStream inputStream = Runtime.getRuntime().exec(str).getInputStream();
                try {
                    String trim = new String(readAll(inputStream), StandardCharsets.UTF_8).trim();
                    String[] split = trim.split("\\s+");
                    if (split.length != 2) {
                        logger.error("Invalid ps output: {}", trim);
                        if (inputStream != null) {
                            inputStream.close();
                            return;
                        }
                        return;
                    }
                    float parseFloat = Float.parseFloat(split[0]);
                    long parseLong = Long.parseLong(split[1]) * RealWebSocket.DEFAULT_MINIMUM_DEFLATE_SIZE;
                    metrics.addMetric(Device.Type.CPU, Float.valueOf(parseFloat), Unit.PERCENT, new Dimension[0]);
                    metrics.addMetric("rss", Long.valueOf(parseLong), Unit.BYTES, new Dimension[0]);
                    if (inputStream != null) {
                        inputStream.close();
                    }
                } finally {
                }
            } catch (IOException e) {
                logger.error("Failed execute cmd: {}", str, e);
            }
        }
    }

    private static byte[] readAll(InputStream inputStream) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            byte[] bArr = new byte[8192];
            while (true) {
                int read = inputStream.read(bArr);
                if (read == -1) {
                    byte[] byteArray = byteArrayOutputStream.toByteArray();
                    byteArrayOutputStream.close();
                    return byteArray;
                }
                byteArrayOutputStream.write(bArr, 0, read);
            }
        } catch (Throwable th) {
            try {
                byteArrayOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
