import * as _ from 'lodash';

import {RunsData} from '../../containers/RunsDataLoader';
import {RunWithRunsetInfo} from '../../state/runs/types';
import * as ColorUtil from '../../util/colors';
import {
  configKeysInExpression,
  evaluateExpression,
  Expression,
  expressionToString,
  metricsInExpression,
  summaryKeysInExpression,
} from '../../util/expr';
import {
  legendTemplateInsertCrosshairValues,
  legendTemplateRemoveCrosshairValues,
  legendTemplateToFancyLegendProps,
  parseLegendTemplate,
} from '../../util/legend';
import * as RunHelpers from '../../util/runhelpers';
import * as Run from '../../util/runs';
import {RunColorConfig} from '../../util/section';
import {Key, keyToString} from './../runs';
import {PlotFontSize, YAxisType} from './axis';
import {formatYAxis, getScaleFnFromScaleObject} from './axis';
import {Chart} from './chart';
import {ChartAggOption} from './chart';
import * as PlotMath from './math';
import {prettifyMetricName} from './prettifyMetricName';
import {AggregateCalculation, Bar, Line, RunSetInfo, Scalar} from './types';

const DEFAULT_VIOLIN_PLOT_BINS = 10;

export function defaultRunCharts(metricNames: [string]) {
  let x = 0;
  let y = 0;
  const importantRegexes = [/loss/, /accuracy|acc$/, /n_success/];
  const charts: Chart[] = [];
  importantRegexes.forEach(regex => {
    const lines = metricNames.filter(n => n.match(regex));
    if (lines.length > 0) {
      charts.push({
        config: {lines},
        layout: {x, y, w: 6, h: 2},
      });
      x += 6;
      if (x === 12) {
        x = 0;
        y += 2;
      }
    }
  });
  if (charts.length === 1) {
    charts[0].layout.w = 12;
  }
  return charts;
}

export function filterNegative(lines: Line[]) {
  /**
   * Iterate over all the lines and remove non-positive values for log scale
   */
  // TODO: Check if the NaN works ok
  // TODO: This doesn't handle area graphs
  return lines.map((line, i) => {
    const newLine = line;
    newLine.data = line.data.map(point => {
      if (point.y <= 0) {
        point.y = NaN;
      }
      return point;
    });
    return newLine;
  });
}

export function isMonotonicIncreasing(arr: number[]) {
  const n = arr.length;
  let i = 1;
  while (i < n && arr[i] - arr[i - 1] >= 0) {
    i++;
  }
  return i === n;
}

interface RunDataPoint {
  run: RunWithRunsetInfo;
  value: number;
  metricName: string;
  runOrGroupUniqueId: string;
  runOrGroupDisplayName: string;
}

type GroupDataPoint = RunDataPoint & {
  groupKeys: Run.Key[];
  stddev?: number;
  stderr?: number;
  mean?: number;
  quartiles?: [number, number, number, number, number];
  bins?: Array<{bin: number; count: number}>;
  range?: [number, number];
};

function isGroupDataPoint(
  dataPoint: GroupDataPoint | RunDataPoint
): dataPoint is GroupDataPoint {
  return (dataPoint as GroupDataPoint).groupKeys != null;
}

