import { useCallback, useEffect, useMemo, useState } from 'react';

import Box from '@mui/material/Box';
import { useTheme } from '@mui/material/styles';
import type { TickFormatter } from '@visx/axis/lib/types';
import { curveMonotoneY } from '@visx/curve';
import GridRows from '@visx/grid/lib/grids/GridRows';
import Group from '@visx/group/lib/Group';
import AnimatedAxis from '@visx/react-spring/lib/axis/AnimatedAxis';
import {
  type BandScaleConfig,
  type LinearScaleConfig,
  createScale,
} from '@visx/scale/lib';
import linear from '@visx/scale/lib/scales/linear';
import ordinal from '@visx/scale/lib/scales/ordinal';
import type { AnyD3Scale } from '@visx/scale/lib/types/Scale';
import LinePath from '@visx/shape/lib/shapes/LinePath';
import Text from '@visx/text/lib/Text';
import useTooltip from '@visx/tooltip/lib/hooks/useTooltip';
import type { ScaleBand } from 'd3-scale';
import { format } from 'date-fns/format';
import { parse } from 'date-fns/parse';
import sum from 'lodash/sum';
import { useSelector } from 'react-redux';

import type { GenericGraphProps } from 'shared/lib/periods/withPeriodSpecificGraphWrapper';
import withPeriodSpecificGraphWrapper from 'shared/lib/periods/withPeriodSpecificGraphWrapper';
import { selectPeriod } from 'accruals/state/slices/periodSlice';
import { AIP_SUFFIX, getCurrencySymbol } from 'formatters';
import { humanizeCostCategory } from 'shared/helpers/helpers';
import CondorAnimatedBarSeries from 'shared/lib/graphing/series/CondorAnimatedBarSeries';
import CondorAnimatedBarStack from 'shared/lib/graphing/series/CondorAnimatedBarStack';
import useStackBar from 'shared/lib/graphing/series/useStackBar';
import useTimeAxisLabels from 'shared/lib/graphing/shared/useTimeAxisLabels';
import {
  PeriodGraphBlobType,
  type TrialActivityEnrollmentGridResponse,
  type TrialExpenseSummaryGridResponse,
} from 'shared/lib/types';
import { selectTrial } from 'shared/state/slices/trialSlice';
import { filterUndefined, parseNullableFloat, parseNullableInt } from 'utils';

import {
  useGetTrialActivityEnrollmentQuery,
  useGetTrialExpenseSummaryQuery,
} from 'shared/api/rtkq/periods';

import GraphTooltipDataProvider from '../../../../shared/lib/graphing/graph-tooltip/GraphTooltipDataProvider';
import {
  getFixedMinimumValues,
  getIntervalForEnrollment,
  getMaxTickValueForEnrollment,
  getTickValuesForEnrollment,
} from '../../../../shared/lib/graphing/graphUtils';
import GraphLegend from '../../../../shared/lib/graphing/shared/GraphLegend';
import ZeroLine from '../../../../shared/lib/graphing/shared/ZeroLine';
import TrialSpendGraphTooltip from './TrialSpendGraphTooltip';
import { TRIAL_SPEND_GRAPH_CONFIG } from './config';
import type { TrialSpendConfig, TrialSpendDatum } from './types';

type GraphData = { [date: string]: { [category: string]: number } };

function getSafeIntegerFromDatum(
  datum: TrialSpendDatum,
  key: string,
  defaultValue = 0,
): number {
  const value = datum[key as keyof TrialSpendDatum];
  const parsed = Number.parseInt(value, 10);
  return Number.isNaN(parsed) ? defaultValue : parsed;
}

function getMaxValue(
  graphData: TrialSpendDatum[] | undefined,
  key: keyof TrialSpendDatum,
) {
  return (
    graphData?.reduce(
      (max, item) => Math.max(max, getSafeIntegerFromDatum(item, key)),
      0,
    ) ?? 0
  );
}

