import {urlPrefixed} from '@wandb/common/config';
import dagre from 'dagre';
import React from 'react';
import ReactFlow, {
  ConnectionLineType,
  Controls,
  Edge as FlowEdge,
  MarkerType,
  Node as FlowNode,
  Position,
  ReactFlowProvider,
  useReactFlow,
} from 'react-flow-renderer';

import {Dag, Direction} from '../state/graphql/artifactDagQuery';
import * as S from './ArtifactFlowDag.styles';

// We use dagre as the layouting engine.
// https://reactflow.dev/docs/examples/layout/dagre/
const getLayoutedElements = (
  nodes: FlowNode[],
  edges: FlowEdge[],
  direction = 'LR',
  artifactID: string
): {nodes: FlowNode[]; edges: FlowEdge[]; selectedNode?: FlowNode} => {
  const dagreGraph = new dagre.graphlib.Graph();
  dagreGraph.setDefaultEdgeLabel(() => ({}));

  const nodeWidth = 130;
  const nodeHeight = 40;
  const isHorizontal = direction === 'LR';
  dagreGraph.setGraph({rankdir: direction});

  nodes.forEach(node => {
    dagreGraph.setNode(node.id, {width: nodeWidth, height: nodeHeight});
  });

  edges.forEach(edge => {
    dagreGraph.setEdge(edge.source, edge.target);
  });

  dagre.layout(dagreGraph);

  let selectedNode;
  nodes.forEach(node => {
    const nodeWithPosition = dagreGraph.node(node.id);
    node.targetPosition = isHorizontal ? Position.Left : Position.Top;
    node.sourcePosition = isHorizontal ? Position.Right : Position.Bottom;

    // We are shifting the dagre node position (anchor=center center) to the top left
    // so it matches the React Flow node anchor point (top left).
    node.position = {
      x: nodeWithPosition.x - nodeWidth / 2,
      y: nodeWithPosition.y - nodeHeight / 2,
    };

    if (node.id === artifactID) {
      selectedNode = node;
    }

    return node;
  });

  return {nodes, edges, selectedNode};
};

function dagToFlowEls(dag: Dag, dagMode: string, artifactID: string): any {
  const {artifacts, runs, edges} = dag;
  const flowNodes: FlowNode[] = [];
  const flowEdges: FlowEdge[] = [];
  const position = {x: 0, y: 0};
  const edgeType = 'straight';

  const extractRunPath = (runNodeID: string): string => {
    const [entity, project, runID] = runNodeID.split('/');
    return ['', entity, project, 'runs', runID, 'overview'].join('/');
  };

  const extractArtifactPath = (artifactNodeID: string): string => {
    const [entity, project, artifactTypeName, artifactWithCommitHash] =
      artifactNodeID.split('/');
    const [artifactName, commitHash] = artifactWithCommitHash.split(':');
    return [
      '',
      entity,
      project,
      'artifacts',
      artifactTypeName,
      artifactName,
      commitHash,
    ].join('/');
  };

  Object.entries(artifacts).forEach(([aid, artifact], _) => {
    const isSelected = aid === artifactID;
    if (artifact != null) {
      flowNodes.push({
        id: aid,
        data: {
          label: (
            <div
              style={{height: '47px'}}
              onClick={() => {
                if (!dagMode.includes('collapsed')) {
                  window.open(urlPrefixed(extractArtifactPath(aid)), '_blank');
                }
              }}>
              <S.NodeHeader type={'artifact'} selected={isSelected}>
                <span>ARTIFACT</span>
                <S.NodeGroupText>{artifact.artifactTypeName}</S.NodeGroupText>
              </S.NodeHeader>
              <S.NodeContent type={'artifact'} selected={isSelected}>
                {artifact.artifacts != null
                  ? `${artifact.artifacts.length} instances`
                  : `${artifact.artifactSequenceName}:v${artifact.versionIndex}`}
              </S.NodeContent>
            </div>
          ),
        },
        position,
        style: {
          padding: 'none',
          border: 'none',
        },
      });
    }
  });

  Object.entries(runs).forEach(([rid, run], _) => {
    if (run != null) {
      let groupText = '';
      if (dagMode !== 'collapsed') {
        groupText = run.jobType!;
      }
      let nodeContent = run.displayName;
      if (dagMode.includes('collapsed') && run.runs != null) {
        nodeContent = `${run.runs.length} instances`;
      }
      flowNodes.push({
        id: rid,
        data: {
          label: (
            <div
              style={{
                height: '47px',
              }}
              onClick={() => {
                if (!dagMode.includes('collapsed')) {
                  window.open(urlPrefixed(extractRunPath(rid)), '_blank');
                }
              }}>
              <S.NodeHeader type={'run'} selected={false}>
                <span>RUN</span>
                <S.NodeGroupText>{groupText}</S.NodeGroupText>
              </S.NodeHeader>
              <S.NodeContent type={'run'} selected={false}>
                {nodeContent}
              </S.NodeContent>
            </div>
          ),
        },
        position,
        style: {
          padding: 'none',
          border: 'none',
        },
      });
    }
  });

  Object.entries(edges).forEach(([id, edge], _) => {
    if (edge != null) {
      let source = null;
      let target = null;
      if (dagMode.includes('collapsed')) {
        source = edge.jobType!;
        target = edge.artifactTypeName!;
      } else {
        source = edge.runID;
        target = edge.artifactID;
      }
      if (edge.dir === Direction.AwayFromArtifact) {
        [source, target] = [target, source];
      }
      flowEdges.push({
        id,
        source,
        target,
        type: edgeType,
        markerEnd: {
          type: MarkerType.Arrow,
          width: 20,
          height: 30,
        },
      });
    }
  });

  return [flowNodes, flowEdges];
}

const LayoutFlow = (props: {dag: Dag; dagMode: string; artifactID: string}) => {
  const {dag, dagMode, artifactID} = props;
  const [nodes, edges] = dagToFlowEls(dag, dagMode, artifactID);
  const {
    nodes: layoutedNodes,
    edges: layoutedEdges,
    selectedNode,
  } = getLayoutedElements(nodes, edges, 'LR', artifactID);

  const {setCenter} = useReactFlow();

  return (
    <div
      className="layoutflow"
      style={{position: 'relative', height: '100%', cursor: 'grab'}}>
      <ReactFlow
        minZoom={0.1}
        nodes={layoutedNodes}
        edges={layoutedEdges}
        onNodesChange={(events: any) => {
          let move = true;
          events.forEach((e: any) => {
            if (e.type === 'select') {
              move = false;
            }
          });
          if (move && selectedNode != null) {
            setCenter(selectedNode.position.x, selectedNode.position.y + 100, {
              zoom: 1,
              duration: 1000,
            });
          }
        }}
        nodesDraggable={false}
        onInit={() => {
          if (selectedNode != null) {
            setCenter(selectedNode.position.x, selectedNode.position.y + 100, {
              zoom: 1,
              duration: 1000,
            });
          }
        }}
        connectionLineType={ConnectionLineType.Straight}
        fitView>
        <Controls style={{marginBottom: '50px'}} />
      </ReactFlow>
    </div>
  );
};

const LayoutFlowOuter = (props: {
  dag: any;
  dagMode: string;
  artifactID: string;
}) => {
  return (
    <ReactFlowProvider>
      <LayoutFlow {...props} />
    </ReactFlowProvider>
  );
};

export default LayoutFlowOuter;
