package io.ray.runtime.functionmanager;

import io.ray.api.function.RayFunc;
import io.ray.runtime.util.LambdaUtils;
import io.ray.shaded.com.google.common.collect.Lists;
import java.io.File;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Stream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/ray/runtime/functionmanager/FunctionManager.class */
public class FunctionManager {
    static final String CONSTRUCTOR_NAME = "<init>";
    private final JobFunctionTable jobFunctionTable = createJobFunctionTable();
    private final List<String> codeSearchPath;
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) FunctionManager.class);
    private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>> RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/ray/runtime/functionmanager/FunctionManager$JobFunctionTable.class */
    public static class JobFunctionTable {
        final ClassLoader classLoader;
        ConcurrentMap<String, Map<Pair<String, String>, Pair<RayFunction, Boolean>>> functions = new ConcurrentHashMap();

        JobFunctionTable(ClassLoader classLoader) {
            this.classLoader = classLoader;
        }

        RayFunction getFunction(JavaFunctionDescriptor javaFunctionDescriptor) {
            Map<Pair<String, String>, Pair<RayFunction, Boolean>> map = this.functions.get(javaFunctionDescriptor.className);
            if (map == null) {
                synchronized (this) {
                    map = this.functions.get(javaFunctionDescriptor.className);
                    if (map == null) {
                        map = loadFunctionsForClass(javaFunctionDescriptor.className);
                        this.functions.put(javaFunctionDescriptor.className, map);
                    }
                }
            }
            ImmutablePair of = ImmutablePair.of(javaFunctionDescriptor.name, javaFunctionDescriptor.signature);
            RayFunction left = map.get(of).getLeft();
            if (left != null) {
                return left;
            }
            if (map.containsKey(of)) {
                throw new RuntimeException(String.format("RayFunction %s is overloaded, the signature can't be empty.", javaFunctionDescriptor.toString()));
            }
            throw new RuntimeException(String.format("RayFunction %s not found", javaFunctionDescriptor.toString()));
        }

        Map<Pair<String, String>, Pair<RayFunction, Boolean>> loadFunctionsForClass(String str) {
            HashMap hashMap = new HashMap();
            try {
                Class<?> cls = Class.forName(str, true, this.classLoader);
                ArrayList arrayList = new ArrayList();
                arrayList.addAll(Arrays.asList(cls.getDeclaredMethods()));
                arrayList.addAll(Arrays.asList(cls.getDeclaredConstructors()));
                for (Class<? super Object> superclass = cls.getSuperclass(); superclass != null && superclass != Object.class; superclass = superclass.getSuperclass()) {
                    arrayList.addAll(Arrays.asList(superclass.getDeclaredMethods()));
                }
                for (Class<?> cls2 : cls.getInterfaces()) {
                    for (Method method : cls2.getDeclaredMethods()) {
                        if (method.isDefault()) {
                            arrayList.add(method);
                        }
                    }
                }
                for (Executable executable : Lists.reverse(arrayList)) {
                    executable.setAccessible(true);
                    String name = executable instanceof Method ? executable.getName() : "<init>";
                    Type type = executable instanceof Method ? Type.getType((Method) executable) : Type.getType((Constructor<?>) executable);
                    String descriptor = type.getDescriptor();
                    RayFunction rayFunction = new RayFunction(executable, this.classLoader, new JavaFunctionDescriptor(str, name, descriptor));
                    boolean z = (executable instanceof Method) && ((Method) executable).isDefault();
                    hashMap.put(ImmutablePair.of(name, descriptor), ImmutablePair.of(rayFunction, Boolean.valueOf(z)));
                    for (String str2 : new String[]{"", String.format("%s", Integer.valueOf(type.getArgumentTypes().length))}) {
                        ImmutablePair of = ImmutablePair.of(name, str2);
                        if (!hashMap.containsKey(of) || ((Boolean) ((Pair) hashMap.get(of)).getRight()).booleanValue()) {
                            hashMap.put(of, ImmutablePair.of(rayFunction, Boolean.valueOf(z)));
                        } else {
                            hashMap.put(of, ImmutablePair.of((Object) null, false));
                        }
                    }
                }
                return hashMap;
            } catch (Exception e) {
                throw new RuntimeException("Failed to load functions from class " + str, e);
            }
        }
    }

    public FunctionManager(List<String> list) {
        this.codeSearchPath = list;
    }

    public ClassLoader getClassLoader() {
        return this.jobFunctionTable.classLoader;
    }

    public RayFunction getFunction(RayFunc rayFunc) {
        JavaFunctionDescriptor javaFunctionDescriptor = RAY_FUNC_CACHE.get().get(rayFunc.getClass());
        if (javaFunctionDescriptor == null) {
            SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(rayFunc);
            javaFunctionDescriptor = new JavaFunctionDescriptor(serializedLambda.getImplClass().replace('/', '.'), serializedLambda.getImplMethodName(), serializedLambda.getImplMethodSignature());
            RAY_FUNC_CACHE.get().put(rayFunc.getClass(), javaFunctionDescriptor);
        }
        return getFunction(javaFunctionDescriptor);
    }

    public RayFunction getFunction(JavaFunctionDescriptor javaFunctionDescriptor) {
        return this.jobFunctionTable.getFunction(javaFunctionDescriptor);
    }

    private JobFunctionTable createJobFunctionTable() {
        ClassLoader classLoader;
        if (this.codeSearchPath == null || this.codeSearchPath.isEmpty()) {
            classLoader = getClass().getClassLoader();
        } else {
            URL[] urlArr = (URL[]) this.codeSearchPath.stream().filter(str -> {
                return StringUtils.isNotBlank(str) && Files.exists(Paths.get(str, new String[0]), new LinkOption[0]);
            }).flatMap(str2 -> {
                try {
                    if (!Files.isDirectory(Paths.get(str2, new String[0]), new LinkOption[0])) {
                        return !str2.endsWith(".jar") ? Stream.of(Paths.get(str2, new String[0]).getParent().toAbsolutePath().toUri().toURL()) : Stream.of(Paths.get(str2, new String[0]).toAbsolutePath().toUri().toURL());
                    }
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(Paths.get(str2, new String[0]).toAbsolutePath().toUri().toURL());
                    Iterator<File> it = FileUtils.listFiles(new File(str2), new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY).iterator();
                    while (it.hasNext()) {
                        arrayList.add(it.next().toPath().toUri().toURL());
                    }
                    return arrayList.stream();
                } catch (MalformedURLException e) {
                    throw new RuntimeException(String.format("Illegal %s resource path", str2));
                }
            }).toArray(i -> {
                return new URL[i];
            });
            classLoader = new URLClassLoader(urlArr);
            LOGGER.debug("Resource loaded from path {}.", (Object[]) urlArr);
        }
        return new JobFunctionTable(classLoader);
    }
}