function aggregatePoints(
  points: RunDataPoint[],
  aggregateCalculations: AggregateCalculation[], // currently we do all calculations
  groupAgg: ChartAggOption,
  groupArea: AggregateCalculation,
  groupKeys: Run.Key[],
  numBins?: number
): GroupDataPoint | null {
  if (points.length === 0) {
    return null;
  }
  const values = points.map(p => p.value);
  const meanVal = PlotMath.avg(values);
  const stddevVal = PlotMath.stddev(values);
  const stderrVal = PlotMath.stderr(values);
  const quartilesVal = PlotMath.quartiles(values);
  const medianVal = quartilesVal[2];

  const bins = numBins != null ? PlotMath.bin(values, numBins) : undefined;

  let runIndex = 0;
  if (groupAgg === 'min') {
    runIndex = PlotMath.argMin(values);
  } else if (groupAgg === 'max') {
    runIndex = PlotMath.argMax(values);
  }

  const value =
    groupAgg === 'mean'
      ? meanVal
      : groupAgg === 'median'
      ? medianVal
      : groupAgg === 'max'
      ? quartilesVal[4]
      : groupAgg === 'min'
      ? quartilesVal[0]
      : groupAgg === 'sum'
      ? _.sum(values)
      : 0;

  const area: [number, number] | undefined =
    groupArea === 'minmax'
      ? [quartilesVal[0], quartilesVal[4]]
      : groupArea === 'stddev'
      ? [value - stddevVal, value + stddevVal]
      : groupArea === 'stderr'
      ? [value - stderrVal, value + stderrVal]
      : undefined;

  return {
    run: points[runIndex].run, // max or min case
    metricName: points[runIndex].metricName, // max or min case
    runOrGroupUniqueId: points[runIndex].runOrGroupUniqueId,
    runOrGroupDisplayName: points[runIndex].runOrGroupDisplayName,
    value,
    mean: meanVal,
    stddev: stddevVal,
    stderr: stderrVal,
    quartiles: quartilesVal,
    range: area,
    bins,
    groupKeys,
  };
}

type PointsFromRunsetParams = {
  runs: RunWithRunsetInfo[];
  metrics: Key[]; // typically one metric otherwise will aggregate across metrics
  expressions?: Expression[];
  groupKeys: Key[];
  aggregateCalculations: AggregateCalculation[];
  legendTemplate?: string;
  groupAgg?: ChartAggOption;
  groupArea?: AggregateCalculation;
  boxPlot?: boolean;
  violinPlot?: boolean;
  mergeRunsets?: boolean;
};

function pointsFromRunset(props: PointsFromRunsetParams) {
  const {
    runs,
    metrics,
    expressions,
    groupKeys,
    aggregateCalculations,
    groupAgg,
    groupArea,
    violinPlot,
    mergeRunsets,
  } = props;
  /*
   * Converts data in runset format to point format
   * Also does aggregation
   */
  const expressionKeys =
    expressions != null
      ? _.flatten(
          expressions.map(expr =>
            _.concat(
              summaryKeysInExpression(expr),
              configKeysInExpression(expr)
            )
          )
        )
      : [];

  let barData: RunDataPoint[] = runs
    .map(run => {
      if (expressions && expressions.length > 0) {
        const metricsToValue: {[key: string]: number} = {};
        expressionKeys.forEach(
          exprKey =>
            (metricsToValue[keyToString(exprKey)] = Run.getValueSafe(
              run,
              exprKey
            ) as number)
        );
        return expressions.map(expr => {
          const val = evaluateExpression(expr, metricsToValue);
          return {
            metricName: expressionToString(expr),
            value: val,
            run,
            runOrGroupUniqueId: Run.uniqueId(run, groupKeys || []),
            runOrGroupDisplayName:
              groupKeys.length === 0
                ? run.displayName
                : Run.groupedRunDisplayName(run, groupKeys),
          };
        });
      } else {
        return metrics.map(key => {
          const metricName = Run.keyDisplayName(key);
          return {
            run,
            value: Run.getValueSafe(run, key) as number,
            metricName,
            runOrGroupUniqueId: Run.uniqueId(
              run,
              groupKeys ?? [],
              !mergeRunsets
            ),
            runOrGroupDisplayName:
              groupKeys.length === 0
                ? run.displayName
                : Run.groupedRunDisplayName(run, groupKeys),
          };
        });
      }
    })
    .flat();

  if (groupKeys.length > 0 || metrics.length > 1) {
    const groupedBars = _.groupBy(
      barData,
      b =>
        // Can use not-null type assertion because we know we set l.run above.
        b.runOrGroupUniqueId
    );
    const bars = _.flatMap(groupedBars, barSet => {
      return aggregatePoints(
        barSet,
        aggregateCalculations,
        groupAgg ?? 'mean',
        groupArea ?? 'none',
        groupKeys,
        violinPlot ? DEFAULT_VIOLIN_PLOT_BINS : undefined
      );
    }).filter((b): b is GroupDataPoint => b != null);
    barData = bars;
  }

  return barData;
}

