package com.redis.om.spring.vectorize;

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.translator.ImageFeatureExtractor;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import com.redis.om.spring.annotations.Vectorize;
import com.redis.om.spring.util.ObjectUtils;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeanWrapper;
import org.springframework.beans.PropertyAccessorFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.data.redis.core.RedisOperations;
import org.springframework.stereotype.Component;

@Component
/* loaded from: input_file:com/redis/om/spring/vectorize/FeatureExtractor.class */
public class FeatureExtractor {
    private final RedisOperations<?, ?> redisOperations;
    private final ZooModel<Image, byte[]> imageEmbeddingModel;
    private final ZooModel<Image, float[]> faceEmbeddingModel;
    private final ImageFactory imageFactory;
    private final ApplicationContext applicationContext;
    private ImageFeatureExtractor imageFeatureExtractor;
    public final Pipeline imagePipeline;
    public final HuggingFaceTokenizer sentenceTokenizer;
    private static final Log logger = LogFactory.getLog(FeatureExtractor.class);

    public FeatureExtractor(RedisOperations<?, ?> redisOperations, ApplicationContext applicationContext, ZooModel<Image, byte[]> zooModel, ZooModel<Image, float[]> zooModel2, ImageFactory imageFactory, Pipeline pipeline, HuggingFaceTokenizer huggingFaceTokenizer) {
        this.redisOperations = redisOperations;
        this.applicationContext = applicationContext;
        this.imageEmbeddingModel = zooModel;
        this.faceEmbeddingModel = zooModel2;
        this.imageFactory = imageFactory;
        this.imagePipeline = pipeline;
        this.sentenceTokenizer = huggingFaceTokenizer;
        this.imageFeatureExtractor = ImageFeatureExtractor.builder().setPipeline(pipeline).build();
    }

    public void processEntity(byte[] bArr, Object obj) {
        processEntity(obj);
    }

    public byte[] getImageEmbeddingsFor(InputStream inputStream) {
        try {
            return (byte[]) this.imageEmbeddingModel.newPredictor(this.imageFeatureExtractor).predict(this.imageFactory.fromInputStream(inputStream));
        } catch (IOException | TranslateException e) {
            logger.warn("Error generating image embedding", e);
            return new byte[0];
        }
    }

    public byte[] getFacialImageEmbeddingsFor(InputStream inputStream) throws IOException, TranslateException {
        Predictor newPredictor = this.faceEmbeddingModel.newPredictor();
        try {
            byte[] floatArrayToByteArray = ObjectUtils.floatArrayToByteArray((float[]) newPredictor.predict(this.imageFactory.fromInputStream(inputStream)));
            if (newPredictor != null) {
                newPredictor.close();
            }
            return floatArrayToByteArray;
        } catch (Throwable th) {
            if (newPredictor != null) {
                try {
                    newPredictor.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public byte[] getSentenceEmbeddingsFor(String str) {
        return ObjectUtils.longArrayToByteArray(this.sentenceTokenizer.encode(str).getIds());
    }

    public void processEntity(Object obj) {
        if (isReady()) {
            List<Field> fieldsWithAnnotation = ObjectUtils.getFieldsWithAnnotation(obj.getClass(), Vectorize.class);
            if (fieldsWithAnnotation.isEmpty()) {
                return;
            }
            BeanWrapper forBeanPropertyAccess = PropertyAccessorFactory.forBeanPropertyAccess(obj);
            fieldsWithAnnotation.forEach(field -> {
                Vectorize vectorize = (Vectorize) field.getAnnotation(Vectorize.class);
                Object propertyValue = forBeanPropertyAccess.getPropertyValue(field.getName());
                if (propertyValue != null) {
                    switch (vectorize.embeddingType()) {
                        case IMAGE:
                            try {
                                forBeanPropertyAccess.setPropertyValue(vectorize.destination(), getImageEmbeddingsFor(this.applicationContext.getResource(propertyValue.toString()).getInputStream()));
                                return;
                            } catch (IOException e) {
                                logger.warn("Error generating image embedding", e);
                                return;
                            }
                        case WORD:
                        default:
                            return;
                        case FACE:
                            try {
                                forBeanPropertyAccess.setPropertyValue(vectorize.destination(), getFacialImageEmbeddingsFor(this.applicationContext.getResource(propertyValue.toString()).getInputStream()));
                                return;
                            } catch (IOException | TranslateException e2) {
                                logger.warn("Error generating facial image embedding", e2);
                                return;
                            }
                        case SENTENCE:
                            forBeanPropertyAccess.setPropertyValue(vectorize.destination(), getSentenceEmbeddingsFor(propertyValue.toString()));
                            return;
                    }
                }
            });
        }
    }

    public boolean isReady() {
        return (this.faceEmbeddingModel == null || this.sentenceTokenizer == null) ? false : true;
    }
}
