/* Copyright (c) 2024-2026 LunarG, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <spirv/unified1/GLSL.std.450.h>
#include "generated/spirv_grammar_helper.h"
#include "gpuav/core/gpuav.h"
#include "gpuav/resources/gpuav_state_trackers.h"
#include "gpuav/shaders/gpuav_error_codes.h"
#include "gpuav/shaders/gpuav_error_header.h"

namespace gpuav {

static std::string GetSpirvSpecLink(const uint32_t opcode) {
    // Currently the Working Group decided to not provide "real" VUIDs as it would become duplicating the SPIR-V spec
    // So these are not "UNASSIGNED", but instead are "SPIRV" VUs because we can point to the instruction in the SPIR-V spec
    // (https://gitlab.khronos.org/vulkan/vulkan/-/merge_requests/7853)
    return "\nSee more at https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#" + std::string(string_SpvOpcode(opcode));
}

void RegisterSanitizer(Validator &gpuav, CommandBufferSubState &cb) {
    if (!gpuav.gpuav_settings.shader_instrumentation.sanitizer) {
        return;
    }

    cb.on_instrumentation_error_logger_register_functions.emplace_back([](Validator &gpuav, CommandBufferSubState &cb,
                                                                          const LastBound &last_bound) {
        CommandBufferSubState::InstrumentationErrorLogger inst_error_logger = [](Validator &gpuav, const Location &loc,
                                                                                 const uint32_t *error_record,
                                                                                 std::string &out_error_msg,
                                                                                 std::string &out_vuid_msg) {
            using namespace glsl;
            bool error_found = false;
            if (GetErrorGroup(error_record) != kErrorGroup_InstSanitizer) {
                return error_found;
            }
            error_found = true;

            std::ostringstream strm;

            const uint32_t error_sub_code = GetSubError(error_record);
            switch (error_sub_code) {
                case kErrorSubCode_Sanitizer_DivideZero: {
                    const uint32_t opcode = error_record[kInst_LogError_ParameterOffset_0];
                    const uint32_t vector_size = error_record[kInst_LogError_ParameterOffset_1];
                    const bool is_float = opcode == spv::OpFMod || opcode == spv::OpFRem;
                    strm << (is_float ? "Float" : "Integer") << " divide by zero. Operand 2 of " << string_SpvOpcode(opcode)
                         << " is ";
                    if (vector_size == 0) {
                        strm << "zero.";
                    } else {
                        strm << "a " << vector_size << "-wide vector which contains a zero value.";
                    }
                    if (is_float) {
                        strm << " The result value is undefined.";
                    }
                    strm << GetSpirvSpecLink(opcode);
                    out_vuid_msg = "SPIRV-Sanitizer-Divide-By-Zero";
                } break;
                case kErrorSubCode_Sanitizer_ImageGather: {
                    const uint32_t component_value = error_record[kInst_LogError_ParameterOffset_0];
                    const int32_t signed_value = (int32_t)component_value;
                    strm << "OpImageGather has a component value of ";
                    if (signed_value > 0) {
                        strm << component_value;
                    } else {
                        strm << signed_value;
                    }
                    strm << ", but it must be 0, 1, 2, or 3" << GetSpirvSpecLink(spv::OpImageGather);
                    out_vuid_msg = "SPIRV-Sanitizer-Image-Gather";
                } break;
                case kErrorSubCode_Sanitizer_Pow: {
                    // Pow is only valid with a scalar/vector of 16/32-bit float
                    const uint32_t vector_size = error_record[kInst_LogError_ParameterOffset_0];
                    // Casting produces artifacts in float value, need to memcpy
                    float x_value = 0.0f;
                    float y_value = 0.0f;
                    memcpy(&x_value, &error_record[kInst_LogError_ParameterOffset_1], sizeof(float));
                    memcpy(&y_value, &error_record[kInst_LogError_ParameterOffset_2], sizeof(float));
                    strm << "Pow (from GLSL.std.450) has an undefined result because operand (x < 0) or (x == 0 && y <= 0)\n  ";
                    if (vector_size > 0) {
                        // Would need a new way to print more than 2 bytes out to get this to work
                        strm << "Using a vector of size " << vector_size << " but currently only can print out scalar values";
                    } else {
                        strm << "X == " << x_value << ", Y == " << y_value;
                    }
                    out_vuid_msg = "SPIRV-Sanitizer-Pow";
                } break;
                case kErrorSubCode_Sanitizer_Atan2: {
                    // Atan is only valid with a scalar/vector of 16/32-bit float
                    strm << "Atan2 (from GLSL.std.450) has an undefined result because both values used are zero.";
                    out_vuid_msg = "SPIRV-Sanitizer-Atan2";
                } break;
                case kErrorSubCode_Sanitizer_Fminmax: {
                    // simple encoding done in inst_sanitizer_fminman (sanitizer.comp)
                    const uint32_t invalid_encode = error_record[kInst_LogError_ParameterOffset_0];
                    const bool x_is_invalid = (invalid_encode & 0x1) != 0;
                    const bool y_is_invalid = (invalid_encode & 0x2) != 0;
                    const uint32_t vector_size = error_record[kInst_LogError_ParameterOffset_1];
                    const uint32_t glsl_opcode = error_record[kInst_LogError_ParameterOffset_2];
                    strm << (glsl_opcode == GLSLstd450FMin ? "FMin" : "FMax")
                         << " (from GLSL.std.450) has an undefined result because ";
                    if (x_is_invalid && y_is_invalid) {
                        strm << "both the x and y operands are NaN\n";
                    } else if (x_is_invalid) {
                        strm << "the x operand is NaN\n";
                    } else {
                        strm << "the y operand is NaN\n";
                    }
                    if (vector_size > 0) {
                        strm << "Using a vector of size " << vector_size << " but currently only can print out scalar values";
                    }
                    out_vuid_msg = "SPIRV-Sanitizer-Fminmax";
                } break;
                default:
                    error_found = false;
                    break;
            }
            out_error_msg += strm.str();
            return error_found;
        };

        return inst_error_logger;
    });
}

}  // namespace gpuav