function pointResultsForRunset(
  runs: RunsData['filtered'],
  {
    groupAgg,
    groupArea,
    aggregateCalculations,
    violinPlot,
    legendTemplate,
    expressions,
    metricKeys,
    panelAggregate,
    groupBy,
    aggregateMetrics,
  }: PointsFromDataProps,
  q: RunSetInfo | undefined,
  mergeRunsets: boolean
): RunDataPoint[] {
  const runSetID = q?.id;
  const runsForRunset =
    runSetID != null ? runs.filter(r => r.runsetInfo.id === runSetID) : runs;

  const groupKeys = panelAggregate
    ? [Run.key('config', groupBy ?? '')]
    : q?.grouping ?? [];

  const partialParams = {
    runs: runsForRunset,
    groupKeys,
    aggregateCalculations,
    groupArea,
    groupAgg,
    violinPlot,
    legendTemplate,
    mergeRunsets,
  };

  if (aggregateMetrics) {
    return pointsFromRunset({
      metrics: metricKeys,
      expressions,
      ...partialParams,
    });
  }

  return metricKeys.flatMap(y => {
    return pointsFromRunset({
      metrics: [y],
      ...partialParams,
    });
  });
}

const convertRunDataPointsToBars = (props: {
  dataPoints: Array<RunDataPoint | GroupDataPoint>;
  useRunName: boolean;
  useMetricName: boolean;
  legendTemplate: string;
  colorEachMetricDifferently: boolean;
  customRunColors?: RunColorConfig;
  boxPlot?: boolean;
  violinPlot?: boolean;
}): Bar[] => {
  const {dataPoints, customRunColors, violinPlot, legendTemplate} = props;
  const metricToColorIdx = new Map<string, number>();
  if (props.colorEachMetricDifferently) {
    // build a map of metrics to indexes for coloring
    let maxIdx = 0;
    dataPoints.forEach(point => {
      if (!metricToColorIdx.has(point.metricName)) {
        metricToColorIdx.set(point.metricName, maxIdx);
        maxIdx++;
      }
    });
  }

  const bars = dataPoints.map(point => {
    const titleTemplate = parseLegendTemplate(
      legendTemplate,
      true,
      point.run,
      isGroupDataPoint(point) ? point.groupKeys : [],
      prettifyMetricName(point.metricName)
    );

    const key = legendTemplateRemoveCrosshairValues(titleTemplate).trim();

    const pointColor = props.colorEachMetricDifferently
      ? ColorUtil.color(metricToColorIdx.get(point.metricName) ?? 0)
      : // normal coloring
      point.run
      ? ColorUtil.runColor(
          point.run,
          isGroupDataPoint(point) ? point.groupKeys : [],
          customRunColors
        )
      : '000000';

    const uniqueId = point.runOrGroupUniqueId;

    return {
      key,
      title: titleTemplate,
      color: pointColor,
      metricName: point.metricName,
      uniqueId,
      value: point.value,
      mean: isGroupDataPoint(point) ? point.mean : undefined,
      stddev: isGroupDataPoint(point) ? point.stddev : undefined,
      displayName: point.runOrGroupDisplayName,
      quartiles: isGroupDataPoint(point) ? point.quartiles : undefined,
      bins: violinPlot && isGroupDataPoint(point) ? point.bins : undefined,
      range: isGroupDataPoint(point) ? point.range : undefined,
    };
  });

  return bars;
};

interface PointsFromDataProps {
  metricKeys: Run.Key[];
  customRunColors?: RunColorConfig;
  groupBy?: string;
  panelAggregate?: boolean;
  groupAgg?: ChartAggOption;
  groupArea?: AggregateCalculation;
  runSets?: RunSetInfo[];
  aggregateCalculations: AggregateCalculation[];
  colorEachMetricDifferently: boolean;
  aggregateMetrics: boolean;
  boxPlot?: boolean;
  violinPlot?: boolean;
  legendTemplate: string;
  expressions?: Expression[];
}

