// ------------   D O   N O T   E D I T !   ------------
// This file is generated from a config definition file.

package com.yahoo.embedding;

import java.util.*;
import java.io.File;
import java.nio.file.Path;
import com.yahoo.config.*;

/**
 * This class represents the root node of splade-embedder
 *
 * Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
 */
public final class SpladeEmbedderConfig extends ConfigInstance {

  public final static String CONFIG_DEF_MD5 = "18fdf7febcb71658372f618e0e800429";
  public final static String CONFIG_DEF_NAME = "splade-embedder";
  public final static String CONFIG_DEF_NAMESPACE = "embedding";
  public final static String[] CONFIG_DEF_SCHEMA = {
    "namespace=embedding",
    "tokenizerPath model",
    "transformerModel model",
    "transformerMaxTokens int default=512",
    "transformerInputIds string default=input_ids",
    "transformerAttentionMask string default=attention_mask",
    "transformerTokenTypeIds string default=token_type_ids",
    "transformerOutput string default=logits",
    "termScoreThreshold double default=0.0",
    "transformerExecutionMode enum { parallel, sequential } default=sequential",
    "transformerInterOpThreads int default=1",
    "transformerIntraOpThreads int default=-4",
    "transformerGpuDevice int default=0"
  };

  public static String getDefMd5()       { return CONFIG_DEF_MD5; }
  public static String getDefName()      { return CONFIG_DEF_NAME; }
  public static String getDefNamespace() { return CONFIG_DEF_NAMESPACE; }

  public interface Producer extends ConfigInstance.Producer {
    void getConfig(Builder builder);
  }

  public static final class Builder implements ConfigInstance.Builder {
    private Set<String> __uninitialized = new HashSet<String>(List.of(
      "tokenizerPath",
      "transformerModel"
      ));

    private ModelReference tokenizerPath = null;
    private ModelReference transformerModel = null;
    private Integer transformerMaxTokens = null;
    private String transformerInputIds = null;
    private String transformerAttentionMask = null;
    private String transformerTokenTypeIds = null;
    private String transformerOutput = null;
    private Double termScoreThreshold = null;
    private TransformerExecutionMode.Enum transformerExecutionMode = null;
    private Integer transformerInterOpThreads = null;
    private Integer transformerIntraOpThreads = null;
    private Integer transformerGpuDevice = null;

    public Builder() { }

    public Builder(SpladeEmbedderConfig config) {
      tokenizerPath(config.tokenizerPath.getModelReference());
      transformerModel(config.transformerModel.getModelReference());
      transformerMaxTokens(config.transformerMaxTokens());
      transformerInputIds(config.transformerInputIds());
      transformerAttentionMask(config.transformerAttentionMask());
      transformerTokenTypeIds(config.transformerTokenTypeIds());
      transformerOutput(config.transformerOutput());
      termScoreThreshold(config.termScoreThreshold());
      transformerExecutionMode(config.transformerExecutionMode());
      transformerInterOpThreads(config.transformerInterOpThreads());
      transformerIntraOpThreads(config.transformerIntraOpThreads());
      transformerGpuDevice(config.transformerGpuDevice());
    }

    private Builder override(Builder __superior) {
      if (__superior.tokenizerPath != null)
        tokenizerPath(__superior.tokenizerPath);
      if (__superior.transformerModel != null)
        transformerModel(__superior.transformerModel);
      if (__superior.transformerMaxTokens != null)
        transformerMaxTokens(__superior.transformerMaxTokens);
      if (__superior.transformerInputIds != null)
        transformerInputIds(__superior.transformerInputIds);
      if (__superior.transformerAttentionMask != null)
        transformerAttentionMask(__superior.transformerAttentionMask);
      if (__superior.transformerTokenTypeIds != null)
        transformerTokenTypeIds(__superior.transformerTokenTypeIds);
      if (__superior.transformerOutput != null)
        transformerOutput(__superior.transformerOutput);
      if (__superior.termScoreThreshold != null)
        termScoreThreshold(__superior.termScoreThreshold);
      if (__superior.transformerExecutionMode != null)
        transformerExecutionMode(__superior.transformerExecutionMode);
      if (__superior.transformerInterOpThreads != null)
        transformerInterOpThreads(__superior.transformerInterOpThreads);
      if (__superior.transformerIntraOpThreads != null)
        transformerIntraOpThreads(__superior.transformerIntraOpThreads);
      if (__superior.transformerGpuDevice != null)
        transformerGpuDevice(__superior.transformerGpuDevice);
      return this;
    }

    public Builder tokenizerPath(ModelReference __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      tokenizerPath = __value;
      __uninitialized.remove("tokenizerPath");
      return this;
    }


