// Handle serialization/deserialization of CG, for remote execution
import _ from 'lodash';

import {StaticOpStore} from '../../opStore/static';
import {opDefIsLowLevel} from '../../runtimeHelpers';
import {invertRemap} from '../../util/invertRemap';
import {isSimpleTypeShape} from '../helpers';
import type {Type} from '../types';
import {constDate, constNodeUnsafe, varNode, voidNode} from './construction';
import type {EditingNode, EditingOp, EditingOutputNode} from './editing';
import {MemoizedHasher} from './editing';
import type {ConstNode, OpInputNodes, VarNode, VoidNode} from './types';

type SerializedRef = number;
interface NormExecGraph {
  ops: Map<EditingOp, SerializedRef>;
  constNodes: Map<ConstNode, SerializedRef>;
  varNodes: Map<VarNode, SerializedRef>;
  voidNodes: Map<VoidNode, SerializedRef>;
  outputNodes: Map<EditingOutputNode, SerializedRef>;
}

// Batched version of above; batching allows us to retain and share node
// identity among graphs
export interface BatchedNormExecGraphs {
  ng: NormExecGraph;
  roots: Set<EditingNode>;
}

// The serialized forms of each node type.
// Serialized forms roughly match non-serialized structure but
// replace node references with a number (SerializedRef)
interface SerializedConst {
  nodeType: 'const';
  type:
    | Type
    | {
        type: 'function';
        inputTypes: Record<string, Type>;
        outputType: Type;
      }
    | {type: 'list'; objectType: Type};
  val: any;
}

interface SerializedOp {
  inputs: Record<string, SerializedRef>;
  name: string;
}

interface SerializedVar {
  nodeType: 'var';
  type: string;
  varName: string;
}

interface SerializedVoid {
  nodeType: 'void';
}

interface SerializedOutput {
  nodeType: 'output';
  fromOp: SerializedRef;
  type:
    | Type
    | {
        type: 'function';
        inputTypes: Record<string, Type>;
        outputType: Type;
      }
    | {type: 'list'; objectType: Type};
  id: string;
}

type SerializedNode =
  | SerializedOp
  | SerializedConst
  | SerializedVar
  | SerializedVoid
  | SerializedOutput;

type SerializableGraph = {
  constNodes: Record<string, SerializedConst>;
  outputNodes: Record<string, SerializedOutput>;
  varNodes: Record<string, SerializedVar>;
  voidNodes: Record<string, SerializedVoid>;
  ops: Record<string, SerializedOp>;
};

function isSerializedOp(x: any): x is SerializedOp {
  return typeof x.name !== 'undefined';
}

// Final data structure before JSON serialization
export type FlatSerializableGraph = SerializedNode[];

export interface BatchedGraphs {
  // Misleading, this may actually be multiple disconnected subgraphs, but the naive TS type
  // is indistinguishable
  nodes: FlatSerializableGraph;
  rootNodes: SerializedRef[];
}

// TODO(np): Leaving this here because it competes with graphNorm, and is not used by anything
// else.  If we need this outside of serialization, extract it.
// Mostly copied from graphNorm.ts, except this operates on Types.Node not CGTypes.EditingNode
// and correctly handles constFunction to include the inner fn's ops
class GraphNormalizer {
  private nextId: SerializedRef = 0;
  private ng: NormExecGraph | null = null;
  private roots: Set<EditingNode> = new Set();

  public normalize(node: EditingNode): NormExecGraph {
    this.reset();
    this.visitNode(node);
    return this.ng!;
  }

  public normalizeBatch(nodes: EditingNode[]): BatchedNormExecGraphs {
    this.reset();
    for (const node of nodes) {
      // TODO(np): This should really be an ordered set
      this.roots.add(node);
      this.visitNode(node);
    }
    return {
      ng: this.ng!,
      roots: new Set(this.roots),
    };
  }

  private reset() {
    this.nextId = 0;
    this.ng = {
      ops: new Map(),
      constNodes: new Map(),
      varNodes: new Map(),
      voidNodes: new Map(),
      outputNodes: new Map(),
    };
    this.roots.clear();
  }

