import {
	IUniform,
	LinearFilter,
	MathUtils,
	Mesh,
	Object3D,
	PlaneGeometry,
	ShaderMaterial,
	Texture,
	TextureLoader,
	Vector2,
	WebGLRenderTarget
} from 'three'
import bloomVertexShader from './bloom.vert'
import bloomFragmentShader from './bloom.frag'
import { Easing, Group, Tween } from '@tweenjs/tween.js'
import styles from '../stage/Stage.module.css'
import { App } from '../../app/App'
import { Stage } from './Stage'
import { clamp } from '../../utils/clamp'

export class BloomImage {
	private readonly mesh: Mesh<PlaneGeometry, ShaderMaterial>
	private readonly image: HTMLImageElement
	private readonly uniforms: { [p: string]: IUniform }
	private readonly lastPointer = new Vector2()
	private readonly object: Object3D = new Object3D()
	private readonly group = new Group()

	private pointer = new Vector2()
	private tween: Tween<{ waveAmount?: number; bloomAmount: number }> | null = null
	private spawnTween: Tween<{ strength: number }> | null = null
	private top = 0
	private left = 0
	private width = 0
	private height = 0
	private targetOffset = 0
	private offset = 0
	private currentSrc = ''

	public renderTarget: WebGLRenderTarget
	public texture: Texture | null = null
	public inView = false

	constructor(
		private readonly node: HTMLElement,
		private readonly parent: Object3D,
		private readonly stage: Stage,
		private readonly app: App,
		private onLoadCb: () => void
	) {
		this.onIntersect = this.onIntersect.bind(this)
		this.onImageLoad = this.onImageLoad.bind(this)
		this.image = node.querySelector('img') as HTMLImageElement
		const wobble = node.classList.contains(styles.Wobble)
		this.currentSrc = this.image?.currentSrc || ''
		this.image.style.setProperty('opacity', '0')
		// this.image.style.setProperty('opacity', '0.25')

		this.renderTarget = new WebGLRenderTarget(256, 256, {
			minFilter: LinearFilter,
			generateMipmaps: false
		})

		this.uniforms = {
			offset: { value: 0 },
			time: { value: 0 },
			random: { value: MathUtils.randFloat(-1, 1) },
			bloomAmount: { value: 1 },
			strength: { value: 0 },
			alpha: { value: 0 },
			waveAmount: { value: 1 },
			pointer: { value: new Vector2() },
			map: { value: new Texture() },
			bloomMap: { value: new Texture() },
			resolution: { value: new Vector2() },
			wobble: { value: wobble ? 1 : 0 }
		}

		const material = new ShaderMaterial({
			transparent: true,
			depthTest: false,
			vertexShader: bloomVertexShader,
			fragmentShader: bloomFragmentShader,
			uniforms: this.uniforms
		})

		this.mesh = new Mesh(new PlaneGeometry(1, 1, 50, 50), material)
		this.mesh.visible = false
		this.mesh.frustumCulled = false

		this.object.add(this.mesh)
		parent.add(this.object)

		this.app.on('intersect', this.onIntersect)
		this.app.intersectionObserver.observe(this.node)

		this.image?.addEventListener('load', this.onImageLoad)
		if (this.image?.complete) {
			this.onImageLoad()
		}
	}

	show() {
		this.group.removeAll()

		this.tween = new Tween({ bloomAmount: 1, waveAmount: this.uniforms.waveAmount.value, alpha: 0 }, this.group)
			.to({ bloomAmount: 0, waveAmount: 0, alpha: 1 }, 2500 + Math.random() * 1000)
			.easing(Easing.Quintic.Out)
			.onStart(() => (this.mesh.visible = true))
			.onUpdate(({ bloomAmount, waveAmount, alpha }) => {
				this.uniforms.bloomAmount.value = bloomAmount
				this.uniforms.waveAmount.value = waveAmount
				this.uniforms.alpha.value = alpha
			})
			.start()
	}