function TrialSpendGraph({ width, height }: GenericGraphProps) {
  const [graphData, setGraphData] = useState<TrialSpendDatum[] | undefined>(
    undefined,
  );
  const [enrollmentScaleObj, setEnrollmentScaleObj] = useState<{
    domain: number[];
    range: number[];
  }>();

  const period = useSelector(selectPeriod);

  const { currentData: expenseSummaryGrid } = useGetTrialExpenseSummaryQuery(
    period.trace_id,
  );
  const { currentData: trialActivityEnrollment } =
    useGetTrialActivityEnrollmentQuery(period.trace_id);

  useEffect(() => {
    setGraphData(
      calculateGraphData(
        expenseSummaryGrid,
        trialActivityEnrollment,
        TRIAL_SPEND_GRAPH_CONFIG.directFeesText,
        TRIAL_SPEND_GRAPH_CONFIG.passThroughText,
        TRIAL_SPEND_GRAPH_CONFIG.investigatorFeesText,
        TRIAL_SPEND_GRAPH_CONFIG.occText,
        TRIAL_SPEND_GRAPH_CONFIG.enrollmentText,
      ),
    );
  }, [expenseSummaryGrid, trialActivityEnrollment]);

  useEffect(() => {
    setEnrollmentScaleObj(
      getEnrollmentScale(
        graphData,
        TRIAL_SPEND_GRAPH_CONFIG.directFeesText,
        TRIAL_SPEND_GRAPH_CONFIG.passThroughText,
        TRIAL_SPEND_GRAPH_CONFIG.investigatorFeesText,
        TRIAL_SPEND_GRAPH_CONFIG.occText,
        TRIAL_SPEND_GRAPH_CONFIG.enrollmentText,
      ),
    );
  }, [graphData]);

  return (
    <TrialSpendGraphGeneric
      graphData={graphData}
      height={height}
      width={width}
      enrollmentScaleObj={
        enrollmentScaleObj ?? { domain: [0, 0], range: [0, 0] }
      }
    />
  );
}

type GenericProps = {
  graphData: TrialSpendDatum[] | undefined;
  graphOptions?: TrialSpendConfig;
  enrollmentScaleObj?: { domain: number[]; range: number[] };
};

