package io.ray.runtime;

import io.ray.api.ActorHandle;
import io.ray.api.BaseActorHandle;
import io.ray.api.CppActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
import io.ray.api.WaitResult;
import io.ray.api.concurrencygroup.ConcurrencyGroup;
import io.ray.api.exception.RuntimeEnvException;
import io.ray.api.function.CppActorClass;
import io.ray.api.function.CppActorMethod;
import io.ray.api.function.CppFunction;
import io.ray.api.function.PyActorClass;
import io.ray.api.function.PyActorMethod;
import io.ray.api.function.PyFunction;
import io.ray.api.function.RayFunc;
import io.ray.api.function.RayFuncR;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.api.options.PlacementGroupCreationOptions;
import io.ray.api.parallelactor.ParallelActorContext;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtime.RayRuntime;
import io.ray.api.runtimecontext.RuntimeContext;
import io.ray.api.runtimeenv.RuntimeEnv;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.config.RunMode;
import io.ray.runtime.context.RuntimeContextImpl;
import io.ray.runtime.context.WorkerContext;
import io.ray.runtime.functionmanager.CppFunctionDescriptor;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.functionmanager.FunctionManager;
import io.ray.runtime.functionmanager.PyFunctionDescriptor;
import io.ray.runtime.functionmanager.RayFunction;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.generated.Common;
import io.ray.runtime.object.ObjectRefImpl;
import io.ray.runtime.object.ObjectStore;
import io.ray.runtime.runtimeenv.RuntimeEnvImpl;
import io.ray.runtime.task.ArgumentsBuilder;
import io.ray.runtime.task.FunctionArg;
import io.ray.runtime.task.TaskExecutor;
import io.ray.runtime.task.TaskSubmitter;
import io.ray.runtime.util.ConcurrencyGroupUtils;
import io.ray.runtime.utils.parallelactor.ParallelActorContextImpl;
import io.ray.shaded.com.fasterxml.jackson.core.JsonProcessingException;
import io.ray.shaded.com.fasterxml.jackson.databind.ObjectMapper;
import io.ray.shaded.com.fasterxml.jackson.databind.node.ObjectNode;
import io.ray.shaded.com.google.common.base.Preconditions;
import io.ray.shaded.com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/ray/runtime/AbstractRayRuntime.class */
public abstract class AbstractRayRuntime implements RayRuntime {
    public static final String PYTHON_INIT_METHOD_NAME = "__init__";
    protected RayConfig rayConfig;
    protected TaskExecutor taskExecutor;
    protected FunctionManager functionManager;
    protected RuntimeContext runtimeContext = new RuntimeContextImpl(this);
    protected ObjectStore objectStore;
    protected TaskSubmitter taskSubmitter;
    protected WorkerContext workerContext;
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) AbstractRayRuntime.class);
    private static ParallelActorContextImpl parallelActorContextImpl = new ParallelActorContextImpl();
    private static final ObjectMapper MAPPER = new ObjectMapper();

    public AbstractRayRuntime(RayConfig rayConfig) {
        this.rayConfig = rayConfig;
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> ObjectRef<T> put(T t) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Putting Object in Task {}.", this.workerContext.getCurrentTaskId());
        }
        return new ObjectRefImpl(this.objectStore.put(t), t == null ? Object.class : t.getClass(), true);
    }

    public abstract GcsClient getGcsClient();

    public abstract void start();

    public abstract void run();

    @Override // io.ray.api.runtime.RayRuntime
    public <T> ObjectRef<T> put(T t, BaseActorHandle baseActorHandle) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Putting an object in task {} with {} as the owner.", this.workerContext.getCurrentTaskId(), baseActorHandle.getId());
        }
        return new ObjectRefImpl(this.objectStore.put(t, baseActorHandle.getId()), t == null ? Object.class : t.getClass(), true);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> T get(ObjectRef<T> objectRef) throws RuntimeException {
        return (T) get(objectRef, -1L);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> T get(ObjectRef<T> objectRef, long j) throws RuntimeException {
        return get(ImmutableList.of(objectRef), j).get(0);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> List<T> get(List<ObjectRef<T>> list) {
        return get(list, -1L);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> List<T> get(List<ObjectRef<T>> list, long j) {
        ArrayList arrayList = new ArrayList();
        Class<?> cls = null;
        Iterator<ObjectRef<T>> it = list.iterator();
        while (it.hasNext()) {
            ObjectRefImpl objectRefImpl = (ObjectRefImpl) it.next();
            arrayList.add(objectRefImpl.getId());
            cls = objectRefImpl.getType();
        }
        LOGGER.debug("Getting Objects {}.", arrayList);
        return this.objectStore.get(arrayList, cls, j);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public void free(List<ObjectRef<?>> list, boolean z) {
        List<ObjectId> list2 = (List) list.stream().map(objectRef -> {
            return ((ObjectRefImpl) objectRef).getId();
        }).collect(Collectors.toList());
        LOGGER.debug("Freeing Objects {}, localOnly = {}.", list2, Boolean.valueOf(z));
        this.objectStore.delete(list2, z);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> WaitResult<T> wait(List<ObjectRef<T>> list, int i, int i2, boolean z) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Waiting Objects {} with minimum number {} within {} ms.", list, Integer.valueOf(i), Integer.valueOf(i2));
        }
        return this.objectStore.wait(list, i, i2, z);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ObjectRef call(RayFunc rayFunc, Object[] objArr, CallOptions callOptions) {
        RayFunction function = this.functionManager.getFunction(rayFunc);
        return callNormalFunction(function.functionDescriptor, objArr, function.getReturnType(), callOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ObjectRef call(PyFunction pyFunction, Object[] objArr, CallOptions callOptions) {
        return callNormalFunction(new PyFunctionDescriptor(pyFunction.moduleName, "", pyFunction.functionName), objArr, Optional.of(pyFunction.returnType), callOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ObjectRef call(CppFunction cppFunction, Object[] objArr, CallOptions callOptions) {
        return callNormalFunction(new CppFunctionDescriptor(cppFunction.functionName, "JAVA", ""), objArr, Optional.of(cppFunction.returnType), callOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ObjectRef callActor(ActorHandle<?> actorHandle, RayFunc rayFunc, Object[] objArr, CallOptions callOptions) {
        RayFunction function = this.functionManager.getFunction(rayFunc);
        return callActorFunction(actorHandle, function.functionDescriptor, objArr, function.getReturnType(), callOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ObjectRef callActor(PyActorHandle pyActorHandle, PyActorMethod pyActorMethod, Object... objArr) {
        return callActorFunction(pyActorHandle, new PyFunctionDescriptor(pyActorHandle.getModuleName(), pyActorHandle.getClassName(), pyActorMethod.methodName), objArr, Optional.of(pyActorMethod.returnType), new CallOptions.Builder().build());
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ObjectRef callActor(CppActorHandle cppActorHandle, CppActorMethod cppActorMethod, Object[] objArr) {
        return callActorFunction(cppActorHandle, new CppFunctionDescriptor(cppActorMethod.methodName, "JAVA", cppActorHandle.getClassName()), objArr, Optional.of(cppActorMethod.returnType), new CallOptions.Builder().build());
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T> ActorHandle<T> createActor(RayFunc rayFunc, Object[] objArr, ActorCreationOptions actorCreationOptions) {
        return (ActorHandle) createActorImpl(this.functionManager.getFunction(rayFunc).functionDescriptor, objArr, actorCreationOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public PyActorHandle createActor(PyActorClass pyActorClass, Object[] objArr, ActorCreationOptions actorCreationOptions) {
        return (PyActorHandle) createActorImpl(new PyFunctionDescriptor(pyActorClass.moduleName, pyActorClass.className, PYTHON_INIT_METHOD_NAME), objArr, actorCreationOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public CppActorHandle createActor(CppActorClass cppActorClass, Object[] objArr, ActorCreationOptions actorCreationOptions) {
        return (CppActorHandle) createActorImpl(new CppFunctionDescriptor(cppActorClass.createFunctionName, "JAVA", cppActorClass.className), objArr, actorCreationOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public PlacementGroup createPlacementGroup(PlacementGroupCreationOptions placementGroupCreationOptions) {
        Preconditions.checkNotNull(placementGroupCreationOptions, "`PlacementGroupCreationOptions` must be specified when creating a new placement group.");
        return this.taskSubmitter.createPlacementGroup(placementGroupCreationOptions);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public void removePlacementGroup(PlacementGroupId placementGroupId) {
        this.taskSubmitter.removePlacementGroup(placementGroupId);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public PlacementGroup getPlacementGroup(PlacementGroupId placementGroupId) {
        return getGcsClient().getPlacementGroupInfo(placementGroupId);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public PlacementGroup getPlacementGroup(String str, String str2) {
        return str2 == null ? getGcsClient().getPlacementGroupInfo(str, this.runtimeContext.getNamespace()) : getGcsClient().getPlacementGroupInfo(str, str2);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public List<PlacementGroup> getAllPlacementGroups() {
        return getGcsClient().getAllPlacementGroupInfo();
    }

    @Override // io.ray.api.runtime.RayRuntime
    public boolean waitPlacementGroupReady(PlacementGroupId placementGroupId, int i) {
        return this.taskSubmitter.waitPlacementGroupReady(placementGroupId, i);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public <T extends BaseActorHandle> T getActorHandle(ActorId actorId) {
        return (T) this.taskSubmitter.getActor(actorId);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ConcurrencyGroup createConcurrencyGroup(String str, int i, List<RayFunc> list) {
        return new ConcurrencyGroupImpl(str, i, list);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public List<ConcurrencyGroup> extractConcurrencyGroups(RayFuncR<?> rayFuncR) {
        return ConcurrencyGroupUtils.extractConcurrencyGroupsByAnnotations(rayFuncR);
    }

    @Override // io.ray.api.runtime.RayRuntime
    public ParallelActorContext getParallelActorContext() {
        return parallelActorContextImpl;
    }

    @Override // io.ray.api.runtime.RayRuntime
    public RuntimeEnv createRuntimeEnv() {
        return new RuntimeEnvImpl();
    }

    @Override // io.ray.api.runtime.RayRuntime
    public RuntimeEnv deserializeRuntimeEnv(String str) throws RuntimeEnvException {
        RuntimeEnvImpl runtimeEnvImpl = new RuntimeEnvImpl();
        try {
            runtimeEnvImpl.runtimeEnvs = (ObjectNode) MAPPER.readTree(str);
            return runtimeEnvImpl;
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    private ObjectRef callNormalFunction(FunctionDescriptor functionDescriptor, Object[] objArr, Optional<Class<?>> optional, CallOptions callOptions) {
        int i = optional.isPresent() ? 1 : 0;
        List<FunctionArg> wrap = ArgumentsBuilder.wrap(objArr, functionDescriptor.getLanguage());
        if (callOptions == null) {
            callOptions = new CallOptions.Builder().build();
        }
        ObjectRefImpl objectRefImpl = new ObjectRefImpl();
        List<ObjectId> currentReturnIds = getCurrentReturnIds(i, ActorId.NIL);
        if (this.rayConfig.runMode == RunMode.CLUSTER && i > 0) {
            ObjectRefImpl.registerObjectRefImpl(currentReturnIds.get(0), objectRefImpl);
        }
        List<ObjectId> submitTask = this.taskSubmitter.submitTask(functionDescriptor, wrap, i, callOptions);
        Preconditions.checkState(submitTask.size() == i);
        validatePreparedReturnIds(currentReturnIds, submitTask);
        if (submitTask.isEmpty()) {
            return null;
        }
        objectRefImpl.init(submitTask.get(0), optional.get(), true);
        return objectRefImpl;
    }

    private ObjectRef callActorFunction(BaseActorHandle baseActorHandle, FunctionDescriptor functionDescriptor, Object[] objArr, Optional<Class<?>> optional, CallOptions callOptions) {
        int i = optional.isPresent() ? 1 : 0;
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Submitting Actor Task {}.", functionDescriptor);
        }
        List<FunctionArg> wrap = ArgumentsBuilder.wrap(objArr, functionDescriptor.getLanguage());
        ObjectRefImpl objectRefImpl = new ObjectRefImpl();
        List<ObjectId> currentReturnIds = getCurrentReturnIds(i, baseActorHandle.getId());
        if (this.rayConfig.runMode == RunMode.CLUSTER && i > 0) {
            ObjectRefImpl.registerObjectRefImpl(currentReturnIds.get(0), objectRefImpl);
        }
        List<ObjectId> submitActorTask = this.taskSubmitter.submitActorTask(baseActorHandle, functionDescriptor, wrap, i, callOptions);
        Preconditions.checkState(submitActorTask.size() == i);
        if (submitActorTask.isEmpty()) {
            return null;
        }
        validatePreparedReturnIds(currentReturnIds, submitActorTask);
        objectRefImpl.init(submitActorTask.get(0), optional.get(), true);
        return objectRefImpl;
    }

    private BaseActorHandle createActorImpl(FunctionDescriptor functionDescriptor, Object[] objArr, ActorCreationOptions actorCreationOptions) {
        if (LOGGER.isDebugEnabled()) {
            if (actorCreationOptions == null) {
                LOGGER.debug("Creating Actor {} with default options.", functionDescriptor);
            } else {
                LOGGER.debug("Creating Actor {}, jvmOptions = {}.", functionDescriptor, actorCreationOptions.jvmOptions);
            }
        }
        if (this.rayConfig.runMode == RunMode.LOCAL && functionDescriptor.getLanguage() != Common.Language.JAVA) {
            throw new IllegalArgumentException("Ray doesn't support cross-language invocation in local mode.");
        }
        List<FunctionArg> wrap = ArgumentsBuilder.wrap(objArr, functionDescriptor.getLanguage());
        if (functionDescriptor.getLanguage() != Common.Language.JAVA && actorCreationOptions != null) {
            Preconditions.checkState(actorCreationOptions.jvmOptions == null || actorCreationOptions.jvmOptions.size() == 0);
        }
        return this.taskSubmitter.createActor(functionDescriptor, wrap, actorCreationOptions);
    }

    abstract List<ObjectId> getCurrentReturnIds(int i, ActorId actorId);

    public WorkerContext getWorkerContext() {
        return this.workerContext;
    }

    public ObjectStore getObjectStore() {
        return this.objectStore;
    }

    public TaskExecutor getTaskExecutor() {
        return this.taskExecutor;
    }

    public FunctionManager getFunctionManager() {
        return this.functionManager;
    }

    public RayConfig getRayConfig() {
        return this.rayConfig;
    }

    @Override // io.ray.api.runtime.RayRuntime
    public RuntimeContext getRuntimeContext() {
        return this.runtimeContext;
    }

    void validatePreparedReturnIds(List<ObjectId> list, List<ObjectId> list2) {
        if (this.rayConfig.runMode == RunMode.CLUSTER) {
            Preconditions.checkState(list2.size() == list.size());
            for (int i = 0; i < list.size(); i++) {
                ObjectId objectId = list.get(i);
                ObjectId objectId2 = list2.get(i);
                Preconditions.checkState(objectId.equals(objectId2), "The prepared object id {} is not equal to the real return id {}", objectId, objectId2);
            }
        }
    }
}
