package org.apache.druid.server.coordinator.rules;

import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.client.DruidServer;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.server.coordination.ServerType;
import org.apache.druid.server.coordinator.CoordinatorDynamicConfig;
import org.apache.druid.server.coordinator.CreateDataSegments;
import org.apache.druid.server.coordinator.DruidCluster;
import org.apache.druid.server.coordinator.DruidCoordinatorRuntimeParams;
import org.apache.druid.server.coordinator.ServerHolder;
import org.apache.druid.server.coordinator.balancer.BalancerStrategy;
import org.apache.druid.server.coordinator.balancer.CachingCostBalancerStrategy;
import org.apache.druid.server.coordinator.balancer.ClusterCostCache;
import org.apache.druid.server.coordinator.balancer.CostBalancerStrategy;
import org.apache.druid.server.coordinator.loading.SegmentLoadQueueManager;
import org.apache.druid.server.coordinator.loading.TestLoadQueuePeon;
import org.apache.druid.server.coordinator.stats.CoordinatorRunStats;
import org.apache.druid.server.coordinator.stats.Stats;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.NoneShardSpec;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/druid/server/coordinator/rules/LoadRuleTest.class */
public class LoadRuleTest {
    private static final String DS_WIKI = "wiki";
    private ListeningExecutorService exec;
    private BalancerStrategy balancerStrategy;
    private SegmentLoadQueueManager loadQueueManager;
    private final boolean useRoundRobinAssignment;
    private final AtomicInteger serverId = new AtomicInteger();

    /* loaded from: input_file:org/apache/druid/server/coordinator/rules/LoadRuleTest$Tier.class */
    private static class Tier {
        static final String T1 = "tier1";
        static final String T2 = "tier2";

        private Tier() {
        }
    }

    @Parameterized.Parameters(name = "useRoundRobin = {0}")
    public static List<Boolean> getTestParams() {
        return Arrays.asList(true, false);
    }

    public LoadRuleTest(boolean z) {
        this.useRoundRobinAssignment = z;
    }

    @Before
    public void setUp() {
        this.exec = MoreExecutors.listeningDecorator(Execs.multiThreaded(1, "LoadRuleTest-%d"));
        this.balancerStrategy = new CostBalancerStrategy(this.exec);
        this.loadQueueManager = new SegmentLoadQueueManager(null, null);
    }

    @After
    public void tearDown() {
        this.exec.shutdown();
    }

