package ai.djl.engine.rust;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.ndarray.BaseNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.Charset;

/* loaded from: input_file:lib/tokenizers-0.31.1.jar:ai/djl/engine/rust/RsNDManager.class */
public class RsNDManager extends BaseNDManager {
    private static final RsNDManager SYSTEM_MANAGER = new SystemManager();

    /* loaded from: input_file:lib/tokenizers-0.31.1.jar:ai/djl/engine/rust/RsNDManager$SystemManager.class */
    private static final class SystemManager extends RsNDManager implements NDManager.SystemNDManager {
        SystemManager() {
            super(null, null);
        }

        @Override // ai.djl.engine.rust.RsNDManager, ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
        public /* bridge */ /* synthetic */ NDArray create(Shape shape, DataType dataType) {
            return super.create(shape, dataType);
        }

        @Override // ai.djl.engine.rust.RsNDManager, ai.djl.ndarray.NDManager
        public /* bridge */ /* synthetic */ NDManager newSubManager(Device device) {
            return super.newSubManager(device);
        }

        @Override // ai.djl.engine.rust.RsNDManager, ai.djl.ndarray.NDManager
        public /* bridge */ /* synthetic */ NDArray create(Buffer buffer, Shape shape, DataType dataType) {
            return super.create(buffer, shape, dataType);
        }

        @Override // ai.djl.engine.rust.RsNDManager, ai.djl.ndarray.NDManager
        public /* bridge */ /* synthetic */ NDArray from(NDArray nDArray) {
            return super.from(nDArray);
        }
    }

