/*
 * Decompiled with CFR 0.152.
 */
package org.sinytra.adapter.patch.transformer.dynamic;

import com.google.common.collect.ArrayListMultimap;
import com.mojang.datafixers.util.Pair;
import com.mojang.logging.LogUtils;
import it.unimi.dsi.fastutil.ints.Int2IntLinkedOpenHashMap;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.VarInsnNode;
import org.sinytra.adapter.patch.LVTOffsets;
import org.sinytra.adapter.patch.PatchInstance;
import org.sinytra.adapter.patch.analysis.LocalVariableLookup;
import org.sinytra.adapter.patch.analysis.params.EnhancedParamsDiff;
import org.sinytra.adapter.patch.analysis.params.ParamsDiffSnapshot;
import org.sinytra.adapter.patch.analysis.params.SimpleParamsDiffSnapshot;
import org.sinytra.adapter.patch.api.MethodContext;
import org.sinytra.adapter.patch.api.MethodTransform;
import org.sinytra.adapter.patch.api.Patch;
import org.sinytra.adapter.patch.api.PatchContext;
import org.sinytra.adapter.patch.selector.AnnotationHandle;
import org.sinytra.adapter.patch.selector.AnnotationValueHandle;
import org.sinytra.adapter.patch.transformer.param.ParamTransformTarget;
import org.sinytra.adapter.patch.util.AdapterUtil;
import org.slf4j.Logger;

public record DynamicLVTPatch(Supplier<LVTOffsets> lvtOffsets) implements MethodTransform
{
    private static final Logger LOGGER = LogUtils.getLogger();
    private static final Set<String> ANNOTATIONS = Set.of("Lorg/spongepowered/asm/mixin/injection/Inject;", "Lcom/llamalad7/mixinextras/injector/ModifyExpressionValue;", "Lorg/spongepowered/asm/mixin/injection/ModifyVariable;");

    @Override
    public Patch.Result apply(ClassNode classNode, MethodNode methodNode, MethodContext methodContext, PatchContext context) {
        ParamsDiffSnapshot diff;
        Type[] paramTypes;
        List<Pair> localAnnotations;
        AnnotationHandle annotation = methodContext.methodAnnotation();
        if (methodNode.invisibleParameterAnnotations != null && !(localAnnotations = AdapterUtil.getAnnotatedParameters(methodNode, paramTypes = Type.getArgumentTypes((String)methodNode.desc), "Lcom/llamalad7/mixinextras/sugar/Local;", Pair::of)).isEmpty()) {
            Patch.Result result = Patch.Result.PASS;
            for (Pair pair : localAnnotations) {
                result = result.or(this.offsetParameterIndex(classNode, methodNode, new AnnotationHandle((AnnotationNode)pair.getFirst()), (Type)pair.getSecond(), methodContext));
            }
            return result;
        }
        if (!ANNOTATIONS.contains(annotation.getDesc())) {
            return Patch.Result.PASS;
        }
        if (annotation.matchesDesc("Lorg/spongepowered/asm/mixin/injection/ModifyVariable;")) {
            AnnotationValueHandle ordinal;
            Patch.Result result = this.offsetVariableIndex(classNode, methodNode, annotation, methodContext);
            if (result == Patch.Result.PASS && (ordinal = (AnnotationValueHandle)annotation.getValue("ordinal").orElse(null)) == null && annotation.getValue("name").isEmpty()) {
                Type[] args = Type.getArgumentTypes((String)methodNode.desc);
                if (args.length < 1) {
                    return Patch.Result.PASS;
                }
                MethodContext.TargetPair targetPair = methodContext.findDirtyInjectionTarget();
                if (targetPair == null) {
                    return Patch.Result.PASS;
                }
                for (Integer level : methodContext.getLvtCompatLevelsOrdered()) {
                    List<MethodContext.LocalVariable> available = methodContext.getTargetMethodLocals(targetPair, 0, level);
                    if (available == null) {
                        return Patch.Result.PASS;
                    }
                    Type expected = args[0];
                    int count = (int)available.stream().filter(lv -> lv.type().equals((Object)expected)).count();
                    if (count != 1) continue;
                    annotation.appendValue("ordinal", 0);
                    return Patch.Result.APPLY;
                }
            }
            return result;
        }
        if (annotation.matchesDesc("Lorg/spongepowered/asm/mixin/injection/Inject;") && annotation.getValue("locals").isPresent() && (diff = this.compareParameters(classNode, methodNode, methodContext)) != null) {
            MethodTransform transform = diff.asParameterTransformer(ParamTransformTarget.METHOD, true);
            return transform.apply(classNode, methodNode, methodContext, methodContext.patchContext());
        }
        return Patch.Result.PASS;
    }