    @Test
    public void testLoadRuleAssignsSegments() {
        CoordinatorRunStats runRuleAndGetStats = runRuleAndGetStats(loadForever(ImmutableMap.of("tier1", 1, "tier2", 2)), createDataSegment(DS_WIKI), DruidCluster.builder().addTier("tier1", createServer("tier1", new DataSegment[0])).addTier("tier2", createServer("tier2", new DataSegment[0])).build());
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "tier1", DS_WIKI));
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "tier2", DS_WIKI));
    }

    private CoordinatorRunStats runRuleAndGetStats(LoadRule loadRule, DataSegment dataSegment, DruidCluster druidCluster) {
        return runRuleAndGetStats(loadRule, dataSegment, makeCoordinatorRuntimeParams(druidCluster, dataSegment));
    }

    private CoordinatorRunStats runRuleAndGetStats(LoadRule loadRule, DataSegment dataSegment, DruidCoordinatorRuntimeParams druidCoordinatorRuntimeParams) {
        loadRule.run(dataSegment, druidCoordinatorRuntimeParams.getSegmentAssigner());
        return druidCoordinatorRuntimeParams.getCoordinatorStats();
    }

    private DruidCoordinatorRuntimeParams makeCoordinatorRuntimeParams(DruidCluster druidCluster, DataSegment... dataSegmentArr) {
        return DruidCoordinatorRuntimeParams.newBuilder(DateTimes.nowUtc()).withDruidCluster(druidCluster).withBalancerStrategy(this.balancerStrategy).withUsedSegments(dataSegmentArr).withDynamicConfigs(CoordinatorDynamicConfig.builder().withSmartSegmentLoading(false).withUseRoundRobinSegmentAssignment(this.useRoundRobinAssignment).build()).withSegmentAssignerUsing(this.loadQueueManager).build();
    }

    @Test
    public void testLoadPrimaryAssignDoesNotOverAssign() {
        DruidCluster build = DruidCluster.builder().addTier("tier1", createServer("tier1", new DataSegment[0]), createServer("tier1", new DataSegment[0])).build();
        LoadRule loadForever = loadForever(ImmutableMap.of("tier1", 1));
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment, build).getSegmentStat(Stats.Segments.ASSIGNED, "tier1", createDataSegment.getDataSource()));
        Assert.assertEquals(1L, r0.getLoadingSegments().size() + r0.getLoadingSegments().size());
        Assert.assertFalse(runRuleAndGetStats(loadForever, createDataSegment, build).hasStat(Stats.Segments.ASSIGNED));
        Assert.assertEquals(1L, r0.getLoadingSegments().size() + r0.getLoadingSegments().size());
    }

    @Test
    @Ignore("Enable this test when timeout behaviour is fixed")
    public void testOverAssignForTimedOutSegments() {
        DruidCluster build = DruidCluster.builder().addTier("tier1", createServer("tier1", new DataSegment[0]), createServer("tier1", new DataSegment[0])).build();
        LoadRule loadForever = loadForever(ImmutableMap.of("tier1", 1));
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment, build).getSegmentStat(Stats.Segments.ASSIGNED, "tier1", createDataSegment.getDataSource()));
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment, build).getSegmentStat(Stats.Segments.ASSIGNED, "tier1", DS_WIKI));
    }

    @Test
    public void testSkipReplicationForTimedOutSegments() {
        DruidCluster build = DruidCluster.builder().addTier("tier1", createServer("tier1", new DataSegment[0]), createServer("tier1", new DataSegment[0])).build();
        LoadRule loadForever = loadForever(ImmutableMap.of("tier1", 1));
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment, build).getSegmentStat(Stats.Segments.ASSIGNED, "tier1", createDataSegment.getDataSource()));
        Assert.assertFalse(runRuleAndGetStats(loadForever, createDataSegment, build).hasStat(Stats.Segments.ASSIGNED));
    }

    @Test
    public void testLoadUsedSegmentsForAllSegmentGranularityAndCachingCostBalancerStrategy() {
        List<DataSegment> eachOfSizeInMb = CreateDataSegments.ofDatasource(DS_WIKI).forIntervals(1, Granularities.ALL).withNumPartitions(2).eachOfSizeInMb(100L);
        DruidCluster build = DruidCluster.builder().addTier("tier1", createServer("tier1", new DataSegment[0])).build();
        this.balancerStrategy = new CachingCostBalancerStrategy(ClusterCostCache.builder().build(), this.exec);
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever(ImmutableMap.of("tier1", 1)), eachOfSizeInMb.get(1), makeCoordinatorRuntimeParams(build, (DataSegment[]) eachOfSizeInMb.toArray(new DataSegment[0]))).getSegmentStat(Stats.Segments.ASSIGNED, "tier1", DS_WIKI));
    }

    @Test
    public void testSegmentsAreDroppedIfLoadRuleHasZeroReplicas() {
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        CoordinatorRunStats runRuleAndGetStats = runRuleAndGetStats(loadForever(ImmutableMap.of("tier1", 0, "tier2", 0)), createDataSegment, DruidCluster.builder().addTier("tier1", createServer("tier1", createDataSegment)).addTier("tier2", createServer("tier2", createDataSegment), createServer("tier2", createDataSegment)).build());
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.DROPPED, "tier1", DS_WIKI));
        Assert.assertEquals(2L, runRuleAndGetStats.getSegmentStat(Stats.Segments.DROPPED, "tier2", DS_WIKI));
    }

    @Test
    public void testLoadIgnoresInvalidTiers() {
        CoordinatorRunStats runRuleAndGetStats = runRuleAndGetStats(loadForever(ImmutableMap.of("invalidTier", 1, "tier1", 1)), createDataSegment(DS_WIKI), DruidCluster.builder().addTier("tier1", createServer("tier1", new DataSegment[0])).build());
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "tier1", DS_WIKI));
        Assert.assertEquals(0L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "invalidTier", DS_WIKI));
    }

    @Test
    public void testDropIgnoresInvalidTiers() {
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        CoordinatorRunStats runRuleAndGetStats = runRuleAndGetStats(loadForever(ImmutableMap.of("invalidTier", 1, "tier1", 1)), createDataSegment, DruidCluster.builder().addTier("tier1", createServer("tier1", createDataSegment), createServer("tier1", createDataSegment)).build());
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.DROPPED, "tier1", DS_WIKI));
        Assert.assertEquals(0L, runRuleAndGetStats.getSegmentStat(Stats.Segments.DROPPED, "invalidTier", DS_WIKI));
    }

    @Test
    public void testMaxLoadingQueueSize() {
        DruidCluster build = DruidCluster.builder().addTier("tier1", new ServerHolder(createDruidServer("tier1").toImmutableDruidServer(), new TestLoadQueuePeon(), false, 2, 10)).build();
        DataSegment createDataSegment = createDataSegment("ds1");
        DataSegment createDataSegment2 = createDataSegment("ds2");
        DataSegment createDataSegment3 = createDataSegment("ds3");
        DruidCoordinatorRuntimeParams build2 = DruidCoordinatorRuntimeParams.newBuilder(DateTimes.nowUtc()).withDruidCluster(build).withBalancerStrategy(this.balancerStrategy).withUsedSegments(createDataSegment, createDataSegment2, createDataSegment3).withDynamicConfigs(CoordinatorDynamicConfig.builder().withSmartSegmentLoading(false).withMaxSegmentsInNodeLoadingQueue(2).withUseRoundRobinSegmentAssignment(this.useRoundRobinAssignment).build()).withSegmentAssignerUsing(this.loadQueueManager).build();
        LoadRule loadForever = loadForever(ImmutableMap.of("tier1", 1));
        CoordinatorRunStats runRuleAndGetStats = runRuleAndGetStats(loadForever, createDataSegment, build2);
        CoordinatorRunStats runRuleAndGetStats2 = runRuleAndGetStats(loadForever, createDataSegment2, build2);
        CoordinatorRunStats runRuleAndGetStats3 = runRuleAndGetStats(loadForever, createDataSegment3, build2);
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "tier1", createDataSegment.getDataSource()));
        Assert.assertEquals(1L, runRuleAndGetStats2.getSegmentStat(Stats.Segments.ASSIGNED, "tier1", createDataSegment2.getDataSource()));
        Assert.assertEquals(0L, runRuleAndGetStats3.getSegmentStat(Stats.Segments.ASSIGNED, "tier1", createDataSegment3.getDataSource()));
    }

    @Test
    public void testSegmentIsAssignedOnlyToActiveServer() {
        ServerHolder createDecommissioningServer = createDecommissioningServer("tier1", new DataSegment[0]);
        ServerHolder createServer = createServer("tier2", new DataSegment[0]);
        DruidCluster build = DruidCluster.builder().addTier("tier1", createDecommissioningServer).addTier("tier2", createServer).build();
        LoadRule loadForever = loadForever(ImmutableMap.of("tier1", 1, "tier2", 1));
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment, build).getSegmentStat(Stats.Segments.ASSIGNED, "tier2", DS_WIKI));
        Assert.assertEquals(0L, createDecommissioningServer.getLoadingSegments().size());
        Assert.assertTrue(createServer.getLoadingSegments().contains(createDataSegment));
    }

    @Test
    public void testSegmentIsAssignedOnlyToActiveServers() {
        ServerHolder createDecommissioningServer = createDecommissioningServer("tier1", new DataSegment[0]);
        CoordinatorRunStats runRuleAndGetStats = runRuleAndGetStats(loadForever(ImmutableMap.of("tier1", 2, "tier2", 2)), createDataSegment(DS_WIKI), DruidCluster.builder().addTier("tier1", createDecommissioningServer, createServer("tier1", new DataSegment[0])).addTier("tier2", createServer("tier2", new DataSegment[0]), createServer("tier2", new DataSegment[0])).build());
        Assert.assertEquals(1L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "tier1", DS_WIKI));
        Assert.assertTrue(createDecommissioningServer.getLoadingSegments().isEmpty());
        Assert.assertEquals(0L, createDecommissioningServer.getLoadingSegments().size());
        Assert.assertEquals(2L, runRuleAndGetStats.getSegmentStat(Stats.Segments.ASSIGNED, "tier2", DS_WIKI));
    }

    @Test
    public void testDropDuringDecommissioning() {
        DataSegment createDataSegment = createDataSegment("foo1");
        DataSegment createDataSegment2 = createDataSegment(CalciteTests.DATASOURCE2);
        ServerHolder createDecommissioningServer = createDecommissioningServer("tier1", createDataSegment);
        ServerHolder createServer = createServer("tier1", createDataSegment2);
        DruidCoordinatorRuntimeParams makeCoordinatorRuntimeParams = makeCoordinatorRuntimeParams(DruidCluster.builder().addTier("tier1", createDecommissioningServer, createServer).build(), createDataSegment, createDataSegment2);
        LoadRule loadForever = loadForever(ImmutableMap.of("tier1", 0));
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment, makeCoordinatorRuntimeParams).getSegmentStat(Stats.Segments.DROPPED, "tier1", createDataSegment.getDataSource()));
        Assert.assertTrue(createDecommissioningServer.getPeon().getSegmentsToDrop().contains(createDataSegment));
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever, createDataSegment2, makeCoordinatorRuntimeParams).getSegmentStat(Stats.Segments.DROPPED, "tier1", createDataSegment2.getDataSource()));
        Assert.assertTrue(createServer.getPeon().getSegmentsToDrop().contains(createDataSegment2));
    }

    @Test
    public void testExtraReplicasAreDroppedFromDecommissioningServer() {
        DataSegment createDataSegment = createDataSegment(DS_WIKI);
        Assert.assertEquals(1L, runRuleAndGetStats(loadForever(ImmutableMap.of("tier1", 2)), createDataSegment, makeCoordinatorRuntimeParams(DruidCluster.builder().addTier("tier1", createServer("tier1", createDataSegment), createDecommissioningServer("tier1", createDataSegment), createServer("tier1", createDataSegment)).build(), createDataSegment)).getSegmentStat(Stats.Segments.DROPPED, "tier1", DS_WIKI));
        Assert.assertEquals(0L, r0.getPeon().getSegmentsToDrop().size());
        Assert.assertEquals(1L, r0.getPeon().getSegmentsToDrop().size());
        Assert.assertEquals(0L, r0.getPeon().getSegmentsToDrop().size());
    }

    private DataSegment createDataSegment(String str) {
        return new DataSegment(str, Intervals.of("0/3000"), DateTimes.nowUtc().toString(), Collections.emptyMap(), Collections.emptyList(), Collections.emptyList(), NoneShardSpec.instance(), 0, 0L);
    }

    private static LoadRule loadForever(Map<String, Integer> map) {
        return new ForeverLoadRule(map, null);
    }

    private DruidServer createDruidServer(String str) {
        String str2 = "hist_" + str + "_" + this.serverId.incrementAndGet();
        return new DruidServer(str2, str2, null, 10737418240L, ServerType.HISTORICAL, str, 0);
    }

    private ServerHolder createServer(String str, DataSegment... dataSegmentArr) {
        DruidServer createDruidServer = createDruidServer(str);
        for (DataSegment dataSegment : dataSegmentArr) {
            createDruidServer.addDataSegment(dataSegment);
        }
        return new ServerHolder(createDruidServer.toImmutableDruidServer(), new TestLoadQueuePeon());
    }

    private ServerHolder createDecommissioningServer(String str, DataSegment... dataSegmentArr) {
        DruidServer createDruidServer = createDruidServer(str);
        for (DataSegment dataSegment : dataSegmentArr) {
            createDruidServer.addDataSegment(dataSegment);
        }
        return new ServerHolder(createDruidServer.toImmutableDruidServer(), new TestLoadQueuePeon(), true);
    }

    @Test
    public void testEquals() {
        EqualsVerifier.forClass(LoadRule.class).withNonnullFields(new String[]{"tieredReplicants"}).usingGetClass().verify();
    }
}
