import {AxisBottom, AxisLeft} from '@visx/axis';
import {localPoint} from '@visx/event';
import {Group} from '@visx/group';
import {ParentSize} from '@visx/responsive';
import {scaleBand, scaleLinear} from '@visx/scale';
import * as Shape from '@visx/shape';
import * as Stats from '@visx/stats';
import {Text} from '@visx/text';
import {TooltipWithBounds, useTooltip} from '@visx/tooltip';
import * as globals from '@wandb/common/css/globals.styles';
import _ from 'lodash';
import React from 'react';

import {
  axisTickRotate,
  Bar,
  barOverlay,
  getAngledXAxisMarginHeight,
  getAxisStyleForFontSize,
  getPlotMargin,
  PlotFontSize,
} from '../../util/plotHelpers';

interface Margin {
  left: number;
  right: number;
  top: number;
  bottom: number;
}

interface BarChartProps {
  bars: Bar[];
  vertical?: boolean;
  height?: number;
  min?: number;
  max?: number;
  maxBars?: number;
  boxPlot?: boolean;
  violinPlot?: boolean;
  highlight?: Bar;
  showAllLabels?: boolean;
  fontSize?: PlotFontSize;
  enableTooltip?: boolean;
  enableAxisTicks?: boolean;
  mouseOver?: (
    event: React.MouseEvent<SVGRectElement, MouseEvent>,
    bar: Bar
  ) => void;
  mouseOut?: () => void;
}

type BarChartVizProps = BarChartProps & {
  margin: Margin;
  xDomain: number[];
  yDomain: number[];
};

// responsive utils for axis ticks
function numTicksForHeight(h: number) {
  if (h <= 300) {
    return 3;
  }
  if (300 < h && h <= 600) {
    return 5;
  }
  return 10;
}

function numTicksForWidth(w: number) {
  if (w <= 300) {
    return 3;
  }
  if (300 < w && w <= 400) {
    return 5;
  }
  return 10;
}

const BarChart = (props: BarChartProps) => {
  const {mouseOver, mouseOut, maxBars, enableTooltip = true} = props;

  const bars = maxBars != null ? props.bars.slice(0, maxBars) : props.bars;

  const {
    tooltipData,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    showTooltip,
    hideTooltip,
  } = useTooltip();

  const yDomain = bars.map((d, i) => i + 1);

  const xDomain = [
    props.min ?? Math.min(...bars.map(d => (isNaN(d.value) ? 0 : d.value)), 0),
    props.max ?? Math.max(...bars.map(d => (isNaN(d.value) ? 0 : d.value)), 0),
  ];

  const handleMouseOver = (
    event: React.MouseEvent<SVGRectElement, MouseEvent>,
    bar: Bar
  ) => {
    const coords = localPoint(event);
    if (enableTooltip) {
      showTooltip({
        tooltipLeft: coords?.x ?? 0,
        tooltipTop: coords?.y ?? 0,
        tooltipData: bar,
      });
    }
    if (mouseOver != null) {
      mouseOver(event, bar);
    }
  };

  const handleMouseOut = () => {
    hideTooltip();
    if (mouseOut != null) {
      mouseOut();
    }
  };

  const tooltip = (
    <TooltipWithBounds
      // set this to random so it correctly updates with parent bounds key
      key={Math.random()}
      className={'rv-hint'}
      top={tooltipTop}
      left={tooltipLeft}
      style={{backgroundColor: 'transparent'}}>
      <div className={'rv-hint__content'}>
        {tooltipData != null && (tooltipData as Bar).hasOwnProperty('key') ? (
          barOverlay(tooltipData as Bar)
        ) : (
          <></>
        )}
      </div>
    </TooltipWithBounds>
  );

  const vals = _.flatten(
    bars.map(b =>
      b.bins != null
        ? b.bins.map(d => d.bin)
        : b.quartiles != null
        ? b.quartiles
        : [b.value]
    )
  );

  const min = props.min ?? Math.min(...vals.map(v => (isNaN(v) ? 0 : v)), 0);
  const max = props.max ?? Math.max(...vals.map(v => (isNaN(v) ? 0 : v)), 0);

  const margin = getPlotMargin({
    axisDomain: {yAxis: yDomain},
    axisType: {
      yAxis: 'linear',
    },
    fontSize: 'small',
  });

  const fillStyle = React.useMemo(() => {
    return props.height != null ? {height: props.height} : {};
  }, [props.height]);

  if (props.vertical) {
    return (
      <div className="bar-chart" style={fillStyle}>
        <ParentSize>
          {parent => (
            <VerticalBarChart
              {...{
                ...props,
                bars,
                width: parent.width,
                height: parent.height,
                margin,
                mouseOver: handleMouseOver,
                mouseOut: handleMouseOut,
                min,
                max,
                xDomain,
                yDomain,
              }}
            />
          )}
        </ParentSize>
        {tooltipOpen && tooltip}
      </div>
    );
  } else {
    return (
      <div className="bar-chart" style={fillStyle}>
        <ParentSize>
          {parent => (
            <div style={{width: '100%', height: '100%', overflow: 'auto'}}>
              <HorizontalBarChart
                {...{
                  ...props,
                  bars,
                  parentHeight: parent.height,
                  parentWidth: parent.width,
                  margin,
                  mouseOver: handleMouseOver,
                  mouseOut: handleMouseOut,
                  min,
                  max,
                  xDomain,
                  yDomain,
                }}
              />
            </div>
          )}
        </ParentSize>
        {tooltipOpen && tooltip}
      </div>
    );
  }
};