  private visitOp(op: EditingOp) {
    const opId = this.ng!.ops.get(op);
    if (opId != null) {
      return;
    }
    const id = this.nextId++;
    this.ng!.ops.set(op, id);
    for (const argNode of Object.values(op.inputs)) {
      this.visitNode(argNode);
    }
    return;
  }

  private visitNode(node: EditingNode) {
    if (node.nodeType === 'const') {
      const nodeId = this.ng!.constNodes.get(node);
      if (nodeId != null) {
        return;
      }
      const id = this.nextId++;
      this.ng!.constNodes.set(node, id);
      if (typeof node.type === 'object' && node.type.type === 'function') {
        this.visitNode(node.val);
      }
      return;
    } else if (node.nodeType === 'var') {
      const nodeId = this.ng!.varNodes.get(node);
      if (nodeId != null) {
        return;
      }
      const id = this.nextId++;
      this.ng!.varNodes.set(node, id);
      return;
    } else if (node.nodeType === 'void') {
      const nodeId = this.ng!.voidNodes.get(node);
      if (nodeId != null) {
        return;
      }
      const id = this.nextId++;
      this.ng!.voidNodes.set(node, id);
      return;
    } else if (node.nodeType === 'output') {
      const nodeId = this.ng!.outputNodes.get(node);
      if (nodeId != null) {
        return;
      }
      const id = this.nextId++;
      this.ng!.outputNodes.set(node, id);
      this.visitOp(node.fromOp);
      return;
    }
    throw new Error(`invalid node: ${JSON.stringify(node)}`);
  }
}

// Given a normalized exec graph, look up node/op and return its serialized ref, if any
function lookup(
  norm: NormExecGraph,
  nodeOrOp: EditingNode | EditingOp
): SerializedRef | undefined {
  if (typeof (nodeOrOp as EditingOp).name !== 'undefined') {
    const op = nodeOrOp as EditingOp;
    return norm.ops.get(op);
  } else {
    const node = nodeOrOp as EditingNode;
    switch (node.nodeType) {
      case 'const':
        return norm.constNodes.get(node);
      case 'output':
        return norm.outputNodes.get(node);
      case 'var':
        return norm.varNodes.get(node);
      case 'void':
        return norm.voidNodes.get(node);
    }
  }
}

// Convert SerializableGraph into the final wire format, a flat, sorted array of nodes
function flattenGraph(graph: SerializableGraph): FlatSerializableGraph {
  return [
    ...Object.entries(graph.constNodes),
    ...Object.entries(graph.outputNodes),
    ...Object.entries(graph.varNodes),
    ...Object.entries(graph.voidNodes),
    ...Object.entries(graph.ops),
  ]
    .sort((a, b) => parseInt(a[0], 10) - parseInt(b[0], 10))
    .map(entry => entry[1]);
}

// Array of CG -> Normalized Graph -> Serializable Graph + Roots -> Flat Serializable Graph
export function serialize(graphs: EditingNode[]): BatchedGraphs {
  const hasher = new MemoizedHasher();
  const norm = new GraphNormalizer().normalizeBatch(graphs);

  const localLookup = (node: EditingNode | EditingOp) => lookup(norm.ng, node);

  const inverse: SerializableGraph = {
    ops: Object.fromEntries(
      invertRemap(norm.ng.ops, (key, id) => {
        const result = {name: key.name, inputs: {}} as any;
        for (const inputName of Object.keys(key.inputs)) {
          result.inputs[inputName] = localLookup(key.inputs[inputName]);
        }
        return result;
      })
    ) as Record<string, SerializedOp>,
    constNodes: Object.fromEntries(
      invertRemap(norm.ng.constNodes, (key, id) => {
        const result = {
          nodeType: key.nodeType,
          type: key.type,
          val: null,
        } as any;
        if (
          key.val !== null &&
          typeof key.val === 'object' &&
          !_.isArray(key.val)
        ) {
          if (key.val.nodeType === 'output') {
            result.val = {
              nodeType: key.val.nodeType,
              fromOp: localLookup(key.val.fromOp),
            };
          } else {
            if (key.val instanceof Date) {
              result.val = {type: 'date', val: key.val};
            } else {
              result.val = key.val;
            }
          }
        } else {
          result.val = key.val;
        }
        return result;
      })
    ) as Record<string, SerializedConst>,
    varNodes: Object.fromEntries(
      invertRemap(norm.ng.varNodes, (key, id) => {
        return key;
      })
    ) as Record<string, SerializedVar>,
    voidNodes: Object.fromEntries(
      invertRemap(norm.ng.voidNodes, (key, id) => {
        return key;
      })
    ) as Record<string, SerializedVoid>,
    outputNodes: Object.fromEntries(
      invertRemap(norm.ng.outputNodes, (key, id) => {
        const result: SerializedOutput = {
          nodeType: key.nodeType,
          fromOp: localLookup(key.fromOp) as number, // TODO(np): TS calculates the wrong type here
          type: key.type,
          id: hasher.typedNodeId(key),
        };
        return result;
      })
    ) as Record<string, SerializedOutput>,
  };

  return {
    nodes: flattenGraph(inverse),
    rootNodes: graphs.map(rootNode => {
      const ref = localLookup(rootNode)!;
      if (typeof ref === 'undefined') {
        throw new Error(`cannot find node ${JSON.stringify(rootNode)}`);
      }
      return ref;
    }),
  };
}

