import React from "react";

import {
  type GroupingState,
  useReactTable,
  getPaginationRowModel,
  getFilteredRowModel,
  getCoreRowModel,
  getExpandedRowModel,
  type ColumnDef,
  flexRender,
  type OnChangeFn,
  type RowData,
  createRow,
  type Cell,
  type Row,
  type SortingState,
  type Updater,
} from "@tanstack/react-table";
import styled from "styled-components";
import useDeepCompareEffect from "use-deep-compare-effect";

import { colorTheme } from "@utils";

import { Button, ToolTip } from "../";
import gridSpacing from "../../../utils/gridSpacing";
import { numberWithCommas } from "../utilities";

const TableHeader = styled.thead`
  width: 100%;
  align-items: center;
  font-weight: bold;
  font-size: 12px;
  color: ${colorTheme("neutralL1")};
  text-transform: uppercase;
  background: #fafafa;
  border-top-right-radius: 5px;
  border-top-left-radius: 5px;
`;

const TableHeadRow = styled.tr`
  width: 100%;

  td {
    border-bottom: 1px solid ${colorTheme("neutralL4")};
  }
`;

const TableRow = styled.tr`
  width: 100%;

  td {
    border-bottom: 1px solid ${colorTheme("neutralL4")};
  }

  &:hover {
    td {
      background: ${colorTheme("neutralL5")};
    }
  }
`;

const TableData = styled.td`
  margin: 0;
  padding: 10px;
  white-space: nowrap;
  overflow: hidden;
  text-overflow: ellipsis;
  background: ${colorTheme("white")};
`;

const TableHead = styled.th`
  margin: 0;
  padding: 10px;
  white-space: nowrap;
  overflow: hidden;
  text-overflow: ellipsis;
  background: ${colorTheme("white")};
`;

const Resizer = styled.div`
  border-left: 3px dashed transparent;
  right: 0;
  width: 4px;
  height: 80%;
  top: 10%;
  touch-action: none;
  position: absolute;
  z-index: 2;

  &:hover {
    cursor: col-resize;
    border-color: ${colorTheme("neutralL2")};
  }
`;

const Pagination = styled.div`
  display: flex;
  flex-direction: row;
  justify-content: flex-end;
  align-items: center;
  padding-top: 15px;
  color: ${colorTheme("neutral")};
  font-size: 14px;
  gap: ${gridSpacing[3]}px;
`;

const PaginationIcon = styled.i<{ $disabled: boolean }>`
  display: flex;
  flex-direction: row;
  justify-content: center;
  align-items: center;
  padding: 8px;
  border: 1px solid ${colorTheme("neutralL3")};
  color: ${({ $disabled }) =>
    $disabled ? colorTheme("neutralL3") : colorTheme("neutral")};
  width: 14px;
  pointer-events: ${({ $disabled }) => ($disabled ? "none" : "auto")};

  &:hover {
    color: ${({ $disabled }) =>
      $disabled ? colorTheme("neutralL2") : colorTheme("neutralL1")};
    cursor: ${({ $disabled }) => ($disabled ? "not-allowed" : "pointer")};
  }
`;

type KeysOfUnion<T> = T extends T ? keyof T : never;

type PaginationProps = {
  totalRows: number;
  rowOffset: number;
  pageSize: number;
};

type PivotTableProps<TData extends RowData> = {
  columns: ColumnDef<TData>[];
  data: TData[];
  grouping: GroupingState;
  setGrouping: OnChangeFn<GroupingState>;
  getSubRows: (originalRow: Row<TData>) => Promise<TData[]>;
  childCountProp?: KeysOfUnion<TData>;
  pinnedColumns?: string[];
  onSortingChange: (sorting: SortingState) => void;
  pagination: PaginationProps;
  onPaginationChange: (pagination: PaginationProps) => void;
  exportAction?: () => void;
};

/**
 * PivotTable is a table component that supports server-side pagination, sorting, and grouping.
 * @param columns - The columns to display in the table, use the ColumnDef type to define the columns
 * @param data - The data to display in the table, needs to be same type used in ColumnDef
 * @param grouping - The current grouping state of the table
 * @param setGrouping - Grouping setter
 * @param getSubRows - Function to fetch subrows for a row, can be async
 * @param childCountProp - The property in the data that contains the count of children
 * @param pinnedColumns - The columns to pin to the left of the table
 * @param onSortingChange - Callback for when the sorting changes
 * @param onPaginationChange - Callback for when the pagination changes
 * @param pagination - The current pagination state
 * @constructor
 */
