/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
import org.opensearch.jobscheduler.spi.schedule.Schedule;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.settings.SettingsChangeListener;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.jobs.MLJobParameter;
import org.opensearch.ml.jobs.MLJobType;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.remote.metadata.client.PutDataObjectRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.UpdateDataObjectRequest;
import org.opensearch.remote.metadata.client.UpdateDataObjectResponse;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.client.Requests;

public class MLTaskManager
implements SettingsChangeListener {
    @Generated
    private static final Logger log = LogManager.getLogger(MLTaskManager.class);
    public static int TASK_SEMAPHORE_TIMEOUT = 5000;
    private final Map<String, MLTaskCache> taskCaches;
    private final Client client;
    private final SdkClient sdkClient;
    private final ThreadPool threadPool;
    private final MLIndicesHandler mlIndicesHandler;
    private final Map<MLTaskType, AtomicInteger> runningTasksCount;
    private boolean taskPollingJobStarted;
    private boolean statsCollectorJobStarted;
    public static final ImmutableSet<MLTaskState> TASK_DONE_STATES = ImmutableSet.of((Object)MLTaskState.COMPLETED, (Object)MLTaskState.COMPLETED_WITH_ERROR, (Object)MLTaskState.FAILED, (Object)MLTaskState.CANCELLED);

    public MLTaskManager(Client client, SdkClient sdkClient, ThreadPool threadPool, MLIndicesHandler mlIndicesHandler) {
        this.client = client;
        this.sdkClient = sdkClient;
        this.threadPool = threadPool;
        this.mlIndicesHandler = mlIndicesHandler;
        this.taskCaches = new ConcurrentHashMap<String, MLTaskCache>();
        this.runningTasksCount = new ConcurrentHashMap<MLTaskType, AtomicInteger>();
    }

    public synchronized void checkLimitAndAddRunningTask(MLTask mlTask, Integer limit) {
        AtomicInteger runningTaskCount = this.runningTasksCount.computeIfAbsent(mlTask.getTaskType(), it -> new AtomicInteger(0));
        if (runningTaskCount.get() < 0) {
            runningTaskCount.set(0);
        }
        log.debug("Task id: {}, current running task {}: {}", (Object)mlTask.getTaskId(), (Object)mlTask.getTaskType(), (Object)runningTaskCount.get());
        if (runningTaskCount.get() >= limit) {
            String error = "exceed max running task limit";
            log.warn("{} for task {}", (Object)error, (Object)mlTask.getTaskId());
            throw new MLLimitExceededException(error);
        }
        if (this.contains(mlTask.getTaskId())) {
            this.getMLTask(mlTask.getTaskId()).setState(MLTaskState.RUNNING);
        } else {
            mlTask.setState(MLTaskState.RUNNING);
            this.add(mlTask);
        }
        runningTaskCount.incrementAndGet();
    }

    public synchronized void checkMaxBatchJobTask(MLTaskType mlTaskType, Integer maxTaskLimit, ActionListener<Boolean> listener) {
        try {
            BoolQueryBuilder boolQuery = QueryBuilders.boolQuery().must((QueryBuilder)QueryBuilders.termQuery((String)"task_type", (String)mlTaskType.name())).must((QueryBuilder)QueryBuilders.boolQuery().should((QueryBuilder)QueryBuilders.termQuery((String)"state", (Object)MLTaskState.CREATED)).should((QueryBuilder)QueryBuilders.termQuery((String)"state", (Object)MLTaskState.RUNNING)));
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query((QueryBuilder)boolQuery);
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-task"});
            searchRequest.source(searchSourceBuilder);
            try (ThreadContext.StoredContext threadContext = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener internalListener = ActionListener.runBefore((ActionListener)ActionListener.wrap(searchResponse -> {
                    long matchedCount = searchResponse.getHits().getHits().length;
                    Boolean exceedLimit = false;
                    if (matchedCount >= (long)maxTaskLimit.intValue()) {
                        exceedLimit = true;
                    }
                    listener.onResponse((Object)exceedLimit);
                }, arg_0 -> listener.onFailure(arg_0)), () -> threadContext.restore());
                this.client.admin().indices().refresh(Requests.refreshRequest((String[])new String[]{".plugins-ml-task"}), ActionListener.wrap(refreshResponse -> this.client.search(searchRequest, internalListener), e -> {
                    log.error("Failed to refresh Task index during search MLTaskType for {}", (Object)mlTaskType, e);
                    internalListener.onFailure(e);
                }));
            }
            catch (Exception e2) {
                listener.onFailure(e2);
            }
        }
        catch (Exception e3) {
            log.error("Failed to search ML task for {}", (Object)mlTaskType, (Object)e3);
            listener.onFailure(e3);
        }
    }

    public synchronized void add(MLTask mlTask) {
        this.add(mlTask, null);
    }

    public synchronized void add(MLTask mlTask, List<String> workerNodes) {
        String taskId = mlTask.getTaskId();
        if (this.contains(taskId)) {
            throw new IllegalArgumentException("Duplicate taskId");
        }
        this.taskCaches.put(taskId, new MLTaskCache(mlTask, workerNodes));
        log.debug("add ML task to cache, taskId: {}, taskType: {} ", (Object)taskId, (Object)mlTask.getTaskType());
    }

    public boolean contains(String taskId) {
        return this.taskCaches.containsKey(taskId);
    }

    public void remove(String taskId) {
        if (this.contains(taskId)) {
            AtomicInteger runningTaskCount;
            MLTaskCache taskCache = this.taskCaches.remove(taskId);
            MLTask mlTask = taskCache.getMlTask();
            if (mlTask.getState() != MLTaskState.CREATED && (runningTaskCount = this.runningTasksCount.get(mlTask.getTaskType())) != null) {
                runningTaskCount.decrementAndGet();
            }
            log.debug("remove ML task from cache {}", (Object)taskId);
        }
    }

    public MLTask getMLTask(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId).getMlTask();
        }
        return null;
    }

    public MLTaskCache getMLTaskCache(String taskId) {
        if (this.contains(taskId)) {
            return this.taskCaches.get(taskId);
        }
        return null;
    }

    public Set<String> getWorkNodes(String taskId) {
        if (this.taskCaches.containsKey(taskId)) {
            return this.taskCaches.get(taskId).getWorkerNodes();
        }
        return null;
    }

    public void addNodeError(String taskId, String workerNodeId, String error) {
        log.debug("add task error: taskId: {}, workerNodeId: {}, error: {}", (Object)taskId, (Object)workerNodeId, (Object)error);
        if (this.taskCaches.containsKey(taskId)) {
            this.taskCaches.get(taskId).addError(workerNodeId, error);
        }
    }

    public String[] getAllTaskIds() {
        return Strings.toStringArray(this.taskCaches.keySet());
    }

    public int getRunningTaskCount() {
        int res = 0;
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getState() == null || mlTask.getState() != MLTaskState.RUNNING) continue;
            ++res;
        }
        return res;
    }

    public void clear() {
        this.taskCaches.clear();
    }

    public void createMLTask(MLTask mlTask, ActionListener<IndexResponse> listener) {
        this.mlIndicesHandler.initMLTaskIndex(ActionListener.wrap(indexCreated -> {
            if (!indexCreated.booleanValue()) {
                listener.onFailure((Exception)new RuntimeException("No response to create ML task index"));
                return;
            }
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.sdkClient.putDataObjectAsync(((PutDataObjectRequest.Builder)((PutDataObjectRequest.Builder)PutDataObjectRequest.builder().index(".plugins-ml-task")).tenantId(mlTask.getTenantId())).dataObject((ToXContentObject)mlTask).build()).whenComplete((r, throwable) -> {
                    context.restore();
                    if (throwable != null) {
                        Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
                        log.error("Failed to index ML task", (Throwable)cause);
                        listener.onFailure(cause);
                    } else {
                        try {
                            IndexResponse indexResponse = r.indexResponse();
                            log.info("Task creation result: {}, Task id: {}", (Object)indexResponse.getResult(), (Object)indexResponse.getId());
                            listener.onResponse((Object)indexResponse);
                        }
                        catch (Exception e) {
                            listener.onFailure(e);
                        }
                    }
                });
            }
            catch (Exception e) {
                log.error("Failed to create ML task for {}, {}", (Object)mlTask.getFunctionName(), (Object)mlTask.getTaskType(), (Object)e);
                listener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to create ML task index", (Throwable)e);
            listener.onFailure(e);
        }));
    }

    public void updateTaskStateAsRunning(String taskId, String tenantId, boolean isAsyncTask) {
        if (!this.contains(taskId)) {
            throw new IllegalArgumentException("Task not found");
        }
        MLTask task = this.getMLTask(taskId);
        task.setState(MLTaskState.RUNNING);
        if (isAsyncTask) {
            this.updateMLTask(taskId, tenantId, (Map<String, Object>)ImmutableMap.of((Object)"state", (Object)MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
        }
    }

    public void updateMLTask(String taskId, String tenantId, Map<String, Object> updatedFields, long timeoutInMillis, boolean removeFromCache) {
        ActionListener internalListener = ActionListener.wrap(response -> {
            if (response.status() == RestStatus.OK) {
                log.debug("Updated ML task successfully: {}, taskId: {}, updatedFields: {}", (Object)response.status(), (Object)taskId, (Object)updatedFields);
            } else {
                log.error("Failed to update ML task {}, status: {}, updatedFields: {}", (Object)taskId, (Object)response.status(), (Object)updatedFields);
            }
        }, e -> MLExceptionUtils.logException("Failed to update ML task: " + taskId, e, log));
        this.updateMLTask(taskId, tenantId, updatedFields, (ActionListener<UpdateResponse>)internalListener, timeoutInMillis, removeFromCache);
    }

    public void updateMLTask(String taskId, String tenantId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener, long timeoutInMillis, boolean removeFromCache) {
        MLTaskCache taskCache = this.taskCaches.get(taskId);
        if (removeFromCache) {
            this.remove(taskId);
        }
        if (taskCache == null) {
            listener.onFailure((Exception)new MLResourceNotFoundException("Can't find task in cache: " + taskId));
            return;
        }
        this.threadPool.executor("opensearch_ml_general").execute(() -> {
            Semaphore semaphore = taskCache.getUpdateTaskIndexSemaphore();
            try {
                if (semaphore != null && !semaphore.tryAcquire(timeoutInMillis, TimeUnit.MILLISECONDS)) {
                    listener.onFailure((Exception)new MLException("Other updating request not finished yet"));
                    return;
                }
            }
            catch (InterruptedException e) {
                log.error("Failed to acquire semaphore for ML task {}", (Object)taskId, (Object)e);
                listener.onFailure((Exception)e);
                return;
            }
            try {
                if (updatedFields == null || updatedFields.isEmpty()) {
                    listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                    return;
                }
                HashMap<String, Long> updatedContent = new HashMap<String, Long>(updatedFields);
                updatedContent.put("last_update_time", Instant.now().toEpochMilli());
                UpdateDataObjectRequest.Builder requestBuilder = ((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)((UpdateDataObjectRequest.Builder)UpdateDataObjectRequest.builder().index(".plugins-ml-task")).id(taskId)).tenantId(tenantId)).dataObject(updatedContent);
                if (updatedFields.containsKey("state") && TASK_DONE_STATES.contains((Object)updatedFields.containsKey("state"))) {
                    requestBuilder.retryOnConflict(3);
                }
                UpdateDataObjectRequest updateDataObjectRequest = requestBuilder.build();
                ActionListener actionListener = semaphore == null ? listener : ActionListener.runAfter((ActionListener)listener, semaphore::release);
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    this.sdkClient.updateDataObjectAsync(updateDataObjectRequest).whenComplete((r, throwable) -> {
                        context.restore();
                        if (semaphore != null) {
                            semaphore.release();
                        }
                        this.handleUpdateDataObjectCompletionStage((UpdateDataObjectResponse)r, (Throwable)throwable, this.getUpdateResponseListener(taskId, listener));
                    });
                }
                catch (Exception e) {
                    log.error("Failed to update ML task {}", (Object)taskId, (Object)e);
                    actionListener.onFailure(e);
                }
            }
            catch (Exception e) {
                if (semaphore != null) {
                    semaphore.release();
                }
                log.error("Failed to update ML task {}", (Object)taskId, (Object)e);
                listener.onFailure(e);
            }
        });
    }

    public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields) {
        this.updateMLTaskDirectly(taskId, updatedFields, (ActionListener<UpdateResponse>)ActionListener.wrap(r -> log.debug("updated ML task directly: {}", (Object)taskId), e -> log.error("Failed to update ML task {}", (Object)taskId, e)));
    }

    public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedFields, ActionListener<UpdateResponse> listener) {
        try {
            if (taskId == null || taskId.isEmpty()) {
                listener.onFailure((Exception)new IllegalArgumentException("Task ID is null or empty"));
                return;
            }
            if (updatedFields == null || updatedFields.isEmpty()) {
                listener.onFailure((Exception)new IllegalArgumentException("Updated fields is null or empty"));
                return;
            }
            if (updatedFields.containsKey("state") && !(updatedFields.get("state") instanceof MLTaskState)) {
                listener.onFailure((Exception)new IllegalArgumentException("Invalid task state"));
                return;
            }
            UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-task", taskId);
            HashMap<String, Object> updatedContent = new HashMap<String, Object>(updatedFields);
            updatedContent.put("last_update_time", Instant.now().toEpochMilli());
            updateRequest.doc(updatedContent);
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            if (updatedFields.containsKey("state") && TASK_DONE_STATES.contains((Object)((MLTaskState)updatedFields.get("state")))) {
                updateRequest.retryOnConflict(3);
            }
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.client.update(updateRequest, ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore()));
            }
            catch (Exception e) {
                listener.onFailure(e);
            }
        }
        catch (Exception e) {
            log.error("Failed to update ML task {}", (Object)taskId, (Object)e);
            listener.onFailure(e);
        }
    }

    public boolean containsModel(String modelId) {
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            if (!modelId.equals(entry.getValue().mlTask.getModelId())) continue;
            return true;
        }
        return false;
    }

    public List<String[]> getLocalRunningDeployModelTasks() {
        ArrayList<String> runningDeployModelTaskIds = new ArrayList<String>();
        ArrayList<String> runningDeployModelIds = new ArrayList<String>();
        for (Map.Entry<String, MLTaskCache> entry : this.taskCaches.entrySet()) {
            MLTask mlTask = entry.getValue().getMlTask();
            if (mlTask.getTaskType() != MLTaskType.DEPLOY_MODEL || mlTask.getState() == MLTaskState.CREATED) continue;
            runningDeployModelTaskIds.add(entry.getKey());
            runningDeployModelIds.add(mlTask.getModelId());
        }
        return Arrays.asList(runningDeployModelTaskIds.toArray(new String[0]), runningDeployModelIds.toArray(new String[0]));
    }

    private void handleUpdateDataObjectCompletionStage(UpdateDataObjectResponse r, Throwable throwable, ActionListener<UpdateResponse> updateListener) {
        if (throwable != null) {
            Exception cause = SdkClientUtils.unwrapAndConvertToException((Throwable)throwable, (Class[])new Class[0]);
            updateListener.onFailure(cause);
        } else {
            try {
                updateListener.onResponse((Object)r.updateResponse());
            }
            catch (Exception e) {
                updateListener.onFailure(e);
            }
        }
    }

    private ActionListener<UpdateResponse> getUpdateResponseListener(String taskId, ActionListener<UpdateResponse> actionListener) {
        return ActionListener.wrap(updateResponse -> {
            if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
                log.error("Failed to update the task with ID: {}", (Object)taskId);
                actionListener.onResponse(updateResponse);
                return;
            }
            log.info("Successfully updated the task with ID: {}", (Object)taskId);
            actionListener.onResponse(updateResponse);
        }, exception -> {
            log.error("Failed to update ML task with ID {}. Details: {}", (Object)taskId, exception);
            actionListener.onFailure(exception);
        });
    }

    public void startTaskPollingJob() {
        if (this.taskPollingJobStarted) {
            return;
        }
        try {
            MLJobParameter jobParameter = new MLJobParameter(MLJobType.BATCH_TASK_UPDATE.name(), (Schedule)new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES), 20L, null, MLJobType.BATCH_TASK_UPDATE, true);
            IndexRequest indexRequest = (IndexRequest)((IndexRequest)new IndexRequest().index(".plugins-ml-jobs")).id(MLJobType.BATCH_TASK_UPDATE.name()).source(jobParameter.toXContent(JsonXContent.contentBuilder(), null)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            this.indexJob(indexRequest, MLJobType.BATCH_TASK_UPDATE, () -> {
                this.taskPollingJobStarted = true;
            });
        }
        catch (IOException e) {
            log.error("Failed to index task polling job", (Throwable)e);
        }
    }

    public void onStaticMetricCollectionEnabledChanged(boolean isEnabled) {
        log.info("Static metric collection setting changed to: {}", (Object)isEnabled);
        this.indexStatsCollectorJob(isEnabled);
    }

    public void indexStatsCollectorJob(boolean enabled) {
        try {
            MLJobParameter jobParameter = new MLJobParameter(MLJobType.STATS_COLLECTOR.name(), (Schedule)new IntervalSchedule(Instant.now(), 5, ChronoUnit.MINUTES), 60L, null, MLJobType.STATS_COLLECTOR, enabled);
            IndexRequest indexRequest = (IndexRequest)((IndexRequest)new IndexRequest().index(".plugins-ml-jobs")).id(MLJobType.STATS_COLLECTOR.name()).source(jobParameter.toXContent(JsonXContent.contentBuilder(), null)).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            this.indexJob(indexRequest, MLJobType.STATS_COLLECTOR, () -> {});
        }
        catch (IOException e) {
            log.error("Failed to index stats collection job", (Throwable)e);
        }
    }

    private void indexJob(IndexRequest indexRequest, MLJobType jobType, Runnable successCallback) {
        this.mlIndicesHandler.initMLJobsIndex(ActionListener.wrap(success -> {
            if (success.booleanValue()) {
                try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                    this.client.index(indexRequest, ActionListener.runBefore((ActionListener)ActionListener.wrap(r -> {
                        log.info("Indexed {} successfully", (Object)jobType.name());
                        if (successCallback != null) {
                            successCallback.run();
                        }
                    }, e -> log.error("Failed to index {} job", (Object)jobType.name(), e)), () -> ((ThreadContext.StoredContext)context).restore()));
                }
            }
        }, e -> log.error("Failed to initialize ML jobs index", (Throwable)e)));
    }
}

