/*
 * Decompiled with CFR 0.152.
 */
package dev.brachtendorf.jimagehash.matcher.categorize.supervised.randomForest;

import com.github.kilianB.pcg.fast.PcgRSFast;
import dev.brachtendorf.ArrayUtil;
import dev.brachtendorf.MathUtil;
import dev.brachtendorf.Require;
import dev.brachtendorf.datastructures.CountHashCollection;
import dev.brachtendorf.datastructures.Pair;
import dev.brachtendorf.datastructures.Triple;
import dev.brachtendorf.jimagehash.hash.FuzzyHash;
import dev.brachtendorf.jimagehash.hash.Hash;
import dev.brachtendorf.jimagehash.hashAlgorithms.AverageHash;
import dev.brachtendorf.jimagehash.hashAlgorithms.HashingAlgorithm;
import dev.brachtendorf.jimagehash.hashAlgorithms.PerceptiveHash;
import dev.brachtendorf.jimagehash.hashAlgorithms.RotAverageHash;
import dev.brachtendorf.jimagehash.matcher.PlainImageMatcher;
import dev.brachtendorf.jimagehash.matcher.categorize.CategoricalImageMatcher;
import dev.brachtendorf.jimagehash.matcher.categorize.CategorizationResult;
import dev.brachtendorf.jimagehash.matcher.categorize.supervised.LabeledImage;
import dev.brachtendorf.jimagehash.matcher.categorize.supervised.randomForest.InnerNode;
import dev.brachtendorf.jimagehash.matcher.categorize.supervised.randomForest.LeafNode;
import dev.brachtendorf.jimagehash.matcher.categorize.supervised.randomForest.TestData;
import dev.brachtendorf.jimagehash.matcher.categorize.supervised.randomForest.TreeNode;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeSet;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import javax.imageio.ImageIO;

