package org.apache.flink.table.runtime.hashtable;

import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.compression.BlockCompressionFactory;
import org.apache.flink.runtime.io.disk.ChannelReaderInputViewIterator;
import org.apache.flink.runtime.io.disk.iomanager.HeaderlessChannelReaderInputView;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.runtime.hashtable.LongHashPartition;
import org.apache.flink.table.runtime.io.ChannelWithMeta;
import org.apache.flink.table.runtime.io.LongHashPartitionChannelReaderInputViewIterator;
import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer;
import org.apache.flink.table.runtime.util.FileChannelUtil;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.util.MathUtils;

/* loaded from: input_file:org/apache/flink/table/runtime/hashtable/LongHybridHashTable.class */
public abstract class LongHybridHashTable extends BaseHybridHashTable {
    private final BinaryRowDataSerializer buildSideSerializer;
    private final BinaryRowDataSerializer probeSideSerializer;
    private final ArrayList<LongHashPartition> partitionsBeingBuilt;
    private final ArrayList<LongHashPartition> partitionsPending;
    private final List<LongHashPartition> partitionsPendingForSMJ;
    private ProbeIterator probeIterator;
    private LongHashPartition.MatchIterator matchIterator;
    private boolean denseMode;
    private long minKey;
    private long maxKey;
    private MemorySegment[] denseBuckets;
    private LongHashPartition densePartition;

    public LongHybridHashTable(Configuration configuration, Object obj, BinaryRowDataSerializer binaryRowDataSerializer, BinaryRowDataSerializer binaryRowDataSerializer2, MemoryManager memoryManager, long j, IOManager iOManager, int i, long j2) {
        super(configuration, obj, memoryManager, j, iOManager, i, j2, false);
        this.denseMode = false;
        this.buildSideSerializer = binaryRowDataSerializer;
        this.probeSideSerializer = binaryRowDataSerializer2;
        this.partitionsBeingBuilt = new ArrayList<>();
        this.partitionsPending = new ArrayList<>();
        this.partitionsPendingForSMJ = new ArrayList();
        createPartitions(this.initPartitionFanOut, 0);
    }

    public void putBuildRow(BinaryRowData binaryRowData) throws IOException {
        long buildLongKey = getBuildLongKey(binaryRowData);
        insertIntoTable(buildLongKey, hashLong(buildLongKey, 0), binaryRowData);
    }

    public void endBuild() throws IOException {
        int i = 0;
        Iterator<LongHashPartition> it = this.partitionsBeingBuilt.iterator();
        while (it.hasNext()) {
            i += it.next().finalizeBuildPhase(this.ioManager, this.currentEnumerator);
        }
        this.buildSpillRetBufferNumbers += i;
        this.probeIterator = new ProbeIterator(this.probeSideSerializer.createInstance2());
        tryDenseMode();
    }

    public boolean tryProbe(RowData rowData) throws IOException {
        long probeLongKey = getProbeLongKey(rowData);
        if (!this.denseMode) {
            if (!this.probeIterator.hasSource()) {
                this.probeIterator.setInstance(rowData);
            }
            int hashLong = hashLong(probeLongKey, this.currentRecursionDepth);
            LongHashPartition longHashPartition = this.partitionsBeingBuilt.get(hashLong % this.partitionsBeingBuilt.size());
            if (longHashPartition.isInMemory()) {
                this.matchIterator = longHashPartition.get(probeLongKey, hashLong);
                return true;
            }
            longHashPartition.insertIntoProbeBuffer(this.probeSideSerializer, probeToBinary(rowData));
            return false;
        }
        this.probeIterator.setInstance(rowData);
        if (probeLongKey < this.minKey || probeLongKey > this.maxKey) {
            this.matchIterator = this.densePartition.valueIter(68719476735L);
            return true;
        }
        long j = (probeLongKey - this.minKey) << 3;
        this.matchIterator = this.densePartition.valueIter(this.denseBuckets[(int) (j >>> this.segmentSizeBits)].getLong((int) (j & this.segmentSizeMask)));
        return true;
    }

