/*
 * Copyright 2018 - 2024 TridentMC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tridevmc.compound.network.core;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tridevmc.compound.network.marshallers.*;
import com.tridevmc.compound.network.message.Message;
import com.tridevmc.compound.network.message.MessageConcept;
import com.tridevmc.compound.network.message.MessageField;
import com.tridevmc.compound.network.message.RegisteredMessage;
import net.minecraft.resources.ResourceLocation;
import net.neoforged.bus.api.SubscribeEvent;
import net.neoforged.fml.LogicalSide;
import net.neoforged.fml.ModContainer;
import net.neoforged.fml.ModList;
import net.neoforged.fml.loading.modscan.ModAnnotation;
import net.neoforged.neoforge.network.event.RegisterPayloadHandlersEvent;
import net.neoforged.neoforge.network.registration.PayloadRegistrar;
import net.neoforged.neoforgespi.language.ModFileScanData;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.maven.artifact.versioning.ArtifactVersion;
import org.objectweb.asm.Type;

import java.lang.reflect.Field;
import java.util.*;
import java.util.stream.Collectors;


/**
 * CompoundNetwork is used for the creation and management of networks.
 * <p>
 * Use createNetwork to create and register a network for a given channel.
 * <p>
 * Use @RegisteredMessage and @RegisteredMarshaller for registration of implementations.
 */
public class CompoundNetwork {

    private static final Map<Class<? extends Message>, CompoundNetwork> NETWORKS = Maps.newHashMap();

    private final Logger logger;
    private final ResourceLocation networkId;
    private final String name;

    private Map<Class<? extends Message>, MessageConcept> messageConcepts;
    private Map<String, Marshaller> marshallers;
    private Map<Class, String> marshallerIds;


    private CompoundNetwork(ResourceLocation name, String version) {
        this.networkId = name;
        this.name = name.getPath();
        this.messageConcepts = Maps.newHashMap();
        this.marshallers = Maps.newHashMap();
        this.marshallerIds = Maps.newHashMap();
        this.logger = LogManager.getLogger("CompoundNetwork-" + name);
    }

    /**
     * Create a network with the given name with messages and marshallers loaded from the given data
     * table.
     *
     * @param container the mod container of the mod this network belongs to.
     * @param channel   the name to use for the network.
     * @return the created network instance.
     */
    public static CompoundNetwork createNetwork(ModContainer container, String channel) {
        try {
            ArtifactVersion version = container.getModInfo().getVersion();
            CompoundNetwork network = new CompoundNetwork(ResourceLocation.fromNamespaceAndPath(container.getModId(), channel), version.toString());
            container.getEventBus().register(network);
            network.loadDefaultMarshallers();
            network.discoverMarshallers();
            network.discoverMessages();
            return network;
        } catch (Exception e) {
            throw new RuntimeException(String.format(
                    "Failed to create a CompoundNetwork with channel name %s",
                    channel),
                    e);
        }
    }

    /**
     * Finds the network that the given message class is registered to.
     *
     * @param msg the class of the registered message.
     * @return the network that the message class is registered to.
     */
    public static CompoundNetwork getNetworkFor(Class<? extends Message> msg) {
        return NETWORKS.getOrDefault(msg, null);
    }

    private void loadDefaultMarshallers() {
        List<MarshallerMetadata> marshallerMetadata = DefaultMarshallers.genDefaultMarshallers();

        for (MarshallerMetadata marshallerMeta : marshallerMetadata) {
            String defaultId = marshallerMeta.ids[0];
            for (String id : marshallerMeta.ids) {
                this.marshallers.put(id, marshallerMeta.marshaller);
            }
            for (Class type : marshallerMeta.acceptedTypes) {
                this.marshallerIds.put(type, defaultId);
            }
        }
    }

    private void discoverMarshallers() {
        List<ModFileScanData.AnnotationData> applicableMarshallers = this.getAnnotationDataOfType(RegisteredMarshaller.class);

        applicableMarshallers.sort(Comparator.comparingInt(
                o -> {
                    ModAnnotation.EnumHolder enumHolder = (ModAnnotation.EnumHolder) o.annotationData().getOrDefault("priority", null);
                    EnumMarshallerPriority priority = enumHolder == null ? EnumMarshallerPriority.NORMAL : EnumMarshallerPriority.valueOf(enumHolder.value());
                    return priority.getRank();
                }));

        for (ModFileScanData.AnnotationData applicableMarshaller : applicableMarshallers) {
            Marshaller marshaller = null;

            try {
                marshaller = (Marshaller) Class.forName(applicableMarshaller.memberName())
                        .newInstance();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(String.format(
                        "Unable to find class: \"%s\" for registered marshaller.",
                        applicableMarshaller.memberName()),
                        e);
            } catch (ClassCastException e) {
                throw new RuntimeException(String.format(
                        "Class: \"%s\" annotated with RegisteredMarshaller does not extend Marshaller.",
                        applicableMarshaller.memberName()),
                        e);
            } catch (IllegalAccessException e) {
                throw new RuntimeException(String.format(
                        "Failed to instantiate %s, is there a public empty constructor?",
                        applicableMarshaller.memberName()),
                        e);
            } catch (InstantiationException e) {
                throw new RuntimeException(String.format(
                        "Failed to instantiate %s",
                        applicableMarshaller.memberName()),
                        e);
            }

            Map<String, Object> annotationInfo = applicableMarshaller.annotationData();
            ArrayList<String> ids = (ArrayList<String>) annotationInfo.get("ids");
            ArrayList<Type> acceptedTypes = (ArrayList<Type>) annotationInfo.get("acceptedTypes");

            for (String id : ids) {
                this.marshallers.put(id, marshaller);
            }

            for (Type acceptedType : acceptedTypes) {
                try {
                    this.marshallerIds.put(Class.forName(acceptedType.getClassName()), ids.get(0));
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(String.format(
                            "Failed to find class to marshall with name %s",
                            acceptedType.getClassName()),
                            e);
                }
            }
        }
    }