    public Builder transformerModel(ModelReference __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerModel = __value;
      __uninitialized.remove("transformerModel");
      return this;
    }


    public Builder transformerMaxTokens(int __value) {
      transformerMaxTokens = __value;
      return this;
    }

    private Builder transformerMaxTokens(String __value) {
      return transformerMaxTokens(Integer.valueOf(__value));
    }

    public Builder transformerInputIds(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerInputIds = __value;
      return this;
    }


    public Builder transformerAttentionMask(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerAttentionMask = __value;
      return this;
    }


    public Builder transformerTokenTypeIds(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerTokenTypeIds = __value;
      return this;
    }


    public Builder transformerOutput(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerOutput = __value;
      return this;
    }


    public Builder termScoreThreshold(double __value) {
      termScoreThreshold = __value;
      return this;
    }

    private Builder termScoreThreshold(String __value) {
      return termScoreThreshold(Double.valueOf(__value));
    }

    public Builder transformerExecutionMode(TransformerExecutionMode.Enum __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerExecutionMode = __value;
      return this;
    }

    private Builder transformerExecutionMode(String __value) {
      return transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(__value));
    }

    public Builder transformerInterOpThreads(int __value) {
      transformerInterOpThreads = __value;
      return this;
    }

    private Builder transformerInterOpThreads(String __value) {
      return transformerInterOpThreads(Integer.valueOf(__value));
    }

    public Builder transformerIntraOpThreads(int __value) {
      transformerIntraOpThreads = __value;
      return this;
    }

    private Builder transformerIntraOpThreads(String __value) {
      return transformerIntraOpThreads(Integer.valueOf(__value));
    }

    public Builder transformerGpuDevice(int __value) {
      transformerGpuDevice = __value;
      return this;
    }

    private Builder transformerGpuDevice(String __value) {
      return transformerGpuDevice(Integer.valueOf(__value));
    }

    private boolean _applyOnRestart = false;

    @java.lang.Override
    public final boolean dispatchGetConfig(ConfigInstance.Producer producer) {
      if (producer instanceof Producer) {
        ((Producer)producer).getConfig(this);
        return true;
      }
      return false;
    }

    @java.lang.Override
    public final String getDefMd5() { return CONFIG_DEF_MD5; }

    @java.lang.Override
    public final String getDefName() { return CONFIG_DEF_NAME; }

    @java.lang.Override
    public final String getDefNamespace() { return CONFIG_DEF_NAMESPACE; }

    @java.lang.Override
    public final boolean getApplyOnRestart() { return _applyOnRestart; }

    @java.lang.Override
    public final void setApplyOnRestart(boolean applyOnRestart) { _applyOnRestart = applyOnRestart; }

