import { useCallback, useEffect, useMemo } from 'react';

// Flow diagram libraries
import ReactFlow, { useNodesState, useEdgesState } from 'reactflow';
import 'reactflow/dist/style.css';
import EntityNode from '../entity-node';
import Dagre from '@dagrejs/dagre';

// Redux Store
import { deepClone } from '@mui/x-data-grid/utils/utils';

// Utilities
import uniqBy from 'lodash/uniqBy';

// Types
import { IDatabase } from 'models/store/store-models';
import {
  ITableData,
  IEdge,
  INodeInfo,
  IHeightByRank,
  IHeightByIndex,
} from '../types';

interface IEntityDiagramProps {
  database: IDatabase;
}

function EntityDiagram({ database }: IEntityDiagramProps) {
  // React Flow state hooks for nodes and edges
  const [nodes, setNodes, onNodesChange] = useNodesState([]);
  const [edges, setEdges, onEdgesChange] = useEdgesState([]);

  // Memoized node types for React Flow
  const nodeTypes = useMemo(
    () => ({
      entity: EntityNode,
    }),
    [],
  );

  // Function to get layouted elements using dagre library
  const getLayoutedElements = useCallback(
    (nodes: INodeInfo[], edges: IEdge[], options: any) => {
      const headerOffset = 50;
      const g = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({}));

      // Set the direction and rank separation for the layout
      g.setGraph({ rankdir: options.direction, ranksep: 250 });

      // Add edges to the graph
      edges.forEach((edge: IEdge) => g.setEdge(edge.source, edge.target));
      // Add nodes to the graph
      nodes.forEach((node: INodeInfo) => g.setNode(node.id, node));

      // Perform layout using Dagre
      Dagre.layout(g);

      // Update node positions based on layout and sorted based on rank
      nodes = nodes
        .map((node: INodeInfo) => {
          const { x, y } = g.node(node.id);
          return { ...node, position: { x, y } };
        })
        .sort((a: INodeInfo, b: INodeInfo) => a.rank - b.rank);

      // Calculate the height of each rank for proper node placement
      const heightByrank: IHeightByRank = {};

      nodes.forEach((node: INodeInfo) => {
        let lastHeight = heightByrank[node.rank] || 0;
        node.position.y = lastHeight;
        heightByrank[node.rank] = lastHeight + node.data.height + headerOffset;
      });

      return {
        nodes,
        edges,
      };
    },
    [],
  );

  // Callback for layout changes
  const onLayout = useCallback((curNodes: INodeInfo[], curEdges: IEdge[]) => {
    const layouted = getLayoutedElements(curNodes, curEdges, {
      direction: 'LR',
    });

    // Update React Flow state with layouted nodes and edges
    setNodes([...layouted.nodes]);
    setEdges([...layouted.edges]);
  }, []);

  // Memoized height of each table based on the number of columns
  const heightByIndex: IHeightByIndex = useMemo(() => {
    const heightOfColumn = 23;

    return Object.fromEntries(
      database.tables.map((table: ITableData, index: number) => [
        index,
        table.columns.length * heightOfColumn,
      ]),
    );
  }, [database.tables]);

  // Function to get related tables for a given table
  const getRelation = useCallback(
    (curTable: ITableData) => {
      let tables = deepClone(database.tables).filter(
        (table: ITableData) => table.name !== curTable.name,
      );

      // Filter tables based on common primary key columns
      const relatedNodes = tables.filter((table: ITableData) => {
        const relatedColumn = curTable.primaryKeyColumns.find(
          (column: string) => table.columns.includes(column),
        );

        if (relatedColumn) {
          table.relatedColumn = relatedColumn;
          return true;
        }
      });

      return {
        relatedNodes,
      };
    },
    [database.tables],
  );

  // Function to get edges between tables
  const getEdges = useCallback(
    (table: ITableData, relatedNodes: ITableData[], edges: IEdge[]) => {
      for (const curTable of relatedNodes) {
        const edgeId = `${table.name}-${curTable.name}`;
        const reverseEdgeId = `${curTable.name}-${table.name}`;

        // Check if the edge already exists
        const nodeToFind = edges.find(
          (node: IEdge) => node.id === edgeId || node.id === reverseEdgeId,
        );

        // Add the edge if not found
        if (!nodeToFind) {
          edges.push({
            id: edgeId,
            source: table.name,
            target: curTable.name,
            style: { strokeWidth: 1, stroke: '#ccc', curveSmoothness: 0.5 },
          });
        }
      }

      return edges;
    },
    [],
  );

  // Function to update nodes and edges based on database tables
  const updateNodes = useCallback(() => {
    let edges: IEdge[] = [];

    const nodes: INodeInfo[] = database.tables.map(
      (table: ITableData, index: number) => {
        const { relatedNodes } = getRelation(table);
        const col: INodeInfo = {
          id: table.name,
          type: 'entity',
          position: { x: 0, y: 0 },
          data: {
            table,
            relatedNodes,
            height: heightByIndex[index],
          },
          rank: 0,
        };

        edges = getEdges(table, relatedNodes, edges);
        return col;
      },
    );

    // Trigger layout with updated nodes and edges
    onLayout(
      nodes,
      uniqBy([...edges], (edge: IEdge) => edge.id),
    );
  }, [database.tables]);

  // useEffect to trigger initial update when component mounts
  useEffect(() => {
    updateNodes();
  }, []);

  return (
    <ReactFlow
      nodes={nodes}
      nodeTypes={nodeTypes}
      edges={edges}
      onNodesChange={onNodesChange}
      onEdgesChange={onEdgesChange}
    />
  );
}

export default EntityDiagram;
