import {
  Edge,
  EdgeChange,
  Node,
  NodeChange,
  OnNodesChange,
  OnEdgesChange,
  XYPosition,
  applyNodeChanges,
  applyEdgeChanges,
} from 'reactflow';
import { create } from 'zustand';
import { nanoid } from 'nanoid';

import { layout } from './services/layout';

type LineStyling = {
  stroke: string;
  strokeWidth: number;
};

export type RFState = {
  nodes: Node[];
  edges: Edge[];
  connectionLineStyle: LineStyling;
  thinking: boolean;
  onNodesChange: OnNodesChange;
  onEdgesChange: OnEdgesChange;
  addChildNode: (parentNode: Node, position: XYPosition) => void;
  updateNodeLabel: (nodeId: string, label: string) => void;
  setNodesAndEdges: (nodes: Node[], edges: Edge[]) => void;
  deleteNode: (nodeId: string) => void;
  setNodes: (nodes: Node[]) => void;
  updateConnectionColor: (color: string) => void;
  updateEdgeColor: (edgeId: string, color: string) => void;
  setThinking: (thinking: boolean) => void;
};

function layoutNodes(nodes: Node[], existingNodes: Node[]): Node[] {
  const nodeLookup: { [key: string]: Node } = {};

  existingNodes.forEach((node) => (nodeLookup[node.id] = node));

  // Merge attrs from existingNodes to nodes
  nodes.forEach((node) => {
    const existingNode = nodeLookup[node.id];

    if (existingNode) {
      if (node.position.x == Infinity && node.position.y == Infinity) {
        node.position = existingNode.position;
      }
      node.width = existingNode.width;
      node.height = existingNode.height;
    }
  });

  const nodesWithPosition = layout(nodes);
  return nodesWithPosition.map((node) => ({
    ...node,
    type: node.parentNode ? 'branch' : 'root',
  }));
}

function mergeEdges(edges: Edge[], existingEdges: Edge[]): Edge[] {
  const edgeLookup: { [key: string]: Edge } = {};

  existingEdges.forEach((edge) => (edgeLookup[edge.id] = edge));

  // Merge attrs from existingEdges to edges
  return edges.map((e) => {
    const edge = edgeLookup[e.id] || e;

    if (!edge.style?.stroke) {
      const parentEdge = existingEdges.find((e) => e.target === edge.source);
      const edgeColor = parentEdge?.style?.stroke;

      edge.style = { stroke: edgeColor, strokeWidth: 3 };
    }
    return edge;
  });
}

const useStore = create<RFState>((set, get) => ({
  nodes: [
    {
      id: 'root',
      type: 'root',
      data: { label: 'My main topic' },
      position: { x: 0, y: 0 },
    },
  ],
  edges: [],
  thinking: false,
  connectionLineStyle: {
    stroke: '#c6c6c6',
    strokeWidth: 3,
  },
  updateConnectionColor: (color: string) => {
    const { connectionLineStyle } = get();
    const newConnectionLineStyle = { ...connectionLineStyle, stroke: color };
    set({ connectionLineStyle: newConnectionLineStyle });
  },
  setNodesAndEdges: (nodes: Node[], edges: Edge[]) => {
    // FIXME: Ensure that the nodes are topologically sorted
    set({
      nodes: layoutNodes(nodes, get().nodes),
      edges: mergeEdges(edges, get().edges),
    });
  },
  setNodes: (nodes: Node[]) => {
    // FIXME: Ensure that the nodes are topologically sorted
    set({
      nodes: layoutNodes(nodes, get().nodes),
    });
  },
  onNodesChange: (changes: NodeChange[]) => {
    set({
      nodes: applyNodeChanges(changes, get().nodes),
    });
  },
  onEdgesChange: (changes: EdgeChange[]) => {
    set({
      edges: applyEdgeChanges(changes, get().edges),
    });
  },
  addChildNode: (parentNode: Node, position: XYPosition) => {
    const { edges, connectionLineStyle } = get();

    // Find the edge that connects the parent node to its children, that is not the root edge
    const parentEdge = edges.find((edge) => edge.target === parentNode.id);
    // Get the edge color from the parentEdge or use the default color
    const edgeColor = parentEdge?.style?.stroke || connectionLineStyle.stroke;

    const newNode = {
      id: nanoid(),
      type: 'branch',
      data: { label: 'Subtopic' },
      position,
      parentNode: parentNode.id,
    };

    const newEdge = {
      id: nanoid(),
      source: parentNode.id,
      target: newNode.id,
      style: { stroke: edgeColor, strokeWidth: 3 },
    };

    const existingNodes = get().nodes;

    set({
      nodes: layoutNodes([...existingNodes, newNode], existingNodes),
      edges: [...get().edges, newEdge],
    });

    console.log(get().nodes);
  },
  updateNodeLabel: (nodeId: string, label: string) => {
    set({
      nodes: get().nodes.map((node) => {
        if (node.id === nodeId) {
          // it's important to create a new object here, to inform React Flow about the changes
          node.data = { ...node.data, label };
        }

        return node;
      }),
    });
  },
  deleteNode: (nodeId: string) => {
    const { nodes, edges } = get();
    const initial: { remainingNodes: Node[]; removedNodeIds: Set<string | undefined> } = {
      remainingNodes: [],
      removedNodeIds: new Set(),
    };

    const { remainingNodes, removedNodeIds } = nodes.reduce((acc, node) => {
      if (node.id === nodeId || acc.removedNodeIds.has(node.parentNode)) {
        acc.removedNodeIds.add(node.id);
      } else {
        acc.remainingNodes.push(node);
      }

      return acc;
    }, initial);

    const remainingEdges = edges.filter(
      (edge) => !(removedNodeIds.has(edge.source) || removedNodeIds.has(edge.target)),
    );

    set({
      nodes: remainingNodes,
      edges: remainingEdges,
    });
  },
  updateEdgeColor: (edgeId: string, color: string) => {
    const defaultConnectionLineStyle = get().connectionLineStyle;
    set({
      edges: get().edges.map((edge) => {
        if (edge.id === edgeId) {
          // it's important to create a new object here, to inform React Flow about the changes
          edge.style = { ...defaultConnectionLineStyle, ...edge.style, stroke: color };
        }

        return edge;
      }),
    });
  },
  setThinking: (thinking: boolean) => {
    set({
      thinking,
    });
  },
}));

export default useStore;