    public SpladeEmbedderConfig build() {
      return new SpladeEmbedderConfig(this);
    }

  }

  // Path to tokenizer.json
  private final ModelNode tokenizerPath;
  // Path to model.onnx
  private final ModelNode transformerModel;
  // Max length of token sequence model can handle
  private final IntegerNode transformerMaxTokens;
  // Input names
  private final StringNode transformerInputIds;
  private final StringNode transformerAttentionMask;
  private final StringNode transformerTokenTypeIds;
  // Output name
  private final StringNode transformerOutput;
  // score threshold to control sparseness
  private final DoubleNode termScoreThreshold;
  // Settings for ONNX model evaluation
  private final TransformerExecutionMode transformerExecutionMode;
  private final IntegerNode transformerInterOpThreads;
  private final IntegerNode transformerIntraOpThreads;
  // GPU device id, -1 for CPU
  private final IntegerNode transformerGpuDevice;

  public SpladeEmbedderConfig(Builder builder) {
    this(builder, true);
  }

  private SpladeEmbedderConfig(Builder builder, boolean throwIfUninitialized) {
    if (throwIfUninitialized && ! builder.__uninitialized.isEmpty())
      throw new IllegalArgumentException("The following builder parameters for " +
          "splade-embedder must be initialized: " + builder.__uninitialized);

    tokenizerPath = (builder.tokenizerPath == null) ?
        new ModelNode() : new ModelNode(builder.tokenizerPath);
    transformerModel = (builder.transformerModel == null) ?
        new ModelNode() : new ModelNode(builder.transformerModel);
    transformerMaxTokens = (builder.transformerMaxTokens == null) ?
        new IntegerNode(512) : new IntegerNode(builder.transformerMaxTokens);
    transformerInputIds = (builder.transformerInputIds == null) ?
        new StringNode("input_ids") : new StringNode(builder.transformerInputIds);
    transformerAttentionMask = (builder.transformerAttentionMask == null) ?
        new StringNode("attention_mask") : new StringNode(builder.transformerAttentionMask);
    transformerTokenTypeIds = (builder.transformerTokenTypeIds == null) ?
        new StringNode("token_type_ids") : new StringNode(builder.transformerTokenTypeIds);
    transformerOutput = (builder.transformerOutput == null) ?
        new StringNode("logits") : new StringNode(builder.transformerOutput);
    termScoreThreshold = (builder.termScoreThreshold == null) ?
        new DoubleNode(0.0D) : new DoubleNode(builder.termScoreThreshold);
    transformerExecutionMode = (builder.transformerExecutionMode == null) ?
        new TransformerExecutionMode(TransformerExecutionMode.sequential) : new TransformerExecutionMode(builder.transformerExecutionMode);
    transformerInterOpThreads = (builder.transformerInterOpThreads == null) ?
        new IntegerNode(1) : new IntegerNode(builder.transformerInterOpThreads);
    transformerIntraOpThreads = (builder.transformerIntraOpThreads == null) ?
        new IntegerNode(-4) : new IntegerNode(builder.transformerIntraOpThreads);
    transformerGpuDevice = (builder.transformerGpuDevice == null) ?
        new IntegerNode(0) : new IntegerNode(builder.transformerGpuDevice);
  }

  /**
   * @return splade-embedder.tokenizerPath
   */
  public Path tokenizerPath() {
    return tokenizerPath.value();
  }

  /**
   * @return splade-embedder.tokenizerPath ModelReference
   */
  public ModelReference tokenizerPathReference() {
    return tokenizerPath.getModelReference();
  }

  /**
   * @return splade-embedder.transformerModel
   */
  public Path transformerModel() {
    return transformerModel.value();
  }

  /**
   * @return splade-embedder.transformerModel ModelReference
   */
  public ModelReference transformerModelReference() {
    return transformerModel.getModelReference();
  }

  /**
   * @return splade-embedder.transformerMaxTokens
   */
  public int transformerMaxTokens() {
    return transformerMaxTokens.value();
  }

  /**
   * @return splade-embedder.transformerInputIds
   */
  public String transformerInputIds() {
    return transformerInputIds.value();
  }

  /**
   * @return splade-embedder.transformerAttentionMask
   */
  public String transformerAttentionMask() {
    return transformerAttentionMask.value();
  }

  /**
   * @return splade-embedder.transformerTokenTypeIds
   */
  public String transformerTokenTypeIds() {
    return transformerTokenTypeIds.value();
  }

  /**
   * @return splade-embedder.transformerOutput
   */
  public String transformerOutput() {
    return transformerOutput.value();
  }

  /**
   * @return splade-embedder.termScoreThreshold
   */
  public double termScoreThreshold() {
    return termScoreThreshold.value();
  }

  /**
   * @return splade-embedder.transformerExecutionMode
   */
  public TransformerExecutionMode.Enum transformerExecutionMode() {
    return transformerExecutionMode.value();
  }

  /**
   * @return splade-embedder.transformerInterOpThreads
   */
  public int transformerInterOpThreads() {
    return transformerInterOpThreads.value();
  }

  /**
   * @return splade-embedder.transformerIntraOpThreads
   */
  public int transformerIntraOpThreads() {
    return transformerIntraOpThreads.value();
  }

  /**
   * @return splade-embedder.transformerGpuDevice
   */
  public int transformerGpuDevice() {
    return transformerGpuDevice.value();
  }

  private ChangesRequiringRestart getChangesRequiringRestart(SpladeEmbedderConfig newConfig) {
    ChangesRequiringRestart changes = new ChangesRequiringRestart("splade-embedder");
    return changes;
  }

  private static boolean containsFieldsFlaggedWithRestart() {
    return false;
  }

  /**
   * This class represents splade-embedder.transformerExecutionMode
   * 
   * Settings for ONNX model evaluation
   */
  public final static class TransformerExecutionMode extends EnumNode<TransformerExecutionMode.Enum> {

    public TransformerExecutionMode(){
      this.value = null;
    }

    public TransformerExecutionMode(Enum enumValue) {
      super(enumValue != null);
      this.value = enumValue;
    }

    public enum Enum {parallel, sequential}
    public final static Enum parallel = Enum.parallel;
    public final static Enum sequential = Enum.sequential;

    @Override
    protected boolean doSetValue(String name) {
      try {
        value = Enum.valueOf(name);
        return true;
      } catch (IllegalArgumentException e) {
      }
      return false;
    }
  }

}