    private void discoverMessages() {
        List<ModFileScanData.AnnotationData> applicableMessages = this.getAnnotationDataOfType(RegisteredMessage.class);

        for (ModFileScanData.AnnotationData registeredMessage : applicableMessages) {
            Map<String, Object> annotationInfo = registeredMessage.annotationData();

            String networkChannel = (String) annotationInfo.get("channel");

            if (Objects.equals(networkChannel, this.name)) {
                // Found a message that can be registered for this network instance.
                ModAnnotation.EnumHolder destinationHolder = (ModAnnotation.EnumHolder) annotationInfo.get("destination");
                LogicalSide destination = LogicalSide.valueOf(destinationHolder.value());
                Class<? extends Message> msgClass;
                try {
                    msgClass = (Class<? extends Message>) Class
                            .forName(registeredMessage.memberName());
                    msgClass.getConstructor();
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(String.format(
                            "Unable to find class: %s for registered message.",
                            registeredMessage.memberName()),
                            e);
                } catch (ClassCastException e) {
                    throw new RuntimeException(String.format(
                            "Class \"%s\" annotated with RegisteredMessage does not extend Message.",
                            registeredMessage.memberName()),
                            e);
                } catch (NoSuchMethodException e) {
                    throw new RuntimeException(String.format(
                            "Class \"%s\" does not have an empty constructor available, this is required for networking.",
                            registeredMessage.memberName()),
                            e);
                }

                this.createConcept(msgClass, destination);
                this.registerMessage(msgClass);
            }
        }
    }

    private List<ModFileScanData.AnnotationData> getAnnotationDataOfType(Class annotation) {
        List<ModFileScanData> modScanData = ModList.get().getAllScanData();
        ArrayList<ModFileScanData.AnnotationData> out = Lists.newArrayList();
        String annotationName = annotation.getName();

        modScanData.forEach((m) -> m.getAnnotations().stream().filter(a -> Objects.equals(a.annotationType().getClassName(), annotationName)).forEach(a -> {
            Map<String, Object> annotationInfo = a.annotationData();
            String channel = (String) annotationInfo.get("channel");
            if (Objects.equals(channel, this.name)) {
                out.add(a);
            }
        }));

        return out;
    }

    private void createConcept(Class<? extends Message> msgClass, LogicalSide destination) {
        List<Field> usableFields = FieldUtils.getAllFieldsList(msgClass).stream().filter(field -> {
            Class fieldDeclarer = field.getDeclaringClass();
            return !fieldDeclarer.equals(Message.class) && !fieldDeclarer.equals(Object.class);
        }).collect(Collectors.toList());

        List<MessageField> messageFields = usableFields.stream().map(field -> {
            String marshallerId = this.getMarshallerIdFor(field);
            return this.marshallers.get(marshallerId).getMessageField(field);
        }).collect(Collectors.toList());

        MessageConcept msgConcept = new MessageConcept(this, msgClass, new ArrayList<>(messageFields), destination);
        this.messageConcepts.put(msgClass, msgConcept);
    }

    private String getMarshallerIdFor(Field field) {
        if (field.isAnnotationPresent(SetMarshaller.class)) {
            return field.getAnnotation(SetMarshaller.class).value();
        }
        Class fieldClass = field.getType();
        String marshallerId = this.marshallerIds.getOrDefault(fieldClass, null);
        if (marshallerId == null) {
            Optional<Class> matchingClass = this.marshallerIds.keySet().stream()
                    .filter(c -> c.isAssignableFrom(fieldClass)).findFirst();

            if (matchingClass.isPresent()) {
                marshallerId = this.marshallerIds.get(matchingClass.get());
            } else {
                throw new RuntimeException(
                        "Unable to find marshaller id for " + fieldClass.getName());
            }
        }

        return marshallerId;
    }

    private <M extends Message> void registerMessage(Class<M> msgClass) {
        NETWORKS.put(msgClass, this);
    }

    public Logger getLogger() {
        return this.logger;
    }

    public MessageConcept getMsgConcept(Message msg) {
        return this.messageConcepts.get(msg.getClass());
    }

    public ResourceLocation getNetworkId() {
        return networkId;
    }

    public MessageConcept getMsgConcept(Class<? extends Message> msgClass) {
        return this.messageConcepts.get(msgClass);
    }

    private void registerMessageConcept(PayloadRegistrar registrar, Class<? extends Message> messageClass, MessageConcept messageConcept) {
        if (!messageConcept.getMessageSide().isClient()) {
            registrar.commonToClient(messageConcept.getMessageType(),
                    messageConcept.getPayloadCodec(),
                    messageConcept.getPayloadHandler());
        } else {
            registrar.commonToServer(messageConcept.getMessageType(),
                    messageConcept.getPayloadCodec(),
                    messageConcept.getPayloadHandler());
        }
    }

    @SubscribeEvent
    private void onRegisterPayloadHandlerEvent(final RegisterPayloadHandlersEvent e) {
        var registrar = e.registrar(this.name);

        this.messageConcepts.forEach((msgClass, msgConcept) -> {
            this.registerMessageConcept(registrar, msgClass, msgConcept);
        });
    }

}