export const getPointsFromData = (
  runs: RunsData['filtered'],
  props: PointsFromDataProps
) => {
  /* Convert runsdata to barchart data */

  const {
    runSets,
    metricKeys,
    customRunColors,
    boxPlot,
    violinPlot,
    legendTemplate,
    colorEachMetricDifferently,
  } = props;

  if (runs.length === 0) {
    return [];
  }

  let barResults: RunDataPoint[];
  const runSetByID: {[id: string]: RunSetInfo} = {};
  if (runSets != null) {
    // When we have a runset (everywhere but the run page), map over the runsets.
    barResults = _.flatten(
      runSets.map(runSet => pointResultsForRunset(runs, props, runSet, false))
    );
    runSets.forEach(rs => (runSetByID[rs.id] = rs));
  } else {
    // Else we're in the run page and there's no runset.
    barResults = pointResultsForRunset(runs, props, undefined, false);
  }

  const useMetricName = metricKeys.length > 1;
  const useRunName =
    !useMetricName ||
    _.uniq(barResults.map(br => br.runOrGroupDisplayName)).length > 1;

  const bars = convertRunDataPointsToBars({
    dataPoints: barResults,
    useRunName,
    useMetricName,
    legendTemplate: legendTemplate || '',
    colorEachMetricDifferently,
    customRunColors,
    boxPlot,
    violinPlot,
  });
  return bars;
};

const convertRunDataPointsToScalar = (props: {
  dataPoint: RunDataPoint | GroupDataPoint;
  useRunName: boolean;
  useMetricName: boolean;
  legendTemplate: string;
  customRunColors?: RunColorConfig;
  entityName?: string;
  projectName?: string;
}): Scalar => {
  const {dataPoint, customRunColors, legendTemplate} = props;

  const rootUrl =
    props.entityName != null && props.projectName != null
      ? `/${props.entityName}/${props.projectName}/runs`
      : null;

  const titleTemplate = legendTemplateToFancyLegendProps(
    legendTemplate,
    dataPoint.run,
    [],
    prettifyMetricName(dataPoint.metricName),
    rootUrl ?? undefined
  );

  const pointColor = dataPoint.run
    ? ColorUtil.runColor(dataPoint.run, [], customRunColors)
    : '000000';
  const uniqueId = dataPoint.runOrGroupUniqueId;

  return {
    key: titleTemplate,
    color: pointColor,
    uniqueId,
    value: dataPoint.value,
    range: isGroupDataPoint(dataPoint) ? dataPoint.range : undefined,
    stddev: isGroupDataPoint(dataPoint) ? dataPoint.stddev : undefined,
    stderr: isGroupDataPoint(dataPoint) ? dataPoint.stderr : undefined,
  };
};

interface ScalarFromDataProps {
  metricKeys: Run.Key[];
  customRunColors?: RunColorConfig;
  groupAgg?: ChartAggOption;
  groupArea?: AggregateCalculation;
  runSets?: RunSetInfo[];
  aggregateCalculations: AggregateCalculation[];
  legendTemplate: string;
  expressions?: Expression[];
  entityName?: string;
  projectName?: string;
}

/* Convert runsdata to barchart data */
export const getScalarFromData = (
  runs: RunsData['filtered'],
  props: ScalarFromDataProps
): Scalar | null => {
  if (runs.length === 0) {
    return null;
  }
  const {metricKeys, customRunColors, legendTemplate} = props;

  const barResults = pointResultsForRunset(
    runs,
    {
      ...props,
      groupBy: '',
      panelAggregate: true,
      aggregateMetrics: true,
      colorEachMetricDifferently: false,
    },
    undefined,
    true
  );

  const useMetricName = metricKeys.length > 1;
  const useRunName =
    !useMetricName ||
    _.uniq(barResults.map(br => br.runOrGroupDisplayName)).length > 1;
  if (barResults.length === 0) {
    return {value: 0} as Scalar;
  }

  const runsetForRun = props.runSets?.find(
    rs => rs.id === barResults[0].run.runsetInfo.id
  );
  const entityName = runsetForRun?.entityName ?? props.entityName;
  const projectName = runsetForRun?.projectName ?? props.projectName;

  const scalar = convertRunDataPointsToScalar({
    dataPoint: barResults[0],
    useRunName,
    useMetricName,
    legendTemplate: legendTemplate || '',
    customRunColors,
    entityName,
    projectName,
  });
  return scalar;
};

