import React, { CSSProperties, forwardRef, ReactElement } from 'react';
import { Column, flexRender, Table, Row } from '@tanstack/react-table';
import { ChevronDownIcon, ChevronUpIcon } from '@heroicons/react/24/outline';
import { VirtualItem, notUndefined, useVirtualizer } from '@tanstack/react-virtual';
import Skeleton, { SkeletonTheme } from 'react-loading-skeleton';

interface IProps {
  table: Table<Record<string, unknown>>;
  emptyState: string | ReactElement;
  styles?: {
    table?: string;
    tHead?: string;
    tBody?: string;
    tHeadTr?: string;
    tRow?: string | ((row: Row<Record<string, unknown>>) => string);
    th?: string;
    td?: string;
  };
  id?: string;
  showHeaderTitle?: boolean;
  estimateSize?: number;
  count?: number;
  isLoading?: boolean;
}

const getCommonPinningStyles = (column: Column<Record<string, unknown>, unknown>): CSSProperties => {
  const canPin = column.getCanPin();

  return {
    left: canPin ? `${column.getStart()}px` : undefined,
    position: canPin ? 'sticky' : 'relative',
    zIndex: canPin ? 1 : undefined,
    opacity: canPin ? 1 : undefined,
  };
};

const VirtualRow = ({
  row,
  virtualRow,
  styles,
}: {
  row: Row<Record<string, unknown>>;
  virtualRow: VirtualItem;
  styles?: { tRow?: string | ((row: Row<Record<string, unknown>>) => string); td?: string };
}): ReactElement => {
  return (
    <tr
      key={row.id}
      className={`
                ${styles?.tRow ? (typeof styles.tRow === 'function' ? styles.tRow(row) : styles.tRow) : ''}
              `}
      style={{
        height: `${virtualRow.size}px`,
      }}
    >
      {row.getVisibleCells().map((cell) => {
        return (
          <td
            key={cell.id}
            className={`${styles?.td ? ` ${styles.td}` : ''}`}
            style={
              cell.column.getCanPin()
                ? {
                    ...getCommonPinningStyles(cell.column),
                    width: cell.column.getSize(),
                    maxWidth: cell.column.getSize(),
                    minWidth: cell.column.getSize(),
                  }
                : undefined
            }
          >
            {flexRender(cell.column.columnDef.cell, cell.getContext())}
          </td>
        );
      })}
    </tr>
  );
};

const VirtualizedBaseTable = forwardRef(
  (
    {
      table,
      emptyState = 'No Data',
      styles,
      id,
      showHeaderTitle = true,
      estimateSize = 48,
      count = 0,
      isLoading,
    }: IProps,
    ref,
  ) => {
    const tableRows = table.getRowModel().rows;
    const columnCount = table.getHeaderGroups().reduce((acc, headerGroup) => headerGroup.headers.length, 0);

    const virtualizer = useVirtualizer({
      count: count,
      estimateSize: () => estimateSize,
      getScrollElement: () => (ref as React.RefObject<HTMLDivElement>).current,
      overscan: 5,
    });

    const virtualItems = virtualizer.getVirtualItems();

    const [before, after] =
      virtualItems.length > 0
        ? [
            notUndefined(virtualItems[0]).start - virtualizer.options.scrollMargin,
            virtualizer.getTotalSize() - notUndefined(virtualItems[virtualItems.length - 1]).end,
          ]
        : [0, 0];

    return (
      <div style={{ height: `${virtualizer.getTotalSize() + 37}px` }}>
        <table className={`table-auto ${styles?.table ? ` ${styles.table}` : ''}`} data-testid={id ?? ''}>
          <thead className={`${styles?.tHead ? ` ${styles.tHead}` : ''}`}>
            {table.getHeaderGroups().map((headerGroup) => (
              <tr key={headerGroup.id} className={`${styles?.tHeadTr ? ` ${styles.tHeadTr}` : ''}`}>
                {headerGroup.headers.map((header) => (
                  <th
                    key={header.id}
                    className={`${styles?.th ? ` ${styles.th}` : ''}`}
                    style={
                      header.column.getCanPin()
                        ? {
                            ...getCommonPinningStyles(header.column),
                            width: header.getSize(),
                            maxWidth: header.getSize(),
                            minWidth: header.getSize(),
                            zIndex: 2, // Needs to be higher than pinned td beneath it
                          }
                        : undefined
                    }
                  >
                    {header.isPlaceholder ? null : (
                      <div
                        className={`${header.column.getCanSort() ? 'cursor-pointer select-none' : ''} flex flex-row gap-4`}
                        onClick={header.column.getToggleSortingHandler()}
                        title={
                          header.column.getCanSort() && showHeaderTitle
                            ? header.column.getNextSortingOrder() === 'asc'
                              ? 'Sort ascending'
                              : header.column.getNextSortingOrder() === 'desc'
                                ? 'Sort descending'
                                : 'Clear sort'
                            : undefined
                        }
                      >
                        {flexRender(header.column.columnDef.header, header.getContext())}
                        {header.column.getCanSort() && (
                          <div className="flex flex-col">
                            <ChevronUpIcon
                              className={`h-3 w-3 ${header.column.getIsSorted() === 'asc' ? 'text-black' : ''}`}
                            />
                            <ChevronDownIcon
                              className={`h-3 w-3 -mt-1 ${header.column.getIsSorted() === 'desc' ? 'text-black' : ''}`}
                            />
                          </div>
                        )}
                      </div>
                    )}
                  </th>
                ))}
              </tr>
            ))}
          </thead>
          <tbody className={`${styles?.tBody ? ` ${styles.tBody}` : ''}`}>
            {isLoading ? (
              <tr data-testid="loading-skeleton">
                <td colSpan={columnCount}>
                  <SkeletonTheme baseColor="#F8F9F6" highlightColor="#EAEFE8">
                    <Skeleton width="full" height={estimateSize - 8} count={20} />
                  </SkeletonTheme>
                </td>
              </tr>
            ) : (
              <>
                {before > 0 && (
                  <tr data-testid="before-skeleton">
                    <td colSpan={columnCount} style={{ height: before }}>
                      <SkeletonTheme baseColor="#F8F9F6" highlightColor="#EAEFE8">
                        <Skeleton width="full" height={estimateSize - 8} count={Math.ceil(before / estimateSize)} />
                      </SkeletonTheme>
                    </td>
                  </tr>
                )}
                {count > 0 ? (
                  virtualItems.map((virtualRow) => {
                    const row = tableRows[virtualRow.index];
                    return <VirtualRow key={row.id} row={row} virtualRow={virtualRow} styles={styles} />;
                  })
                ) : (
                  <tr>
                    <td colSpan={columnCount} className="text-center">
                      {emptyState}
                    </td>
                  </tr>
                )}
                {after > 0 && (
                  <tr data-testid="after-skeleton">
                    <td colSpan={columnCount} style={{ height: after }}>
                      <SkeletonTheme baseColor="#EAEFE8" highlightColor="#DFE5DC">
                        <Skeleton width="full" height={estimateSize - 8} count={Math.ceil(after / estimateSize)} />
                      </SkeletonTheme>
                    </td>
                  </tr>
                )}
              </>
            )}
          </tbody>
        </table>
      </div>
    );
  },
);

VirtualizedBaseTable.displayName = 'VirtualizedBaseTable';

export default VirtualizedBaseTable;