	async hide(): Promise<void> {
		this.group.removeAll()
		if (!this.inView) return

		return new Promise((resolve) => {
			this.tween?.stop()

			this.tween = new Tween(
				{
					waveAmount: this.uniforms.waveAmount.value,
					bloomAmount: this.uniforms.bloomAmount.value,
					alpha: this.uniforms.alpha.value
				},
				this.group
			)
				.to({ bloomAmount: 1, waveAmount: 0, alpha: 0 }, 750)
				.easing(Easing.Sinusoidal.In)
				.onUpdate(({ bloomAmount, waveAmount, alpha }) => {
					this.uniforms.bloomAmount.value = bloomAmount
					this.uniforms.waveAmount.value = waveAmount
					this.uniforms.alpha.value = alpha
				})
				.onComplete(() => resolve())
				.start()
		})
	}

	spawn(): void {
		this.spawnTween?.stop()

		this.spawnTween = new Tween({ strength: this.uniforms.strength.value }, this.group)
			.to({ strength: 1 }, 250)
			.easing(Easing.Sinusoidal.Out)
			.onUpdate(({ strength }) => {
				this.uniforms.strength.value = strength
			})
			.start()
	}

	onIntersect(entries: IntersectionObserverEntry[]): void {
		entries.forEach((entry) => {
			if (entry.target === this.node) {
				this.tween?.stop()

				this.inView = entry.isIntersecting

				if (this.inView) {
					this.show()
				} else {
					this.mesh.visible = false
				}
			}
		})
	}

	async onImageLoad(): Promise<void> {
		this.currentSrc = this.image?.currentSrc || ''
		if (this.currentSrc === '') return
		this.texture?.dispose()

		const loader = new TextureLoader(this.stage.loadingManager)
		this.texture = await loader.loadAsync(this.currentSrc)
		this.texture.minFilter = LinearFilter
		this.texture.generateMipmaps = false

		const width = this.texture?.image.naturalWidth || 256
		const height = this.texture?.image.naturalHeight || 256

		this.renderTarget.setSize(width, height)
		this.uniforms.resolution.value.set(width, height)

		this.uniforms.map.value = this.texture

		this.onLoadCb()
	}

	setBloomMap(): void {
		this.uniforms.bloomMap.value = this.renderTarget.texture
	}

	setPointer(pointerIn: Vector2, pointerOut: Vector2): void {
		pointerOut.set(
			(pointerIn.x - this.left) / this.width,
			1 - (pointerIn.y + this.app.scrollY - this.top) / this.height
		)
	}

	isPointerOver(): boolean {
		return this.pointer.x >= 0 && this.pointer.x <= 1 && this.pointer.y >= 0 && this.pointer.y <= 1
	}

	resize(fov: number): void {
		const { top, left, width, height } = this.image.getBoundingClientRect()

		this.top = top + this.app.scrollY

		this.left = left
		this.width = width
		this.height = height

		const w = (width / this.stage.width) * fov
		const h = (height / this.stage.height) * fov
		const y = ((this.top + height * 0.5) / this.stage.height - 0.5) * fov // initial offset
		const x = ((this.left + width * 0.5) / this.stage.width - 0.5) * fov // initial offset

		this.object.scale.set(w, h, 1)
		this.object.position.set(x, y * -1, 0)
	}

	scroll() {
		const top = this.top - this.app.scrollY
		this.targetOffset = clamp((1 - top / this.stage.height) * 2, 0, 1)
	}

	update(time: number, pointer: Vector2): void {
		if (!this.inView) return

		this.group.update(time)
		this.uniforms.time.value = time

		this.offset += (this.targetOffset - this.offset) * 0.025
		this.uniforms.offset.value = 1 - this.offset

		this.setPointer(pointer, this.pointer)
		const dist = this.lastPointer.distanceTo(this.pointer)
		if (dist > 0.01 && this.isPointerOver()) {
			this.spawn()
		}
		this.lastPointer.copy(this.pointer)
		this.uniforms.pointer.value.copy(this.pointer)

		this.uniforms.strength.value *= 0.98
		this.uniforms.waveAmount.value *= 0.98
	}

	dispose(): void {
		this.group.removeAll()
		this.spawnTween?.stop()
		this.tween?.stop()
		this.renderTarget?.dispose()
		this.texture?.dispose()
		this.mesh.geometry?.dispose()
		this.mesh.material.dispose()
		this.image?.removeEventListener('load', this.onImageLoad)
		this.app.removeListener('intersect', this.onIntersect)
		this.app.intersectionObserver.unobserve(this.node)
		this.parent.remove(this.object)
	}
}
