/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.prediction;

import java.lang.runtime.SwitchBootstraps;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLPredictTaskRunner;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportPredictionTaskAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportPredictionTaskAction.class);
    private MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> mlPredictTaskRunner;
    private TransportService transportService;
    private MLModelCacheHelper modelCacheHelper;
    private Client client;
    private SdkClient sdkClient;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private MLModelManager mlModelManager;
    private ModelAccessControlHelper modelAccessControlHelper;
    private volatile boolean enableAutomaticDeployment;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public TransportPredictionTaskAction(TransportService transportService, ActionFilters actionFilters, MLModelCacheHelper modelCacheHelper, MLPredictTaskRunner mlPredictTaskRunner, ClusterService clusterService, Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, MLModelManager mlModelManager, ModelAccessControlHelper modelAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings) {
        super("cluster:admin/opensearch/ml/predict", transportService, actionFilters, MLPredictionTaskRequest::new);
        this.mlPredictTaskRunner = mlPredictTaskRunner;
        this.transportService = transportService;
        this.modelCacheHelper = modelCacheHelper;
        this.clusterService = clusterService;
        this.client = client;
        this.sdkClient = sdkClient;
        this.xContentRegistry = xContentRegistry;
        this.mlModelManager = mlModelManager;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.enableAutomaticDeployment = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> {
            this.enableAutomaticDeployment = it;
        });
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
        final MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest((ActionRequest)request);
        final String modelId = mlPredictionTaskRequest.getModelId();
        String tenantId = mlPredictionTaskRequest.getTenantId();
        if (!TenantAwareHelper.validateTenantId(this.mlFeatureEnabledSetting, tenantId, listener)) {
            return;
        }
        User user = mlPredictionTaskRequest.getUser();
        if (user == null) {
            user = RestActionUtils.getUserContext(this.client);
            mlPredictionTaskRequest.setUser(user);
        }
        final User userInfo = user;
        try (final ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            final ActionListener wrappedListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            MLModel cachedMlModel = this.modelCacheHelper.getModelInfo(modelId);
            ActionListener<MLModel> modelActionListener = new ActionListener<MLModel>(){

                public void onResponse(MLModel mlModel) {
                    context.restore();
                    TransportPredictionTaskAction.this.modelCacheHelper.setModelInfo(modelId, mlModel);
                    FunctionName functionName = mlModel.getAlgorithm();
                    if (FunctionName.isDLModel((FunctionName)functionName) && !TransportPredictionTaskAction.this.mlFeatureEnabledSetting.isLocalModelEnabled()) {
                        throw new OpenSearchStatusException("Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.", RestStatus.BAD_REQUEST, new Object[0]);
                    }
                    mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
                    if (mlModel.getIsHidden().booleanValue()) {
                        TransportPredictionTaskAction.this.handleModelAccessForHiddenModel(modelId, mlPredictionTaskRequest, userInfo, (ActionListener<MLTaskResponse>)wrappedListener, functionName);
                    } else {
                        TransportPredictionTaskAction.this.handleModelAccessForVisibleModel(mlModel, mlPredictionTaskRequest, (ActionListener<MLTaskResponse>)wrappedListener, functionName, userInfo);
                    }
                }

                public void onFailure(Exception e) {
                    log.error("Failed to find model {}", (Object)modelId, (Object)e);
                    wrappedListener.onFailure(e);
                }
            };
            if (cachedMlModel != null) {
                modelActionListener.onResponse((Object)cachedMlModel);
            } else {
                this.mlModelManager.getModel(modelId, tenantId, modelActionListener);
            }
        }
    }

    private void handleModelAccessForHiddenModel(String modelId, MLPredictionTaskRequest mlPredictionTaskRequest, User userInfo, ActionListener<MLTaskResponse> wrappedListener, FunctionName functionName) {
        if (this.modelCacheHelper.getIsModelEnabled(modelId) != null && !this.modelCacheHelper.getIsModelEnabled(modelId).booleanValue()) {
            wrappedListener.onFailure((Exception)new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN, new Object[0]));
        } else {
            if (FunctionName.isDLModel((FunctionName)functionName) && !this.checkRateLimiting(modelId, userInfo, wrappedListener)) {
                return;
            }
            this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
            this.executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
        }
    }

    private void handleModelAccessForVisibleModel(MLModel mlModel, MLPredictionTaskRequest mlPredictionTaskRequest, ActionListener<MLTaskResponse> wrappedListener, FunctionName functionName, User userInfo) {
        this.modelAccessControlHelper.validateModelGroupAccess(userInfo, this.mlFeatureEnabledSetting, mlPredictionTaskRequest.getTenantId(), mlModel.getModelGroupId(), "cluster:admin/opensearch/ml/predict", this.client, this.sdkClient, (ActionListener<Boolean>)ActionListener.wrap(access -> {
            if (!access.booleanValue()) {
                wrappedListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model", RestStatus.FORBIDDEN, new Object[0]));
            } else {
                String modelId = mlModel.getModelId();
                if (this.modelCacheHelper.getIsModelEnabled(modelId) != null && !this.modelCacheHelper.getIsModelEnabled(modelId).booleanValue()) {
                    wrappedListener.onFailure((Exception)new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN, new Object[0]));
                } else {
                    if (FunctionName.isDLModel((FunctionName)functionName) && !this.checkRateLimiting(modelId, userInfo, wrappedListener)) {
                        return;
                    }
                    this.validateInputSchema(modelId, mlPredictionTaskRequest.getMlInput());
                    this.executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
                }
            }
        }, e -> {
            String safeId = mlPredictionTaskRequest.getModelId();
            if (safeId == null) {
                safeId = mlModel.getModelId();
            }
            this.handleError((Exception)e, safeId, wrappedListener);
        }));
    }

    private boolean checkRateLimiting(String modelId, User userInfo, ActionListener<MLTaskResponse> wrappedListener) {
        if (this.modelCacheHelper.getRateLimiter(modelId) != null && !this.modelCacheHelper.getRateLimiter(modelId).request()) {
            wrappedListener.onFailure((Exception)new OpenSearchStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
            return false;
        }
        if (userInfo != null && this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()) != null && !this.modelCacheHelper.getUserRateLimiter(modelId, userInfo.getName()).request()) {
            wrappedListener.onFailure((Exception)new OpenSearchStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]));
            return false;
        }
        return true;
    }

    private void handleError(Exception e, String modelId, ActionListener<MLTaskResponse> wrappedListener) {
        log.error("Failed to Validate Access for ModelId {}", (Object)modelId, (Object)e);
        Exception exception = e;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{OpenSearchStatusException.class, MLResourceNotFoundException.class, CircuitBreakingException.class}, (Object)exception, n)) {
            case 0: {
                OpenSearchStatusException openSearchStatusException = (OpenSearchStatusException)exception;
                wrappedListener.onFailure((Exception)new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode((int)openSearchStatusException.status().getStatus()), new Object[0]));
                break;
            }
            case 1: {
                MLResourceNotFoundException mlResourceNotFoundException = (MLResourceNotFoundException)exception;
                wrappedListener.onFailure((Exception)new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND, new Object[0]));
                break;
            }
            case 2: {
                CircuitBreakingException circuitBreakingException = (CircuitBreakingException)exception;
                wrappedListener.onFailure(e);
                break;
            }
            default: {
                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Failed to Validate Access for ModelId " + modelId, RestStatus.FORBIDDEN, new Object[0]));
            }
        }
    }

    private void executePredict(MLPredictionTaskRequest mlPredictionTaskRequest, ActionListener<MLTaskResponse> wrappedListener, String modelId) {
        String requestId = mlPredictionTaskRequest.getRequestID();
        log.debug("receive predict request {} for model {}", (Object)requestId, (Object)mlPredictionTaskRequest.getModelId());
        long startTime = System.nanoTime();
        FunctionName functionName = this.modelCacheHelper.getOptionalFunctionName(modelId).orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm());
        this.mlPredictTaskRunner.run(functionName, mlPredictionTaskRequest, this.transportService, (ActionListener<MLTaskResponse>)ActionListener.runAfter(wrappedListener, () -> {
            long endTime = System.nanoTime();
            double durationInMs = (double)(endTime - startTime) / 1000000.0;
            this.modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
            this.modelCacheHelper.refreshLastAccessTime(modelId);
            log.debug("completed predict request {} for model {}", (Object)requestId, (Object)modelId);
        }));
    }

    public void validateInputSchema(String modelId, MLInput mlInput) {
        ConnectorAction.ActionType actionType = null;
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            actionType = ((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getActionType();
        }
        ConnectorAction.ActionType actionType2 = actionType = actionType == null ? ConnectorAction.ActionType.PREDICT : actionType;
        if (actionType == ConnectorAction.ActionType.BATCH_PREDICT) {
            return;
        }
        if (this.modelCacheHelper.getModelInterface(modelId) != null && this.modelCacheHelper.getModelInterface(modelId).get("input") != null) {
            String inputSchemaString = this.modelCacheHelper.getModelInterface(modelId).get("input");
            try {
                String InputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
                String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString, inputSchemaString);
                MLNodeUtils.validateSchema(inputSchemaString, processedInputString);
            }
            catch (Exception e) {
                throw new OpenSearchStatusException("Error validating input schema, if you think this is expected, please update your 'input' field in the 'interface' field for this model: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]);
            }
        }
    }
}

