import {
  ColumnDef as _ColumnDef,
  ExpandedState,
  Row,
  flexRender,
  getCoreRowModel,
  getExpandedRowModel,
  useReactTable,
  SortingState,
  OnChangeFn,
  getSortedRowModel,
  FilterFn,
  getFilteredRowModel
} from '@tanstack/react-table';
import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from './Table';
import { useCallback, useEffect, useRef, useState, Fragment, ReactNode } from 'react';
import { useInfiniteQuery } from '@tanstack/react-query';
import { useVirtual } from 'react-virtual';
import { VariantProps, cn, getClassNames } from '~/lib/utils';
import LoadingDots from '../v2/feedback/LoadingDots';
import { tv } from 'tailwind-variants';
import { Icon } from '../Icon';
import { isEmpty, isFunction } from 'lodash';

interface PaginatedQueryResult<TData extends { id: string }> {
  edges?: Array<{
    cursor: string;
    node: TData;
  }> | null;
  pageInfo: {
    startCursor: string;
    endCursor: string;
    hasNextPage: boolean;
    hasPreviousPage: boolean;
  };
}

export interface PaginatedQueryParams {
  after?: string;
  before?: string;
}

export type ColumnDef<TData> = _ColumnDef<TData> & { isVisible?: boolean };

// TODO @poindexd - complete refactor
const tableVariants = tv({
  slots: {
    wrapper: 'max-h-full w-max min-w-full overflow-auto rounded-md border border-gray-300 bg-white',
    header: 'sticky top-0 bg-gray-50 hover:bg-gray-50',
    head: 'px-2 first:pl-4 last:pr-4',
    row: '',
    cell: 'p-2 first:pl-4 last:pr-4',
    table: ''
  }
});

interface DataTableVirtualProps<TData extends { id: string }>
  extends VariantProps<typeof tableVariants> {
  columns: ColumnDef<TData>[];
  getData?: (params: PaginatedQueryParams) => Promise<PaginatedQueryResult<TData>>;
  onRowClick?: (row: Row<TData>) => void;
  selectedRow?: TData;
  tableBodyRef?: React.RefObject<HTMLTableSectionElement>;
  estimateSize: (index: number) => number;
  // Sorting
  sorting?: SortingState;
  onSortingChange?: OnChangeFn<SortingState>;
  // Filtering
  globalFilter?: string | FilterFn<TData>;
  // Todo - is there a better way to trigger refetch when filters change?
  filters?: any;
  queryKey: string;
  disableFetchMore?: boolean;
  emptyComponent?: ReactNode;
}

