package com.google.cloud.vertexai.generativeai.preview;

import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensRequest;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.GenerateContentRequest;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/google/cloud/vertexai/generativeai/preview/GenerativeModel.class */
public class GenerativeModel {
    private final String modelName;
    private final String resourceName;
    private final VertexAI vertexAi;
    private GenerationConfig generationConfig;
    private List<SafetySetting> safetySettings;
    private Transport transport;

    public GenerativeModel(String str, VertexAI vertexAI) {
        this(str, null, null, vertexAI, null);
    }

    public GenerativeModel(String str, VertexAI vertexAI, Transport transport) {
        this(str, null, null, vertexAI, transport);
    }

    public GenerativeModel(String str, GenerationConfig generationConfig, VertexAI vertexAI) {
        this(str, generationConfig, null, vertexAI, null);
    }

    public GenerativeModel(String str, GenerationConfig generationConfig, VertexAI vertexAI, Transport transport) {
        this(str, generationConfig, null, vertexAI, transport);
    }

    public GenerativeModel(String str, List<SafetySetting> list, VertexAI vertexAI) {
        this(str, null, list, vertexAI, null);
    }

    public GenerativeModel(String str, List<SafetySetting> list, VertexAI vertexAI, Transport transport) {
        this(str, null, list, vertexAI, transport);
    }

    public GenerativeModel(String str, GenerationConfig generationConfig, List<SafetySetting> list, VertexAI vertexAI) {
        this(str, generationConfig, list, vertexAI, null);
    }

    public GenerativeModel(String str, GenerationConfig generationConfig, List<SafetySetting> list, VertexAI vertexAI, Transport transport) {
        this.generationConfig = null;
        this.safetySettings = null;
        validateModelName(str);
        this.modelName = str;
        this.resourceName = String.format("projects/%s/locations/%s/publishers/google/models/%s", vertexAI.getProjectId(), vertexAI.getLocation(), str);
        if (generationConfig != null) {
            this.generationConfig = generationConfig;
        }
        if (list != null) {
            this.safetySettings = new ArrayList();
            Iterator<SafetySetting> it = list.iterator();
            while (it.hasNext()) {
                this.safetySettings.add(it.next());
            }
        }
        this.vertexAi = vertexAI;
        if (transport != null) {
            this.transport = transport;
        } else {
            this.transport = vertexAI.getTransport();
        }
    }

    public CountTokensResponse countTokens(String str) throws IOException {
        return countTokensFromBuilder(CountTokensRequest.newBuilder().addAllContents(Arrays.asList(ContentMaker.fromString(str))));
    }

    public CountTokensResponse countTokens(Content content) throws IOException {
        return countTokens(Arrays.asList(content));
    }

    public CountTokensResponse countTokens(List<Content> list) throws IOException {
        return countTokensFromBuilder(CountTokensRequest.newBuilder().addAllContents(list));
    }

    private CountTokensResponse countTokensFromBuilder(CountTokensRequest.Builder builder) throws IOException {
        CountTokensRequest build = builder.setEndpoint(this.resourceName).setModel(this.resourceName).build();
        return this.transport == Transport.REST ? this.vertexAi.getPredictionServiceRestClient().countTokens(build) : this.vertexAi.getPredictionServiceClient().countTokens(build);
    }

    public GenerateContentResponse generateContent(String str) throws IOException {
        return generateContent(str, (GenerationConfig) null, (List<SafetySetting>) null);
    }

    public GenerateContentResponse generateContent(String str, GenerationConfig generationConfig) throws IOException {
        return generateContent(str, generationConfig, (List<SafetySetting>) null);
    }

    public GenerateContentResponse generateContent(String str, List<SafetySetting> list) throws IOException {
        return generateContent(str, (GenerationConfig) null, list);
    }

    public GenerateContentResponse generateContent(String str, GenerationConfig generationConfig, List<SafetySetting> list) throws IOException {
        return generateContent(Arrays.asList(Content.newBuilder().addParts(Part.newBuilder().setText(str).build()).setRole("user").build()), generationConfig, list);
    }

    public GenerateContentResponse generateContent(List<Content> list) throws IOException {
        return generateContent(list, (GenerationConfig) null, (List<SafetySetting>) null);
    }

    public GenerateContentResponse generateContent(List<Content> list, GenerationConfig generationConfig) throws IOException {
        return generateContent(list, generationConfig, (List<SafetySetting>) null);
    }

    public GenerateContentResponse generateContent(List<Content> list, List<SafetySetting> list2) throws IOException {
        return generateContent(list, (GenerationConfig) null, list2);
    }

    public GenerateContentResponse generateContent(List<Content> list, GenerationConfig generationConfig, List<SafetySetting> list2) throws IOException {
        GenerateContentRequest.Builder addAllContents = GenerateContentRequest.newBuilder().addAllContents(list);
        if (generationConfig != null) {
            addAllContents.setGenerationConfig(generationConfig);
        } else if (this.generationConfig != null) {
            addAllContents.setGenerationConfig(this.generationConfig);
        }
        if (list2 != null) {
            addAllContents.addAllSafetySettings(list2);
        } else if (this.safetySettings != null) {
            addAllContents.addAllSafetySettings(this.safetySettings);
        }
        return ResponseHandler.aggregateStreamIntoResponse(generateContentStream(addAllContents));
    }

