/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.index.engine.Engine;
import org.opensearch.index.engine.EngineException;
import org.opensearch.index.shard.IllegalIndexShardStateException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.neuralsearch.sparse.accessor.SparseVectorReader;
import org.opensearch.neuralsearch.sparse.cache.CacheGatedForwardIndexReader;
import org.opensearch.neuralsearch.sparse.cache.CacheGatedPostingsReader;
import org.opensearch.neuralsearch.sparse.cache.CacheKey;
import org.opensearch.neuralsearch.sparse.cache.ClusteredPostingCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCache;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCacheItem;
import org.opensearch.neuralsearch.sparse.codec.CodecUtilWrapper;
import org.opensearch.neuralsearch.sparse.codec.SparseBinaryDocValuesPassThrough;
import org.opensearch.neuralsearch.sparse.codec.SparseTermsLuceneReader;
import org.opensearch.neuralsearch.sparse.common.PredicateUtils;
import org.opensearch.neuralsearch.sparse.mapper.SparseVectorField;

public class NeuralSparseIndexShard {
    @Generated
    private static final Logger log = LogManager.getLogger(NeuralSparseIndexShard.class);
    @NonNull
    private final IndexShard indexShard;
    private static final String WARM_UP_SEARCHER_SOURCE = "warm-up-searcher-source";
    private static final String CLEAR_CACHE_SEARCHER_SOURCE = "clear-cache-searcher-source";

    public String getIndexName() {
        return this.indexShard.shardId().getIndexName();
    }

    public void warmUp() throws IOException {
        try (Engine.Searcher searcher = this.indexShard.acquireSearcher(WARM_UP_SEARCHER_SOURCE);){
            List<CacheOperationContext> cacheOperationContexts = this.collectCacheOperationContexts(searcher);
            this.warmUpAllForwardIndices(cacheOperationContexts);
            this.warmUpAllClusteredPostings(cacheOperationContexts);
        }
        catch (EngineException | IllegalIndexShardStateException e) {
            log.error("[Neural Sparse] Failed to acquire searcher", e);
            throw e;
        }
        catch (CircuitBreakingException e) {
            log.error("[Neural Sparse] Circuit Breaker reaches limit", (Throwable)e);
            throw e;
        }
        catch (IOException e) {
            log.error("[Neural Sparse] Failed to read data during warm up", (Throwable)e);
            throw e;
        }
    }

    public void clearCache() throws IOException {
        try (Engine.Searcher searcher = this.indexShard.acquireSearcher(CLEAR_CACHE_SEARCHER_SOURCE);){
            List<CacheOperationContext> cacheOperationContexts = this.collectCacheOperationContexts(searcher);
            this.clearAllCaches(cacheOperationContexts);
        }
        catch (EngineException | IllegalIndexShardStateException e) {
            log.error("[Neural Sparse] Failed to acquire searcher", e);
            throw e;
        }
        catch (IOException e) {
            log.error("[Neural Sparse] Failed to read data during cache clearing", (Throwable)e);
            throw e;
        }
    }

    private void warmUpAllForwardIndices(List<CacheOperationContext> contexts) throws IOException, CircuitBreakingException {
        for (CacheOperationContext context : contexts) {
            BinaryDocValues binaryDocValues = context.binaryDocValues;
            SparseVectorReader forwardIndexReader = context.forwardIndexReader;
            if (forwardIndexReader == null) continue;
            int docId = binaryDocValues.nextDoc();
            while (docId != Integer.MAX_VALUE) {
                forwardIndexReader.read(docId);
                docId = binaryDocValues.nextDoc();
            }
        }
    }

    private void warmUpAllClusteredPostings(List<CacheOperationContext> contexts) throws IOException, CircuitBreakingException {
        for (CacheOperationContext context : contexts) {
            CacheGatedPostingsReader postingsReader = context.postingsReader;
            Set<BytesRef> terms = postingsReader.getTerms();
            for (BytesRef term : terms) {
                postingsReader.read(term);
            }
        }
    }

    private void clearAllCaches(List<CacheOperationContext> contexts) {
        for (CacheOperationContext context : contexts) {
            CacheKey cacheKey = context.cacheKey;
            ClusteredPostingCache.getInstance().onIndexRemoval(cacheKey);
            ForwardIndexCache.getInstance().onIndexRemoval(cacheKey);
        }
    }

