/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.helper;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.transport.client.Client;

public final class MemoryContainerModelValidator {
    @Generated
    private static final Logger log = LogManager.getLogger(MemoryContainerModelValidator.class);

    public static void validateLlmModel(String llmId, MLModelManager modelManager, Client client, ActionListener<Boolean> listener) {
        if (llmId == null) {
            listener.onResponse((Object)true);
            return;
        }
        try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore((ActionListener)ActionListener.wrap(llmModel -> {
                if (llmModel.getAlgorithm() != FunctionName.REMOTE) {
                    listener.onFailure((Exception)new IllegalArgumentException(String.format("LLM model must be a REMOTE model, found: %s", llmModel.getAlgorithm())));
                    return;
                }
                listener.onResponse((Object)true);
            }, e -> {
                log.error("Failed to get LLM model: {}", (Object)llmId, e);
                listener.onFailure((Exception)new IllegalArgumentException(String.format("LLM model with ID %s not found", llmId)));
            }), () -> ((ThreadContext.StoredContext)context).restore());
            modelManager.getModel(llmId, (ActionListener<MLModel>)wrappedListener);
        }
    }

    public static void validateEmbeddingModel(String embeddingModelId, FunctionName expectedType, MLModelManager modelManager, Client client, ActionListener<Boolean> listener) {
        if (embeddingModelId == null) {
            listener.onResponse((Object)true);
            return;
        }
        try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore((ActionListener)ActionListener.wrap(embeddingModel -> {
                FunctionName modelAlgorithm = embeddingModel.getAlgorithm();
                if (modelAlgorithm != expectedType && modelAlgorithm != FunctionName.REMOTE) {
                    listener.onFailure((Exception)new IllegalArgumentException(String.format("Embedding model must be of type %s or REMOTE, found: %s", expectedType, modelAlgorithm)));
                    return;
                }
                listener.onResponse((Object)true);
            }, e -> {
                log.error("Failed to get embedding model: {}", (Object)embeddingModelId, e);
                listener.onFailure((Exception)new IllegalArgumentException(String.format("Embedding model with ID %s not found", embeddingModelId)));
            }), () -> ((ThreadContext.StoredContext)context).restore());
            modelManager.getModel(embeddingModelId, (ActionListener<MLModel>)wrappedListener);
        }
    }

    @Generated
    private MemoryContainerModelValidator() {
    }
}