    public GenerateContentResponse generateContent(Content content) throws IOException {
        return generateContent(content, (GenerationConfig) null, (List<SafetySetting>) null);
    }

    public GenerateContentResponse generateContent(Content content, GenerationConfig generationConfig) throws IOException {
        return generateContent(content, generationConfig, (List<SafetySetting>) null);
    }

    public GenerateContentResponse generateContent(Content content, List<SafetySetting> list) throws IOException {
        return generateContent(content, (GenerationConfig) null, list);
    }

    public GenerateContentResponse generateContent(Content content, GenerationConfig generationConfig, List<SafetySetting> list) throws IOException {
        return generateContent(Arrays.asList(content), generationConfig, list);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(String str) throws IOException {
        return generateContentStream(str, (GenerationConfig) null, (List<SafetySetting>) null);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(String str, GenerationConfig generationConfig) throws IOException {
        return generateContentStream(str, generationConfig, (List<SafetySetting>) null);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(String str, List<SafetySetting> list) throws IOException {
        return generateContentStream(str, (GenerationConfig) null, list);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(String str, GenerationConfig generationConfig, List<SafetySetting> list) throws IOException {
        return generateContentStream(Arrays.asList(Content.newBuilder().addParts(Part.newBuilder().setText(str).build()).setRole("user").build()), generationConfig, list);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(Content content) throws IOException {
        return generateContentStream(content, (GenerationConfig) null, (List<SafetySetting>) null);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(Content content, GenerationConfig generationConfig) throws IOException {
        return generateContentStream(content, generationConfig, (List<SafetySetting>) null);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(Content content, List<SafetySetting> list) throws IOException {
        return generateContentStream(content, (GenerationConfig) null, list);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(Content content, GenerationConfig generationConfig, List<SafetySetting> list) throws IOException {
        return generateContentStream(Arrays.asList(content), generationConfig, list);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(List<Content> list) throws IOException {
        return generateContentStream(list, (GenerationConfig) null, (List<SafetySetting>) null);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(List<Content> list, GenerationConfig generationConfig) throws IOException {
        return generateContentStream(list, generationConfig, (List<SafetySetting>) null);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(List<Content> list, List<SafetySetting> list2) throws IOException {
        return generateContentStream(list, (GenerationConfig) null, list2);
    }

    public ResponseStream<GenerateContentResponse> generateContentStream(List<Content> list, GenerationConfig generationConfig, List<SafetySetting> list2) throws IOException {
        GenerateContentRequest.Builder addAllContents = GenerateContentRequest.newBuilder().addAllContents(list);
        if (generationConfig != null) {
            addAllContents.setGenerationConfig(generationConfig);
        } else if (this.generationConfig != null) {
            addAllContents.setGenerationConfig(this.generationConfig);
        }
        if (list2 != null) {
            addAllContents.addAllSafetySettings(list2);
        } else if (this.safetySettings != null) {
            addAllContents.addAllSafetySettings(this.safetySettings);
        }
        return generateContentStream(addAllContents);
    }

    private ResponseStream<GenerateContentResponse> generateContentStream(GenerateContentRequest.Builder builder) throws IOException {
        GenerateContentRequest build = builder.setEndpoint(this.resourceName).setModel(this.resourceName).build();
        return this.transport == Transport.REST ? new ResponseStream<>(new ResponseStreamIteratorWithHistory(this.vertexAi.getPredictionServiceRestClient().streamGenerateContentCallable().call(build).iterator())) : new ResponseStream<>(new ResponseStreamIteratorWithHistory(this.vertexAi.getPredictionServiceClient().streamGenerateContentCallable().call(build).iterator()));
    }

    public void setGenerationConfig(GenerationConfig generationConfig) {
        this.generationConfig = generationConfig;
    }

    public void setSafetySettings(List<SafetySetting> list) {
        this.safetySettings = new ArrayList();
        Iterator<SafetySetting> it = list.iterator();
        while (it.hasNext()) {
            this.safetySettings.add(it.next());
        }
    }

    public void setTransport(Transport transport) {
        this.transport = transport;
    }

    public String getModelName() {
        return this.modelName;
    }

    public Transport getTransport() {
        return this.transport;
    }

    public GenerationConfig getGenerationConfig() {
        return this.generationConfig;
    }

    public List<SafetySetting> getSafetySettings() {
        if (this.safetySettings != null) {
            return Collections.unmodifiableList(this.safetySettings);
        }
        return null;
    }

    public ChatSession startChat() {
        return new ChatSession(this);
    }

    private static void validateModelName(String str) {
        if (!Constants.GENERATIVE_MODEL_NAMES.contains(str)) {
            throw new IllegalArgumentException(String.format("Invalid model name: %s. Please choose from: %s", str, Constants.GENERATIVE_MODEL_NAMES));
        }
    }
}