// This component is used as it is for both open and closed periods
export function TrialSpendGraphGeneric(
  props: GenericGraphProps & GenericProps,
) {
  const {
    graphData,
    enrollmentScaleObj,
    graphOptions = TRIAL_SPEND_GRAPH_CONFIG,
    width: parentWidth,
    height,
  } = props;
  const trial = useSelector(selectTrial);
  const trialCurrency = trial.currency;
  const tooltip = useTooltip<TrialSpendDatum>();
  const [, enrollmentRangeUpperValue] = filterUndefined(
    enrollmentScaleObj?.range,
  );
  const timeAxisLabels = useTimeAxisLabels(graphData?.length);
  const finalHeight = height - 48; // account for labels / legend
  const width = Math.max(parentWidth, 800); // this looks bad at less than 800, so just stop it from shrinking any smaller
  const { xRange, yRange, xScale, yScale } = useStackBar({
    height: finalHeight,
    width,
    graphOptions,
    graphData,
  });
  const themeMode = useTheme().palette.mode;

  const showZeroLine = yScale?.domain().some((item) => item < 0);

  const { margin } = graphOptions;
  const left = parseNullableInt(margin?.left);
  const right = parseNullableInt(margin?.right);
  const top = parseNullableInt(margin?.top);
  const bottom = parseNullableInt(margin?.bottom);

  // Done this way to ensure we have a fallback, bur still prefer the configuration
  const enrollmentText =
    graphOptions.enrollmentText ??
    TRIAL_SPEND_GRAPH_CONFIG.enrollmentText ??
    'Actual enrollment';

  const innerWidth = width - left - right;
  const innerHeight = finalHeight - bottom - top;

  const enrollmentMaxValue = useMemo(
    () =>
      getMaxTickValueForEnrollment(getMaxValue(graphData, 'Actual enrollment')),
    [graphData],
  );

  const [xMin, xMax] = xRange;
  const [yMin, yMax] = yRange;
  const xScaleEnrollment = createScale({
    range: [xMin ?? 0, xMax ?? 0],
    domain: (graphData ?? []).map((item) => item.date)!,
    ...(graphOptions.xScaleConfig as BandScaleConfig<string>),
  });

  const yScaleEnrollment = createScale({
    range: [showZeroLine ? yScale?.(0) : yMin, yMax],
    domain: [0, enrollmentMaxValue],
    ...(graphOptions.yScaleConfig as LinearScaleConfig<number>),
  });

  const accessorsX = useMemo(() => (item: TrialSpendDatum) => item.date, []);
  const barStackData = useMemo(() => {
    if (!graphData) {
      return [];
    }

    const onePercent = Math.ceil(enrollmentRangeUpperValue / 100);
    return getFixedMinimumValues(
      onePercent * 2,
      graphData,
      TRIAL_SPEND_GRAPH_CONFIG.orderOfData as unknown as string[],
    );
  }, [enrollmentRangeUpperValue, graphData]);

  const accessorsY = useCallback(
    (item: Record<string, string>, key: string) =>
      parseNullableFloat(item[key]),
    [barStackData],
  );

  const legendShapes = useCallback(
    (datum: ReturnType<AnyD3Scale>) => {
      if ([enrollmentText].includes(datum.text)) {
        return { type: 'dashed-line' as const };
      }
      return {
        type: 'rect' as const,
      };
    },
    [enrollmentText],
  );

  if (graphData === undefined || graphData.length === 0) {
    return (
      <Box
        sx={{
          display: 'flex',
          justifyContent: 'center',
          alignItems: 'center',
          height: 'calc(100% - 48px)',
        }}
      >
        <Box>
          {graphData === undefined
            ? 'Please wait... loading.'
            : 'Trial spend chart will appear when expenses are available.'}
        </Box>
      </Box>
    );
  }

  const barColorScale = ordinal<string, string>({
    domain: filterUndefined(graphOptions.orderOfData),
    range: filterUndefined(graphOptions.barColors),
  });

  const legendColorScale = ordinal<string, string>({
    domain: filterUndefined<string>([
      ...filterUndefined(graphOptions.orderOfData).map(humanizeCostCategory),
      graphOptions.enrollmentText,
    ]),
    range: filterUndefined<string>([
      ...filterUndefined(graphOptions.barColors),
      graphOptions.enrollmentCurveColor,
    ]),
  });

  const uniqEnrollmentsValues = [
    ...new Set(
      filterUndefined(
        graphData.flatMap(
          (item) => item[enrollmentText as keyof TrialSpendDatum],
        ),
      ),
    ),
  ];

  const max = Math.max(
    ...uniqEnrollmentsValues.map((value) => Number.parseInt(value, 10)),
  );
  const maxTickValue = getMaxTickValueForEnrollment(max);
  const interval = getIntervalForEnrollment(max);
  const scaleMultiplier =
    getMaxTickValueForEnrollment(enrollmentRangeUpperValue) / maxTickValue;
  const tickValues = getTickValuesForEnrollment(
    scaleMultiplier,
    maxTickValue,
    interval,
  );
  const enrollmentTicksCount = maxTickValue / interval;
  const isEnrollmentZero = enrollmentScaleObj?.domain[1] === 0;
  const enrollmentRange = isEnrollmentZero
    ? [0, 0]
    : [0, getMaxTickValueForEnrollment(enrollmentRangeUpperValue)];
  const enrollmentScale = linear<number>({
    ...enrollmentScaleObj,
    range: enrollmentRange,
  });

  const getPatientEnrollment: TickFormatter<number> = (value: number) => {
    if (enrollmentRangeUpperValue < value) {
      return undefined;
    }

    return enrollmentScale.invert(value).toFixed(0);
  };

  if (!yScale || !xScale) {
    return null;
  }

  const trialCurrencySymbol = getCurrencySymbol(trialCurrency);

  return (
    <Box sx={{ display: 'grid', height: '100%' }}>
      <GraphLegend colorScale={legendColorScale} shapes={legendShapes} />
      <Box sx={{ width: '100%' }}>
        <svg height={finalHeight} width={width}>
          <GridRows
            left={left}
            scale={yScale}
            stroke={graphOptions.horizontalLinesColor}
            width={innerWidth}
          />

          <LinePath
            curve={curveMonotoneY}
            data={graphData}
            stroke={graphOptions.enrollmentCurveColor}
            strokeDasharray="1,8"
            strokeLinecap="round"
            strokeLinejoin="round"
            strokeWidth={1.5}
            x={(datum) => xScaleEnrollment(accessorsX(datum)) ?? 0}
            style={{
              mixBlendMode: themeMode === 'light' ? 'multiply' : 'normal',
            }}
            y={(datum) =>
              yScaleEnrollment(
                getSafeIntegerFromDatum(datum, graphOptions.enrollmentText!),
              ) ?? 0
            }
          />

          {graphOptions.orderOfData && (
            <CondorAnimatedBarStack xScale={xScale} yScale={yScale}>
              {graphOptions.orderOfData.map((key) => (
                <CondorAnimatedBarSeries
                  key={key}
                  colorAccessor={() => barColorScale(key)}
                  data={barStackData}
                  dataKey={key}
                  xAccessor={accessorsX}
                  yAccessor={(item: TrialSpendDatum) => accessorsY(item, key)}
                />
              ))}
            </CondorAnimatedBarStack>
          )}
          <Group>
            <Text
              angle={270}
              dx={40}
              dy="15%"
              fontSize={graphOptions.fontSize}
              fontWeight={graphOptions.fontWeightBold}
              textAnchor="end"
            >
              {`Monthly spend (${trialCurrencySymbol})`}
            </Text>
            <AnimatedAxis
              animationTrajectory="min"
              hideZero={!showZeroLine}
              left={left}
              orientation="left"
              scale={yScale}
              stroke={graphOptions.horizontalLinesColor}
              tickStroke={graphOptions.horizontalLinesColor}
              tickLabelProps={() => ({
                dx: -10,
                fontSize: graphOptions.fontSize,
                fill: graphOptions.textColor,
                fontWeight: graphOptions.fontWeight,
                textAnchor: 'end',
                verticalAnchor: 'middle',
              })}
              hideAxisLine
              hideTicks
            />
          </Group>
          <Group>
            <Text
              angle={270}
              dx={width - 30}
              dy="15%"
              fontSize={graphOptions.fontSize}
              fontWeight={graphOptions.fontWeightBold}
              textAnchor="end"
            >
              Actual patient enrollment
            </Text>
            <AnimatedAxis
              key="axis-right"
              animationTrajectory="min"
              hideZero={!isEnrollmentZero}
              left={width - right}
              numTicks={enrollmentTicksCount}
              orientation="right"
              scale={yScale}
              tickFormat={getPatientEnrollment}
              tickValues={tickValues}
              labelProps={{
                dx: '40px',
                dy: '-20px',
                fontSize: graphOptions.fontSize,
                fontWeight: graphOptions.fontWeightBold,
              }}
              tickLabelProps={() => ({
                fontSize: graphOptions.fontSize,
                fill: graphOptions.textColor,
                fontWeight: graphOptions.fontWeight,
              })}
              hideAxisLine
              hideTicks
            />
          </Group>
          <AnimatedAxis
            key="axis-bottom"
            animationTrajectory="min"
            numTicks={timeAxisLabels.numTicks}
            orientation="bottom"
            scale={xScale}
            stroke={graphOptions.horizontalLinesColor}
            strokeWidth={1}
            tickLength={10}
            top={innerHeight + 30}
            tickComponent={({ formattedValue }) => (
              <g transform="translate(0, 10)">
                <text
                  fill={graphOptions.textColor}
                  fontSize={graphOptions.fontSize}
                  textAnchor="middle"
                  transform={
                    timeAxisLabels.autoLabels ? '' : graphOptions.angleLabels
                  }
                >
                  {timeAxisLabels.formatMonthLabel(formattedValue)}
                </text>
                <text
                  dy={timeAxisLabels.yearLabelDy}
                  fill={graphOptions.textColor}
                  fontSize={graphOptions.fontSizeYear}
                  fontWeight={graphOptions.fontWeightBold}
                  textAnchor="middle"
                >
                  {timeAxisLabels.formatYearLabel(formattedValue)}
                </text>
              </g>
            )}
          />
          {showZeroLine && (
            <ZeroLine
              from={{ x: left, y: yScale(0) }}
              to={{ x: width - right, y: yScale(0) }}
            />
          )}
          <GraphTooltipDataProvider<TrialSpendDatum, ScaleBand<string>>
            graphData={graphData}
            height={innerHeight}
            margin={{ top, left, right, bottom }}
            tooltip={tooltip}
            width={innerWidth}
            xAccessor={accessorsX}
            xScale={xScale as ScaleBand<string>}
          />
        </svg>

        <TrialSpendGraphTooltip
          currency={trialCurrency}
          innerHeight={innerHeight}
          labels={filterUndefined(graphOptions.orderOfData)}
          legendColorScale={legendColorScale}
          marginTop={top}
          tooltip={tooltip}
        />
      </Box>
    </Box>
  );
}

