// ------------   D O   N O T   E D I T !   ------------
// This file is generated from a config definition file.

package com.yahoo.embedding.huggingface;

import java.util.*;
import java.io.File;
import java.nio.file.Path;
import com.yahoo.config.*;

/**
 * This class represents the root node of hugging-face-embedder
 *
 * Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
 */
public final class HuggingFaceEmbedderConfig extends ConfigInstance {

  public final static String CONFIG_DEF_MD5 = "a82695dc8744f29bb2ce23053ac8d61d";
  public final static String CONFIG_DEF_NAME = "hugging-face-embedder";
  public final static String CONFIG_DEF_NAMESPACE = "embedding.huggingface";
  public final static String[] CONFIG_DEF_SCHEMA = {
    "namespace=embedding.huggingface",
    "tokenizerPath model",
    "transformerModel model",
    "transformerMaxTokens int default=512",
    "transformerInputIds string default=input_ids",
    "transformerAttentionMask string default=attention_mask",
    "transformerTokenTypeIds string default=token_type_ids",
    "transformerOutput string default=last_hidden_state",
    "prependQuery string default=\"\"",
    "prependDocument string default=\"\"",
    "normalize bool default=false",
    "poolingStrategy enum { cls, mean, none } default=mean",
    "transformerExecutionMode enum { parallel, sequential } default=sequential",
    "transformerInterOpThreads int default=1",
    "transformerIntraOpThreads int default=-4",
    "transformerGpuDevice int default=0",
    "transformerOnnxConfigPath path optional"
  };

  public static String getDefMd5()       { return CONFIG_DEF_MD5; }
  public static String getDefName()      { return CONFIG_DEF_NAME; }
  public static String getDefNamespace() { return CONFIG_DEF_NAMESPACE; }

  public interface Producer extends ConfigInstance.Producer {
    void getConfig(Builder builder);
  }

  public static final class Builder implements ConfigInstance.Builder {
    private Set<String> __uninitialized = new HashSet<String>(List.of(
      "tokenizerPath",
      "transformerModel"
      ));

    private ModelReference tokenizerPath = null;
    private ModelReference transformerModel = null;
    private Integer transformerMaxTokens = null;
    private String transformerInputIds = null;
    private String transformerAttentionMask = null;
    private String transformerTokenTypeIds = null;
    private String transformerOutput = null;
    private String prependQuery = null;
    private String prependDocument = null;
    private Boolean normalize = null;
    private PoolingStrategy.Enum poolingStrategy = null;
    private TransformerExecutionMode.Enum transformerExecutionMode = null;
    private Integer transformerInterOpThreads = null;
    private Integer transformerIntraOpThreads = null;
    private Integer transformerGpuDevice = null;
    private Optional<FileReference> transformerOnnxConfigPath = Optional.empty();

    public Builder() { }

    public Builder(HuggingFaceEmbedderConfig config) {
      tokenizerPath(config.tokenizerPath.getModelReference());
      transformerModel(config.transformerModel.getModelReference());
      transformerMaxTokens(config.transformerMaxTokens());
      transformerInputIds(config.transformerInputIds());
      transformerAttentionMask(config.transformerAttentionMask());
      transformerTokenTypeIds(config.transformerTokenTypeIds());
      transformerOutput(config.transformerOutput());
      prependQuery(config.prependQuery());
      prependDocument(config.prependDocument());
      normalize(config.normalize());
      poolingStrategy(config.poolingStrategy());
      transformerExecutionMode(config.transformerExecutionMode());
      transformerInterOpThreads(config.transformerInterOpThreads());
      transformerIntraOpThreads(config.transformerIntraOpThreads());
      transformerGpuDevice(config.transformerGpuDevice());
      transformerOnnxConfigPath(config.transformerOnnxConfigPath.getFileReference());
    }

