import { useMemo } from 'react';

import Skeleton from '@mui/material/Skeleton';
import ParentSize from '@visx/responsive/lib/components/ParentSize';
import { addDays } from 'date-fns/addDays';
import { format } from 'date-fns/format';
import { parse } from 'date-fns/parse';

import CostPerPatientGraph from 'forecasting/components/graphing/cost-per-patient-graph/CostPerPatientGraph';
import { TRIAL_SPEND_HEIGHT } from 'forecasting/components/graphing/trial-spend-forecast-graph/TrialSpendForecastGraph';

import useFinancialForecastSummaryRows from 'forecasting/pages/forecasting/hooks/useFinancialForecastSummaryRows';
import useForecast from 'forecasting/pages/forecasting/hooks/useForecast';
import { addNullableFloats } from 'utils';

import { getSum, hasKey } from './utils';

type PatientCostGraphDatum = {
  totalEnrollment: number;
  totalSpend: number;
  actual: boolean;
  actualSpend: number | null;
  forecastedSpend: number | null;
  date: string;
};

function PatientCostGraph() {
  const { loading: forecastLoading, generatedForecast } = useForecast();
  const { loading: financialForecastLoading, rows } =
    useFinancialForecastSummaryRows();

  const graphData = useMemo(() => {
    const forecastedMonths = Object.keys(rows?.[0] ?? {}).filter((key) =>
      key.startsWith('forecasted_month_'),
    );
    const actualMonths = Object.keys(rows?.[0] ?? {}).filter((key) =>
      key.startsWith('actual_month_'),
    );

    const months = [...actualMonths, ...forecastedMonths].map((month) =>
      // add 5 days so we don't have to worry about timezones
      addDays(
        parse(
          month.replace('forecasted_month_', '').replace('actual_month_', ''),
          'LLL-yyyy',
          new Date(),
        ),
        5,
      ),
    );
    months.sort((monthA, monthB) => monthA.getTime() - monthB.getTime());

    const directFeesRows = rows?.filter(
      (row) => row.cost_category === 'DIRECT_FEES' && row.vendor_type === 'CRO',
    );
    const passthroughsRows = rows?.filter(
      (row) =>
        row.cost_category === 'PASS_THROUGH' && row.vendor_type === 'CRO',
    );
    const investigatorFeesRows = rows?.filter(
      (row) =>
        row.cost_category === 'INVESTIGATOR_FEES' && row.vendor_type === 'CRO',
    );
    const occRows = rows?.filter((row) => row.vendor_type === 'OCC');

    const enrollmentRowsByRegion =
      generatedForecast?.patients.filter(
        (row) => row.type === 'Cumulative enrollment',
      ) ?? [];

    const monthlySpend = months.map((month) => {
      const monthString = format(month, 'LLL-yyyy').toLowerCase();
      const actualMonthString = `actual_month_${monthString}`;
      const forecastedMonthString = `forecasted_month_${monthString}`;

      const forecastedEnrollment = getSum(
        enrollmentRowsByRegion,
        forecastedMonthString,
      );
      const actualEnrollment = getSum(
        enrollmentRowsByRegion,
        actualMonthString,
      );
      const totalEnrollment = addNullableFloats(
        forecastedEnrollment,
        actualEnrollment,
      );

      const actualDirectFeesSpend = getSum(directFeesRows, actualMonthString);
      const actualInvestigatorFeesSpend = getSum(
        investigatorFeesRows,
        actualMonthString,
      );
      const actualPassthroughsSpend = getSum(
        passthroughsRows,
        actualMonthString,
      );
      const actualOccSpend = getSum(occRows, actualMonthString);

      const forecastedDirectFeesSpend = getSum(
        directFeesRows,
        forecastedMonthString,
      );
      const forecastedInvestigatorFeesSpend = getSum(
        investigatorFeesRows,
        forecastedMonthString,
      );
      const forecastedPassthroughsSpend = getSum(
        passthroughsRows,
        forecastedMonthString,
      );
      const forecastedOccSpend = getSum(occRows, forecastedMonthString);

      const actual = hasKey(rows, actualMonthString) ?? false;
      const actualSpend = actual
        ? (actualDirectFeesSpend ?? 0) +
          (actualInvestigatorFeesSpend ?? 0) +
          (actualPassthroughsSpend ?? 0) +
          (actualOccSpend ?? 0)
        : null;
      const forecastedSpend = actual
        ? null
        : (forecastedDirectFeesSpend ?? 0) +
          (forecastedInvestigatorFeesSpend ?? 0) +
          (forecastedPassthroughsSpend ?? 0) +
          (forecastedOccSpend ?? 0);

      return {
        totalEnrollment,
        totalSpend: addNullableFloats(actualSpend, forecastedSpend),
        actual,
        actualSpend,
        forecastedSpend,
        date: month.toISOString(),
      };
    });

    const cumulativeMonths: PatientCostGraphDatum[] = [];
    let cumulativeSpend = 0;
    for (const currentMonth of monthlySpend) {
      const { totalEnrollment, totalSpend } = currentMonth;

      cumulativeSpend += totalSpend;
      const costPerPatient =
        totalEnrollment > 0 ? cumulativeSpend / totalEnrollment : null;

      cumulativeMonths.push({
        ...currentMonth,
        forecastedSpend: currentMonth.actual ? null : costPerPatient,
        actualSpend: currentMonth.actual ? costPerPatient : null,
      });
    }

    return cumulativeMonths.flatMap((month) => {
      if (month.totalEnrollment === 0 || month.totalSpend === 0) {
        return [];
      }

      return [
        {
          actual: month.actual,
          actualSpend:
            month.actualSpend === null ? null : month.actualSpend.toString(),
          forecastedSpend:
            month.forecastedSpend === null
              ? null
              : month.forecastedSpend.toString(),
          date: month.date,
        },
      ];
    });
  }, [generatedForecast?.patients, rows]);

  const expectedCostPerPatient = useMemo(() => {
    if (
      generatedForecast?.patients_contracted === 0 ||
      generatedForecast?.patients_contracted === undefined
    ) {
      return undefined;
    }

    return (
      (getSum(rows, 'default_contract_value') ?? 0) /
      generatedForecast.patients_contracted
    ).toString();
  }, [generatedForecast, rows]);

  return (
    <ParentSize>
      {(parent) => {
        const graphWidth = Math.max(parent.width, 900);

        return forecastLoading || financialForecastLoading ? (
          <Skeleton
            height={TRIAL_SPEND_HEIGHT}
            variant="rectangular"
            width="100%"
          />
        ) : (
          <CostPerPatientGraph
            expectedCostPerPatient={expectedCostPerPatient}
            graphData={graphData}
            latestCloseDate={generatedForecast?.latest_close_date}
            width={graphWidth}
          />
        );
      }}
    </ParentSize>
  );
}

export default PatientCostGraph;
