import { useState, useEffect, useRef, useContext } from 'react'
import ForceGraph3D from 'react-force-graph-3d'
import { SNAGraph } from './Network2dPlusRenderer'
import { useTheme } from '@mui/material'
//@ts-ignore
import { UnrealBloomPass } from 'three/examples/jsm/postprocessing/UnrealBloomPass.js'
//@ts-ignore
import { Vector2 } from 'three'
import { RootContext } from 'contexts/RootContext'
import usePreviousValue from 'hooks/usePreviousValue'
import { GLTFExporter } from 'three/examples/jsm/exporters/GLTFExporter'
import * as THREE from 'three'
/**
 * TODO:
 *  1.  'ForceGraph3D' defines a static height and width based on 'GraphShow's root element.
 *      After the canvas is painted, resizing the window causes visual bugs and overflows.
 *      Add a resize event handler for ForceGraph3D if possible.
 */

const bloomPass = new UnrealBloomPass(new Vector2(window.innerWidth, window.innerHeight), 3, 0.5, 0.35)

const Network3dRenderer: React.FC<SNAGraph> = ({
    action,
    nodes,
    onNodeHover,
    edges,
    setSelectedItem: setSelectedNode,
    selectedNode,
}: SNAGraph) => {
    const { theme } = useContext(RootContext)
    const lastActionId = usePreviousValue(action?.id || 0)

    const threedGraphRef = useRef<any>()
    const [graphData, setGraphData] = useState({ links: [], nodes: [] })

    const muiTheme = useTheme()

    useEffect(() => {
        if (threedGraphRef.current && threedGraphRef.current.postProcessingComposer().passes.length <= 1) {
            if (theme === 'dark') {
                threedGraphRef.current.postProcessingComposer().addPass(bloomPass)
            }
        } else {
            if (theme === 'light') {
                threedGraphRef.current.postProcessingComposer().removePass(bloomPass)
            }
        }
    }, [threedGraphRef.current, theme])

    useEffect(() => {
        if (nodes == null) setGraphData({ links: [], nodes: [] })

        setGraphData({
            links: edges != null ? JSON.parse(JSON.stringify(edges)) : [],
            nodes: JSON.parse(JSON.stringify(nodes)),
        })
    }, [edges, nodes])

    const saveImagePNG = () => {
        if (threedGraphRef.current) {
            //@ts-ignore
            document.querySelector('#threed-wrapper canvas').toBlob(function (blob: any) {
                var link = document.createElement('a')
                link.download = 'img' + '.png'
                link.href = URL.createObjectURL(blob)
                document.body.appendChild(link)
                link.click()
                document.body.removeChild(link)
            })

            const threeObj = threedGraphRef.current.scene()
            const exporter = new GLTFExporter()
            exporter.parse(
                threeObj,
                function (gltf) {
                    const output = JSON.stringify(gltf, null, 2)

                    // Create a blob from the output string
                    const blob = new Blob([output], { type: 'text/plain' })

                    // Use file-saver to download the file
                    var link = document.createElement('a')
                    link.download = 'my3dgraph.gltf'
                    link.href = URL.createObjectURL(blob)
                    document.body.appendChild(link)
                    link.click()
                    document.body.removeChild(link)
                },
                (error) => {
                    console.log(error)
                }
            ) // set the `binary` option to true for a binary .glb file instead of a .gltf file
        }
    }

    //action listener
    useEffect(() => {
        if (action && action.id > 0 && lastActionId !== action.id) {
            switch (action.content.type) {
                case 'fit':
                    threedGraphRef?.current?.zoomToFit()
                    break
                case 'export-image':
                    saveImagePNG()
            }
        }
    }, [action?.id])

    const createArrow = () => {
        const arrowLength = 4
        const arrowWidth = 1.5
        const coneGeometry = new THREE.ConeGeometry(arrowWidth, arrowLength, 6)
        const coneMaterial = new THREE.MeshBasicMaterial({ color: 0xff0000 })

        const arrow = new THREE.Mesh(coneGeometry, coneMaterial)
        arrow.rotation.x = Math.PI / 2

        return arrow
    }

    return (
        <div id="threed-wrapper">
            <ForceGraph3D
                ref={threedGraphRef}
                graphData={graphData}
                // 'graph-show-wrapper' is the ID of 'GraphShow' component's root element.
                height={document.getElementById('graph-show-wrapper')?.clientHeight}
                width={document.getElementById('graph-show-wrapper')?.clientWidth}
                // @Theme conditional
                backgroundColor={
                    muiTheme.palette.mode === 'light' ? muiTheme.palette.common.bg_1 : muiTheme.palette.common.bg_3
                }
                nodeResolution={10}
                warmupTicks={0}
                cooldownTicks={0}
                showNavInfo={false}
                linkWidth={(edge: any) => {
                    return edge.lineStyle?.width || 1
                }}
                linkColor={(edge: any) => {
                    return edge.lineStyle?.color || '#c4c4c4'
                }}
                nodeAutoColorBy="id"
                nodeColor={(node: any) => {
                    if (selectedNode === node.id) {
                        return node.select?.itemStyle?.color || '#f00'
                    }
                    return node.itemStyle?.color || '#c4c4c4'
                }}
                nodeVal={(node: any) => {
                    return (node.symbolSize || 4) * 0.2
                }}
                onNodeClick={(node) => {
                    setSelectedNode(node.id as any)
                }}
                onNodeHover={(node) => {
                    onNodeHover(node === null ? null : (node.id as any))
                }}
                linkThreeObjectExtend={true}
                linkThreeObject={(link) => {
                    // Arrow pointing from source to target
                    const group = new THREE.Group()
                    const arrow1 = createArrow()
                    group.add(arrow1)

                    const arrow2 = createArrow()
                    group.add(arrow2)
                    return group
                }}
                linkPositionUpdate={(obj, { start, end }, link) => {
                    const { __curve } = link

                    const [arrow1, arrow2] = obj.children

                    const dx = end.x - start.x
                    const dy = end.y - start.y
                    const dz = end.z - start.z
                    const length = Math.sqrt(dx * dx + dy * dy + dz * dz)

                    // Normalize the direction vector
                    const nx = dx / length
                    const ny = dy / length
                    const nz = dz / length

                    if (__curve) {
                        const nearEnd = __curve.getPointAt(0.9)
                        arrow1.position.copy(nearEnd)
                    } else {
                        const offset1 = 5 // Adjust this value as needed
                        arrow1.position.set(end.x - nx * offset1, end.y - ny * offset1, end.z - nz * offset1)
                    }

                    // Compute and apply the quaternion for arrow1
                    const direction = new THREE.Vector3()
                    direction.subVectors(new THREE.Vector3(end.x, end.y, end.z), arrow1.position).normalize()
                    const quaternion = new THREE.Quaternion()
                    quaternion.setFromUnitVectors(new THREE.Vector3(0, 1, 0), direction)
                    arrow1.quaternion.copy(quaternion)

                    // Set the position of arrow2 near the start node A
                    if (__curve) {
                        const nearStart = __curve.getPointAt(0.1)
                        arrow2.position.copy(nearStart)
                    } else {
                        const offset2 = 5 // Adjust this value as needed
                        arrow2.position.set(start.x + nx * offset2, start.y + ny * offset2, start.z + nz * offset2)
                    }

                    // Compute and apply the quaternion for arrow2
                    const direction2 = new THREE.Vector3()
                    direction2.subVectors(new THREE.Vector3(start.x, start.y, start.z), arrow2.position).normalize()
                    const quaternion2 = new THREE.Quaternion()
                    quaternion2.setFromUnitVectors(new THREE.Vector3(0, 1, 0), direction2)
                    arrow2.quaternion.copy(quaternion2)
                }}
                linkCurvature={1}
                linkCurveRotation={0}
            />
        </div>
    )
}

export default Network3dRenderer