    private Builder override(Builder __superior) {
      if (__superior.tokenizerPath != null)
        tokenizerPath(__superior.tokenizerPath);
      if (__superior.transformerModel != null)
        transformerModel(__superior.transformerModel);
      if (__superior.transformerMaxTokens != null)
        transformerMaxTokens(__superior.transformerMaxTokens);
      if (__superior.transformerInputIds != null)
        transformerInputIds(__superior.transformerInputIds);
      if (__superior.transformerAttentionMask != null)
        transformerAttentionMask(__superior.transformerAttentionMask);
      if (__superior.transformerTokenTypeIds != null)
        transformerTokenTypeIds(__superior.transformerTokenTypeIds);
      if (__superior.transformerOutput != null)
        transformerOutput(__superior.transformerOutput);
      if (__superior.prependQuery != null)
        prependQuery(__superior.prependQuery);
      if (__superior.prependDocument != null)
        prependDocument(__superior.prependDocument);
      if (__superior.normalize != null)
        normalize(__superior.normalize);
      if (__superior.poolingStrategy != null)
        poolingStrategy(__superior.poolingStrategy);
      if (__superior.transformerExecutionMode != null)
        transformerExecutionMode(__superior.transformerExecutionMode);
      if (__superior.transformerInterOpThreads != null)
        transformerInterOpThreads(__superior.transformerInterOpThreads);
      if (__superior.transformerIntraOpThreads != null)
        transformerIntraOpThreads(__superior.transformerIntraOpThreads);
      if (__superior.transformerGpuDevice != null)
        transformerGpuDevice(__superior.transformerGpuDevice);
      if (__superior.transformerOnnxConfigPath != null)
        transformerOnnxConfigPath(__superior.transformerOnnxConfigPath);
      return this;
    }

    public Builder tokenizerPath(ModelReference __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      tokenizerPath = __value;
      __uninitialized.remove("tokenizerPath");
      return this;
    }


    public Builder transformerModel(ModelReference __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerModel = __value;
      __uninitialized.remove("transformerModel");
      return this;
    }


    public Builder transformerMaxTokens(int __value) {
      transformerMaxTokens = __value;
      return this;
    }

    private Builder transformerMaxTokens(String __value) {
      return transformerMaxTokens(Integer.valueOf(__value));
    }

    public Builder transformerInputIds(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerInputIds = __value;
      return this;
    }


    public Builder transformerAttentionMask(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerAttentionMask = __value;
      return this;
    }


    public Builder transformerTokenTypeIds(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerTokenTypeIds = __value;
      return this;
    }


    public Builder transformerOutput(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerOutput = __value;
      return this;
    }


    public Builder prependQuery(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      prependQuery = __value;
      return this;
    }


    public Builder prependDocument(String __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      prependDocument = __value;
      return this;
    }


    public Builder normalize(boolean __value) {
      normalize = __value;
      return this;
    }

    private Builder normalize(String __value) {
      return normalize(Boolean.valueOf(__value));
    }

    public Builder poolingStrategy(PoolingStrategy.Enum __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      poolingStrategy = __value;
      return this;
    }

    private Builder poolingStrategy(String __value) {
      return poolingStrategy(PoolingStrategy.Enum.valueOf(__value));
    }

    public Builder transformerExecutionMode(TransformerExecutionMode.Enum __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerExecutionMode = __value;
      return this;
    }

    private Builder transformerExecutionMode(String __value) {
      return transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(__value));
    }

    public Builder transformerInterOpThreads(int __value) {
      transformerInterOpThreads = __value;
      return this;
    }

    private Builder transformerInterOpThreads(String __value) {
      return transformerInterOpThreads(Integer.valueOf(__value));
    }

    public Builder transformerIntraOpThreads(int __value) {
      transformerIntraOpThreads = __value;
      return this;
    }

    private Builder transformerIntraOpThreads(String __value) {
      return transformerIntraOpThreads(Integer.valueOf(__value));
    }

    public Builder transformerGpuDevice(int __value) {
      transformerGpuDevice = __value;
      return this;
    }

    private Builder transformerGpuDevice(String __value) {
      return transformerGpuDevice(Integer.valueOf(__value));
    }

    public Builder transformerOnnxConfigPath(Optional<FileReference> __value) {
    if (__value == null) throw new IllegalArgumentException("Null value is not allowed.");
      transformerOnnxConfigPath = __value;
      __uninitialized.remove("transformerOnnxConfigPath");
      return this;
    }

    private Builder transformerOnnxConfigPath(FileReference __value) {
      return transformerOnnxConfigPath(Optional.of(__value));
    }

    private boolean _applyOnRestart = false;

    @java.lang.Override
    public final boolean dispatchGetConfig(ConfigInstance.Producer producer) {
      if (producer instanceof Producer) {
        ((Producer)producer).getConfig(this);
        return true;
      }
      return false;
    }

    @java.lang.Override
    public final String getDefMd5() { return CONFIG_DEF_MD5; }

    @java.lang.Override
    public final String getDefName() { return CONFIG_DEF_NAME; }

    @java.lang.Override
    public final String getDefNamespace() { return CONFIG_DEF_NAMESPACE; }