export default BarChart;

const HorizontalBarChart = (
  props: BarChartVizProps & {parentWidth: number; parentHeight: number}
) => {
  const {
    bars,
    parentWidth,
    parentHeight,
    mouseOver,
    mouseOut,
    boxPlot,
    violinPlot,
    yDomain,
  } = props;

  const xDomain = [
    props.min ?? 0,
    props.max ?? Math.max(...bars.map(d => (isNaN(d.value) ? 0 : d.value))),
  ];

  const yKeys = bars.map(d => d.key);

  const margin = getPlotMargin({
    axisDomain: {yAxis: yDomain},
    axisType: {
      yAxis: 'linear',
    },
    axisValues: {yAxis: yKeys},

    fontSize: props.fontSize ?? 'small',
  });

  const axisFontStyles = getAxisStyleForFontSize(props.fontSize ?? 'small');

  margin.left = 32;
  margin.right = 32;

  const labelHeight = 24;
  const minimumBarHeight = 2;
  const minWidth = 0;
  const minHeight =
    (labelHeight + minimumBarHeight) * yDomain.length +
    margin.top +
    margin.bottom;

  const width = Math.max(parentWidth, minWidth);
  const height = Math.max(parentHeight, minHeight);

  const xMax = width - margin.left - margin.right;
  const yMax = height - margin.top - margin.bottom;

  const spaceReservedPerBar = yMax / yDomain.length;
  const requiredPaddingForLabel = labelHeight / spaceReservedPerBar;
  const paddingBetweenBars = Math.max(0.25, requiredPaddingForLabel);
  const yScale = scaleBand({
    range: [0, yMax],
    round: true,
    domain: yDomain,
    padding: paddingBetweenBars,
  });

  const xScale = scaleLinear({
    range: [0, xMax],
    round: true,
    domain: xDomain,
  });

  return (
    <svg style={{display: 'block', width: '100%', height: '100%'}}>
      <Group top={margin.top} left={margin.left} key="chart">
        {bars.map((d, i) => {
          const barHeight = Math.max(0, yScale.bandwidth());
          let barX = Math.max(xScale(0), xScale(xDomain[0]), 0);

          const barY = yScale(i + 1) ?? 0;
          const constrainedHeight = Math.min(40, barHeight);
          const key = d.key + '--' + i.toString();

          if (violinPlot) {
            return (
              <Group key={key}>
                <Text
                  x={barX}
                  y={barY - 4}
                  fill={d.color ?? 'red'}
                  style={axisFontStyles}>
                  {d.key}
                </Text>
                <Stats.ViolinPlot
                  key={key}
                  data={d.bins ?? []}
                  count={data => data.count}
                  value={data => data.bin}
                  top={barY + barHeight / 2 - constrainedHeight / 2}
                  width={constrainedHeight} // not a bug - here width is "bar" width
                  valueScale={xScale}
                  fill={d.color ?? 'red'}
                  fillOpacity={0.5}
                  stroke={d.color ?? 'red'}
                  strokeWidth={1}
                  horizontal
                />
              </Group>
            );
          } else if (boxPlot) {
            return (
              <Group key={key}>
                <Text
                  x={barX}
                  y={barY - 4}
                  fill={d.color ?? 'red'}
                  style={axisFontStyles}>
                  {d.key}
                </Text>
                <Stats.BoxPlot
                  key={key}
                  min={d.quartiles != null ? d.quartiles[0] : d.value}
                  max={d.quartiles != null ? d.quartiles[4] : d.value}
                  median={d.quartiles != null ? d.quartiles[2] : d.value}
                  firstQuartile={d.quartiles != null ? d.quartiles[1] : d.value}
                  thirdQuartile={d.quartiles != null ? d.quartiles[3] : d.value}
                  top={barY + barHeight / 2 - constrainedHeight / 2}
                  fill={d.color ?? 'red'}
                  fillOpacity={0.5}
                  stroke={d.color ?? 'red'}
                  strokeWidth={1}
                  boxWidth={constrainedHeight}
                  valueScale={xScale}
                  horizontal
                  boxProps={{
                    onMouseOver: event => mouseOver && mouseOver(event, d),
                    onMouseLeave: event => mouseOut && mouseOut(),
                  }}
                />
              </Group>
            );
          } else {
            let barWidth = xScale(d.value) - barX;
            if (isNaN(barWidth)) {
              barWidth = 0;
            }
            if (barWidth < 0) {
              barWidth = Math.abs(barWidth);
              barX = barX - barWidth;
              if (barX < 0) {
                barWidth = Math.max(barWidth + barX, 0);
                barX = 0;
              }
            }
            return (
              <Group key={key}>
                <Text
                  x={barX}
                  y={barY - 4}
                  fill={d.color ?? 'red'}
                  style={axisFontStyles}>
                  {d.key}
                </Text>
                <Shape.Bar
                  key={key}
                  width={barWidth}
                  height={barHeight}
                  x={barX}
                  y={barY}
                  fill={d.color ?? 'red'}
                  onMouseOver={event => mouseOver && mouseOver(event, d)}
                  onMouseOut={event => mouseOut && mouseOut()}
                />
                {d.range != null &&
                  _.isFinite(d.range[0]) &&
                  _.isFinite(d.range[1]) && (
                    <Shape.Bar
                      key={key + 'error-bar'}
                      width={Math.max(
                        0,
                        xScale(d.range[1]) - xScale(d.range[0])
                      )}
                      height={1}
                      x={xScale(d.range[0])}
                      y={(barY ?? 0) + barHeight / 2}
                      fill={'black'}
                    />
                  )}
              </Group>
            );
          }
        })}
      </Group>
      <Group key="axis">
        <line
          x1={margin.left}
          y1={height - margin.bottom}
          x2={width - margin.right}
          y2={height - margin.bottom}
          stroke="#888"
          strokeWidth={0.5}
        />

        <AxisBottom
          top={height - margin.bottom}
          left={margin.left}
          scale={xScale}
          stroke={globals.gray500}
          strokeWidth={0.5}
          tickStroke="#b3b3b0"
          labelProps={{
            ...axisFontStyles,
            textAnchor: 'middle',
          }}
          tickLabelProps={(value: any, index: any) => ({
            ...axisFontStyles,
            textAnchor: 'middle',
            dx: '-0.25em',
            dy: '0.25em',
          })}
          numTicks={numTicksForWidth(width)}></AxisBottom>
      </Group>
    </svg>
  );
};