// Below exported functions are not run for closed periods, we save their outputs when closing a period and use them when rendering TrialSpendGraphGeneric
export function getEnrollmentScale(
  graphData?: TrialSpendDatum[],
  // Done this way to ensure we have a fallback, bur still prefer the configuration
  directFeesText = 'DIRECT_FEES',
  passThroughText = 'PASS_THROUGH',
  investigatorFeesText = 'INVESTIGATOR_FEES',
  occText = 'OCC',
  enrollmentText = 'Actual enrollment',
) {
  function getSumInMonth(item: TrialSpendDatum) {
    return sum([
      getSafeIntegerFromDatum(item, directFeesText),
      getSafeIntegerFromDatum(item, passThroughText),
      getSafeIntegerFromDatum(item, investigatorFeesText),
      getSafeIntegerFromDatum(item, occText),
    ]);
  }

  return {
    domain: [
      0,
      getMaxTickValueForEnrollment(
        getMaxValue(graphData, enrollmentText as keyof TrialSpendDatum),
      ),
    ],
    range: [0, Math.max(...filterUndefined(graphData).flatMap(getSumInMonth))],
  };
}

export function calculateGraphData(
  expenseSummaryGrid?: TrialExpenseSummaryGridResponse,
  trialActivityEnrollment?: TrialActivityEnrollmentGridResponse,
  // Done this way to ensure we have a fallback, bur still prefer the configuration
  directFeesText = 'DIRECT_FEES',
  passThroughText = 'PASS_THROUGH',
  investigatorFeesText = 'INVESTIGATOR_FEES',
  occText = 'OCC',
  enrollmentText = 'Actual enrollment',
): TrialSpendDatum[] | undefined {
  if (!expenseSummaryGrid || !trialActivityEnrollment) {
    return;
  }
  const graphDataWithNumbers: GraphData = {};
  for (const month of expenseSummaryGrid.month_list) {
    graphDataWithNumbers[month] = {
      [directFeesText]: 0,
      [passThroughText]: 0,
      [investigatorFeesText]: 0,
      [occText]: 0,
      [enrollmentText]: 0,
    };
  }

  if (expenseSummaryGrid.rows) {
    for (const row of expenseSummaryGrid.rows) {
      for (const month of Object.keys(row.booked_expenses)) {
        // aip rows should be combined with the original row
        const cost_category_without_aip = row.cost_category.replace(
          AIP_SUFFIX,
          '',
        );
        const cost_category =
          row.vendor_type === 'CRO' ? cost_category_without_aip : occText;
        graphDataWithNumbers[month][cost_category] +=
          row.booked_expenses[month];
      }
    }
  }

  for (const month of expenseSummaryGrid.month_list) {
    graphDataWithNumbers[month][enrollmentText] =
      Number(trialActivityEnrollment.cumulative_enrollment[month]) || 0;
  }

  const data = Object.keys(graphDataWithNumbers).map((month) => {
    const date = format(
      parse(month, 'MMM-yyyy', new Date()),
      "yyyy-MM-dd'T'00:00:00",
    );
    return {
      date,
      [directFeesText]: String(graphDataWithNumbers[month][directFeesText]),
      [passThroughText]: String(graphDataWithNumbers[month][passThroughText]),
      [investigatorFeesText]: String(
        graphDataWithNumbers[month][investigatorFeesText],
      ),
      [occText]: String(graphDataWithNumbers[month][occText]),
      [enrollmentText]: String(graphDataWithNumbers[month][enrollmentText]),
    };
  });

  return data as TrialSpendDatum[];
}

export default withPeriodSpecificGraphWrapper(
  TrialSpendGraph,
  PeriodGraphBlobType.TRIAL_EXPENSE_AND_ENROLLMENT,
);
