import React, { PropsWithChildren, useMemo } from "react"
import * as d3 from "d3"
import { sankey, sankeyLinkHorizontal } from "d3-sankey"
import { getDomainAndRange } from "../BarChart/utils"
import { StyledLabelText } from "../style"
import { SankeyChartDataLink, SankeyChartDataNode, SankeyChartProps } from "./types"
import { isValidLink, isValidNode, sortLinks, sortNodes } from "./utils"
import { getConfig } from "./config"
import { StyledSankeyChart } from "./style"

export const SankeyChart = ({
  data,
  linkColorScale,
  nodeColorScale,
  className,
}: PropsWithChildren<SankeyChartProps>) => {
  const {
    areaWidth,
    areaHeight,
    graphHeight,
    graphWidth,
    nodeWidth,
    nodePadding,
    nodeLabelFontSize,
    nodeLegendHeight,
    nodeLabelIndent,
  } = getConfig()
  const [linkDomain, linkRange] = getDomainAndRange(linkColorScale)
  const linkColor = d3.scaleOrdinal().domain(linkDomain).range(linkRange)

  const [nodeDomain, nodeRange] = getDomainAndRange(nodeColorScale)
  const nodeColor = d3.scaleOrdinal().domain(nodeDomain).range(nodeRange)

  const sankeyDiagram = useMemo(
    () =>
      sankey<SankeyChartDataNode, SankeyChartDataLink>()
        .nodeId(d => d.name)
        .nodeSort(sortNodes(linkDomain))
        .linkSort(sortLinks(linkDomain))
        .nodeWidth(nodeWidth)
        .nodePadding(nodePadding)
        .size([graphWidth, graphHeight]),
    [graphHeight, graphWidth, linkDomain, nodePadding, nodeWidth],
  )

  const { nodes, links } = useMemo(() => {
    const clonedData = structuredClone(data)
    return sankeyDiagram(clonedData)
  }, [sankeyDiagram, data])

  const legend = nodes.filter((obj, index) => nodes.findIndex(item => item.layer === obj.layer) === index)

  return (
    <StyledSankeyChart
      overflow="visible"
      width={areaWidth}
      height={areaHeight}
      viewBox={`0 0 ${areaWidth} ${areaHeight}`}
      className={className}
    >
      {nodes.filter(isValidNode).map(node => {
        const isSourceNode = node.sourceLinks?.length
        const nodeLinks = (isSourceNode ? node.sourceLinks : node.targetLinks) || []
        return (
          <g key={node.name}>
            {nodeLinks.filter(isValidLink).map((link, i) => {
              const linkY = isSourceNode ? "y0" : "y1"
              return (
                <rect
                  key={`${node.name}-${i}`}
                  x={node.x0}
                  y={link[linkY] - link.width / 2}
                  width={node.x1 - node.x0}
                  height={link.width}
                  fill={String(nodeColor(link.score))}
                />
              )
            })}
            <text x={node.x1 + nodeLabelIndent} y={(node.y1 + node.y0) / 2} fontSize={nodeLabelFontSize}>
              {node.name}
            </text>
          </g>
        )
      })}
      <g fill="none">
        {links
          .filter(isValidLink)
          .sort((a, b) => b.width - a.width)
          .map(link => {
            const { source, target, score } = link
            return (
              <g key={`${source.name}-${target.name}-${score}`}>
                <path
                  key={`${source.name}-${target.name}`}
                  d={sankeyLinkHorizontal()(link) || undefined}
                  strokeWidth={Math.max(1, link.width)}
                  stroke={String(linkColor(link.score))}
                  width={link.width}
                />
              </g>
            )
          })}
      </g>
      <g transform={`translate(0, -${nodeLegendHeight})`}>
        {legend.filter(isValidNode).map(l => {
          return (
            <StyledLabelText key={l.depth} y={480} x={l.x0 + (l.x1 - l.x0) / 2} dy="1em" textAnchor="middle">
              {l.group}
            </StyledLabelText>
          )
        })}
      </g>
    </StyledSankeyChart>
  )
}