    private SparseVectorReader getCacheGatedForwardIndexReader(BinaryDocValues binaryDocValues, CacheKey key, int docCount) {
        if (!(binaryDocValues instanceof SparseBinaryDocValuesPassThrough)) {
            return SparseVectorReader.NOOP_READER;
        }
        SparseBinaryDocValuesPassThrough sparseBinaryDocValues = (SparseBinaryDocValuesPassThrough)binaryDocValues;
        ForwardIndexCacheItem cacheItem = ForwardIndexCache.getInstance().getOrCreate(key, docCount);
        return new CacheGatedForwardIndexReader(cacheItem.getReader(), cacheItem.getWriter(this::customizedConsumer), sparseBinaryDocValues);
    }

    private CacheGatedPostingsReader getCacheGatedPostingReader(FieldInfo fieldInfo, CacheKey key, SegmentInfo segmentInfo) throws IOException {
        SparseTermsLuceneReader luceneReader = new SparseTermsLuceneReader(this.createSegmentReadState(segmentInfo), new CodecUtilWrapper());
        return new CacheGatedPostingsReader(fieldInfo.name, ClusteredPostingCache.getInstance().getOrCreate(key).getReader(), ClusteredPostingCache.getInstance().getOrCreate(key).getWriter(this::customizedConsumer), luceneReader);
    }

    private void customizedConsumer(long ramBytesUsed) {
        throw new CircuitBreakingException("Circuit Breaker reaches limit", CircuitBreaker.Durability.PERMANENT);
    }

    private Set<FieldInfo> collectSparseFieldInfos(LeafReader leafReader) {
        return StreamSupport.stream(leafReader.getFieldInfos().spliterator(), false).filter(SparseVectorField::isSparseField).collect(Collectors.toSet());
    }

    private SegmentReadState createSegmentReadState(SegmentInfo segmentInfo) throws IOException {
        Codec codec = segmentInfo.getCodec();
        Object cfsDir = segmentInfo.getUseCompoundFile() ? codec.compoundFormat().getCompoundReader(segmentInfo.dir, segmentInfo) : segmentInfo.dir;
        FieldInfos coreFieldInfos = codec.fieldInfosFormat().read(cfsDir, segmentInfo, "", IOContext.DEFAULT);
        return new SegmentReadState(cfsDir, segmentInfo, coreFieldInfos, IOContext.DEFAULT);
    }

    private List<CacheOperationContext> collectCacheOperationContexts(Engine.Searcher searcher) throws IOException {
        ArrayList<CacheOperationContext> contexts = new ArrayList<CacheOperationContext>();
        for (LeafReaderContext leafReaderContext : searcher.getIndexReader().leaves()) {
            LeafReader leafReader = leafReaderContext.reader();
            Set<FieldInfo> sparseFieldInfos = this.collectSparseFieldInfos(leafReader);
            SegmentReader segmentReader = Lucene.segmentReader((LeafReader)leafReader);
            SegmentInfo segmentInfo = segmentReader.getSegmentInfo().info;
            for (FieldInfo fieldInfo : sparseFieldInfos) {
                SparseVectorReader forwardIndexReader;
                if (!PredicateUtils.shouldRunSeisPredicate.test(segmentInfo, fieldInfo)) continue;
                CacheKey key = new CacheKey(segmentInfo, fieldInfo);
                BinaryDocValues binaryDocValues = leafReader.getBinaryDocValues(fieldInfo.name);
                if (binaryDocValues == null) {
                    log.error("[Neural Sparse] No binary doc values found for field: {}", (Object)fieldInfo.name);
                    forwardIndexReader = null;
                } else {
                    forwardIndexReader = this.getCacheGatedForwardIndexReader(binaryDocValues, key, segmentInfo.maxDoc());
                }
                CacheGatedPostingsReader postingsReader = this.getCacheGatedPostingReader(fieldInfo, key, segmentInfo);
                contexts.add(new CacheOperationContext(binaryDocValues, forwardIndexReader, postingsReader, key));
            }
        }
        return contexts;
    }

    @Generated
    public NeuralSparseIndexShard(@NonNull IndexShard indexShard) {
        Objects.requireNonNull(indexShard, "indexShard is marked non-null but is null");
        this.indexShard = indexShard;
    }

    @NonNull
    @Generated
    public IndexShard getIndexShard() {
        return this.indexShard;
    }

    private static class CacheOperationContext {
        final BinaryDocValues binaryDocValues;
        final SparseVectorReader forwardIndexReader;
        final CacheGatedPostingsReader postingsReader;
        final CacheKey cacheKey;

        CacheOperationContext(BinaryDocValues binaryDocValues, SparseVectorReader forwardIndexReader, CacheGatedPostingsReader postingsReader, CacheKey key) {
            this.binaryDocValues = binaryDocValues;
            this.forwardIndexReader = forwardIndexReader;
            this.postingsReader = postingsReader;
            this.cacheKey = key;
        }
    }
}

