import { useCallback, useMemo } from 'react';

import Box from '@mui/material/Box';
import { useTheme } from '@mui/material/styles';
import type { TickFormatter } from '@visx/axis/lib/types';
import { curveMonotoneY } from '@visx/curve';
import GridRows from '@visx/grid/lib/grids/GridRows';
import Group from '@visx/group/lib/Group';
import AnimatedAxis from '@visx/react-spring/lib/axis/AnimatedAxis';
import type { BandScaleConfig, LinearScaleConfig } from '@visx/scale/lib';
import { createScale } from '@visx/scale/lib';
import linear from '@visx/scale/lib/scales/linear';
import ordinal from '@visx/scale/lib/scales/ordinal';
import type { AnyD3Scale } from '@visx/scale/lib/types/Scale';
import LinePath from '@visx/shape/lib/shapes/LinePath';
import Text from '@visx/text/lib/Text';
import useTooltip from '@visx/tooltip/lib//hooks/useTooltip';
import color from 'color';
import type { ScaleBand } from 'd3-scale';
import sum from 'lodash/sum';

import { humanizeCostCategory } from 'shared/helpers/helpers';
import {
  getFixedMinimumValues,
  getIntervalForEnrollment,
  getMaxTickValueForEnrollment,
  getTickValuesForEnrollment,
} from 'shared/lib/graphing/graphUtils';
import { isAllocatedPeriod } from 'shared/lib/graphing/helper';
import CondorAnimatedBarSeries from 'shared/lib/graphing/series/CondorAnimatedBarSeries';
import CondorAnimatedBarStack from 'shared/lib/graphing/series/CondorAnimatedBarStack';
import useStackBar from 'shared/lib/graphing/series/useStackBar';
import GraphLegend from 'shared/lib/graphing/shared/GraphLegend';
import useTimeAxisLabels from 'shared/lib/graphing/shared/useTimeAxisLabels';
import ZeroLine from 'shared/lib/graphing/shared/ZeroLine';
import { filterUndefined, parseNullableFloat, parseNullableInt } from 'utils';

import GraphTooltipDataProvider from '../../../../shared/lib/graphing/graph-tooltip/GraphTooltipDataProvider';
import { TRIAL_SPEND_FORECAST_GRAPH_CONFIG } from './config';
import TrialSpendForecastGraphTooltip from './TrialSpendForecastGraphTooltip';
import type {
  TrialSpendForecastConfig,
  TrialSpendForecastDatum,
} from './types';

type Props = {
  graphData: TrialSpendForecastDatum[] | undefined;
  graphOptions?: TrialSpendForecastConfig;
  latestCloseDate: string | undefined;
  width: number;
};

function getMaxValue(
  graphData: TrialSpendForecastDatum[] | undefined,
  key: keyof Pick<
    TrialSpendForecastDatum,
    'actualEnrollment' | 'forecastedEnrollment'
  >,
) {
  return graphData?.reduce(
    (max, item) =>
      Math.max(max, item[key] ? Number.parseInt(item[key] ?? '0') : 0),
    0,
  );
}