// Serialized Graph -> Flat Serializable Graph -> CG
export function deserialize(batch: BatchedGraphs): EditingNode[] {
  const nodeCache = new Map<SerializedNode, EditingNode>();
  const cached = (
    serializedNode: SerializedNode,
    node: EditingNode
  ): EditingNode => {
    nodeCache.set(serializedNode, node);
    return node;
  };

  const doDeserialize = (node: SerializedNode): EditingNode => {
    if (nodeCache.has(node)) {
      return nodeCache.get(node)!;
    }
    if (isSerializedOp(node)) {
      const opDef = StaticOpStore.getInstance().getOpDef(node.name);
      if (opDefIsLowLevel(opDef)) {
        return cached(
          node,
          opDef.op(
            _.mapValues(node.inputs, o =>
              doDeserialize(batch.nodes[o])
            ) as OpInputNodes<any>
          )
        );
      }
    } else if (node.nodeType === 'const') {
      if (isSimpleTypeShape(node.type)) {
        switch (node.type) {
          case 'date':
            return cached(node, constDate(node.val.val));
          default:
            return cached(node, node);
        }
      } else {
        // function or other complex type
        switch (node.type.type) {
          case 'function':
            const fnNode = doDeserialize(batch.nodes[node.val.fromOp]);
            return cached(node, constNodeUnsafe(node.type, fnNode));
          default:
            return cached(node, constNodeUnsafe(node.type, node.val));
        }
      }
    } else if (node.nodeType === 'var') {
      return cached(node, varNode(node.type as Type, node.varName));
    } else if (node.nodeType === 'void') {
      return voidNode();
    } else if (node.nodeType === 'output') {
      const fromOp = batch.nodes[node.fromOp];
      if (!isSerializedOp(fromOp)) {
        throw new Error(`invalid graph: Expected op at index ${node.fromOp}`);
      }
      const opDef = StaticOpStore.getInstance().getOpDef(fromOp.name);
      if (opDefIsLowLevel(opDef)) {
        const result = opDef.op(
          _.mapValues(fromOp.inputs, o =>
            doDeserialize(batch.nodes[o])
          ) as OpInputNodes<any>
        );
        result.id = node.id;
        result.type = node.type;
        return cached(node, result);
      }
    }
    throw new Error(`Can't handle node: ${JSON.stringify(node)}`);
  };

  return batch.rootNodes.map(ref => doDeserialize(batch.nodes[ref]));
}

// Actually we're just going to try to deserialize the thing and report errors
export function isSerializedGraph(
  maybeGraph: any
): maybeGraph is FlatSerializableGraph {
  if (!_.isArray(maybeGraph)) {
    return false;
  }
  try {
    deserialize({nodes: maybeGraph, rootNodes: [0]});
  } catch (err) {
    return false;
  }
  return true;
}