    public boolean nextMatching() throws IOException {
        return !this.denseMode && (processProbeIter() || prepareNextPartition());
    }

    public RowData getCurrentProbeRow() {
        return this.probeIterator.current();
    }

    public LongHashPartition.MatchIterator getBuildSideIterator() {
        return this.matchIterator;
    }

    @Override // org.apache.flink.table.runtime.hashtable.BaseHybridHashTable
    public void close() {
        if (this.denseMode) {
            this.closed.compareAndSet(false, true);
        } else {
            super.close();
        }
    }

    @Override // org.apache.flink.table.runtime.hashtable.BaseHybridHashTable
    public void free() {
        if (this.denseMode) {
            returnAll(Arrays.asList(this.denseBuckets));
            returnAll(Arrays.asList(this.densePartition.getPartitionBuffers()));
        }
        super.free();
    }

    private void tryDenseMode() {
        if (this.numSpillFiles != 0) {
            return;
        }
        long j = Long.MAX_VALUE;
        long j2 = Long.MIN_VALUE;
        long j3 = 0;
        Iterator<LongHashPartition> it = this.partitionsBeingBuilt.iterator();
        while (it.hasNext()) {
            LongHashPartition next = it.next();
            long buildSideRecordCount = next.getBuildSideRecordCount();
            j3 += buildSideRecordCount;
            if (buildSideRecordCount > 0) {
                if (next.getMinKey() < j) {
                    j = next.getMinKey();
                }
                if (next.getMaxKey() > j2) {
                    j2 = next.getMaxKey();
                }
            }
        }
        if (this.buildSpillRetBufferNumbers != 0) {
            throw new RuntimeException("buildSpillRetBufferNumbers should be 0: " + this.buildSpillRetBufferNumbers);
        }
        long j4 = (j2 - j) + 1;
        if (j4 > 0) {
            if (j4 <= j3 * 4 || j4 <= this.segmentSize / 8) {
                int ceil = (int) Math.ceil((j4 * 8) / this.segmentSize);
                MemorySegment[] memorySegmentArr = new MemorySegment[ceil];
                for (int i = 0; i < ceil; i++) {
                    MemorySegment nextBuffer = getNextBuffer();
                    if (nextBuffer == null) {
                        returnAll(Arrays.asList(memorySegmentArr));
                        return;
                    }
                    memorySegmentArr[i] = nextBuffer;
                    for (int i2 = 0; i2 < this.segmentSize; i2 += 8) {
                        nextBuffer.putLong(i2, 68719476735L);
                    }
                }
                this.denseMode = true;
                LOG.info("LongHybridHashTable: Use dense mode!");
                this.minKey = j;
                this.maxKey = j2;
                ArrayList arrayList = new ArrayList();
                this.buildSpillReturnBuffers.drainTo(arrayList);
                returnAll(arrayList);
                ArrayList arrayList2 = new ArrayList();
                long j5 = 0;
                Iterator<LongHashPartition> it2 = this.partitionsBeingBuilt.iterator();
                while (it2.hasNext()) {
                    LongHashPartition next2 = it2.next();
                    next2.iteratorToDenseBucket(memorySegmentArr, j5, j);
                    next2.updateDenseAddressOffset(j5);
                    arrayList2.addAll(Arrays.asList(next2.getPartitionBuffers()));
                    j5 += next2.getPartitionBuffers().length << this.segmentSizeBits;
                    returnAll(Arrays.asList(next2.getBuckets()));
                }
                this.denseBuckets = memorySegmentArr;
                this.densePartition = new LongHashPartition(this, this.buildSideSerializer, (MemorySegment[]) arrayList2.toArray(new MemorySegment[0]));
                freeCurrent();
            }
        }
    }

    private void createPartitions(int i, int i2) {
        ensureNumBuffersReturned(i);
        this.currentEnumerator = this.ioManager.createChannelEnumerator();
        this.partitionsBeingBuilt.clear();
        double d = this.buildRowCount / i;
        int maxInitBufferOfBucketArea = maxInitBufferOfBucketArea(i);
        for (int i3 = 0; i3 < i; i3++) {
            this.partitionsBeingBuilt.add(new LongHashPartition(this, i3, this.buildSideSerializer, d, maxInitBufferOfBucketArea, i2));
        }
    }

