package org.deeplearning4j.spark.impl.paramavg.aggregator;

import java.util.Collection;
import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.core.storage.Persistable;
import org.deeplearning4j.core.storage.StorageMetaData;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.class */
public class ParameterAveragingElementCombineFunction implements Function2<ParameterAveragingAggregationTuple, ParameterAveragingAggregationTuple, ParameterAveragingAggregationTuple> {
    public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple parameterAveragingAggregationTuple, ParameterAveragingAggregationTuple parameterAveragingAggregationTuple2) throws Exception {
        INDArray updaterStateSum;
        if (parameterAveragingAggregationTuple == null) {
            return parameterAveragingAggregationTuple2;
        }
        if (parameterAveragingAggregationTuple2 == null) {
            return parameterAveragingAggregationTuple;
        }
        if (parameterAveragingAggregationTuple.getParametersSum() == null) {
            return parameterAveragingAggregationTuple2;
        }
        if (parameterAveragingAggregationTuple2.getParametersSum() == null) {
            return parameterAveragingAggregationTuple;
        }
        INDArray addi = parameterAveragingAggregationTuple.getParametersSum().addi(parameterAveragingAggregationTuple2.getParametersSum());
        if (parameterAveragingAggregationTuple.getUpdaterStateSum() == null) {
            updaterStateSum = parameterAveragingAggregationTuple2.getUpdaterStateSum();
        } else {
            updaterStateSum = parameterAveragingAggregationTuple.getUpdaterStateSum();
            if (parameterAveragingAggregationTuple2.getUpdaterStateSum() != null) {
                updaterStateSum.addi(parameterAveragingAggregationTuple2.getUpdaterStateSum());
            }
        }
        double scoreSum = parameterAveragingAggregationTuple.getScoreSum() + parameterAveragingAggregationTuple2.getScoreSum();
        int aggregationsCount = parameterAveragingAggregationTuple.getAggregationsCount() + parameterAveragingAggregationTuple2.getAggregationsCount();
        SparkTrainingStats sparkTrainingStats = parameterAveragingAggregationTuple.getSparkTrainingStats();
        if (parameterAveragingAggregationTuple2.getSparkTrainingStats() != null) {
            if (sparkTrainingStats == null) {
                sparkTrainingStats = parameterAveragingAggregationTuple2.getSparkTrainingStats();
            } else {
                sparkTrainingStats.addOtherTrainingStats(parameterAveragingAggregationTuple2.getSparkTrainingStats());
            }
        }
        Nd4j.getExecutioner().commit();
        Collection<StorageMetaData> listenerMetaData = parameterAveragingAggregationTuple.getListenerMetaData();
        if (listenerMetaData == null) {
            listenerMetaData = parameterAveragingAggregationTuple2.getListenerMetaData();
        } else {
            Collection<StorageMetaData> listenerMetaData2 = parameterAveragingAggregationTuple2.getListenerMetaData();
            if (listenerMetaData2 != null) {
                listenerMetaData.addAll(listenerMetaData2);
            }
        }
        Collection<Persistable> listenerStaticInfo = parameterAveragingAggregationTuple.getListenerStaticInfo();
        if (listenerStaticInfo == null) {
            listenerStaticInfo = parameterAveragingAggregationTuple2.getListenerStaticInfo();
        } else {
            Collection<Persistable> listenerStaticInfo2 = parameterAveragingAggregationTuple2.getListenerStaticInfo();
            if (listenerStaticInfo2 != null) {
                listenerStaticInfo.addAll(listenerStaticInfo2);
            }
        }
        Collection<Persistable> listenerUpdates = parameterAveragingAggregationTuple.getListenerUpdates();
        if (listenerUpdates == null) {
            listenerUpdates = parameterAveragingAggregationTuple2.getListenerUpdates();
        } else {
            Collection<Persistable> listenerUpdates2 = parameterAveragingAggregationTuple2.getListenerUpdates();
            if (listenerUpdates2 != null) {
                listenerUpdates.addAll(listenerUpdates2);
            }
        }
        return new ParameterAveragingAggregationTuple(addi, updaterStateSum, scoreSum, aggregationsCount, sparkTrainingStats, listenerMetaData, listenerStaticInfo, listenerUpdates);
    }
}