export function DataTableVirtual<TData extends { id: string }>({
  columns,
  getData,
  onRowClick,
  selectedRow,
  tableBodyRef,
  estimateSize,
  sorting,
  onSortingChange,
  globalFilter,
  filters,
  queryKey,
  disableFetchMore,
  emptyComponent,
  ...rest
}: DataTableVirtualProps<TData>) {
  const tableContainerRef = useRef<HTMLDivElement>(null);
  const [expanded, setExpanded] = useState<ExpandedState>({});

  const classNames = getClassNames(tableVariants, rest);

  const { fetchNextPage, isLoading, refetch, data, hasNextPage, isFetched } = useInfiniteQuery<
    PaginatedQueryResult<TData>
  >(
    [queryKey],
    async ({ pageParam }) =>
      await getData({
        after: pageParam
      }),
    {
      getNextPageParam: lastGroup =>
        lastGroup.edges?.length && lastGroup.pageInfo?.hasNextPage
          ? lastGroup.pageInfo?.endCursor
          : undefined,
      keepPreviousData: true,
      refetchOnWindowFocus: false
    }
  );

  const table = useReactTable({
    data: data?.pages?.flatMap(page => page.edges.flatMap(edge => edge.node)) ?? [],
    columns: columns.filter(column => column.isVisible !== false),
    state: {
      expanded,
      sorting,
      globalFilter
    },
    manualSorting: true,
    enableGlobalFilter: true,
    onExpandedChange: setExpanded,
    onSortingChange,
    globalFilterFn: isFunction(globalFilter) ? globalFilter : 'includesString',
    getExpandedRowModel: getExpandedRowModel(),
    getCoreRowModel: getCoreRowModel(),
    getSortedRowModel: getSortedRowModel(),
    getFilteredRowModel: getFilteredRowModel()
  });

  const fetchMoreOnBottomReached = useCallback(
    (containerRefElement?: HTMLDivElement | null) => {
      if (containerRefElement && !isLoading && hasNextPage) {
        const { scrollHeight, scrollTop, clientHeight } = containerRefElement;
        if (scrollHeight - scrollTop - clientHeight < 300) {
          fetchNextPage();
        }
      }
    },
    [fetchNextPage, hasNextPage, isLoading]
  );

  useEffect(() => {
    !disableFetchMore && fetchMoreOnBottomReached(tableContainerRef.current);
  }, [fetchMoreOnBottomReached]);

  useEffect(() => {
    refetch();
    if (tableContainerRef.current) {
      tableContainerRef.current.scrollTop = 0;
    }
  }, [filters, sorting]);

  const { rows } = table.getRowModel();

  const rowVirtualizer = useVirtual({
    parentRef: tableContainerRef,
    size: rows.length,
    overscan: 10,
    estimateSize: isFunction(estimateSize) ? estimateSize : undefined
  });
  const { virtualItems: virtualRows, totalSize } = rowVirtualizer;
  const paddingTop = virtualRows.length > 0 ? virtualRows?.[0]?.start || 0 : 0;
  const paddingBottom =
    virtualRows.length > 0 ? totalSize - (virtualRows?.[virtualRows.length - 1]?.end || 0) : 0;

  if (!virtualRows?.length && isLoading && !isFetched) {
    return (
      <div className="w-full text-center">
        <LoadingDots />
      </div>
    );
  }
  if (!virtualRows?.length && !isLoading && isEmpty(filters) && emptyComponent) {
    return <div>{emptyComponent}</div>;
  }

  return (
    <div
      ref={tableContainerRef}
      onScroll={e => fetchMoreOnBottomReached(e.target as HTMLDivElement)}
      className={classNames.wrapper}
    >
      <Table className={classNames.table}>
        <TableHeader className={classNames.header}>
          {table.getHeaderGroups().map(headerGroup => (
            <TableRow key={headerGroup.id}>
              {headerGroup.headers.map(header => (
                <TableHead
                  key={header.id}
                  className={cn(
                    header.column.getCanSort() && 'cursor-pointer select-none',
                    classNames.head
                  )}
                  onClick={header.column.getToggleSortingHandler()}
                >
                  <div className="flex items-center space-x-2">
                    {flexRender(header.column.columnDef.header, header.getContext())}

                    {!!header.column.getIsSorted() && (
                      <Icon
                        name="ArrowNarrowRight"
                        className={cn([
                          'h-3.5 w-3.5 transition-all',
                          header.column.getIsSorted() === 'asc' && 'rotate-90',
                          header.column.getIsSorted() === 'desc' && '-rotate-90'
                        ])}
                      />
                    )}
                  </div>
                </TableHead>
              ))}
            </TableRow>
          ))}
        </TableHeader>
        <TableBody ref={tableBodyRef}>
          {paddingTop > 0 && (
            <tr>
              <td style={{ height: `${paddingTop}px` }} />
            </tr>
          )}
          {virtualRows?.length ? (
            virtualRows.map(virtualRow => {
              const row = rows[virtualRow.index] as Row<TData>;
              return (
                <Fragment key={row.id}>
                  <TableRow
                    data-state={row.getIsSelected() && 'selected'}
                    onClick={() => onRowClick?.(row)}
                    className={cn(
                      'group/row bg-white hover:bg-indigo-50',
                      row.original.id === selectedRow?.id && 'bg-indigo-100 hover:bg-indigo-100',
                      classNames.row
                    )}
                  >
                    {row.getVisibleCells().map(cell => (
                      <TableCell
                        key={cell.id}
                        className={classNames.cell}
                        style={{ width: cell.column.getSize(), maxWidth: cell.column.getSize() }}
                      >
                        {flexRender(cell.column.columnDef.cell, cell.getContext())}
                      </TableCell>
                    ))}
                  </TableRow>
                </Fragment>
              );
            })
          ) : (
            <TableRow>
              <TableCell
                colSpan={table.getAllColumns().length}
                className="h-24 bg-white text-center"
              >
                {isLoading ? (
                  <LoadingDots />
                ) : (
                  <span>{`No results${isEmpty(filters) ? '' : ' match the filters'}.`}</span>
                )}
              </TableCell>
            </TableRow>
          )}
          {paddingBottom > 0 && (
            <tr>
              <td style={{ height: `${paddingBottom}px` }} />
            </tr>
          )}
        </TableBody>
      </Table>
    </div>
  );
}
