package com.github.tjake.jlama.safetensors;

import com.github.tjake.jlama.math.ActivationFunction;
import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.model.DistributedContext;
import com.github.tjake.jlama.tensor.TensorCache;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.io.Files;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/Config.class */
public class Config {
    public final int contextLength;
    public final int embeddingLength;
    public final int attentionLength;
    public final int hiddenLength;
    public final int numberOfHeads;
    public final int numberOfKeyValueHeads;
    public final int headSize;
    public final ActivationFunction.Type activationFunction;
    public final int headGroupSize;
    public final int kvLength;
    public final boolean isGQA;
    public final int numberOfLayers;
    public final float layerNormEps;
    public final Float finalLogitSoftCapping;
    public final Float attnLogitSoftCapping;
    public final Float residualMultiplier;
    public final Float attentionMultiplier;
    public final Float embeddingMultiplier;
    public final Float logitMultiplier;
    public final int vocabularySize;
    public final int bosToken;
    public final List<Integer> eosTokens;
    public final Optional<float[][]> ropeFreqs;
    public final Optional<BiMap<String, Integer>> classifcationLabels;
    private volatile DistributedContext dctx;
    private volatile File workingDirectory;
    public final TensorCache tensorCache;

    public Config(int i, int i2, int i3, int i4, int i5, int i6, float f, int i7, int i8, List<Integer> list, ActivationFunction.Type type, Double d, Double d2, Integer num, Float f2, Float f3) {
        this(i, i2, i3, i4, i5, i6, f, i7, i8, list, type, d, d2, null, Integer.valueOf(num == null ? i2 / i4 : num.intValue()), f2, f3, null, null, null, null);
    }

    public Config(int i, int i2, int i3, int i4, int i5, int i6, float f, int i7, int i8, List<Integer> list, ActivationFunction.Type type, Double d, Double d2) {
        this(i, i2, i3, i4, i5, i6, f, i7, i8, list, type, d, d2, null, Integer.valueOf(i2 / i4), null, null, null, null, null, null);
    }

    public Config(int i, int i2, int i3, int i4, int i5, int i6, float f, int i7, int i8, List<Integer> list, ActivationFunction.Type type, Double d, Double d2, Float f2, Float f3, Float f4, Float f5) {
        this(i, i2, i3, i4, i5, i6, f, i7, i8, list, type, d, d2, null, Integer.valueOf(i2 / i4), null, null, f2, f3, f4, f5);
    }

    public Config(int i, int i2, int i3, int i4, int i5, int i6, float f, int i7, int i8, List<Integer> list, ActivationFunction.Type type, Double d, Double d2, Map<String, Integer> map) {
        this(i, i2, i3, i4, i5, i6, f, i7, i8, list, type, d, d2, map, Integer.valueOf(i2 / i4), null, null, null, null, null, null);
    }

    public Config(int i, int i2, int i3, int i4, int i5, int i6, float f, int i7, int i8, List<Integer> list, ActivationFunction.Type type, Double d, Double d2, Map<String, Integer> map, Integer num, Float f2, Float f3, Float f4, Float f5, Float f6, Float f7) {
        Optional<float[][]> of;
        this.contextLength = i;
        this.attentionLength = i4 * num.intValue();
        this.embeddingLength = i2;
        this.hiddenLength = i3;
        this.numberOfHeads = i4;
        this.numberOfKeyValueHeads = i5;
        this.numberOfLayers = i6;
        this.layerNormEps = f;
        this.vocabularySize = i7;
        this.bosToken = i8;
        this.eosTokens = list;
        this.tensorCache = TensorCache.instance;
        this.headSize = num.intValue();
        this.headGroupSize = i4 / i5;
        this.kvLength = i5 * num.intValue();
        this.isGQA = i5 < i4;
        this.activationFunction = type;
        if (d == null) {
            of = Optional.empty();
        } else {
            of = Optional.of(VectorMath.precomputeFreqsCis(num.intValue(), i, d.doubleValue(), d2 == null ? 1.0d : d2.doubleValue()));
        }
        this.ropeFreqs = of;
        this.classifcationLabels = map == null ? Optional.empty() : Optional.of(ImmutableBiMap.copyOf(map));
        this.finalLogitSoftCapping = f2;
        this.attnLogitSoftCapping = f3;
        this.residualMultiplier = f4;
        this.attentionMultiplier = f5;
        this.embeddingMultiplier = f6;
        this.logitMultiplier = f7;
        this.dctx = DistributedContext.builder(this).build();
    }

    public void setDistributedContext(DistributedContext distributedContext) {
        this.dctx = distributedContext;
    }

    public void setWorkingDirectory(File file) {
        if (file == null) {
            this.workingDirectory = Files.createTempDir();
            this.workingDirectory.deleteOnExit();
        } else {
            Preconditions.checkArgument(file.isDirectory());
            this.workingDirectory = file;
        }
    }

    public Optional<File> workingDirectory() {
        return Optional.ofNullable(this.workingDirectory);
    }

    public DistributedContext dctx() {
        return this.dctx;
    }

    public int maybeMapToGroupHead(int i) {
        return !this.isGQA ? i : Math.floorDiv(i, this.headGroupSize);
    }

    public boolean isClassifier() {
        return this.classifcationLabels.isPresent();
    }
}