    public abstract long getBuildLongKey(RowData rowData);

    public abstract long getProbeLongKey(RowData rowData);

    public abstract BinaryRowData probeToBinary(RowData rowData);

    private void insertIntoTable(long j, int i, BinaryRowData binaryRowData) throws IOException {
        this.partitionsBeingBuilt.get(i % this.partitionsBeingBuilt.size()).insertIntoTable(j, i, binaryRowData);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int hashLong(long j, int i) {
        long j2 = j * 2654435769L;
        return BaseHybridHashTable.hash((int) (j2 ^ (j2 >> 32)), i);
    }

    private boolean processProbeIter() throws IOException {
        if (!this.probeIterator.hasSource()) {
            return false;
        }
        ProbeIterator probeIterator = this.probeIterator;
        while (true) {
            BinaryRowData next = probeIterator.next();
            if (next == null) {
                return false;
            }
            long probeLongKey = getProbeLongKey(next);
            int hashLong = hashLong(probeLongKey, this.currentRecursionDepth);
            LongHashPartition longHashPartition = this.partitionsBeingBuilt.get(hashLong % this.partitionsBeingBuilt.size());
            if (longHashPartition.isInMemory()) {
                this.matchIterator = longHashPartition.get(probeLongKey, hashLong);
                return true;
            }
            longHashPartition.insertIntoProbeBuffer(this.probeSideSerializer, next);
        }
    }

    private boolean prepareNextPartition() throws IOException {
        Iterator<LongHashPartition> it = this.partitionsBeingBuilt.iterator();
        while (it.hasNext()) {
            it.next().finalizeProbePhase(this.partitionsPending);
        }
        this.partitionsBeingBuilt.clear();
        if (this.currentSpilledProbeSide != null) {
            this.currentSpilledProbeSide.getChannel().closeAndDelete();
            this.currentSpilledProbeSide = null;
        }
        if (this.partitionsPending.isEmpty()) {
            return false;
        }
        LongHashPartition longHashPartition = this.partitionsPending.get(0);
        LOG.info(String.format("Begin to process spilled partition [%d]", Integer.valueOf(longHashPartition.getPartitionNumber())));
        if (longHashPartition.probeSideRecordCounter == 0) {
            this.partitionsPending.remove(0);
            return prepareNextPartition();
        }
        int recursionLevel = longHashPartition.getRecursionLevel() + 1;
        if (recursionLevel == 2) {
            LOG.info("Recursive hash join: partition number is " + longHashPartition.getPartitionNumber());
        } else if (recursionLevel > 3) {
            LOG.info("Partition number [{}] recursive level more than {}, process the partition using SortMergeJoin later.", Integer.valueOf(longHashPartition.getPartitionNumber()), 3);
            this.partitionsPendingForSMJ.add(longHashPartition);
            this.partitionsPending.remove(0);
            return prepareNextPartition();
        }
        buildTableFromSpilledPartition(longHashPartition, recursionLevel);
        setPartitionProbeReader(longHashPartition);
        this.partitionsPending.remove(0);
        this.currentRecursionDepth = longHashPartition.getRecursionLevel() + 1;
        return nextMatching();
    }

    private void setPartitionProbeReader(LongHashPartition longHashPartition) throws IOException {
        this.currentSpilledProbeSide = FileChannelUtil.createInputView(this.ioManager, new ChannelWithMeta(longHashPartition.probeSideBuffer.getChannel().getChannelID(), longHashPartition.probeSideBuffer.getBlockCount(), longHashPartition.probeNumBytesInLastSeg), new ArrayList(), this.compressionEnable, this.compressionCodecFactory, this.compressionBlockSize, this.segmentSize);
        this.probeIterator.set(new ChannelReaderInputViewIterator<>(this.currentSpilledProbeSide, new ArrayList(), this.probeSideSerializer));
        this.probeIterator.setReuse(this.probeSideSerializer.createInstance2());
    }

    private void buildTableFromSpilledPartition(LongHashPartition longHashPartition, int i) throws IOException {
        if (longHashPartition.getBuildSideBlockCount() > longHashPartition.getProbeSideBlockCount()) {
            LOG.info(String.format("Hash join: Partition(%d) build side block [%d] more than probe side block [%d]", Integer.valueOf(longHashPartition.getPartitionNumber()), Integer.valueOf(longHashPartition.getBuildSideBlockCount()), Integer.valueOf(longHashPartition.getProbeSideBlockCount())));
        }
        int freePages = this.internalPool.freePages() + this.buildSpillRetBufferNumbers;
        if (freePages != this.totalNumBuffers) {
            throw new RuntimeException(String.format("Hash Join bug in memory management: Memory buffers leaked. availableMemory(%s), buildSpillRetBufferNumbers(%s), reservedNumBuffers(%s)", Integer.valueOf(this.internalPool.freePages()), Integer.valueOf(this.buildSpillRetBufferNumbers), Integer.valueOf(this.totalNumBuffers)));
        }
        int roundUpToPowerOfTwo = MathUtils.roundUpToPowerOfTwo((int) Math.max(1.0d, Math.ceil((Math.ceil(longHashPartition.getBuildSideRecordCount() / 0.5d) * 16.0d) / this.segmentSize)));
        long buildSideBlockCount = roundUpToPowerOfTwo + longHashPartition.getBuildSideBlockCount() + 2;
        if (buildSideBlockCount < freePages) {
            LOG.info(String.format("Build in memory hash table from spilled partition [%d]", Integer.valueOf(longHashPartition.getPartitionNumber())));
            LongHashPartition longHashPartition2 = new LongHashPartition(this, 0, this.buildSideSerializer, roundUpToPowerOfTwo, i, readAllBuffers(longHashPartition.getBuildSideChannel().getChannelID(), longHashPartition.getBuildSideBlockCount()), longHashPartition.getLastSegmentLimit());
            this.partitionsBeingBuilt.add(longHashPartition2);
            LongHashPartition.PartitionIterator newPartitionIterator = longHashPartition2.newPartitionIterator();
            while (newPartitionIterator.advanceNext()) {
                long buildLongKey = getBuildLongKey(newPartitionIterator.getRow());
                longHashPartition2.insertIntoBucket(buildLongKey, hashLong(buildLongKey, i), newPartitionIterator.getRow().getSizeInBytes(), (int) newPartitionIterator.getPointer());
            }
            return;
        }
        createPartitions(Math.min(Math.min(10 * (((int) (buildSideBlockCount / freePages)) + 1), 127), maxNumPartition()), i);
        LOG.info(String.format("Build hybrid hash table from spilled partition [%d] with recursion level [%d]", Integer.valueOf(longHashPartition.getPartitionNumber()), Integer.valueOf(i)));
        HeaderlessChannelReaderInputView createInputView = createInputView(longHashPartition.getBuildSideChannel().getChannelID(), longHashPartition.getBuildSideBlockCount(), longHashPartition.getLastSegmentLimit());
        BinaryRowData createInstance2 = this.buildSideSerializer.createInstance2();
        while (true) {
            try {
                LongHashPartition.deserializeFromPages(createInstance2, createInputView, this.buildSideSerializer);
                long buildLongKey2 = getBuildLongKey(createInstance2);
                insertIntoTable(buildLongKey2, hashLong(buildLongKey2, i), createInstance2);
            } catch (EOFException e) {
                createInputView.getChannel().closeAndDelete();
                int i2 = 0;
                Iterator<LongHashPartition> it = this.partitionsBeingBuilt.iterator();
                while (it.hasNext()) {
                    i2 += it.next().finalizeBuildPhase(this.ioManager, this.currentEnumerator);
                }
                this.buildSpillRetBufferNumbers += i2;
                return;
            }
        }
    }

    @Override // org.apache.flink.table.runtime.hashtable.BaseHybridHashTable
    public int spillPartition() throws IOException {
        MemorySegment poll;
        int i = 0;
        int i2 = -1;
        for (int i3 = 0; i3 < this.partitionsBeingBuilt.size(); i3++) {
            LongHashPartition longHashPartition = this.partitionsBeingBuilt.get(i3);
            if (longHashPartition.isInMemory() && longHashPartition.getNumOccupiedMemorySegments() > i) {
                i = longHashPartition.getNumOccupiedMemorySegments();
                i2 = i3;
            }
        }
        LongHashPartition longHashPartition2 = this.partitionsBeingBuilt.get(i2);
        int spillPartition = longHashPartition2.spillPartition(this.ioManager, this.currentEnumerator.next(), this.buildSpillReturnBuffers);
        longHashPartition2.releaseBuckets();
        this.buildSpillRetBufferNumbers += spillPartition;
        LOG.info(String.format("Grace hash join: Ran out memory, choosing partition [%d] to spill, %d memory segments being freed", Integer.valueOf(i2), Integer.valueOf(spillPartition)));
        while (this.buildSpillRetBufferNumbers > 0 && (poll = this.buildSpillReturnBuffers.poll()) != null) {
            returnPage(poll);
            this.buildSpillRetBufferNumbers--;
        }
        this.numSpillFiles++;
        this.spillInBytes += spillPartition * this.segmentSize;
        return i2;
    }

    public List<LongHashPartition> getPartitionsPendingForSMJ() {
        return this.partitionsPendingForSMJ;
    }

    public RowIterator getSpilledPartitionBuildSideIter(LongHashPartition longHashPartition) throws IOException {
        if (this.currentSpilledBuildSide != null) {
            try {
                this.currentSpilledBuildSide.getChannel().closeAndDelete();
            } catch (Throwable th) {
                LOG.warn("Could not close and delete the temp file for the current spilled partition build side.", th);
            }
            this.currentSpilledBuildSide = null;
        }
        this.currentSpilledBuildSide = createInputView(longHashPartition.getBuildSideChannel().getChannelID(), longHashPartition.getBuildSideBlockCount(), longHashPartition.getLastSegmentLimit());
        return new WrappedRowIterator(new LongHashPartitionChannelReaderInputViewIterator(this.currentSpilledBuildSide, this.buildSideSerializer), this.buildSideSerializer.createInstance2());
    }

    public ProbeIterator getSpilledPartitionProbeSideIter(LongHashPartition longHashPartition) throws IOException {
        if (this.currentSpilledProbeSide != null) {
            try {
                this.currentSpilledProbeSide.getChannel().closeAndDelete();
            } catch (Throwable th) {
                LOG.warn("Could not close and delete the temp file for the current spilled partition probe side.", th);
            }
            this.currentSpilledProbeSide = null;
        }
        this.probeIterator = new ProbeIterator(this.probeSideSerializer.createInstance2());
        setPartitionProbeReader(longHashPartition);
        return this.probeIterator;
    }

    @Override // org.apache.flink.table.runtime.hashtable.BaseHybridHashTable
    protected void clearPartitions() {
        this.probeIterator = null;
        for (int size = this.partitionsBeingBuilt.size() - 1; size >= 0; size--) {
            try {
                this.partitionsBeingBuilt.get(size).clearAllMemory(this.internalPool);
            } catch (Exception e) {
                LOG.error("Error during partition cleanup.", e);
            }
        }
        this.partitionsBeingBuilt.clear();
        Iterator<LongHashPartition> it = this.partitionsPending.iterator();
        while (it.hasNext()) {
            it.next().clearAllMemory(this.internalPool);
        }
        Iterator<LongHashPartition> it2 = this.partitionsPendingForSMJ.iterator();
        while (it2.hasNext()) {
            try {
                it2.next().clearAllMemory(this.internalPool);
            } catch (Exception e2) {
                LOG.error("Error during partition cleanup.", e2);
            }
        }
        this.partitionsPendingForSMJ.clear();
    }

    public boolean compressionEnable() {
        return this.compressionEnable;
    }

    public BlockCompressionFactory compressionCodecFactory() {
        return this.compressionCodecFactory;
    }

    public int compressionBlockSize() {
        return this.compressionBlockSize;
    }
}
