import React, {useEffect, useRef} from "react";
import * as d3 from "d3";
import {Axis} from "d3";
import {getColorByCat} from "../../style/colors";
import {getParetoFunction} from "./functions";
import './BarChart.scss';
import {appendLine} from "../../utils/d3-utils";
import useTheme from "@material-ui/core/styles/useTheme";

export type BarDataPoint = {
    value: number;
    category: string;
    categoryLabel?: string;
    valueLabel?: string;
}
export type BarData = BarDataPoint[];

type Props = {
    data: BarData;
    height: number;
    width?: number;
    vertical?: boolean;
    horizontal?: boolean;
    labelMargin?: number;
    valueLabelPxW?: number;
    curve?: (v: number) => number;
    valueAxisPercentage?: boolean;
    onClicked?: (d: BarDataPoint) => void;
    redLine?: number;
    axisLabelAngle?: number;
    valueAxisTitle?: string;
    categoryAxisTitle?: string;
};

const MARGIN = {
    // left: 34,
    // right: 10,
    // top: 6,
    // bottom: 20,
    left: 1,
    right: 1,
    top: 1,
    bottom: 1,
};
const axisTitleMargin = 5;
const axisTitleHeight = 20;

// Taken from: https://observablehq.com/@d3/horizontal-bar-chart
// D3+React: https://medium.com/@jeffbutsch/using-d3-in-react-with-hooks-4a6c61f1d102
export const BarChart: React.FC<Props> = (
    {
        data,
        height,
        width,
        vertical,
        horizontal,
        labelMargin,
        valueLabelPxW,
        curve,
        valueAxisPercentage,
        onClicked,
        redLine,
        axisLabelAngle,
        valueAxisTitle,
        categoryAxisTitle,
    }
) => {
    const theme = useTheme();
    if (horizontal && vertical) {
        throw new Error();
    }
    if (!horizontal && !vertical) {
        vertical = true;
    }
    if (width === undefined) {
        width = 255;
    }
    const margin = {...MARGIN};
    // if (horizontal && valueAxisPercentage) {
    //     margin.right += 6;
    // }
    // if (vertical && valueAxisPercentage) {
    //     margin.left += 2;
    // }
    // const marginLeft = !vertical && labelMargin !== undefined ? labelMargin : margin.left;
    // const marginBottom = vertical && labelMargin !== undefined ? labelMargin : margin.bottom;

    if (labelMargin) {
        if (vertical) {
            margin.bottom += labelMargin;
        } else {
            margin.left += labelMargin;
        }
    }
    // Add some margin to ensure the value labels fit on the graph
    if (valueLabelPxW === undefined) {
        valueLabelPxW = valueAxisPercentage ? 32 : 20;
    }
    const valueLabelPxH = 16;
    const tickSize = 4;
    if (vertical) {
        margin.left += valueLabelPxW + tickSize;
        margin.top += valueLabelPxH / 2;
    } else {
        margin.bottom += valueLabelPxH + tickSize;
        margin.right += valueLabelPxW / 2;
    }
    if (categoryAxisTitle) {
        if (vertical) {
            margin.bottom += axisTitleHeight + axisTitleMargin;
        } else {
            // noinspection JSSuspiciousNameCombination
            margin.left += axisTitleHeight + axisTitleMargin;
        }
    }
    if (valueAxisTitle) {
        if (vertical) {
            // noinspection JSSuspiciousNameCombination
            margin.left += axisTitleHeight + axisTitleMargin;
        } else {
            // Special case, when the title is just below the value axis, then some extra margin is needed
            margin.bottom += axisTitleHeight + 4 + axisTitleMargin;
        }
    }
    // console.log((vertical ? 'vertical' : 'horizontal') + ' margin:', margin);
    const graphWidth = width - margin.left - margin.right;
    const graphHeight = height - margin.top - margin.bottom;

    const addHover = true;

    const svgRef = useRef<SVGSVGElement>(null);

    let max: number;
    if (valueAxisPercentage) {
        max = 1;
    } else {
        max = d3.max(data, d => d.value) as number;
        const max2 = d3.max(data, d => (d as any).plot_value) as number;
        if (max2 !== undefined) {
            max = Math.max(max, max2);
        }
    }
    const BAR_PADDING = 0.3;

    useEffect(() => {
        if (!data || !svgRef.current) {
            console.log('BarChart.render: REJECT', svgRef.current, data);
        }
        console.log('BarChart.render: ACCEPT', svgRef.current, data);

        const svg = d3.select(svgRef.current as SVGElement);
        svg.html(''); // clear

        // // DEBUG: show margins
        // svg.append('rect')
        //     .attr('x', margin.left)
        //     .attr('y', margin.top)
        //     .attr('width', graphWidth)
        //     .attr('height', graphHeight);

        const root = svg
            .append("g")
            .attr("transform", "translate(" + margin.left + "," + margin.top + ")");

        // set the ranges
        const catRange = vertical ? [0, graphWidth] : [0, graphHeight]
        const catDomainLabels: { [name: string]: string } = data.reduce((dict, data) => {
            dict[data.category] = data.categoryLabel ? data.categoryLabel : data.category;
            return dict;
        }, {});
        let catAxis = d3.scaleBand()
            .domain(data.map(d => d.category))
            .range(catRange)
        const BAR_SPACING = BAR_PADDING / 2 * catAxis.bandwidth()

        const valRange = vertical ? [graphHeight, 0] : [0, graphWidth];
        const valDomain = [0, max];
        // [0, d3.max(data, function(d) { return d.sales; })]
        let valueAxis = d3.scaleLinear()
            .domain(valDomain)
            .range(valRange)

        const originXY = [catRange[0], valRange[0]];
        const catEndpointXY = vertical ? [catRange[1], valRange[0]] : [valRange[0], catRange[1]];
        const valEndpointXY = vertical ? [catRange[0], valRange[1]] : [valRange[1], catRange[1]];

        // append the rectangles for the bar chart
        const barGroups = root
            .append('g')
            .classed('data-bars', true)
            .selectAll('g.bar-wrapper')
            .data(data)
            .join('g')
            .classed('bar-wrapper', true)

        if (addHover) {
            if (vertical) {
                barGroups.append('rect')
                    .classed('hover-overlay', true)
                    .attr('x', d => (catAxis(d.category) as number))
                    .attr('y', 0)
                    .attr('width', catAxis.bandwidth())
                    .attr('height', graphHeight)
            } else {
                barGroups.append('rect')
                    .classed('hover-overlay', true)
                    .attr('x', valRange[0])
                    .attr('y', d => (catAxis(d.category) as number))
                    .attr('width', valRange[1] - valRange[0])
                    .attr('height', catAxis.bandwidth())
            }
        }

        const bars = barGroups
            .append('rect')
            .classed('bar', true)
            .attr('fill', d => getColorByCat(d.category))
        if (vertical) {
            bars.attr("x", d => (catAxis(d.category) as number) + BAR_SPACING)
                .attr("y", d => (valueAxis(d.value) as number))
                .attr("width", catAxis.bandwidth() * (1 - BAR_PADDING))
                .attr("height", d => graphHeight - valueAxis(d.value))
        } else {
            bars.attr("y", d => (catAxis(d.category) as number) + BAR_SPACING)
                .attr("width", d => valueAxis(d.value))
                .attr("height", catAxis.bandwidth() * (1 - BAR_PADDING))
        }

        if (onClicked) {
            barGroups
                .classed('clickable', true)
                .on('click', function () {
                    const data = d3.select(this).datum() as BarDataPoint;
                    console.log('barGroups.click', data);
                    onClicked(data);
                })
        }

        if (true) {
            // Add the catAxis
            const catAxisWrapper = root.append('g')
                .classed('cat-axis axis', true)
            const catAxisGroup = catAxisWrapper.append('g')
            const catAxisFormat = (ax: Axis<string>) => ax.tickFormat(d_name => catDomainLabels[d_name])
            if (vertical) {
                catAxisGroup.attr("transform", "translate(0," + graphHeight + ")")
                catAxisGroup.call(catAxisFormat(d3.axisBottom(catAxis)))
                if (axisLabelAngle) {
                    catAxisGroup.selectAll('text')
                        .style("text-anchor", "end")
                        .attr("dx", "-.8em")
                        .attr("dy", axisLabelAngle === -90 ? '-.6em' : '.15em')
                        .attr("transform", `rotate(${axisLabelAngle})`);
                } else {
                    // REMOVE THE FIRST ELEMENT: catAxisGroup.call(g => g.select('text').remove())
                }
            } else {
                catAxisGroup.call(catAxisFormat(d3.axisLeft(catAxis)))
            }
            catAxisGroup
                .call(g => g.selectAll('text').classed('MuiTypography-body1', true))
                .call(g => g.select('.domain').remove())

            // Add a line (without bars)
            if (vertical) {
                catAxisWrapper.append('line')
                    .attr('x1', originXY[0])
                    .attr('y1', originXY[1] + 0.5)
                    .attr('x2', catEndpointXY[0])
                    .attr('y2', catEndpointXY[1] + 0.5)
            } else {
                catAxisWrapper.append('line')
                    .attr('x1', originXY[0] + 0.5)
                    .attr('y1', originXY[1])
                    .attr('x2', catEndpointXY[0] + 0.5)
                    .attr('y2', catEndpointXY[1])
            }

            // Add the title
            if (categoryAxisTitle) {
                if (vertical) {
                    catAxisWrapper.append('text')
                        .style('text-anchor', 'middle')
                        .text(`${categoryAxisTitle}`)
                        .attr("x", (catRange[1] - catRange[0]) / 2)
                        .attr("y", valRange[0] + margin.bottom - axisTitleHeight)
                        .attr("dy", "1em")
                } else {
                    catAxisWrapper.append('g')
                        .attr("transform", `translate(${-margin.left},${(catRange[1] - catRange[0]) / 2})`)
                        .append('text')
                        .attr("transform", `rotate(-90)`)
                        .attr("dy", "1em") // Beware, this is done before rotation!
                        .style('text-anchor', 'middle')
                        .text(`${categoryAxisTitle}`)
                }
            }
        }
        if (true) {
            // Add the value Axis
            const valAxisWrapper = root.append('g')
                .classed('val-axis axis', true)
            const valAxisGroup = valAxisWrapper.append('g')
            const valueAxisFormat = (ax: Axis<d3.NumberValue>) =>
                valueAxisPercentage
                    ? ax.ticks(5).tickFormat(v => `${(v as number) * 100}%`)
                    // ? ax.ticks(10).tickFormat(v => `${(v as number) * 100}%`)
                    : ax.ticks(5).tickFormat(v => d3.format("~s")(v))
            if (vertical) {
                valAxisGroup.call(valueAxisFormat(d3.axisLeft(valueAxis)))
            } else {
                valAxisGroup.attr("transform", "translate(0," + graphHeight + ")")
                valAxisGroup.call(valueAxisFormat(d3.axisBottom(valueAxis)))
            }
            valAxisGroup
                .call(g => g.selectAll('text').classed('MuiTypography-body1', true))

            // Add the title
            if (valueAxisTitle) {
                if (vertical) {
                    valAxisWrapper.append('g')
                        .attr("transform", `translate(${-margin.left},${(valRange[0] - valRange[1]) / 2})`)
                        .append('text')
                        .attr("transform", `rotate(-90)`)
                        .attr("dy", "1em") // Beware, this is done before rotation!
                        .style('text-anchor', 'middle')
                        .text(`${valueAxisTitle}`)
                } else {
                    valAxisWrapper.append('text')
                        .style('text-anchor', 'middle')
                        .text(`${valueAxisTitle}`)
                        .attr("x", (valRange[1] - valRange[0]) / 2)
                        .attr("y", graphHeight + margin.bottom - axisTitleHeight)
                        .attr("dy", "1em")
                }
            }
        }


        const FONT_SIZE = 12;
        const MAX_WORD_SIZE = FONT_SIZE * 10
        const SPACING = 4;
        const yOffset = 1;

        const showValueLabels = true;
        if (showValueLabels) {
            const valueLabels = barGroups.append('text')
                .classed('value-label', true)
                .text(d => d.valueLabel ? d.valueLabel : `${d.value}`)

            // const valueLabels = root.append("g")
            //     .attr("fill", "black")
            //     // .attr("font-family", "sans-serif")
            //     // .attr("font-size", FONT_SIZE)
            //     .attr("text-anchor", vertical ? 'middle' : 'start')
            //     .selectAll("text")
            //     .data(data)
            //     .join("text")
            //     // .attr("dy", "0.35em")
            //     .text(d => d.valueLabel ? d.valueLabel : `${d.value}`)
            if (vertical) {
                valueLabels
                    .attr('dominant-baseline', 'text-bottom')
                    .attr('text-anchor', 'middle')
                    .attr("x", d => (catAxis(d.category) as number) + catAxis.bandwidth() / 2)
                    .attr("y", d => valueAxis(d.value) - SPACING)
                    .call(text =>
                        text.filter(d => valueAxis(d.value) < FONT_SIZE * 2)
                            .attr("y", d => valueAxis(d.value) + SPACING)
                            .attr('dominant-baseline', 'hanging')
                            .attr('fill', 'white')
                    )
            } else {
                valueLabels
                    .attr('dominant-baseline', 'middle')
                    .attr("x", d => valueAxis(d.value) + SPACING)
                    .attr("y", d => (catAxis(d.category) as number) + catAxis.bandwidth() / 2 + yOffset)
                    .call(text => {
                            text.filter(d => valRange[1] - valueAxis(d.value) < MAX_WORD_SIZE)
                                .attr('fill', 'white')
                                .attr("text-anchor", 'end')
                                .attr("x", d => Math.min(
                                    valueAxis(d.value) - SPACING,
                                    redLine ? valueAxis(redLine) - SPACING : Number.MAX_VALUE,
                                ))
                        }
                    )
            }
        }

        type CurveType = {
            name: string;
            cut: number;
            value: number;
        }
        if (curve) {
            const curveData = data.map((d, i) => ({
                name: d.category,
                cut: i / data.length,
                value: (d as any).plot_value,
            }));
            // const curveData = data.map((d, i) => ({
            //     name: d.name,
            //     cut: i / data.length,
            //     value: curve(i / data.length),
            // }))
            console.log('curveData', curveData);
            // Add the line
            const curveLine = vertical
                ? d3.line<CurveType>(
                    d => d.cut * graphWidth + catAxis.bandwidth() / 2,
                    // (d, i) => graphWidth / data.length * i + catAxis.bandwidth(),
                    d => valueAxis(d.value) as number,
                )
                : d3.line<CurveType>(
                    d => valueAxis(d.value) as number,
                    d => d.cut * graphHeight + catAxis.bandwidth() / 2,
                )

            // : d3.line<CurveType>()
            //     .x(d => catAxis(d.name))
            //     .y(d => valueAxis(d.value));
            root.append("path")
                .classed('curve', true)
                .datum(curveData)
                .attr("fill", "none")
                .attr("stroke", "steelblue")
                .attr("stroke-width", 1.5)
                .attr("d", curveLine)

            const p = getParetoFunction(80, max);
            const idealCurve = data.map((d, i) => ({
                name: d.category,
                cut: i / data.length,
                // cut: (i+1) / (data.length - 1),
                value: p(i / data.length),
            }));
            // const curveData = data.map((d, i) => ({
            //     name: d.name,
            //     cut: i / data.length,
            //     value: curve(i / data.length),
            // }))
            console.log('idealCurve', idealCurve);
            // Add the line
            const line = vertical
                ? d3.line<CurveType>(
                    d => d.cut * graphWidth + catAxis.bandwidth() / 2,
                    // (d, i) => graphWidth / data.length * i + catAxis.bandwidth(),
                    d => valueAxis(d.value) as number,
                )
                : d3.line<CurveType>(
                    d => valueAxis(d.value) as number,
                    d => d.cut * graphHeight + catAxis.bandwidth() / 2,
                )

            root.append("path")
                .classed('curve', true)
                .datum(idealCurve)
                .attr("fill", "none")
                .attr("stroke", "purple")
                .attr("stroke-width", 1.5)
                .attr("d", line)
        }

        if (redLine !== undefined) {
            const lineValuePx = valueAxis(redLine) as number
            const line = vertical
                ? appendLine(root, 0, lineValuePx, graphWidth, lineValuePx, 'highlight-line')
                : appendLine(root, lineValuePx, 0, lineValuePx, graphHeight, 'highlight-line')
            // line.attr('stroke', theme.palette.warning.main)
            line.attr('stroke', theme.palette.warning.light)
        }

    }, [data])

    return <svg
        className="bar-chart"
        ref={svgRef}
        viewBox={`0 0 ${width} ${height}`}
        style={{width: '100%', height: 'auto'}}/>;
};
