/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.constraint;

import java.util.Collections;
import java.util.Set;
import org.deeplearning4j.nn.conf.constraint.BaseConstraint;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;

public class MinMaxNormConstraint
extends BaseConstraint {
    public static final double DEFAULT_RATE = 1.0;
    private double min;
    private double max;
    private double rate;

    private MinMaxNormConstraint() {
    }

    public MinMaxNormConstraint(double min, double max, int ... dimensions) {
        this(min, max, 1.0, null, dimensions);
    }

    public MinMaxNormConstraint(double min, double max, double rate, int ... dimensions) {
        this(min, max, rate, Collections.emptySet(), dimensions);
    }

    public MinMaxNormConstraint(double min, double max, double rate, Set<String> paramNames, int ... dimensions) {
        super(paramNames, dimensions);
        if (rate <= 0.0 || rate > 1.0) {
            throw new IllegalStateException("Invalid rate: must be in interval (0,1]: got " + rate);
        }
        this.min = min;
        this.max = max;
        this.rate = rate;
    }

    @Override
    public void apply(INDArray param) {
        INDArray norm = param.norm2(this.dimensions);
        INDArray clipped = norm.unsafeDuplication();
        DynamicCustomOp op = DynamicCustomOp.builder((String)"clipbyvalue").addInputs(new INDArray[]{clipped}).callInplace(true).addFloatingPointArguments(new Double[]{this.min, this.max}).build();
        Nd4j.getExecutioner().exec((CustomOp)op);
        norm.addi((Number)this.epsilon);
        clipped.divi(norm);
        if (this.rate != 1.0) {
            clipped.muli((Number)this.rate).addi(norm.muli((Number)(1.0 - this.rate)));
        }
        Broadcast.mul((INDArray)param, (INDArray)clipped, (INDArray)param, (int[])MinMaxNormConstraint.getBroadcastDims(this.dimensions, param.rank()));
    }

    @Override
    public MinMaxNormConstraint clone() {
        return new MinMaxNormConstraint(this.min, this.max, this.rate, this.params, this.dimensions);
    }

    public double getMin() {
        return this.min;
    }

    public double getMax() {
        return this.max;
    }

    public double getRate() {
        return this.rate;
    }

    public void setMin(double min) {
        this.min = min;
    }

    public void setMax(double max) {
        this.max = max;
    }

    public void setRate(double rate) {
        this.rate = rate;
    }

    @Override
    public String toString() {
        return "MinMaxNormConstraint(min=" + this.getMin() + ", max=" + this.getMax() + ", rate=" + this.getRate() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MinMaxNormConstraint)) {
            return false;
        }
        MinMaxNormConstraint other = (MinMaxNormConstraint)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (Double.compare(this.getMin(), other.getMin()) != 0) {
            return false;
        }
        if (Double.compare(this.getMax(), other.getMax()) != 0) {
            return false;
        }
        return Double.compare(this.getRate(), other.getRate()) == 0;
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof MinMaxNormConstraint;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $min = Double.doubleToLongBits(this.getMin());
        result = result * 59 + (int)($min >>> 32 ^ $min);
        long $max = Double.doubleToLongBits(this.getMax());
        result = result * 59 + (int)($max >>> 32 ^ $max);
        long $rate = Double.doubleToLongBits(this.getRate());
        result = result * 59 + (int)($rate >>> 32 ^ $rate);
        return result;
    }
}

