import type { MouseEvent, ReactElement } from 'react';
import { useCallback, useContext, useMemo } from 'react';

import localPoint from '@visx/event/lib/localPoint';
import type { PositionScale, StackPathConfig } from '@visx/shape/lib/types';
import { getFirstItem, getSecondItem } from '@visx/shape/lib/util/accessors';
import getBandwidth from '@visx/shape/lib/util/getBandwidth';
import stackOffset from '@visx/shape/lib/util/stackOffset';
import stackOrder from '@visx/shape/lib/util/stackOrder';
import type { BaseBarSeriesProps } from '@visx/xychart/lib/components/series/private/BaseBarSeries';
import DataContext from '@visx/xychart/lib/context/DataContext';
import isValidNumber from '@visx/xychart/lib/typeguards/isValidNumber';
import type { DataContextType } from '@visx/xychart/lib/types/data';
import type {
  Bar,
  BarStackDatum,
  CombinedStackData,
  SeriesProps,
} from '@visx/xychart/lib/types/series';
import combineBarStackData, {
  getStackValue,
} from '@visx/xychart/lib/utils/combineBarStackData';
import getChildrenAndGrandchildrenWithProps from '@visx/xychart/lib/utils/getChildrenAndGrandchildrenWithProps';
import type { SeriesPoint } from 'd3-shape';
import { stack as d3stack } from 'd3-shape';

import type { CondorBarsProps } from './CondorAnimatedBars';

type BarStackChildProps<
  XScale extends PositionScale,
  YScale extends PositionScale,
  Datum extends object,
> = Omit<BaseBarSeriesProps<XScale, YScale, Datum>, 'BarsComponent'>;

type BaseBarStackProps<
  XScale extends PositionScale,
  YScale extends PositionScale,
  Datum extends object,
> = Pick<
  SeriesProps<XScale, YScale, Datum>,
  | 'enableEvents'
  | 'onBlur'
  | 'onFocus'
  | 'onPointerDown'
  | 'onPointerMove'
  | 'onPointerOut'
  | 'onPointerUp'
> &
  Pick<StackPathConfig<Datum, string>, 'offset' | 'order'> & {
    /** `BarSeries` elements, note we can't strictly enforce this with TS yet. */
    children:
      | Array<ReactElement<BarStackChildProps<XScale, YScale, Datum>>>
      | ReactElement<BarStackChildProps<XScale, YScale, Datum>>;
    /** Rendered component which is passed BarsProps by BaseBarStack after processing. */
    BarsComponent: React.FC<CondorBarsProps<XScale, YScale>>;
  };

function CondorBaseBarStack<
  XScale extends PositionScale,
  YScale extends PositionScale,
  Datum extends object,