export function barOverlay(b: Bar) {
  // quartiles wont exist if bar was converted from line
  return legendTemplateInsertCrosshairValues(
    b.key,
    b.title ?? '',
    true,
    {
      x: {xAxis: b.metricName ?? '', val: b.value},
      mean: b.mean,
      min: b.quartiles != null ? b.quartiles[0] : b?.range?.[0],
      max: b.quartiles != null ? b.quartiles[4] : b?.range?.[1],
      stddev: b.stddev,
    },
    'bar'
  );
}

// Timestamps have long axis tick labels, so we need to add margin for them
type GetPlotMarginParams = {
  axisKeys?: {
    xAxis?: string;
    yAxis?: string;
    zAxis?: string;
  };
  axisDomain?: {
    yAxis?: number[];
  };
  axisType?: {
    yAxis?: YAxisType;
  };
  axisValues?: {
    yAxis?: string[];
  };
  tickTotal?: {
    yAxis?: number;
  };
  fontSize?: PlotFontSize;
};

export const getPlotMargin = ({
  axisKeys = {},
  axisDomain = {},
  axisType = {},
  axisValues = {},
  tickTotal = {},
  fontSize = 'small',
}: GetPlotMarginParams) => {
  const {xAxis, yAxis, zAxis} = axisKeys;
  const xIsTime = xAxis && Run.isTimeKeyString(xAxis);
  const yIsTime = yAxis && Run.isTimeKeyString(yAxis);
  const zIsTime = zAxis && Run.isTimeKeyString(zAxis);

  const margin = {bottom: 30, left: 10, top: 5, right: 20};

  if (axisValues?.yAxis != null) {
    const marginWidth = RunHelpers.getAxisMarginWidth(axisValues.yAxis);
    margin.left = marginWidth;
  } else if (axisDomain?.yAxis) {
    // This is the internal function used by react vis to calculate the axis values
    const scale = getScaleFnFromScaleObject({
      type: axisType.yAxis ?? 'linear',
      domain: axisDomain.yAxis,
    });

    const formatFn = yIsTime ? Run.formatTimestamp : formatYAxis;

    const marginWidth = RunHelpers.getAxisMarginWidth(
      (scale.ticks(tickTotal?.yAxis) as any[]).map(formatFn)
    );
    margin.left = Math.max(marginWidth, margin.left);
  }

  if (xIsTime) {
    margin.top = 5;
    margin.bottom = 55;
    margin.left = 55;
  }
  if (yIsTime) {
    margin.left = 100;
  }
  if (zIsTime) {
    margin.top = 20;
  }

  if (fontSize === 'medium') {
    margin.left *= 1.2;
  } else if (fontSize === 'large') {
    margin.left *= 1.5;
  }

  return margin;
};

export function getMetricIdentifiersFromExpressions(
  expressions?: Expression[],
  xExpression?: Expression
): {
  xExpressionMetricIdentifiers: string[];
  expressionMetricIdentifiers: string[];
} {
  const xExpressionMetricIdentifiers =
    xExpression != null ? metricsInExpression(xExpression) : [];
  const expressionMetricIdentifiers =
    expressions != null
      ? _.flatten(expressions.map(expr => metricsInExpression(expr)))
      : [];
  return {xExpressionMetricIdentifiers, expressionMetricIdentifiers};
}

export function getAllMetrics(
  metrics: string[],
  expressionMetricIdentifiers: string[],
  xExpressionMetricIdentifiers: string[]
): string[] {
  return _.uniq(
    _.concat(xExpressionMetricIdentifiers, expressionMetricIdentifiers, metrics)
  );
}
