
#include "/lib/Surface/BRDF.glsl"

#define RAYTRACE_SAMPLES 16 // [4 8 12 16 24 32 48 64 128 256 512]
//#define REAL_SKY_REFLECTION

#define RAYTRACE_REFINEMENT // Improves ray trace quality by refining the rays with minimal performance overhead.
#define RAYTRACE_REFINEMENT_STEPS 6 // [2 3 4 5 6 7 8 9 10 12 14 16 18 20 22 24 26 28 30 32]

bool ScreenSpaceRayTrace(in vec3 viewPos, in vec3 viewDir, in float dither, in uint steps, inout vec3 rayPos) {
    const float maxLength = 1.0 / steps;
    const float minLength = length(screenPixelSize);
    //float maxDist = far * sqrt(3.0);
    //float rayLength = ((viewPos.z + rayDir.z * maxDist) > -near) ?
    //                (-near - viewPos.z) / rayDir.z : maxDist;

    //vec3 position = ViewToScreenSpace(rayDir * rayLength + viewPos);
    vec3 position = ViewToScreenSpace(viewDir * abs(viewPos.z) + viewPos);
    vec3 screenDir = normalize(position - rayPos);
    float stepWeight = 1.0 / abs(screenDir.z);

    float stepLength = minOf((step(0.0, screenDir) - rayPos) / screenDir) * rcp(steps);

    screenDir.xy *= screenSize;
    rayPos.xy *= screenSize;

    vec3 rayStep = screenDir * stepLength;
    rayPos += rayStep * dither + screenDir * minLength;

    float depthTolerance = max(abs(rayStep.z) * 3.0, 0.02 / sqr(viewPos.z)); // From DrDesten

    #ifdef REAL_SKY_REFLECTION
        bool hitSky = false;
    #endif

    for (uint i = 0u; i < steps; ++i) {
        if (clamp(rayPos.xy, vec2(0.0), screenSize) != rayPos.xy) return false;
        if (rayPos.z >= 1.0) {
            #ifdef REAL_SKY_REFLECTION
                hitSky = true;
            #endif
            break;
        }

        float depth = texelFetch(depthtex1, ivec2(rayPos.xy), 0).x;
        stepLength = abs(depth - rayPos.z) * stepWeight;
        rayPos += screenDir * clamp(stepLength, minLength, maxLength);

        //if (depth < rayPos.z) {
        //    float linearSample = ScreenToViewSpace(depth);
        //    float currentDepth = ScreenToViewSpace(rayPos.z);
        //    if (abs(linearSample - currentDepth) / currentDepth < 0.2) {
        if (depth < rayPos.z && abs(depthTolerance - (rayPos.z - depth)) < depthTolerance) {
            return true;
        }
    }

    #ifdef RAYTRACE_REFINEMENT
        for (uint i = 0u; i < RAYTRACE_REFINEMENT_STEPS; ++i) {
            if (clamp(rayPos.xy, vec2(0.0), screenSize) != rayPos.xy) return false;
            rayStep *= 0.5;

            float depth = texelFetch(depthtex1, ivec2(rayPos.xy), 0).x;

            if (depth < rayPos.z && abs(depthTolerance - (rayPos.z - depth)) < depthTolerance) {
                rayPos -= rayStep;
            } else {
                rayPos += rayStep;
            }
        }
    #endif
    
    #ifdef REAL_SKY_REFLECTION
        return depth - isEyeInWater >= 1.0 && hitSky; // Real sky reflection
    #else
        return false;
    #endif
}
/*
float SignExtract(in float x) {
    return uintBitsToFloat((floatBitsToUint(x) & 0x80000000u) | floatBitsToUint(1.0));
}

mat3 GetRotationMatrix(in vec3 from, in vec3 to) {
    float cosine = dot(from, to);

    float tmp = SignExtract(cosine);
        tmp = 1.0 / (tmp + cosine);

    vec3 axis = cross(to, from);
    vec3 tmpv = axis * tmp;

    return mat3(
        axis.x * tmpv.x + cosine, axis.x * tmpv.y - axis.z, axis.x * tmpv.z + axis.y,
        axis.y * tmpv.x + axis.z, axis.y * tmpv.y + cosine, axis.y * tmpv.z - axis.x,
        axis.z * tmpv.x - axis.y, axis.z * tmpv.y + axis.x, axis.z * tmpv.z + cosine
    );
}
*/

// https://ggx-research.github.io/publication/2023/06/09/publication-ggx.html
vec3 sampleGGXVNDF(in vec3 viewDir, in float roughness, in vec2 xy) {
    #define SPECULAR_TAIL_CLAMP

    #ifdef SPECULAR_TAIL_CLAMP
        xy.y = clamp(xy.y / PI, 1e-3, rPI);
    #endif
    // Transform viewer direction to the hemisphere configuration
    viewDir = normalize(vec3(roughness * viewDir.xy, viewDir.z));

    // Sample a reflection direction off the hemisphere
    float phi = TAU * xy.x;
    float cosTheta = oneMinus(xy.y) * (1.0 + viewDir.z) - viewDir.z;
    float sinTheta = sqrt(saturate(1.0 - cosTheta * cosTheta));
    vec3 reflected = vec3(cossin(phi) * sinTheta, cosTheta);

    // Evaluate halfway direction
    // This gives the normal on the hemisphere
    vec3 halfway = reflected + viewDir;

    // Transform the halfway direction back to hemiellispoid configuation
    // This gives the final sampled normal
    return normalize(vec3(roughness * halfway.xy, halfway.z));
}
//#endif

vec3 Reproject(in vec3 screenPos) {
    vec3 projection = screenPos * 2.0 - 1.0;
    projection = (vec3(vec2(gbufferProjectionInverse[0].x, gbufferProjectionInverse[1].y) * projection.xy, 0.0) + gbufferProjectionInverse[3].xyz) / (gbufferProjectionInverse[2].w * projection.z + gbufferProjectionInverse[3].w);
    projection = mat3(gbufferPreviousModelView) * (mat3(gbufferModelViewInverse) * projection + gbufferModelViewInverse[3].xyz + cameraPosition - previousCameraPosition) + gbufferPreviousModelView[3].xyz;
    projection = (vec3(gbufferPreviousProjection[0].x, gbufferPreviousProjection[1].y, gbufferPreviousProjection[2].z) * projection + gbufferPreviousProjection[3].xyz) / -projection.z * 0.5 + 0.5;
    return projection;
}