    private Patch.Result offsetParameterIndex(ClassNode classNode, MethodNode methodNode, AnnotationHandle annotation, Type paramType, MethodContext methodContext) {
        Patch.Result result = this.offsetVariableIndex(classNode, methodNode, annotation, methodContext);
        if (result == Patch.Result.PASS && annotation.getAllValues().isEmpty()) {
            List<MethodContext.LocalVariable> oldCompatLocals;
            List<MethodContext.LocalVariable> sameType;
            int compatLevel = methodContext.patchContext().environment().fabricLVTCompatibility();
            if (compatLevel != 10000) {
                return Patch.Result.PASS;
            }
            MethodContext.TargetPair targetPair = methodContext.findDirtyInjectionTarget();
            if (targetPair == null) {
                return Patch.Result.PASS;
            }
            List<MethodContext.LocalVariable> locals = methodContext.getTargetMethodLocals(targetPair, 0, compatLevel);
            if (locals != null && locals.stream().filter(var -> var.type() == paramType).count() > 1L && (sameType = (oldCompatLocals = methodContext.getTargetMethodLocals(targetPair, 0, 9002)).stream().filter(var -> var.type() == paramType).toList()).size() == 1) {
                int index = sameType.get(0).index();
                annotation.appendValue("index", index);
                LOGGER.info(PatchInstance.MIXINPATCH, "Fixing @Local annotation target on {}.{} using index {}", new Object[]{classNode.name, methodNode.name, index});
                return Patch.Result.APPLY;
            }
        }
        return result;
    }

    private Patch.Result offsetVariableIndex(ClassNode classNode, MethodNode methodNode, AnnotationHandle annotation, MethodContext methodContext) {
        AnnotationValueHandle handle = annotation.getValue("index").orElse(null);
        if (handle != null) {
            int index = (Integer)handle.get();
            if (index == -1) {
                return Patch.Result.PASS;
            }
            MethodContext.TargetPair targetPair = methodContext.findDirtyInjectionTarget();
            if (targetPair == null) {
                return Patch.Result.PASS;
            }
            ClassNode targetClass = targetPair.classNode();
            MethodNode targetMethod = targetPair.methodNode();
            OptionalInt reorder = this.lvtOffsets.get().findReorder(targetClass.name, targetMethod.name, targetMethod.desc, index);
            if (reorder.isPresent()) {
                int newIndex = reorder.getAsInt();
                LOGGER.info(PatchInstance.MIXINPATCH, "Swapping {} index in {}.{} from {} for {}", new Object[]{annotation.getDesc(), classNode.name, methodNode.name, index, newIndex});
                handle.set(newIndex);
                return Patch.Result.APPLY;
            }
        }
        return Patch.Result.PASS;
    }

    @Nullable
    private ParamsDiffSnapshot compareParameters(ClassNode classNode, MethodNode methodNode, MethodContext methodContext) {
        int maxInsert;
        ParamsDiffSnapshot offsetDiff;
        AdapterUtil.CapturedLocals capturedLocals = AdapterUtil.getCapturedLocals(methodNode, methodContext);
        if (capturedLocals == null) {
            return null;
        }
        List<MethodContext.LocalVariable> available = methodContext.getTargetMethodLocals(capturedLocals.target());
        if (available == null) {
            return null;
        }
        List<Type> availableTypes = available.stream().map(MethodContext.LocalVariable::type).toList();
        Record diff = EnhancedParamsDiff.createLayered(capturedLocals.expected(), availableTypes);
        if (diff.isEmpty()) {
            return null;
        }
        if (!diff.replacements().isEmpty() && DynamicLVTPatch.areReplacedParamsUsed(diff.replacements(), methodNode)) {
            SimpleParamsDiffSnapshot rearrange = DynamicLVTPatch.rearrangeParameters(capturedLocals.expected(), availableTypes);
            if (rearrange == null) {
                LOGGER.debug("Tried to replace local variables in mixin method {}.{} using {}", new Object[]{classNode.name, methodNode.name + methodNode.desc, diff.replacements()});
                return null;
            }
            diff = rearrange;
        }
        int paramLocalStart = capturedLocals.paramLocalStart();
        if (!diff.removals().isEmpty()) {
            List<LocalVariableNode> lvt = methodNode.localVariables.stream().sorted(Comparator.comparingInt(lvn -> lvn.index)).toList();
            for (int removal : diff.removals()) {
                int removalLocal = removal + capturedLocals.lvtOffset() + paramLocalStart;
                if (removalLocal >= lvt.size()) continue;
                int removalIndex = lvt.get((int)removalLocal).index;
                for (AbstractInsnNode insn : methodNode.instructions) {
                    if (!(insn instanceof VarInsnNode)) continue;
                    VarInsnNode varInsn = (VarInsnNode)insn;
                    if (varInsn.var != removalIndex) continue;
                    LOGGER.debug("Cannot remove parameter {} in mixin method {}.{}", new Object[]{removal, classNode.name, methodNode.name + methodNode.desc});
                    return null;
                }
            }
        }
        if ((offsetDiff = diff.offset(paramLocalStart, maxInsert = DynamicLVTPatch.getMaxLocalIndex(capturedLocals.expected(), diff.insertions()))).isEmpty()) {
            return null;
        }
        return offsetDiff;
    }

