package ai.djl.training.initializer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

/* loaded from: input_file:ai/djl/training/initializer/TruncatedNormalInitializer.class */
public class TruncatedNormalInitializer implements Initializer {
    private final float sigma;

    public TruncatedNormalInitializer() {
        this(0.01f);
    }

    public TruncatedNormalInitializer(float f) {
        this.sigma = f;
    }

    @Override // ai.djl.training.initializer.Initializer
    public NDArray initialize(NDManager nDManager, Shape shape, DataType dataType) {
        long size = shape.size();
        if (size < 0) {
            throw new IllegalArgumentException("Shape is not determined.");
        }
        NDManager newSubManager = nDManager.newSubManager();
        NDArray create = newSubManager.create(new float[0], new Shape(0));
        int i = 0;
        NDArray create2 = newSubManager.create((-2.0f) * this.sigma);
        NDArray create3 = newSubManager.create(2.0f * this.sigma);
        while (create.size() < size) {
            NDArray randomNormal = newSubManager.randomNormal(0.0f, this.sigma, new Shape((long) ((size - create.size()) * 1.1d)), dataType, newSubManager.getDevice());
            create = create.concat(randomNormal.get(randomNormal.gt(create2).logicalAnd(randomNormal.lt(create3))));
            i++;
            if (i > 10) {
                throw new IllegalStateException("Initialization of truncated normal takes too long - This is incredibly unlikely, something must be seriously wrong.");
            }
        }
        NDArray reshape = create.get(new NDIndex().addSliceDim(0L, size)).reshape(shape);
        reshape.attach(nDManager);
        newSubManager.close();
        return reshape;
    }
}
