package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:ai/djl/training/loss/TabNetRegressionLoss.class */
public class TabNetRegressionLoss extends Loss {
    public TabNetRegressionLoss() {
        this("TabNetRegressionLoss");
    }

    public TabNetRegressionLoss(String str) {
        super(str);
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        return nDList.singletonOrThrow().sub(nDList2.get(0)).square().mean().add(nDList2.get(1));
    }
}