    private static boolean areReplacedParamsUsed(List<Pair<Integer, Type>> replacements, MethodNode methodNode) {
        LocalVariableLookup lookup = new LocalVariableLookup(methodNode);
        Set paramLocals = replacements.stream().map(p -> lookup.getByParameterOrdinal((int)((Integer)p.getFirst()).intValue()).index).collect(Collectors.toSet());
        for (AbstractInsnNode insn : methodNode.instructions) {
            if (!(insn instanceof VarInsnNode)) continue;
            VarInsnNode varInsn = (VarInsnNode)insn;
            if (!paramLocals.contains(varInsn.var)) continue;
            return true;
        }
        return false;
    }

    private static int getMaxLocalIndex(List<Type> expected, List<Pair<Integer, Type>> insertions) {
        int maxIndex = expected.size();
        for (Pair<Integer, Type> pair : insertions) {
            int at = (Integer)pair.getFirst();
            if (at >= maxIndex) continue;
            ++maxIndex;
        }
        return maxIndex;
    }

    @VisibleForTesting
    @Nullable
    public static SimpleParamsDiffSnapshot rearrangeParameters(List<Type> parameterTypes, List<Type> newParameterTypes) {
        Object2IntOpenHashMap typeCount = new Object2IntOpenHashMap();
        ArrayListMultimap typeIndices = ArrayListMultimap.create();
        for (int i = 0; i < parameterTypes.size(); ++i) {
            Type type = parameterTypes.get(i);
            typeCount.put((Object)type, typeCount.getInt((Object)type) + 1);
            typeIndices.put((Object)type, (Object)i);
        }
        Object2IntOpenHashMap newTypeCount = new Object2IntOpenHashMap();
        for (Type type : newParameterTypes) {
            newTypeCount.put((Object)type, newTypeCount.getInt((Object)type) + 1);
        }
        for (Object2IntMap.Entry entry : typeCount.object2IntEntrySet()) {
            if (newTypeCount.getInt(entry.getKey()) == entry.getIntValue()) continue;
            return null;
        }
        ArrayList<Pair> insertions = new ArrayList<Pair>();
        for (int i = 0; i < newParameterTypes.size(); ++i) {
            Type type = newParameterTypes.get(i);
            if (typeCount.containsKey((Object)type)) continue;
            insertions.add(Pair.of((Object)i, (Object)type));
        }
        Object2IntOpenHashMap seenTypes = new Object2IntOpenHashMap();
        Int2IntLinkedOpenHashMap swaps = new Int2IntLinkedOpenHashMap();
        for (int i = 0; i < newParameterTypes.size(); ++i) {
            Type type = newParameterTypes.get(i);
            if (!typeIndices.containsKey((Object)type)) continue;
            List indices = typeIndices.get((Object)type);
            int seen = seenTypes.getInt((Object)type);
            int oldIndex = (Integer)indices.get(seen);
            seenTypes.put((Object)type, seen + 1);
            if (oldIndex == i || swaps.containsKey(i)) continue;
            swaps.put(oldIndex, i);
        }
        if (swaps.isEmpty()) {
            return null;
        }
        ArrayList swapsList = new ArrayList();
        swaps.forEach((from, to) -> swapsList.add(Pair.of((Object)from, (Object)to)));
        return ((SimpleParamsDiffSnapshot.Builder)((SimpleParamsDiffSnapshot.Builder)SimpleParamsDiffSnapshot.builder().insertions(insertions)).swaps((List)swapsList)).build();
    }
}