>({
  children,
  order,
  offset,
  BarsComponent,
  xScale,
  yScale,
  onMouseMoveOverBarStack,
  onMouseLeaveBarStack,
}: BaseBarStackProps<XScale, YScale, Datum> & {
  xScale: XScale;
  yScale: YScale;
  onMouseMoveOverBarStack?: (
    event: MouseEvent<SVGElement>,
    datum: CombinedStackData<XScale, YScale>,
    left: number,
    top: number,
    stackIndex?: number,
  ) => void;
  onMouseLeaveBarStack?: (event: MouseEvent<SVGElement>) => void;
}) {
  type StackBar = SeriesPoint<CombinedStackData<XScale, YScale>>;

  type ChildrenProps = BaseBarSeriesProps<XScale, YScale, Datum> & {
    strokeAccessor: (datum: Datum, index: number) => string | null | undefined;
    strokeWidthAccessor: (
      datum: Datum,
      index: number,
    ) => number | null | undefined;
  };

  const { colorScale, horizontal } = useContext(
    DataContext,
  ) as unknown as DataContextType<
    XScale,
    YScale,
    BarStackDatum<XScale, YScale>
  >;

  const seriesChildren = useMemo(
    () => getChildrenAndGrandchildrenWithProps<ChildrenProps>(children),
    [children],
  );
  const dataKeys = seriesChildren.map((item) => item.props.dataKey);

  // group all child data by stack value { [x | y]: { [dataKey]: value } }
  // this format is needed by d3Stack
  const combinedData = useMemo(
    () =>
      combineBarStackData<XScale, YScale, Datum>(seriesChildren, horizontal),
    [horizontal, seriesChildren],
  );

  // stack data
  const stackedData = useMemo(() => {
    // automatically set offset to diverging if it's undefined and negative values are present
    const hasSomeNegativeValues = offset
      ? null
      : combinedData.some((datum) => datum.negativeSum < 0);

    const stack = d3stack<CombinedStackData<XScale, YScale>, string>();
    stack.keys(dataKeys);
    if (order) {
      stack.order(stackOrder(order));
    }
    if (offset || hasSomeNegativeValues) {
      stack.offset(stackOffset(offset ?? 'diverging'));
    }

    return stack(combinedData);
  }, [combinedData, dataKeys, order, offset]);

  const barThickness = getBandwidth(horizontal ? yScale : xScale);
  const halfBarThickness = barThickness / 2;

  const getWidth = () => barThickness;
  const getHeight = (bar: StackBar) =>
    (yScale(getFirstItem(bar)) ?? Number.NaN) -
    (yScale(getSecondItem(bar)) ?? Number.NaN);
  const getX = (bar: StackBar) =>
    'bandwidth' in xScale
      ? xScale(getStackValue(bar.data))
      : Math.max(xScale(getStackValue(bar.data)) - halfBarThickness);
  const getY = (bar: StackBar) => yScale(getSecondItem(bar));

  const getDatumLeft = useCallback(
    (data: CombinedStackData<XScale, YScale>) =>
      Number(xScale(data.stack)) + halfBarThickness,
    [halfBarThickness, xScale],
  );

  const onMouseMove = useCallback(
    (event: MouseEvent<SVGElement>, stackIndex: number | undefined) => {
      if (stackIndex === undefined) {
        return;
      }

      const data = combinedData[stackIndex];
      const point = localPoint(event);
      const top = point?.y ?? 0;
      const left = getDatumLeft(data);
      onMouseMoveOverBarStack?.(event, data, left, top, stackIndex);
    },
    [onMouseMoveOverBarStack, stackedData, combinedData],
  );

  const barSeries = stackedData.map((barStack, stackIndex) => {
    // get props from child BarSeries, if available
    const childBarSeries: ReactElement<ChildrenProps> | undefined =
      seriesChildren.find((child) => child.props.dataKey === barStack.key);
    const { colorAccessor, strokeAccessor, strokeWidthAccessor } =
      childBarSeries?.props ?? {};

    return {
      key: barStack.key,
      bars: barStack
        .map((bar, index) => {
          const barX = getX(bar);
          if (!isValidNumber(barX)) {
            return null;
          }
          const barY = getY(bar);
          if (!isValidNumber(barY)) {
            return null;
          }
          const barWidth = getWidth();
          if (!isValidNumber(barWidth)) {
            return null;
          }
          const barHeight = getHeight(bar);
          if (!isValidNumber(barHeight)) {
            return null;
          }

          const barSeriesDatum = colorAccessor
            ? childBarSeries?.props.data[index]
            : null;

          return {
            key: `${stackIndex}-${barStack.key}-${index}`,
            x: barX,
            y: barY,
            width: barWidth,
            height: Math.max(0, barHeight - 1), // add vertical space between bars
            fill:
              barSeriesDatum && colorAccessor
                ? colorAccessor(barSeriesDatum, index)
                : colorScale(barStack.key),
            stroke:
              barSeriesDatum && strokeAccessor
                ? strokeAccessor(barSeriesDatum, index)
                : undefined,
            strokeWidth:
              barSeriesDatum && strokeWidthAccessor
                ? strokeWidthAccessor(barSeriesDatum, index)
                : undefined,
          };
        })
        .filter((bar) => bar) as Bar[],
    };
  });

  return (
    <g className="visx-bar-stack">
      {barSeries.map((series) => (
        <BarsComponent
          key={series.key}
          bars={series.bars}
          horizontal={horizontal}
          xScale={xScale}
          yScale={yScale}
          onMouseLeaveBarStack={onMouseLeaveBarStack}
          onMouseMoveOverBarStack={onMouseMove}
        />
      ))}
    </g>
  );
}

export default CondorBaseBarStack;