public class RandomForestCategorizer
extends PlainImageMatcher
implements CategoricalImageMatcher {
    protected List<TreeNode> forest = new ArrayList<TreeNode>();
    protected List<LabeledImage> labeledImages = new ArrayList<LabeledImage>();
    protected TreeSet<Integer> categories = new TreeSet();

    public void addTestImages(Collection<LabeledImage> data) {
        for (LabeledImage t : data) {
            this.addTestImages(t);
        }
    }

    public void addTestImages(LabeledImage ... data) {
        for (LabeledImage t : data) {
            this.addTestImages(t);
        }
    }

    public void addTestImages(LabeledImage lData) {
        this.categories.add(lData.getCategory());
        this.labeledImages.add(lData);
    }

    public void clearTestImages() {
        this.categories.clear();
        this.labeledImages.clear();
    }

    public void trainMatcher(int trees, int numVarsSearchRange, int numVarsRep) {
        Require.positiveValue((Number)numVarsSearchRange, (String)"NumVarsSearchRange has to be positive.");
        Require.oddValue((Number)trees, (String)"The number of trees should be odd to prevent ambiguity");
        System.out.println("");
        System.out.println("Hashing algos available: " + this.steps);
        HashMap<HashingAlgorithm, Map<BufferedImage, Hash>> preComputedHashes = new HashMap<HashingAlgorithm, Map<BufferedImage, Hash>>();
        for (HashingAlgorithm hashAlgorithm : this.getAlgorithms()) {
            HashMap<BufferedImage, Hash> hashMap = new HashMap<BufferedImage, Hash>();
            for (LabeledImage lImage : this.labeledImages) {
                hashMap.put(lImage.getbImage(), hashAlgorithm.hash(lImage.getbImage()));
            }
            preComputedHashes.put(hashAlgorithm, hashMap);
        }
        List<Pair<FuzzyHash, HashingAlgorithm>> randomHashes = this.createFuzzyHashes(preComputedHashes);
        System.out.println(randomHashes);
        ExecutorService tPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        int maxNumberOfVariables = randomHashes.size();
        int numVars = (int)Math.sqrt(maxNumberOfVariables);
        boolean newBestFound = false;
        for (int i = numVars - numVarsSearchRange; i < numVars + numVarsSearchRange || newBestFound; ++i) {
            System.out.println("Create Forest with number of vars: " + i + "/" + maxNumberOfVariables);
            if (i < 0) continue;
            if (i > maxNumberOfVariables) break;
            Object[] packedForestData = this.createForest(trees, i, numVarsRep, randomHashes, preComputedHashes, tPool);
            List forest = (List)packedForestData[0];
            double outOfBagClassificationError = (Double)packedForestData[1];
            double classificationErrorAll = (Double)packedForestData[2];
            System.out.println("Out of bag error: " + outOfBagClassificationError);
            System.out.println("Class error all : " + classificationErrorAll);
            this.forest = forest;
            System.out.println(" ");
        }
        tPool.shutdown();
        this.testDecisionTree(this.labeledImages);
    }

    private void testDecisionTree(List<LabeledImage> labeledImages) {
        int match = 0;
        int mismatch = 0;
        for (LabeledImage lImage : labeledImages) {
            if (lImage.getCategory() == this.categorizeImage(lImage.getbImage()).getCategory()) {
                ++match;
                continue;
            }
            ++mismatch;
        }
        System.out.println("Classification Error: " + (double)mismatch / ((double)match + (double)mismatch));
    }

    protected Object[] createForest(int trees, int numVars, int numVarsRep, List<Pair<FuzzyHash, HashingAlgorithm>> randomVariables, Map<HashingAlgorithm, Map<BufferedImage, Hash>> preComputedHashesTestAgainst, ExecutorService tPool) {
        HashMap bootstrappedOutOfBag = new HashMap();
        ArrayList forestCandidate = new ArrayList();
        PcgRSFast rng = new PcgRSFast();
        ArrayList<Future<Void>> tasks = new ArrayList<Future<Void>>();
        for (int i = 0; i < trees; ++i) {
            tasks.add(tPool.submit(() -> {
                CountHashCollection cHasher = new CountHashCollection();
                for (int j = 0; j < numVarsRep; ++j) {
                    cHasher.addAll((Collection)randomVariables);
                }
                ArrayList<LabeledImage> bootstrappedData = new ArrayList<LabeledImage>(this.labeledImages);
                TreeNode root = this.buildTree(bootstrappedData, (CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>>)cHasher, numVars, preComputedHashesTestAgainst);
                forestCandidate.add(root);
                return null;
            }));
        }
        for (Future future : tasks) {
            try {
                future.get();
            }
            catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
            }
        }
        return new Object[]{forestCandidate, 0.0, 0.0};
    }

    private List<TestData> bootstrapDataset(List<TestData> testData, Random rng) {
        int size = testData.size();
        ArrayList<TestData> bootstrapped = new ArrayList<TestData>(size);
        for (int i = 0; i < size; ++i) {
            bootstrapped.add(testData.get(rng.nextInt(size)));
        }
        return bootstrapped;
    }

    private TreeNode buildTree(List<LabeledImage> testData, CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> cHasher, int numVars, Map<HashingAlgorithm, Map<BufferedImage, Hash>> preComputedHashes) {
        return this.buildTree(testData, cHasher, numVars, preComputedHashes, Double.MAX_VALUE);
    }

    private TreeNode buildTree(List<LabeledImage> testData, CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> cHasher, int numVars, Map<HashingAlgorithm, Map<BufferedImage, Hash>> preComputedHashes, double threshold) {
        CountHashCollection algorithmCopy = new CountHashCollection(cHasher);
        Triple<TreeNode, List<LabeledImage>[], int[]> packed = this.computeNode(testData, (CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>>)algorithmCopy, numVars, preComputedHashes, threshold);
        if (packed.getFirst() instanceof LeafNode) {
            return (TreeNode)packed.getFirst();
        }
        InnerNode node = (InnerNode)packed.getFirst();
        if (((List[])packed.getSecond())[0].size() > 0 && !MathUtil.isDoubleEquals((double)node.qualityLeft, (double)0.0, (double)1.0E-8)) {
            node.leftNode = this.buildTree(((List[])packed.getSecond())[0], (CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>>)algorithmCopy, numVars, preComputedHashes, node.qualityLeft);
        } else {
            System.out.println("Create Leaf node");
            node.leftNode = new LeafNode(((int[])packed.getThird())[0]);
        }
        if (((List[])packed.getSecond())[1].size() > 0 && !MathUtil.isDoubleEquals((double)node.qualityRight, (double)0.0, (double)1.0E-8)) {
            node.rightNode = this.buildTree(((List[])packed.getSecond())[1], (CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>>)algorithmCopy, numVars, preComputedHashes, node.qualityRight);
        } else {
            System.out.println("Create Leaf node");
            node.rightNode = new LeafNode(((int[])packed.getThird())[1]);
        }
        return node;
    }

    private Triple<TreeNode, List<LabeledImage>[], int[]> computeNode(List<LabeledImage> testData, CountHashCollection<Pair<FuzzyHash, HashingAlgorithm>> randomVariables, int numVars, Map<HashingAlgorithm, Map<BufferedImage, Hash>> preComputedHashes, double qualityThreshold) {
        double bestCutoff = 0.0;
        Pair bestVariable = null;
        double bestGini = Double.MAX_VALUE;
        double giniLeft = Double.MAX_VALUE;
        double giniRight = Double.MAX_VALUE;
        int nodeCategoryLeft = -1;
        int nodeCategoryRight = -1;
        double bestF1Score = -1.7976931348623157E308;
        int matchSize = 0;
        ArrayList[] propagatedTestData = new ArrayList[]{new ArrayList(), new ArrayList()};
        PcgRSFast rng = new PcgRSFast();
        int numVarsAvailable = randomVariables.sizeUnique();
        ArrayList<Integer> indices = new ArrayList<Integer>(numVarsAvailable);
        for (int i = 0; i < numVarsAvailable; ++i) {
            indices.add(i);
        }
        Collections.shuffle(indices, (Random)rng);
        Pair[] variablesAvailable = (Pair[])randomVariables.toArrayUnique((Object[])new Pair[randomVariables.sizeUnique()]);
        for (int i = 0; i < numVars && !indices.isEmpty(); ++i) {
            Pair randomVariable = variablesAvailable[(Integer)indices.remove(0)];
            FuzzyHash variableHash = (FuzzyHash)randomVariable.getFirst();
            HashingAlgorithm hashAlgorithm = (HashingAlgorithm)randomVariable.getSecond();
            Map<BufferedImage, Hash> hashes = preComputedHashes.get(hashAlgorithm);
            LinkedHashSet<Double> linkedHashSet = new LinkedHashSet<Double>();
            HashMap<BufferedImage, Double> minDistanceMap = new HashMap<BufferedImage, Double>();
            for (LabeledImage tData : testData) {
                Hash tHash = (Hash)hashes.get(tData.getbImage());
                double distance = variableHash.normalizedHammingDistance(tHash);
                minDistanceMap.put(tData.getbImage(), distance);
                linkedHashSet.add(distance);
            }
            ArrayList distances = new ArrayList(linkedHashSet);
            Collections.sort(distances);
            LinkedHashSet<Double> potentialCutoffValues = new LinkedHashSet<Double>();
            for (int j = 0; j < distances.size() - 1; ++j) {
                potentialCutoffValues.add(((Double)distances.get(j) + (Double)distances.get(j + 1)) / 2.0);
            }
            Iterator iterator = potentialCutoffValues.iterator();
            while (iterator.hasNext()) {
                boolean matchRight;
                boolean matchLeft;
                int cat;
                int count;
                double cutoff = (Double)iterator.next();
                ArrayList<LabeledImage> leftNode = new ArrayList<LabeledImage>();
                ArrayList<LabeledImage> rightNode = new ArrayList<LabeledImage>();
                HashMap<Integer, Integer> dominantCategoryLeft = new HashMap<Integer, Integer>();
                HashMap<Integer, Integer> dominantCategoryRight = new HashMap<Integer, Integer>();
                for (LabeledImage tData : testData) {
                    int category;
                    double distance = (Double)minDistanceMap.get(tData.getbImage());
                    if (distance < cutoff) {
                        leftNode.add(tData);
                        category = tData.getCategory();
                        dominantCategoryLeft.merge(category, 1, (oldV, newV) -> oldV + newV);
                        continue;
                    }
                    rightNode.add(tData);
                    category = tData.getCategory();
                    dominantCategoryRight.merge(category, 1, (oldV, newV) -> oldV + newV);
                }
                int leftCategory = -1;
                int rightCategory = -1;
                int bestCount = 0;
                for (Map.Entry e2 : dominantCategoryLeft.entrySet()) {
                    count = (Integer)e2.getValue();
                    cat = (Integer)e2.getKey();
                    if (leftCategory == -1) {
                        leftCategory = cat;
                        bestCount = count;
                        continue;
                    }
                    if (count <= bestCount) continue;
                    leftCategory = cat;
                    bestCount = count;
                }
                bestCount = 0;
                for (Map.Entry e3 : dominantCategoryRight.entrySet()) {
                    count = (Integer)e3.getValue();
                    cat = (Integer)e3.getKey();
                    if (rightCategory == -1) {
                        rightCategory = cat;
                        bestCount = count;
                        continue;
                    }
                    if (count <= bestCount) continue;
                    rightCategory = cat;
                    bestCount = count;
                }
                int truePositiveLeft = 0;
                int trueNegativeLeft = 0;
                int falsePositiveLeft = 0;
                int falseNegativeLeft = 0;
                int truePositiveRight = 0;
                int trueNegativeRight = 0;
                int falsePositiveRight = 0;
                int falseNegativeRight = 0;
                for (LabeledImage tData : leftNode) {
                    matchLeft = tData.getCategory() == leftCategory;
                    boolean bl = matchRight = tData.getCategory() == rightCategory;
                    if (matchLeft) {
                        ++truePositiveLeft;
                    } else {
                        ++falsePositiveLeft;
                    }
                    if (matchRight) {
                        ++falseNegativeRight;
                        continue;
                    }
                    ++trueNegativeRight;
                }
                for (LabeledImage tData : rightNode) {
                    matchLeft = tData.getCategory() == leftCategory;
                    boolean bl = matchRight = tData.getCategory() == rightCategory;
                    if (matchRight) {
                        ++truePositiveRight;
                    } else {
                        ++falsePositiveRight;
                    }
                    if (matchLeft) {
                        ++falseNegativeLeft;
                        continue;
                    }
                    ++trueNegativeLeft;
                }
                int sum = truePositiveLeft + trueNegativeLeft + falsePositiveLeft + falseNegativeLeft;
                double giniImpurityLeft = 1.0 - Math.pow((double)truePositiveLeft / (double)(truePositiveLeft + falsePositiveLeft), 2.0) - Math.pow((double)falsePositiveLeft / (double)(truePositiveLeft + falsePositiveLeft), 2.0);
                double giniImpurityRight = 1.0 - Math.pow((double)truePositiveRight / (double)(truePositiveRight + falsePositiveRight), 2.0) - Math.pow((double)falsePositiveRight / (double)(truePositiveRight + falsePositiveRight), 2.0);
                double leftWeight = (double)(truePositiveLeft + falsePositiveLeft) / (double)sum;
                double rightWeight = (double)(truePositiveRight + falsePositiveRight) / (double)sum;
                double giniImpurity = leftWeight * giniImpurityLeft + rightWeight * giniImpurityRight;
                double recall = (double)(truePositiveLeft + truePositiveRight) / (double)(truePositiveLeft + truePositiveRight + (falseNegativeLeft + falseNegativeRight));
                double specifity = (double)(trueNegativeLeft + trueNegativeRight) / (double)(trueNegativeLeft + trueNegativeRight + (falsePositiveLeft + falsePositiveRight));
                double temp = truePositiveLeft + truePositiveRight + (falsePositiveLeft + falsePositiveRight);
                double precision = Double.NaN;
                if (temp != 0.0) {
                    precision = (double)(truePositiveLeft + truePositiveRight) / temp;
                }
                double f1 = 2.0 * (precision * recall) / (precision + recall);
                if (!(giniImpurity < bestGini) && (giniImpurity != bestGini || leftNode.size() <= matchSize)) continue;
                bestCutoff = cutoff;
                bestVariable = randomVariable;
                propagatedTestData[0] = leftNode;
                propagatedTestData[1] = rightNode;
                bestGini = giniImpurity;
                matchSize = leftNode.size();
                nodeCategoryLeft = leftCategory;
                nodeCategoryRight = rightCategory;
                giniLeft = giniImpurityLeft;
                giniRight = giniImpurityRight;
            }
        }
        randomVariables.remove(bestVariable);
        if (bestGini < qualityThreshold && !MathUtil.isDoubleEquals((double)giniLeft, (double)qualityThreshold, (double)1.0E-8)) {
            InnerNode node = new InnerNode((FuzzyHash)bestVariable.getFirst(), (HashingAlgorithm)bestVariable.getSecond(), bestCutoff, bestGini, giniLeft, giniRight);
            int[] categories = new int[]{nodeCategoryLeft, nodeCategoryRight};
            return new Triple((Object)node, (Object)propagatedTestData, (Object)categories);
        }
        Map<Integer, Long> count = testData.stream().collect(Collectors.groupingBy(e -> e.getCategory(), Collectors.counting()));
        long maxCount = Long.MIN_VALUE;
        int bestCat = -1;
        for (Map.Entry entry : count.entrySet()) {
            if ((Long)entry.getValue() <= maxCount) continue;
            maxCount = (Long)entry.getValue();
            bestCat = (Integer)entry.getKey();
        }
        LeafNode n = new LeafNode(bestCat);
        return new Triple((Object)n, (Object)propagatedTestData, null);
    }

    private List<Pair<FuzzyHash, HashingAlgorithm>> createFuzzyHashes(Map<HashingAlgorithm, Map<BufferedImage, Hash>> preComputedHashes) {
        ArrayList<Pair<FuzzyHash, HashingAlgorithm>> variableList = new ArrayList<Pair<FuzzyHash, HashingAlgorithm>>();
        HashSet<Integer> cat = new HashSet<Integer>();
        for (LabeledImage lImage : this.labeledImages) {
            cat.add(lImage.getCategory());
        }
        for (HashingAlgorithm algo : this.steps) {
            HashMap clusteroid = new HashMap();
            int keyResolution = algo.getKeyResolution();
            PcgRSFast rng = new PcgRSFast();
            for (int i = 0; i < cat.size(); ++i) {
                BigInteger bInt = new BigInteger(keyResolution, (Random)rng);
                variableList.add((Pair<FuzzyHash, HashingAlgorithm>)new Pair((Object)new FuzzyHash(new Hash(bInt, keyResolution, algo.algorithmId())), (Object)algo));
            }
        }
        return variableList;
    }

    public Map<Integer, Integer> countLeafCategories() {
        HashMap<Integer, Integer> categoryTreeCount = new HashMap<Integer, Integer>();
        ArrayDeque<TreeNode> queue = new ArrayDeque<TreeNode>();
        queue.add(this.forest.get(0));
        while (!queue.isEmpty()) {
            TreeNode node = (TreeNode)queue.poll();
            if (node instanceof InnerNode) {
                queue.add(((InnerNode)node).rightNode);
                queue.add(((InnerNode)node).leftNode);
                continue;
            }
            categoryTreeCount.merge(((LeafNode)node).category, 1, (oldV, newV) -> oldV + newV);
        }
        return categoryTreeCount;
    }

    public static void main(String[] args) throws IOException {
        RandomForestCategorizer randomForst = new RandomForestCategorizer();
        randomForst.addHashingAlgorithm(new AverageHash(32));
        randomForst.addHashingAlgorithm(new PerceptiveHash(32));
        randomForst.addHashingAlgorithm(new RotAverageHash(32));
        randomForst.addTestImages(new LabeledImage(0, new File("src/test/resources/ballon.jpg")));
        randomForst.addTestImages(new LabeledImage(1, new File("src/test/resources/copyright.jpg")));
        randomForst.addTestImages(new LabeledImage(1, new File("src/test/resources/highQuality.jpg")));
        randomForst.addTestImages(new LabeledImage(1, new File("src/test/resources/lowQuality.jpg")));
        randomForst.addTestImages(new LabeledImage(2, new File("src/test/resources/Lenna.png")));
        randomForst.addTestImages(new LabeledImage(2, new File("src/test/resources/Lenna90.png")));
        randomForst.addTestImages(new LabeledImage(2, new File("src/test/resources/Lenna180.png")));
        randomForst.addTestImages(new LabeledImage(2, new File("src/test/resources/LennaSaltAndPepper.png")));
        randomForst.addTestImages(new LabeledImage(3, new File("src/test/resources/TestShapes.png")));
        randomForst.trainMatcher(3, 2, 1);
        randomForst.forest.get(0).printTree();
        BufferedImage bi = ImageIO.read(new File("src/test/resources/lowQuality.jpg"));
        System.out.println(randomForst.categorizeImage(bi));
    }

    @Override
    public CategorizationResult categorizeImage(BufferedImage bi) {
        List<Integer> categories = this.getCategories();
        int[] catCount = new int[categories.size()];
        for (TreeNode root : this.forest) {
            int cat = root.predictAgainstAll(bi)[0];
            if (cat == -1) continue;
            int n = cat;
            catCount[n] = catCount[n] + 1;
        }
        int maxIndex = ArrayUtil.maximumIndex((int[])catCount);
        double agree = (double)catCount[maxIndex] / (double)this.forest.size();
        int bestFitCategory = categories.get(maxIndex);
        return new CategorizationResult(bestFitCategory, agree);
    }

    @Override
    public List<Integer> getCategories() {
        return new ArrayList<Integer>(this.categories);
    }

    @Override
    public void recomputeCategories() {
    }

    @Override
    public List<String> getImagesInCategory(int category) {
        return null;
    }

    @Override
    public int getCategory(String uniqueId) {
        return 0;
    }

    public void printTree() {
        this.forest.get(0).printTree();
    }

    @Override
    public CategorizationResult categorizeImageAndAdd(BufferedImage bi, String uniqueId) {
        throw new UnsupportedOperationException("Can't add images on the fly. Rebuilding time to expensive");
    }
}