const VerticalBarChart = (
  props: BarChartVizProps & {width: number; height: number}
) => {
  const {
    bars,
    width,
    height,
    mouseOver,
    mouseOut,
    boxPlot,
    violinPlot,
    enableAxisTicks = true,
  } = props;

  const yDomain = [
    props.min ?? 0,
    props.max ?? Math.max(...bars.map(d => (isNaN(d.value) ? 0 : d.value))),
  ];

  const xKeys = bars.map(d => d.key);
  const margin = getPlotMargin({
    axisDomain: {yAxis: yDomain},
    axisType: {
      yAxis: 'linear',
    },
    fontSize: props.fontSize ?? 'small',
  });

  const axisFontStyles = getAxisStyleForFontSize(props.fontSize ?? 'small');

  const xMargin = getAngledXAxisMarginHeight(xKeys);
  margin.bottom = xMargin;

  const xMax = width - margin.left - margin.right;
  const yMax = height - margin.top - margin.bottom;

  const xScale = scaleBand({
    range: [0, xMax],
    round: true,
    domain: bars.map((d, i) => i + 1),
    padding: 0.25,
  });
  const yScale = scaleLinear({
    range: [yMax, 0],
    round: true,
    domain: yDomain,
  });
  return (
    <svg style={{width: '100%', height: '100%'}}>
      <Group top={margin.top} left={margin.left} key="chart">
        {bars.map((d, i) => {
          const barWidth = Math.max(0, xScale.bandwidth());
          const constrainedWidth = Math.min(40, barWidth);

          const barX = xScale(i + 1) ?? 0;
          const key = d.key + i.toString();
          if (violinPlot) {
            return (
              <Stats.ViolinPlot
                key={key}
                data={d.bins ?? []}
                count={data => data.count}
                value={data => data.bin}
                left={barX + barWidth / 2 - constrainedWidth / 2}
                width={constrainedWidth}
                valueScale={yScale}
                fill={d.color ?? 'red'}
                fillOpacity={0.5}
                stroke={d.color ?? 'red'}
                strokeWidth={1}
              />
            );
          } else if (boxPlot) {
            return (
              <Stats.BoxPlot
                key={key}
                min={d.quartiles != null ? d.quartiles[0] : d.value}
                max={d.quartiles != null ? d.quartiles[4] : d.value}
                median={d.quartiles != null ? d.quartiles[2] : d.value}
                firstQuartile={d.quartiles != null ? d.quartiles[1] : d.value}
                thirdQuartile={d.quartiles != null ? d.quartiles[3] : d.value}
                left={barX + barWidth / 2 - constrainedWidth / 2}
                fill={d.color ?? 'red'}
                fillOpacity={0.5}
                stroke={d.color ?? 'red'}
                strokeWidth={1}
                valueScale={yScale}
                boxWidth={constrainedWidth}
                boxProps={{
                  onMouseOver: event => mouseOver && mouseOver(event, d),
                  onMouseLeave: event => mouseOut && mouseOut(),
                }}
              />
            );
          } else {
            const min = Math.min(yDomain[0], yDomain[1]);
            if (d.value < min) {
              return null;
            }

            // barY is where the bar starts, either the origin or the minimum y value
            let barY = Math.max(0, Math.min(yScale(0), yScale(yDomain[0])));
            let barHeight = yScale(d.value) - barY;

            if (isNaN(barHeight)) {
              barHeight = 0;
            }
            if (barHeight < 0) {
              barHeight = Math.abs(barHeight);
              barY = barY - barHeight;
              if (barY < 0) {
                barHeight = Math.max(barHeight + barY, 0);
                barY = 0;
              }
            }

            return (
              <Group key={key}>
                <Shape.Bar
                  key={key}
                  width={barWidth}
                  height={barHeight}
                  x={barX}
                  y={barY}
                  fill={d.color ?? 'red'}
                  onMouseOver={event => mouseOver && mouseOver(event, d)}
                  onMouseOut={event => mouseOut && mouseOut()}
                />
                {d.range != null &&
                  _.isFinite(d.range[0]) &&
                  _.isFinite(d.range[1]) && (
                    <Shape.Bar
                      key={key + 'error-bar'}
                      height={Math.max(
                        0,
                        yScale(d.range[0]) - yScale(d.range[1])
                      )}
                      width={1}
                      y={yScale(d.range[1])}
                      x={(barX ?? 0) + barWidth / 2}
                      fill={'black'}
                    />
                  )}
              </Group>
            );
          }
        })}
        {yDomain[0] <= 0 && (
          // origin line
          <Shape.Line
            x1={0}
            y1={yScale(0)}
            x2={xMax}
            y2={yScale(0)}
            stroke={globals.gray500}
            strokeWidth={1.0}
          />
        )}
      </Group>
      <Group key="axis">
        <AxisLeft
          top={margin.top}
          left={margin.left}
          scale={yScale}
          numTicks={numTicksForHeight(height)}
          labelProps={{
            ...axisFontStyles,
            textAnchor: 'middle',
          }}
          stroke={globals.gray500}
          tickStroke={enableAxisTicks ? `#b3b3b0` : `transparent`}
          strokeWidth={0.5}
          tickLabelProps={(value: any, index: any) => ({
            ...axisFontStyles,
            textAnchor: 'end',
            dx: '-0.25em',
            dy: '0.25em',
          })}
          tickComponent={({formattedValue, ...tickProps}) => {
            return enableAxisTicks ? (
              <text {...(tickProps as any)}>{formattedValue}</text>
            ) : null;
          }}
        />

        <line
          x1={margin.left}
          y1={height - margin.bottom}
          x2={width - margin.right}
          y2={height - margin.bottom}
          stroke={globals.gray500}
          strokeWidth={0.5}
        />

        <AxisBottom
          top={height - margin.bottom}
          left={margin.left}
          scale={xScale}>
          {axis => {
            const tickLabelSize = 8;
            const tickRotate = axisTickRotate;
            const tickColor = globals.gray500;
            const axisCenter = (axis.axisToPoint.x - axis.axisFromPoint.x) / 2;
            const numTicks = numTicksForWidth(width);
            return (
              <g className="my-custom-bottom-axis">
                {axis.ticks
                  .filter(
                    (tick, i) =>
                      props.showAllLabels ||
                      axis.ticks.length <= numTicks ||
                      i % Math.floor(axis.ticks.length / (numTicks - 1)) ===
                        0 ||
                      i === 0 ||
                      i === axis.ticks.length - 1
                  )
                  .map((tick, i) => {
                    const tickX = tick.to.x;
                    const tickY =
                      tick.to.y + tickLabelSize + (axis.tickLength || 0) - 10;

                    if (!enableAxisTicks) {
                      return null;
                    }
                    return (
                      <Group
                        key={`visx-tick-${tick.value}-${i}`}
                        className={'visx-axis-tick'}>
                        <Shape.Line
                          from={tick.from}
                          to={tick.to}
                          stroke={tickColor}
                        />
                        <text
                          transform={`translate(${tickX}, ${tickY}) rotate(${tickRotate})`}
                          fontFamily={axisFontStyles.fontFamily}
                          fontSize={tickLabelSize}
                          textAnchor="end"
                          fill={'#6b6b76'}>
                          {
                            // Here the formattedValue is acutally the index of the tick
                            // If we just use the text, when two values are the same, the
                            // bars are put on top of each other.
                            bars[
                              parseInt(
                                (tick.formattedValue ?? '0').toString(),
                                10
                              ) - 1
                            ].key
                          }
                        </text>
                      </Group>
                    );
                  })}
                <text
                  textAnchor="middle"
                  transform={`translate(${axisCenter}, 20)`}
                  fontSize="8">
                  {axis.label}
                </text>
              </g>
            );
          }}
        </AxisBottom>
      </Group>
    </svg>
  );
};
