package org.tensorflow.op;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.BaseGradientAdapter;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.internal.c_api.NativeOperation;
import org.tensorflow.internal.c_api.NativeOutputVector;
import org.tensorflow.internal.c_api.NativeStatus;
import org.tensorflow.internal.c_api.TF_Scope;
import org.tensorflow.op.RawOpInputs;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/tensorflow/op/TypedGradientAdapter.class */
public final class TypedGradientAdapter<T extends RawOpInputs<?>> extends BaseGradientAdapter {
    private final CustomGradient<T> gradient;
    private final Class<T> opInputClass;
    private final Constructor<T> ctor;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TypedGradientAdapter(CustomGradient<T> customGradient, Class<T> cls) {
        this.gradient = customGradient;
        this.opInputClass = cls;
        this.ctor = (Constructor<T>) this.opInputClass.getDeclaredConstructors()[0];
    }

    @Override // org.tensorflow.internal.c_api.GradFunc
    public NativeStatus call(TF_Scope tF_Scope, NativeOperation nativeOperation, NativeOutputVector nativeOutputVector, NativeOutputVector nativeOutputVector2) {
        try {
            PointerScope pointerScope = new PointerScope();
            try {
                Graph findGraphForPointer = Graph.findGraphForPointer(tF_Scope.graph());
                if (findGraphForPointer == null) {
                    throw new IllegalStateException("No graph found for native gradient scope.");
                }
                T newInstance = this.ctor.newInstance(BaseGradientAdapter.getGraphOp(findGraphForPointer, nativeOperation.node()));
                Ops ops = new Ops(new GradientScope(tF_Scope, findGraphForPointer, null).withSubScope(newInstance.getOutputs().op().name()));
                List<Output<?>> fromNativeOutputs = BaseGradientAdapter.fromNativeOutputs(findGraphForPointer, nativeOutputVector);
                BaseGradientAdapter.useDangerousLockedBuilders(findGraphForPointer, true);
                List<Operand<?>> call = this.gradient.call(ops, newInstance, fromNativeOutputs);
                BaseGradientAdapter.useDangerousLockedBuilders(findGraphForPointer, false);
                BaseGradientAdapter.putToNativeOutputs(call, nativeOutputVector2);
                pointerScope.close();
                return NativeStatus.OK();
            } finally {
            }
        } catch (IllegalAccessException | InstantiationException | InvocationTargetException e) {
            throw new RuntimeException("Could not instantiate Op class " + this.opInputClass, e);
        }
    }
}