const PivotTable = <TData extends RowData>({
  columns,
  data,
  grouping,
  setGrouping,
  getSubRows,
  childCountProp,
  pinnedColumns,
  onSortingChange: _onSortingChange,
  onPaginationChange,
  pagination,
  exportAction,
}: PivotTableProps<TData>) => {
  const [sorting, setSorting] = React.useState<SortingState>([]);
  const [rows, setRows] = React.useState<Row<TData>[]>([]);

  const table = useReactTable({
    data: data,
    columns: columns,
    defaultColumn: {
      minSize: 60,
      size: 150,
      maxSize: 300,
    },
    state: {
      sorting,
      grouping,
      columnPinning: {
        left: pinnedColumns,
      },
    },
    onGroupingChange: setGrouping,
    getExpandedRowModel: getExpandedRowModel(),
    getCoreRowModel: getCoreRowModel(),
    getPaginationRowModel: getPaginationRowModel(),
    getFilteredRowModel: getFilteredRowModel(),
    manualSorting: true,
    // eslint-disable-next-line no-use-before-define
    onSortingChange,
    manualPagination: true,
    manualGrouping: true,
    columnResizeMode: "onChange",
  });

  useDeepCompareEffect(() => {
    setRows(table.getRowModel().rows);
    table.resetExpanded();
  }, [data]);

  function onSortingChange(sortingUpdaterOrNewSorting: Updater<SortingState>) {
    let newSorting;
    if (typeof sortingUpdaterOrNewSorting === "function") {
      newSorting = sortingUpdaterOrNewSorting(sorting);
    } else {
      newSorting = sortingUpdaterOrNewSorting;
    }
    table.resetExpanded();
    _onSortingChange(newSorting);
    setSorting(newSorting);
  }

  // Memoize the column sizes to prevent unnecessary re-renders
  // Taken from doc example: https://tanstack.com/table/latest/docs/framework/react/examples/column-resizing-performant
  const columnSizeVars = React.useMemo(() => {
    const headers = table.getFlatHeaders();
    const colSizes: Record<string, number> = {};
    for (let i = 0; i < headers.length; i++) {
      const header = headers[i]!;
      colSizes[`--header-${header.id}-size`] = header.getSize();
      colSizes[`--col-${header.column.id}-size`] = header.column.getSize();
    }
    return colSizes;
  }, [table.getState().columnSizingInfo]);

  const getCellIsGrouped = (cell: Cell<TData, unknown>) =>
    grouping.includes(cell.column.id) && cell.row.depth === 0;

  const upperPaginationLimit = Math.min(
    pagination.pageSize + pagination.rowOffset,
    pagination.totalRows,
  );

  return (
    <div>
      {exportAction && (
        <div style={{ position: "relative", width: "100%" }}>
          <div
            style={{
              position: "absolute",
              right: gridSpacing[2],
              top: -gridSpacing[4],
              display: "flex",
              alignItems: "center",
              zIndex: 4,
            }}
          >
            <ToolTip text="Export">
              <Button
                small
                icon="fa-download"
                type="link"
                onClick={exportAction}
              />
            </ToolTip>
          </div>
        </div>
      )}
      <div style={{ width: "100%", overflowX: "auto" }}>
        <table
          style={{
            ...columnSizeVars,
            width: "100%",
            borderSpacing: 0,
            position: "relative",
          }}
        >
          <TableHeader>
            {table.getHeaderGroups().map((headerGroup) => (
              <TableHeadRow style={{ width: "100%" }} key={headerGroup.id}>
                {headerGroup.headers.map((header) => (
                  <TableHead
                    key={header.id}
                    colSpan={header.colSpan}
                    style={{
                      position: header.column.getIsPinned()
                        ? "sticky"
                        : "relative",
                      left: header.column.getIsPinned()
                        ? header.column.getStart("left")
                        : undefined,
                      minWidth: `calc(var(--header-${header?.id}-size) * 1px)`,
                      background: "white",
                      zIndex: header.column.getIsPinned() ? 3 : 2,
                    }}
                  >
                    <div style={{ position: "relative" }}>
                      {header.isPlaceholder ? null : (
                        <div style={{ display: "flex", alignItems: "center" }}>
                          <p
                            style={{
                              cursor: header.column.getCanSort()
                                ? "pointer"
                                : "normal",
                              fontSize: 12,
                              color: colorTheme("neutralL1"),
                              textTransform: "uppercase",
                            }}
                            data-testid={`pivot-table-sort-${header.id}`}
                            onClick={header.column.getToggleSortingHandler()}
                          >
                            {flexRender(
                              header.column.columnDef.header,
                              header.getContext(),
                            )}
                          </p>
                          <a
                            style={{ marginLeft: gridSpacing[1] }}
                            className={
                              {
                                asc: "fa-solid fa-arrow-up",
                                desc: "fa-solid fa-arrow-down",
                              }[header.column.getIsSorted() as string]
                            }
                          />
                        </div>
                      )}
                      <Resizer
                        onMouseDown={header.getResizeHandler()}
                        onTouchStart={header.getResizeHandler()}
                      />
                    </div>
                  </TableHead>
                ))}
              </TableHeadRow>
            ))}
          </TableHeader>
          <tbody style={{ width: "100%" }}>
            {rows.map((row) => {
              const expandDisabled =
                childCountProp &&
                (!row.original[childCountProp] ||
                  row.original[childCountProp] === "0");
              return (
                <TableRow key={row.id}>
                  {row.getVisibleCells().map((cell) => (
                    <TableData
                      key={cell.id}
                      style={{
                        background: getCellIsGrouped(cell)
                          ? colorTheme("infoL3")
                          : cell.getIsAggregated() && row.getIsExpanded()
                            ? colorTheme("infoL5")
                            : "white",
                        width: `calc(var(--col-${cell.column.id}-size) * 1px)`,
                        maxWidth: cell.column.columnDef.size,
                        position: cell.column.getIsPinned()
                          ? "sticky"
                          : "relative",
                        left: cell.column.getIsPinned()
                          ? cell.column.getStart("left")
                          : undefined,
                        zIndex: cell.column.getIsPinned() ? 2 : 1,
                      }}
                    >
                      {getCellIsGrouped(cell) ? (
                        // If it's a grouped cell, add an expander and row count
                        <button
                          style={{
                            cursor: expandDisabled ? "not-allowed" : "pointer",
                            width: "100%",
                            display: "flex",
                            alignItems: "center",
                          }}
                          data-testid={`pivot-table-expand-button-${row.id}`}
                          disabled={expandDisabled}
                          onClick={async () => {
                            if (!row.getIsExpanded()) {
                              const fetchedSubrows = await getSubRows(row);
                              const newSubrows = fetchedSubrows.map(
                                (subrow, i) =>
                                  createRow<TData>(
                                    table,
                                    `${i}.${row.id}`,
                                    subrow,
                                    0,
                                    1,
                                    undefined,
                                    row.id,
                                  ),
                              );
                              // Add new subrows after the current row
                              setRows((rows) => {
                                const newRows = [...rows];
                                const rowIndex = newRows.findIndex(
                                  (r) => r.id === row.id,
                                );
                                newRows.splice(rowIndex + 1, 0, ...newSubrows);
                                return newRows;
                              });
                            } else {
                              // remove subrows
                              setRows((rows) =>
                                rows.filter((r) => r.parentId !== row.id),
                              );
                            }
                            row.toggleExpanded();
                          }}
                        >
                          <span
                            style={{
                              fontSize: 14,
                              cursor: expandDisabled
                                ? "not-allowed"
                                : "pointer",
                              marginRight: gridSpacing[2],
                              color: expandDisabled
                                ? colorTheme("neutralL3")
                                : colorTheme("neutralL2"),
                            }}
                            className={
                              row.getIsExpanded()
                                ? "fa-regular fa-minus"
                                : "fa-regular fa-plus"
                            }
                          />
                          <p
                            style={{
                              whiteSpace: "nowrap",
                              textOverflow: "ellipsis",
                              overflow: "hidden",
                              maxWidth: "100%",
                            }}
                          >
                            {flexRender(
                              cell.column.columnDef.cell,
                              cell.getContext(),
                            )}{" "}
                            {childCountProp &&
                              `(${row.original[childCountProp] ?? 0})`}
                          </p>
                        </button>
                      ) : cell.getIsAggregated() ? (
                        // This a normal cell since we are doing server-sided aggregation
                        flexRender(
                          cell.column.columnDef.cell,
                          cell.getContext(),
                        )
                      ) : cell.getIsPlaceholder() ? null : (
                        // Otherwise, just render the regular cell
                        flexRender(
                          cell.column.columnDef.cell,
                          cell.getContext(),
                        )
                      )}
                    </TableData>
                  ))}
                </TableRow>
              );
            })}
          </tbody>
        </table>
      </div>
      <Pagination>
        <span>Rows</span>
        <span>
          {numberWithCommas(pagination.rowOffset)} -{" "}
          {numberWithCommas(upperPaginationLimit)} of{" "}
          {numberWithCommas(pagination.totalRows)}
        </span>
        <div style={{ display: "flex" }}>
          <PaginationIcon
            style={{ borderTopLeftRadius: 8, borderBottomLeftRadius: 8 }}
            className="fa-regular fa-chevron-left"
            data-testid="pagination-previous"
            onClick={() => {
              if (pagination.rowOffset > 0) {
                table.resetExpanded();
                onPaginationChange({
                  ...pagination,
                  rowOffset: pagination.rowOffset - pagination.pageSize,
                });
              }
            }}
            $disabled={pagination.rowOffset === 0}
          />
          <PaginationIcon
            style={{
              borderLeft: 0,
              borderTopRightRadius: 8,
              borderBottomRightRadius: 8,
            }}
            className="fa-regular fa-chevron-right"
            data-testid="pagination-next"
            onClick={() => {
              if (
                pagination.rowOffset + pagination.pageSize <
                pagination.totalRows
              ) {
                table.resetExpanded();
                onPaginationChange({
                  ...pagination,
                  rowOffset: pagination.rowOffset + pagination.pageSize,
                });
              }
            }}
            $disabled={upperPaginationLimit === pagination.totalRows}
          />
        </div>
      </Pagination>
    </div>
  );
};

export default PivotTable;