    private RsNDManager(NDManager nDManager, Device device) {
        super(nDManager, device);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static RsNDManager getSystemManager() {
        return SYSTEM_MANAGER;
    }

    @Override // ai.djl.ndarray.NDManager
    public ByteBuffer allocateDirect(int i) {
        return ByteBuffer.allocateDirect(i).order(ByteOrder.nativeOrder());
    }

    @Override // ai.djl.ndarray.NDManager
    public RsNDArray from(NDArray nDArray) {
        if (nDArray == null || (nDArray instanceof RsNDArray)) {
            return (RsNDArray) nDArray;
        }
        RsNDArray create = create((Buffer) nDArray.toByteBuffer(), nDArray.getShape(), nDArray.getDataType());
        create.setName(nDArray.getName());
        return create;
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public RsNDArray create(Shape shape, DataType dataType) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        return new RsNDArray(this, RustLibrary.zeros(shape.getShape(), toRustDataType(dataType), deviceType, deviceId), dataType);
    }

    @Override // ai.djl.ndarray.NDManager
    public RsNDArray create(Buffer buffer, Shape shape, DataType dataType) {
        ByteBuffer allocateDirect;
        int intExact = Math.toIntExact(shape.size());
        BaseNDManager.validateBuffer(buffer, dataType, intExact);
        if (buffer.isDirect() && (buffer instanceof ByteBuffer)) {
            allocateDirect = (ByteBuffer) buffer;
        } else {
            allocateDirect = allocateDirect(intExact * dataType.getNumOfBytes());
            copyBuffer(buffer, allocateDirect);
        }
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        return new RsNDArray(this, RustLibrary.tensorOf(allocateDirect, shape.getShape(), toRustDataType(dataType), deviceType, deviceId), dataType, allocateDirect);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray create(String[] strArr, Charset charset, Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray createCoo(Buffer buffer, long[][] jArr, Shape shape) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // ai.djl.ndarray.NDManager
    public NDArray zeros(Shape shape, DataType dataType) {
        return create(shape, dataType);
    }

    @Override // ai.djl.ndarray.NDManager
    public NDArray ones(Shape shape, DataType dataType) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        return new RsNDArray(this, RustLibrary.ones(shape.getShape(), toRustDataType(dataType), deviceType, deviceId), dataType);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray full(Shape shape, float f, DataType dataType) {
        String deviceType = this.device.getDeviceType();
        int deviceId = this.device.getDeviceId();
        return new RsNDArray(this, RustLibrary.full(f, shape.getShape(), toRustDataType(dataType), deviceType, deviceId), dataType);
    }

    @Override // ai.djl.ndarray.NDManager
    public NDArray arange(int i, int i2, int i3, DataType dataType) {
        return arange(i, i2, i3, dataType, this.device);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray arange(float f, float f2, float f3, DataType dataType) {
        return Math.signum(f2 - f) != Math.signum(f3) ? create(new Shape(0), dataType, this.device) : new RsNDArray(this, RustLibrary.arange(f, f2, f3, toRustDataType(dataType), this.device.getDeviceType(), this.device.getDeviceId()), dataType);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray eye(int i, int i2, int i3, DataType dataType) {
        if (i3 != 0) {
            throw new UnsupportedOperationException("index of the diagonal is not supported in Rust");
        }
        if (i != i2) {
            throw new UnsupportedOperationException("rows must equals to columns in Rust");
        }
        return new RsNDArray(this, RustLibrary.eye(i, i2, toRustDataType(dataType), this.device.getDeviceType(), this.device.getDeviceId()), dataType);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray linspace(float f, float f2, int i, boolean z) {
        if (!z) {
            throw new UnsupportedOperationException("endpoint only support true");
        }
        return new RsNDArray(this, RustLibrary.linspace(f, f2, i, DataType.FLOAT32.ordinal(), this.device.getDeviceType(), this.device.getDeviceId()), DataType.FLOAT32);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray randomInteger(long j, long j2, Shape shape, DataType dataType) {
        return new RsNDArray(this, RustLibrary.randint(j, j2, shape.getShape(), DataType.FLOAT32.ordinal(), this.device.getDeviceType(), this.device.getDeviceId()), DataType.FLOAT32);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray randomPermutation(long j) {
        return new RsNDArray(this, RustLibrary.randomPermutation(j, this.device.getDeviceType(), this.device.getDeviceId()));
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray randomUniform(float f, float f2, Shape shape, DataType dataType) {
        return new RsNDArray(this, RustLibrary.uniform(f, f2, shape.getShape(), toRustDataType(dataType), this.device.getDeviceType(), this.device.getDeviceId()), dataType);
    }

    @Override // ai.djl.ndarray.BaseNDManager, ai.djl.ndarray.NDManager
    public NDArray randomNormal(float f, float f2, Shape shape, DataType dataType) {
        return new RsNDArray(this, RustLibrary.randomNormal(f, f2, shape.getShape(), toRustDataType(dataType), this.device.getDeviceType(), this.device.getDeviceId()), dataType);
    }

    @Override // ai.djl.ndarray.NDManager
    public NDArray hanningWindow(long j) {
        return new RsNDArray(this, RustLibrary.hannWindow(j, this.device.getDeviceType(), this.device.getDeviceId()));
    }

    @Override // ai.djl.ndarray.NDManager
    public RsNDManager newSubManager(Device device) {
        RsNDManager rsNDManager = new RsNDManager(this, device);
        attachUncappedInternal(rsNDManager.uid, rsNDManager);
        return rsNDManager;
    }

    @Override // ai.djl.ndarray.NDManager
    public final Engine getEngine() {
        return Engine.getEngine(RsEngine.ENGINE_NAME);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int toRustDataType(DataType dataType) {
        switch (dataType) {
            case BOOLEAN:
            case INT8:
                return DataType.UINT8.ordinal();
            case INT32:
                return DataType.UINT32.ordinal();
            case FLOAT16:
            case BFLOAT16:
            case FLOAT32:
            case FLOAT64:
            case UINT8:
            case UINT32:
            case INT64:
                return dataType.ordinal();
            default:
                throw new UnsupportedOperationException("Unsupported data type: " + dataType);
        }
    }
}