function TrialSpendForecastGraph(props: Props) {
  const {
    graphData,
    latestCloseDate,
    width,
    graphOptions = TRIAL_SPEND_FORECAST_GRAPH_CONFIG,
  } = props;

  const themeMode = useTheme().palette.mode;
  const tooltip = useTooltip<TrialSpendForecastDatum>();
  const timeAxisLabels = useTimeAxisLabels(graphData?.length);
  const { xRange, yRange, xScale, yScale } = useStackBar({
    height: TRIAL_SPEND_HEIGHT,
    width,
    graphOptions,
    graphData,
  });

  const showZeroLine = yScale?.domain().some((item) => item < 0);

  const { margin } = graphOptions;
  const left = parseNullableInt(margin?.left);
  const right = parseNullableInt(margin?.right);
  const top = parseNullableInt(margin?.top);
  const bottom = parseNullableInt(margin?.bottom);

  const innerWidth = width - left - right;
  const innerHeight = TRIAL_SPEND_HEIGHT - top - bottom;
  const enrollmentMaxValue = useMemo(() => {
    const enrollmentValues = [
      getMaxTickValueForEnrollment(
        getMaxValue(graphData, 'actualEnrollment') ?? 0,
      ),
      getMaxTickValueForEnrollment(
        getMaxValue(graphData, 'forecastedEnrollment') ?? 0,
      ),
    ];

    return Math.max(...enrollmentValues);
  }, [graphData]);

  const enrollmentScaleDomainAndRange = useMemo(() => {
    function getSumInMonth(item: TrialSpendForecastDatum) {
      return sum([
        Number.parseInt(item.DIRECT_FEES, 10),
        Number.parseInt(item.PASS_THROUGH, 10),
        Number.parseInt(item.INVESTIGATOR_FEES, 10),
        Number.parseInt(item.OCC, 10),
      ]);
    }

    return {
      domain: [0, enrollmentMaxValue],
      range: [0, Math.max(...(graphData?.flatMap(getSumInMonth) ?? []))],
    };
  }, [graphData, enrollmentMaxValue]);

  const [, enrollmentRangeUpperValue] = enrollmentScaleDomainAndRange.range;

  const [xMin, xMax] = xRange;
  const [yMin, yMax] = yRange;
  const xScaleEnrollment = createScale({
    range: [xMin ?? 0, xMax ?? 0],
    domain: (graphData ?? []).map((item) => item.date)!,
    ...(graphOptions.xScaleConfig as BandScaleConfig<string>),
  });

  const yScaleEnrollment = createScale({
    range: [showZeroLine ? yScale?.(0) : yMin, yMax],
    domain: [0, enrollmentMaxValue],
    ...(graphOptions.yScaleConfig as LinearScaleConfig<number>),
  });

  const accessorsX = useCallback(
    (item: TrialSpendForecastDatum) => item.date,
    [],
  );

  const barStackData = useMemo(() => {
    if (!graphData) {
      return [];
    }

    const onePercent = Math.ceil(enrollmentRangeUpperValue / 100);
    return getFixedMinimumValues(
      onePercent * 2,
      graphData,
      filterUndefined(graphOptions.orderOfData),
    );
  }, [enrollmentRangeUpperValue, graphData, graphOptions.orderOfData]);

  const accessorsY = useCallback(
    (item: Record<string, boolean | string | null>, key: string) => {
      const value = item[key];
      if (typeof value === 'boolean') {
        return 0;
      }
      return parseNullableFloat(value);
    },
    [barStackData],
  );

  const legendShapes = useCallback(
    (datum: ReturnType<AnyD3Scale>) => {
      if ([graphOptions.enrollmentText].includes(datum.text)) {
        return { type: 'line' as const };
      }

      if ([graphOptions.forecastedEnrollmentText].includes(datum.text)) {
        return { type: 'dashed-line' as const };
      }

      return {
        type: 'rect' as const,
      };
    },
    [graphOptions.enrollmentText, graphOptions.forecastedEnrollmentText],
  );

  const barColorScale = ordinal<string, string>({
    domain: filterUndefined(graphOptions.orderOfData),
    range: graphOptions.barColors,
  });

  const forecastedBarColorScale = ordinal<string, string>({
    domain: filterUndefined(graphOptions.orderOfData),
    range: filterUndefined(
      graphOptions.barColors?.map((barColor) =>
        color(barColor).alpha(0.25).rgb().string(),
      ),
    ),
  });

  const legendColorScale = ordinal<string, string>({
    domain: filterUndefined<string>([
      ...filterUndefined(graphOptions.orderOfData).map(humanizeCostCategory),
      graphOptions.actualEnrollmentText,
      graphOptions.forecastedEnrollmentText,
    ]),
    range: filterUndefined<string>([
      ...filterUndefined(graphOptions.barColors),
      graphOptions.actualEnrollmentCurveColor,
      graphOptions.forecastedEnrollmentCurveColor,
    ]),
  });

  const maxTickValue = getMaxTickValueForEnrollment(enrollmentMaxValue);
  const interval = getIntervalForEnrollment(enrollmentMaxValue);
  const scaleMultiplier =
    getMaxTickValueForEnrollment(enrollmentRangeUpperValue) / maxTickValue;
  const tickValues = getTickValuesForEnrollment(
    scaleMultiplier,
    maxTickValue,
    interval,
  );
  const enrollmentTicksCount = maxTickValue / interval;
  const isEnrollmentZero = enrollmentScaleDomainAndRange.domain[1] === 0;

  const enrollmentRange = isEnrollmentZero
    ? [0, 0]
    : [0, getMaxTickValueForEnrollment(enrollmentRangeUpperValue)];

  const enrollmentScale = linear<number>({
    ...enrollmentScaleDomainAndRange,
    range: enrollmentRange,
    nice: true,
  });

  const getPatientEnrollment: TickFormatter<number> = (value: number) => {
    if (enrollmentRangeUpperValue < value) {
      return undefined;
    }

    return enrollmentScale.invert(value).toFixed(0);
  };

  if (graphData === undefined || graphData.length === 0) {
    return (
      <Box
        sx={{
          display: 'flex',
          justifyContent: 'center',
          alignItems: 'center',
          height: 'calc(100% - 48px)',
        }}
      >
        <Box>
          {graphData === undefined
            ? 'Please wait... loading.'
            : 'Trial spend chart will appear when expenses are available.'}
        </Box>
      </Box>
    );
  }

  if (!xScale || !yScale) {
    return null;
  }

  return (
    <Box sx={{ display: 'grid', height: '100%' }}>
      <GraphLegend colorScale={legendColorScale} shapes={legendShapes} />
      <Box sx={{ width: '100%' }}>
        <svg height={TRIAL_SPEND_HEIGHT} width={width}>
          <AnimatedAxis
            key="axis-bottom"
            animationTrajectory="min"
            numTicks={timeAxisLabels.numTicks}
            orientation="bottom"
            scale={xScale}
            stroke={graphOptions.horizontalLinesColor}
            strokeWidth={1}
            tickLength={10}
            top={TRIAL_SPEND_HEIGHT - bottom}
            tickComponent={({ formattedValue }) => (
              <g transform="translate(0, 10)">
                <text
                  fontSize={graphOptions.fontSize}
                  textAnchor="middle"
                  fill={
                    isAllocatedPeriod(formattedValue, latestCloseDate)
                      ? graphOptions.textColor
                      : 'red'
                  }
                  transform={
                    timeAxisLabels.autoLabels ? '' : graphOptions.angleLabels
                  }
                >
                  {timeAxisLabels.formatMonthLabel(formattedValue)}
                </text>
                <text
                  dy={timeAxisLabels.yearLabelDy}
                  fontSize={graphOptions.fontSizeYear}
                  fontWeight={graphOptions.fontWeightBold}
                  textAnchor="middle"
                  fill={
                    isAllocatedPeriod(formattedValue, latestCloseDate)
                      ? graphOptions.textColor
                      : 'red'
                  }
                >
                  {timeAxisLabels.formatYearLabel(formattedValue)}
                </text>
              </g>
            )}
          />
          <GridRows
            left={left}
            scale={yScale}
            stroke={graphOptions.horizontalLinesColor}
            width={innerWidth}
          />
          <Group>
            <Text
              angle={270}
              dx={width - 30}
              dy={90}
              fontSize={graphOptions.fontSize}
              fontWeight={graphOptions.fontWeightBold}
              textAnchor="middle"
            >
              Patient enrollment
            </Text>
            <AnimatedAxis
              key="axis-right"
              animationTrajectory="min"
              hideZero={!isEnrollmentZero}
              left={width - right}
              numTicks={enrollmentTicksCount}
              orientation="right"
              scale={yScale}
              tickFormat={getPatientEnrollment}
              tickValues={tickValues}
              tickLabelProps={() => ({
                dx: 20,
                fontSize: graphOptions.fontSize,
                fill: graphOptions.textColor,
                fontWeight: graphOptions.fontWeight,
                textAnchor: 'end',
                verticalAnchor: 'middle',
              })}
              hideAxisLine
              hideTicks
            />
          </Group>
          <Group>
            <Text
              angle={270}
              dx={40}
              dy={90}
              fontSize={graphOptions.fontSize}
              fontWeight={graphOptions.fontWeightBold}
              textAnchor="middle"
            >
              Monthly spend ($)
            </Text>
            <AnimatedAxis
              key="axis-left"
              animationTrajectory="min"
              left={left}
              orientation="left"
              scale={yScale}
              stroke={graphOptions.horizontalLinesColor}
              tickStroke={graphOptions.horizontalLinesColor}
              tickLabelProps={() => ({
                dx: -10,
                fontSize: graphOptions.fontSize,
                fill: graphOptions.textColor,
                fontWeight: graphOptions.fontWeight,
                textAnchor: 'end',
                verticalAnchor: 'middle',
              })}
              hideAxisLine
              hideTicks
            />
          </Group>
          <CondorAnimatedBarStack xScale={xScale} yScale={yScale}>
            {(graphOptions.orderOfData ?? []).map((key) => (
              <CondorAnimatedBarSeries
                key={key}
                data={barStackData}
                dataKey={key}
                strokeWidthAccessor={(item) => (item.actual ? undefined : 0.5)}
                xAccessor={accessorsX}
                colorAccessor={(item) =>
                  item.actual
                    ? barColorScale(key)
                    : forecastedBarColorScale(key)
                }
                strokeAccessor={(item) =>
                  item.actual ? undefined : barColorScale(key)
                }
                yAccessor={(item: TrialSpendForecastDatum) =>
                  accessorsY(item, key)
                }
              />
            ))}
          </CondorAnimatedBarStack>
          <LinePath
            curve={curveMonotoneY}
            data={graphData}
            defined={(datum) => datum.actualEnrollment !== null}
            stroke={graphOptions.actualEnrollmentCurveColor}
            strokeLinecap="round"
            strokeLinejoin="round"
            strokeWidth={1.5}
            x={(datum) => xScaleEnrollment(datum.date) ?? 0}
            style={{
              mixBlendMode: themeMode === 'light' ? 'multiply' : 'normal',
            }}
            y={(datum) =>
              yScaleEnrollment(Number.parseInt(datum.actualEnrollment!, 10)) ??
              0
            }
          />
          <LinePath
            curve={curveMonotoneY}
            data={graphData}
            defined={(datum) => datum.forecastedEnrollment !== null}
            stroke={graphOptions.forecastedEnrollmentCurveColor}
            strokeDasharray="1,4"
            strokeLinecap="round"
            strokeLinejoin="round"
            strokeWidth={1.5}
            x={(datum) => xScaleEnrollment(datum.date) ?? 0}
            style={{
              mixBlendMode: themeMode === 'light' ? 'multiply' : 'normal',
            }}
            y={(datum) =>
              yScaleEnrollment(
                Number.parseInt(datum.forecastedEnrollment!, 10),
              ) ?? 0
            }
          />
          {showZeroLine && (
            <ZeroLine
              from={{ x: left, y: yScale(0) }}
              to={{ x: width - right, y: yScale(0) }}
            />
          )}
          <GraphTooltipDataProvider<TrialSpendForecastDatum, ScaleBand<string>>
            graphData={graphData}
            height={innerHeight}
            margin={{ top, left, right, bottom }}
            tooltip={tooltip}
            width={innerWidth}
            xAccessor={accessorsX}
            xScale={xScale as ScaleBand<string>}
          />
        </svg>
        <TrialSpendForecastGraphTooltip
          actualEnrollment="actualEnrollment"
          actualEnrollmentText={graphOptions.actualEnrollmentText}
          forecastedEnrollment="forecastedEnrollment"
          forecastedEnrollmentText={graphOptions.forecastedEnrollmentText}
          innerHeight={innerHeight}
          legendColorScale={legendColorScale}
          marginTop={top}
          orderOfData={graphOptions.orderOfData}
          tooltip={tooltip}
        />
      </Box>
    </Box>
  );
}

export const TRIAL_SPEND_HEIGHT = 500;

export default TrialSpendForecastGraph;