    @java.lang.Override
    public final boolean getApplyOnRestart() { return _applyOnRestart; }

    @java.lang.Override
    public final void setApplyOnRestart(boolean applyOnRestart) { _applyOnRestart = applyOnRestart; }

    public HuggingFaceEmbedderConfig build() {
      return new HuggingFaceEmbedderConfig(this);
    }

  }

  // Path to tokenizer.json
  private final ModelNode tokenizerPath;
  // Path to model.onnx
  private final ModelNode transformerModel;
  // Max length of token sequence model can handle
  private final IntegerNode transformerMaxTokens;
  // Input names
  private final StringNode transformerInputIds;
  private final StringNode transformerAttentionMask;
  private final StringNode transformerTokenTypeIds;
  // Output name
  private final StringNode transformerOutput;
  private final StringNode prependQuery;
  private final StringNode prependDocument;
  // Normalize tensors from tokenizer
  private final BooleanNode normalize;
  private final PoolingStrategy poolingStrategy;
  // Settings for ONNX model evaluation
  private final TransformerExecutionMode transformerExecutionMode;
  private final IntegerNode transformerInterOpThreads;
  private final IntegerNode transformerIntraOpThreads;
  // GPU device id, -1 for CPU
  private final IntegerNode transformerGpuDevice;
  // Internal ONNX config file, e.g for Triton specific configuration
  private final OptionalPathNode transformerOnnxConfigPath;

  public HuggingFaceEmbedderConfig(Builder builder) {
    this(builder, true);
  }

  private HuggingFaceEmbedderConfig(Builder builder, boolean throwIfUninitialized) {
    if (throwIfUninitialized && ! builder.__uninitialized.isEmpty())
      throw new IllegalArgumentException("The following builder parameters for " +
          "hugging-face-embedder must be initialized: " + builder.__uninitialized);

    tokenizerPath = (builder.tokenizerPath == null) ?
        new ModelNode() : new ModelNode(builder.tokenizerPath);
    transformerModel = (builder.transformerModel == null) ?
        new ModelNode() : new ModelNode(builder.transformerModel);
    transformerMaxTokens = (builder.transformerMaxTokens == null) ?
        new IntegerNode(512) : new IntegerNode(builder.transformerMaxTokens);
    transformerInputIds = (builder.transformerInputIds == null) ?
        new StringNode("input_ids") : new StringNode(builder.transformerInputIds);
    transformerAttentionMask = (builder.transformerAttentionMask == null) ?
        new StringNode("attention_mask") : new StringNode(builder.transformerAttentionMask);
    transformerTokenTypeIds = (builder.transformerTokenTypeIds == null) ?
        new StringNode("token_type_ids") : new StringNode(builder.transformerTokenTypeIds);
    transformerOutput = (builder.transformerOutput == null) ?
        new StringNode("last_hidden_state") : new StringNode(builder.transformerOutput);
    prependQuery = (builder.prependQuery == null) ?
        new StringNode("") : new StringNode(builder.prependQuery);
    prependDocument = (builder.prependDocument == null) ?
        new StringNode("") : new StringNode(builder.prependDocument);
    normalize = (builder.normalize == null) ?
        new BooleanNode(false) : new BooleanNode(builder.normalize);
    poolingStrategy = (builder.poolingStrategy == null) ?
        new PoolingStrategy(PoolingStrategy.mean) : new PoolingStrategy(builder.poolingStrategy);
    transformerExecutionMode = (builder.transformerExecutionMode == null) ?
        new TransformerExecutionMode(TransformerExecutionMode.sequential) : new TransformerExecutionMode(builder.transformerExecutionMode);
    transformerInterOpThreads = (builder.transformerInterOpThreads == null) ?
        new IntegerNode(1) : new IntegerNode(builder.transformerInterOpThreads);
    transformerIntraOpThreads = (builder.transformerIntraOpThreads == null) ?
        new IntegerNode(-4) : new IntegerNode(builder.transformerIntraOpThreads);
    transformerGpuDevice = (builder.transformerGpuDevice == null) ?
        new IntegerNode(0) : new IntegerNode(builder.transformerGpuDevice);
    transformerOnnxConfigPath = (builder.transformerOnnxConfigPath == null) ?
        new OptionalPathNode() : new OptionalPathNode(builder.transformerOnnxConfigPath);
  }

  /**
   * @return hugging-face-embedder.tokenizerPath
   */
  public Path tokenizerPath() {
    return tokenizerPath.value();
  }

  /**
   * @return hugging-face-embedder.tokenizerPath ModelReference
   */
  public ModelReference tokenizerPathReference() {
    return tokenizerPath.getModelReference();
  }

  /**
   * @return hugging-face-embedder.transformerModel
   */
  public Path transformerModel() {
    return transformerModel.value();
  }

  /**
   * @return hugging-face-embedder.transformerModel ModelReference
   */
  public ModelReference transformerModelReference() {
    return transformerModel.getModelReference();
  }

  /**
   * @return hugging-face-embedder.transformerMaxTokens
   */
  public int transformerMaxTokens() {
    return transformerMaxTokens.value();
  }

  /**
   * @return hugging-face-embedder.transformerInputIds
   */
  public String transformerInputIds() {
    return transformerInputIds.value();
  }

  /**
   * @return hugging-face-embedder.transformerAttentionMask
   */
  public String transformerAttentionMask() {
    return transformerAttentionMask.value();
  }

  /**
   * @return hugging-face-embedder.transformerTokenTypeIds
   */
  public String transformerTokenTypeIds() {
    return transformerTokenTypeIds.value();
  }

  /**
   * @return hugging-face-embedder.transformerOutput
   */
  public String transformerOutput() {
    return transformerOutput.value();
  }

  /**
   * @return hugging-face-embedder.prependQuery
   */
  public String prependQuery() {
    return prependQuery.value();
  }

  /**
   * @return hugging-face-embedder.prependDocument
   */
  public String prependDocument() {
    return prependDocument.value();
  }

  /**
   * @return hugging-face-embedder.normalize
   */
  public boolean normalize() {
    return normalize.value();
  }

  /**
   * @return hugging-face-embedder.poolingStrategy
   */
  public PoolingStrategy.Enum poolingStrategy() {
    return poolingStrategy.value();
  }

  /**
   * @return hugging-face-embedder.transformerExecutionMode
   */
  public TransformerExecutionMode.Enum transformerExecutionMode() {
    return transformerExecutionMode.value();
  }

  /**
   * @return hugging-face-embedder.transformerInterOpThreads
   */
  public int transformerInterOpThreads() {
    return transformerInterOpThreads.value();
  }

  /**
   * @return hugging-face-embedder.transformerIntraOpThreads
   */
  public int transformerIntraOpThreads() {
    return transformerIntraOpThreads.value();
  }

  /**
   * @return hugging-face-embedder.transformerGpuDevice
   */
  public int transformerGpuDevice() {
    return transformerGpuDevice.value();
  }

  /**
   * @return hugging-face-embedder.transformerOnnxConfigPath
   */
  public Optional<Path> transformerOnnxConfigPath() {
    return transformerOnnxConfigPath.value();
  }

  private ChangesRequiringRestart getChangesRequiringRestart(HuggingFaceEmbedderConfig newConfig) {
    ChangesRequiringRestart changes = new ChangesRequiringRestart("hugging-face-embedder");
    return changes;
  }

  private static boolean containsFieldsFlaggedWithRestart() {
    return false;
  }

  /**
   * This class represents hugging-face-embedder.poolingStrategy
   */
  public final static class PoolingStrategy extends EnumNode<PoolingStrategy.Enum> {

    public PoolingStrategy(){
      this.value = null;
    }

    public PoolingStrategy(Enum enumValue) {
      super(enumValue != null);
      this.value = enumValue;
    }

    public enum Enum {cls, mean, none}
    public final static Enum cls = Enum.cls;
    public final static Enum mean = Enum.mean;
    public final static Enum none = Enum.none;

    @Override
    protected boolean doSetValue(String name) {
      try {
        value = Enum.valueOf(name);
        return true;
      } catch (IllegalArgumentException e) {
      }
      return false;
    }
  }

  /**
   * This class represents hugging-face-embedder.transformerExecutionMode
   * 
   * Settings for ONNX model evaluation
   */
  public final static class TransformerExecutionMode extends EnumNode<TransformerExecutionMode.Enum> {

    public TransformerExecutionMode(){
      this.value = null;
    }

    public TransformerExecutionMode(Enum enumValue) {
      super(enumValue != null);
      this.value = enumValue;
    }

    public enum Enum {parallel, sequential}
    public final static Enum parallel = Enum.parallel;
    public final static Enum sequential = Enum.sequential;

    @Override
    protected boolean doSetValue(String name) {
      try {
        value = Enum.valueOf(name);
        return true;
      } catch (IllegalArgumentException e) {
      }
      return false;
    }
  }

}
